Compare commits

..

No commits in common. "main" and "v0.2.0" have entirely different histories.
main ... v0.2.0

28 changed files with 475 additions and 863 deletions

View File

@ -1,15 +0,0 @@
on: [push, pull_request]
name: Test
jobs:
test:
strategy:
matrix:
go-version: [1.22.x]
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
- uses: actions/checkout@v4
- run: go build ./cmd/mjwt/
- run: go test ./...

View File

@ -1,3 +1,3 @@
# MJWT
A simple wrapper for JWT. Contains an AccessToken and RefreshToken model.
A simple wrapper for JWT. Contains an AccessToken and RefreshToken model.

View File

@ -2,13 +2,14 @@ package auth
import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"time"
)
// AccessTokenClaims contains the JWT claims for an access token
type AccessTokenClaims struct {
Perms *PermStorage `json:"per"`
Perms *claims.PermStorage `json:"per"`
}
func (a AccessTokenClaims) Valid() error { return nil }
@ -16,11 +17,11 @@ func (a AccessTokenClaims) Valid() error { return nil }
func (a AccessTokenClaims) Type() string { return "access-token" }
// CreateAccessToken creates an access token with the default 15 minute duration
func CreateAccessToken(p *mjwt.Issuer, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms)
}
// CreateAccessTokenWithDuration creates an access token with a custom duration
func CreateAccessTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms})
}

View File

@ -1,27 +1,29 @@
package auth
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/golang-jwt/jwt/v4"
"github.com/1f349/mjwt/claims"
"github.com/stretchr/testify/assert"
"testing"
)
func TestCreateAccessToken(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ps := NewPermStorage()
ps := claims.NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
s := mjwt.NewMJwtSigner("mjwt.test", key)
accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)

View File

@ -2,19 +2,20 @@ package auth
import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"time"
)
// CreateTokenPair creates an access and refresh token pair using the default
// 15 minute and 7 day durations respectively
func CreateTokenPair(p *mjwt.Issuer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
func CreateTokenPair(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms)
}
// CreateTokenPairWithDuration creates an access and refresh token pair using
// custom durations for the access and refresh tokens
func CreateTokenPairWithDuration(p *mjwt.Issuer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms)
if err != nil {
return "", "", err

View File

@ -1,27 +1,29 @@
package auth
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/golang-jwt/jwt/v4"
"github.com/1f349/mjwt/claims"
"github.com/stretchr/testify/assert"
"testing"
)
func TestCreateTokenPair(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ps := NewPermStorage()
ps := claims.NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key2", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
s := mjwt.NewMJwtSigner("mjwt.test", key)
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
@ -29,7 +31,7 @@ func TestCreateTokenPair(t *testing.T) {
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b2.Subject)
assert.Equal(t, "test2", b2.ID)

View File

@ -16,11 +16,11 @@ func (r RefreshTokenClaims) Valid() error { return nil }
func (r RefreshTokenClaims) Type() string { return "refresh-token" }
// CreateRefreshToken creates a refresh token with the default 7 day duration
func CreateRefreshToken(p *mjwt.Issuer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud)
}
// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
func CreateRefreshTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati})
}

View File

@ -1,23 +1,24 @@
package auth
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"testing"
)
func TestCreateRefreshToken(t *testing.T) {
t.Parallel()
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
s := mjwt.NewMJwtSigner("mjwt.test", key)
refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)

View File

@ -1,12 +1,10 @@
package auth
package claims
import (
"bufio"
"encoding/json"
"github.com/becheran/wildmatch-go"
"gopkg.in/yaml.v3"
"sort"
"strings"
)
type PermStorage struct {
@ -17,16 +15,6 @@ func NewPermStorage() *PermStorage {
return new(PermStorage).setup()
}
func ParsePermStorage(perms string) *PermStorage {
ps := NewPermStorage()
sc := bufio.NewScanner(strings.NewReader(perms))
sc.Split(bufio.ScanWords)
for sc.Scan() {
ps.Set(sc.Text())
}
return ps
}
func (p *PermStorage) setup() *PermStorage {
if p.values == nil {
p.values = make(map[string]struct{})
@ -57,7 +45,7 @@ func (p *PermStorage) OneOf(o *PermStorage) bool {
}
func (p *PermStorage) Dump() []string {
a := make([]string, 0, len(p.values))
var a []string
for i := range p.values {
a = append(a, i)
}
@ -76,16 +64,6 @@ func (p *PermStorage) Search(v string) []string {
return a
}
func (p *PermStorage) Filter(match []string) *PermStorage {
out := NewPermStorage()
for _, i := range match {
for _, j := range p.Search(i) {
out.Set(j)
}
}
return out
}
func (p *PermStorage) prepare(a []string) {
for _, i := range a {
p.Set(i)

View File

@ -1,20 +1,11 @@
package auth
package claims
import (
"bytes"
"github.com/stretchr/testify/assert"
"sort"
"testing"
)
func TestParsePermStorage(t *testing.T) {
t.Parallel()
ps := ParsePermStorage("mjwt:test mjwt:test2")
if _, ok := ps.values["mjwt:test"]; !ok {
assert.Fail(t, "perm not set")
}
}
func TestPermStorage_Set(t *testing.T) {
t.Parallel()
ps := NewPermStorage()
@ -60,22 +51,6 @@ func TestPermStorage_OneOf(t *testing.T) {
assert.False(t, ps.OneOf(o))
}
func TestPermStorage_Search(t *testing.T) {
ps := ParsePermStorage("mjwt:test mjwt:test2 mjwt:other")
a := ps.Search("mjwt:test*")
sort.Strings(a)
assert.Equal(t, []string{"mjwt:test", "mjwt:test2"}, a)
assert.Equal(t, []string{"mjwt:other"}, ps.Search("mjwt:other"))
}
func TestPermStorage_Filter(t *testing.T) {
ps := ParsePermStorage("mjwt:test mjwt:test2 mjwt:other mjwt:other2 mjwt:another")
a := ps.Filter([]string{"mjwt:test*", "mjwt:other*"}).Dump()
sort.Strings(a)
assert.Equal(t, []string{"mjwt:other", "mjwt:other2", "mjwt:test", "mjwt:test2"}, a)
assert.Equal(t, []string{"mjwt:another"}, ps.Filter([]string{"mjwt:another"}).Dump())
}
func TestPermStorage_MarshalJSON(t *testing.T) {
t.Parallel()
ps := NewPermStorage()

View File

@ -1,101 +0,0 @@
package mjwt
import (
"fmt"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type testClaims struct{ TestValue string }
func (t testClaims) Valid() error {
if t.TestValue != "hello" && t.TestValue != "world" {
return fmt.Errorf("TestValue should be hello")
}
return nil
}
func (t testClaims) Type() string { return "testClaims" }
type testClaims2 struct{ TestValue2 string }
func (t testClaims2) Valid() error {
if t.TestValue2 != "world" {
return fmt.Errorf("TestValue2 should be world")
}
return nil
}
func (t testClaims2) Type() string { return "testClaims2" }
func TestExtractClaims(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
t.Run("TestNoKID", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
a, _, err := ExtractClaims[testClaims](kStore, token)
assert.NoError(t, err)
kid, _ := a.Header["kid"].(string)
assert.Equal(t, "key1", kid)
})
t.Run("TestKID", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
s2, err := NewIssuerWithKeyStore("mjwt.test", "key3", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
token2, err := s2.GenerateJwt("2", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
k1, _, err := ExtractClaims[testClaims](kStore, token1)
assert.NoError(t, err)
k2, _, err := ExtractClaims[testClaims](kStore, token2)
assert.NoError(t, err)
assert.NotEqual(t, k1.Header["kid"], k2.Header["kid"])
})
}
func TestExtractClaimsFail(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
t.Run("TestInvalidClaims", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims2](kStore, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
})
t.Run("TestKIDNonExist", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err)
kStore.RemoveKey("key2")
assert.NotContains(t, kStore.ListKeys(), "key2")
_, _, err = ExtractClaims[testClaims](kStore, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPublicKey)
})
}

View File

@ -2,11 +2,14 @@ package main
import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/auth"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"github.com/google/subcommands"
"os"
@ -15,7 +18,7 @@ import (
)
type accessCmd struct {
issuer, subject, id, audience, duration, kID string
issuer, subject, id, audience, duration string
}
func (s *accessCmd) Name() string { return "access" }
@ -23,7 +26,7 @@ func (s *accessCmd) Synopsis() string {
return "Generates an access token with permissions using the private key"
}
func (s *accessCmd) Usage() string {
return `sign [-iss <issuer>] [-sub <subject>] [-id <id>] [-aud <audience>] [-dur <duration>] [-kid <name>] <private key path> <space separated permissions>
return `sign [-iss <issuer>] [-sub <subject>] [-id <id>] [-aud <audience>] [-dur <duration>] <private key path> <space separated permissions>
Output a signed MJWT token with the specified permissions.
`
}
@ -34,7 +37,6 @@ func (s *accessCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&s.id, "id", "", "MJWT ID")
f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT")
f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)")
f.StringVar(&s.kID, "kid", "", "The Key ID of the signing key")
}
func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
@ -44,13 +46,13 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
}
args := f.Args()
key, err := rsaprivate.Read(args[0])
key, err := s.parseKey(args[0])
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to parse private key: ", err)
return subcommands.ExitFailure
}
ps := auth.NewPermStorage()
ps := claims.NewPermStorage()
for i := 1; i < len(args); i++ {
ps.Set(args[i])
}
@ -65,17 +67,8 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
return subcommands.ExitFailure
}
var token string
kStore := mjwt.NewKeyStore()
kStore.LoadPrivateKey(s.kID, key)
issuer, err := mjwt.NewIssuerWithKeyStore(s.issuer, s.kID, jwt.SigningMethodRS512, kStore)
if err != nil {
panic("this should not fail")
}
token, err = issuer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
signer := mjwt.NewMJwtSigner(s.issuer, key)
token, err := signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err)
return subcommands.ExitFailure
@ -84,3 +77,13 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
fmt.Println(token)
return subcommands.ExitSuccess
}
func (s *accessCmd) parseKey(privKeyFile string) (*rsa.PrivateKey, error) {
b, err := os.ReadFile(privKeyFile)
if err != nil {
return nil, err
}
p, _ := pem.Decode(b)
return x509.ParsePKCS1PrivateKey(p.Bytes)
}

View File

@ -3,10 +3,10 @@ package main
import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/google/subcommands"
"math/rand"
"os"
@ -49,14 +49,29 @@ func (g *genCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) s
}
func (g *genCmd) gen(privPath, pubPath string) error {
createPriv, err := os.OpenFile(privPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer createPriv.Close()
createPub, err := os.OpenFile(pubPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer createPub.Close()
key, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), g.bits)
if err != nil {
return err
}
err = rsaprivate.Write(privPath, key)
keyBytes := x509.MarshalPKCS1PrivateKey(key)
pubBytes := x509.MarshalPKCS1PublicKey(&key.PublicKey)
err = pem.Encode(createPriv, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes})
if err != nil {
return err
}
return rsapublic.Write(pubPath, &key.PublicKey)
err = pem.Encode(createPub, &pem.Block{Type: "RSA PUBLIC KEY", Bytes: pubBytes})
return err
}

View File

@ -1,8 +0,0 @@
package mjwt
// EmptyClaims contains no claims
type EmptyClaims struct{}
func (e EmptyClaims) Valid() error { return nil }
func (e EmptyClaims) Type() string { return "empty-claims" }

15
go.mod
View File

@ -1,28 +1,17 @@
module github.com/1f349/mjwt
go 1.22
toolchain go1.22.3
go 1.19
require (
github.com/1f349/rsa-helper v0.0.2
github.com/becheran/wildmatch-go v1.0.0
github.com/go-jose/go-jose/v4 v4.0.4
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/google/subcommands v1.2.0
github.com/pkg/errors v0.9.1
github.com/spf13/afero v1.11.0
github.com/stretchr/testify v1.9.0
golang.org/x/sync v0.7.0
github.com/stretchr/testify v1.8.4
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
golang.org/x/crypto v0.25.0 // indirect
golang.org/x/text v0.16.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
)

33
go.sum
View File

@ -1,45 +1,18 @@
github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9U=
github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg=
github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA=
github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-jose/go-jose/v4 v4.0.4 h1:VsjPI33J0SB9vQM6PLmNjoHqMQNGPiZ0rHL7Ni7Q6/E=
github.com/go-jose/go-jose/v4 v4.0.4/go.mod h1:NKb5HO1EZccyMpiZNbdUw/14tiXNyUJh188dfnMCAfc=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

23
interfaces.go Normal file
View File

@ -0,0 +1,23 @@
package mjwt
import (
"crypto/rsa"
"github.com/golang-jwt/jwt/v4"
"time"
)
// Signer is used to generate MJWT tokens.
// Signer can also be used as a Verifier.
type Signer interface {
Verifier
GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error)
SignJwt(claims jwt.Claims) (string, error)
Issuer() string
PrivateKey() *rsa.PrivateKey
}
// Verifier is used to verify the validity MJWT tokens and extract the claim values.
type Verifier interface {
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
PublicKey() *rsa.PublicKey
}

View File

@ -1,62 +0,0 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/golang-jwt/jwt/v4"
"time"
)
// Issuer provides the signing for a PrivateKey identified by the KID in the
// provided KeyStore
type Issuer struct {
issuer string
kid string
signing jwt.SigningMethod
keystore *KeyStore
}
// NewIssuer creates an Issuer with an empty KeyStore
func NewIssuer(name, kid string, signing jwt.SigningMethod) (*Issuer, error) {
return NewIssuerWithKeyStore(name, kid, signing, NewKeyStore())
}
// NewIssuerWithKeyStore creates an Issuer with a provided KeyStore
func NewIssuerWithKeyStore(name, kid string, signing jwt.SigningMethod, keystore *KeyStore) (*Issuer, error) {
i := &Issuer{name, kid, signing, keystore}
if i.keystore.HasPrivateKey(kid) {
return i, nil
}
key, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, err
}
i.keystore.LoadPrivateKey(kid, key)
return i, i.keystore.SaveSingleKey(kid)
}
// GenerateJwt produces a signed JWT in string form
func (i *Issuer) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
return i.SignJwt(wrapClaims[Claims](sub, id, i.issuer, aud, dur, claims))
}
// SignJwt produces a signed JWT in string form from a raw jwt.Claims structure
func (i *Issuer) SignJwt(wrapped jwt.Claims) (string, error) {
key, err := i.PrivateKey()
if err != nil {
return "", err
}
token := jwt.NewWithClaims(i.signing, wrapped)
token.Header["kid"] = i.kid
return token.SignedString(key)
}
// PrivateKey outputs the rsa.PrivateKey from the KID of the Issuer
func (i *Issuer) PrivateKey() (*rsa.PrivateKey, error) {
return i.keystore.GetPrivateKey(i.kid)
}
// KeyStore outputs the underlying KeyStore used by the Issuer
func (i *Issuer) KeyStore() *KeyStore {
return i.keystore
}

View File

@ -1,59 +0,0 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/golang-jwt/jwt/v4"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"testing"
)
func TestNewIssuer(t *testing.T) {
t.Parallel()
t.Run("generate missing key for issuer", func(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
issuer, err := NewIssuerWithKeyStore("Test", "test", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
assert.True(t, kStore.HasPrivateKey("test"))
assert.True(t, kStore.HasPublicKey("test"))
assert.Equal(t, "Test", issuer.issuer)
assert.Equal(t, "test", issuer.kid)
})
t.Run("use existing issuer key", func(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
kStore.LoadPrivateKey("test", key)
issuer, err := NewIssuerWithKeyStore("Test", "test", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
assert.True(t, kStore.HasPrivateKey("test"))
assert.True(t, kStore.HasPublicKey("test"))
assert.Equal(t, "Test", issuer.issuer)
assert.Equal(t, "test", issuer.kid)
privateKey, err := issuer.PrivateKey()
assert.NoError(t, err)
assert.True(t, key.Equal(privateKey))
})
t.Run("generate missing key in filesystem", func(t *testing.T) {
t.Parallel()
dir := afero.NewMemMapFs()
kStore := NewKeyStoreWithDir(dir)
issuer, err := NewIssuerWithKeyStore("Test", "test", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
assert.True(t, kStore.HasPrivateKey("test"))
assert.True(t, kStore.HasPublicKey("test"))
assert.Equal(t, "Test", issuer.issuer)
assert.Equal(t, "test", issuer.kid)
privKeyFile, err := dir.Open("test.private.pem")
assert.NoError(t, err)
privKey, err := rsaprivate.Decode(privKeyFile)
assert.NoError(t, err)
key, err := issuer.PrivateKey()
assert.NoError(t, err)
assert.True(t, key.Equal(privKey))
})
}

31
jwks.go
View File

@ -1,31 +0,0 @@
package mjwt
import (
"encoding/json"
"github.com/go-jose/go-jose/v4"
"io"
)
// WriteJwkSetJson outputs the public keys used by the Issuers
func WriteJwkSetJson(w io.Writer, issuers []*Issuer) error {
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
var j jose.JSONWebKeySet
for _, issuer := range issuers {
// get public key from private key
key, err := issuer.PrivateKey()
if err != nil {
return err
}
pubKey := &key.PublicKey
// format as JWK
j.Keys = append(j.Keys, jose.JSONWebKey{
Algorithm: issuer.signing.Alg(),
Use: "sig",
KeyID: issuer.kid,
Key: pubKey,
})
}
return enc.Encode(j)
}

View File

@ -1,265 +0,0 @@
package mjwt
import (
"crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/golang-jwt/jwt/v4"
"github.com/spf13/afero"
"golang.org/x/sync/errgroup"
"io/fs"
"path/filepath"
"runtime"
"strings"
"sync"
)
var ErrMissingPrivateKey = errors.New("missing private key")
var ErrMissingPublicKey = errors.New("missing public key")
var ErrMissingKeyPair = errors.New("missing key pair")
const PrivateStr = ".private"
const PublicStr = ".public"
const PemExt = ".pem"
const PrivatePemExt = PrivateStr + PemExt
const PublicPemExt = PublicStr + PemExt
// KeyStore provides a store for a collection of private/public keypair structs
type KeyStore struct {
mu *sync.RWMutex
store map[string]*keyPair
dir afero.Fs
}
// NewKeyStore creates an empty KeyStore
func NewKeyStore() *KeyStore {
return &KeyStore{
mu: new(sync.RWMutex),
store: make(map[string]*keyPair),
}
}
// NewKeyStoreWithDir creates an empty KeyStore with an underlying afero.Fs
// filesystem for saving the internal store data
func NewKeyStoreWithDir(dir afero.Fs) *KeyStore {
keyStore := NewKeyStore()
keyStore.dir = dir
return keyStore
}
// NewKeyStoreFromPath creates an empty KeyStore. The provided path is walked to
// load the private/public keys. See implementation in NewKeyStoreFromDir.
func NewKeyStoreFromPath(dir string) (*KeyStore, error) {
abs, err := filepath.Abs(dir)
if err != nil {
return nil, err
}
return NewKeyStoreFromDir(afero.NewBasePathFs(afero.NewOsFs(), abs))
}
// NewKeyStoreFromDir creates an empty KeyStore. The provided afero.Fs is walked
// to find all private/public keys in files named `.private.pem` and
// `.public.pem` respectively. The keys are loaded into the KeyStore and any
// errors are returned immediately.
func NewKeyStoreFromDir(dir afero.Fs) (*KeyStore, error) {
keyStore := NewKeyStoreWithDir(dir)
err := afero.Walk(dir, ".", func(path string, d fs.FileInfo, err error) error {
// maybe this is "name.private.pem"
name := filepath.Base(path)
ext := filepath.Ext(name)
if ext != PemExt {
return nil
}
name = strings.TrimSuffix(name, ext)
ext = filepath.Ext(name)
name = strings.TrimSuffix(name, ext)
switch ext {
case PrivateStr:
open, err := dir.Open(path)
if err != nil {
return err
}
decode, err := rsaprivate.Decode(open)
if err != nil {
return err
}
keyStore.LoadPrivateKey(name, decode)
return nil
case PublicStr:
open, err := dir.Open(path)
if err != nil {
return err
}
decode, err := rsapublic.Decode(open)
if err != nil {
return err
}
keyStore.LoadPublicKey(name, decode)
return nil
}
// still invalid
return nil
})
return keyStore, err
}
type keyPair struct {
private *rsa.PrivateKey
public *rsa.PublicKey
}
// LoadPrivateKey sets the rsa.PrivateKey/rsa.PublicKey for the KID
func (k *KeyStore) LoadPrivateKey(kid string, key *rsa.PrivateKey) {
k.mu.Lock()
if k.store[kid] == nil {
k.store[kid] = &keyPair{}
}
k.store[kid].private = key
k.store[kid].public = &key.PublicKey
k.mu.Unlock()
}
// LoadPublicKey sets the rsa.PublicKey for the KID
func (k *KeyStore) LoadPublicKey(kid string, key *rsa.PublicKey) {
k.mu.Lock()
if k.store[kid] == nil {
k.store[kid] = &keyPair{}
}
k.store[kid].public = key
k.mu.Unlock()
}
// RemoveKey deletes the KID keypair from the KeyStore
func (k *KeyStore) RemoveKey(kid string) {
k.mu.Lock()
delete(k.store, kid)
k.mu.Unlock()
}
// ListKeys provides a slice of the KIDs for all keys loaded in the KeyStore
func (k *KeyStore) ListKeys() []string {
k.mu.RLock()
defer k.mu.RUnlock()
keys := make([]string, 0, len(k.store))
for k, _ := range k.store {
keys = append(keys, k)
}
return keys
}
// GetPrivateKey outputs the rsa.PrivateKey for the KID from the KeyStore
func (k *KeyStore) GetPrivateKey(kid string) (*rsa.PrivateKey, error) {
k.mu.RLock()
defer k.mu.RUnlock()
if !k.internalHasPrivateKey(kid) {
return nil, ErrMissingPrivateKey
}
return k.store[kid].private, nil
}
// GetPublicKey outputs the rsa.PublicKey for the KID from the KeyStore
func (k *KeyStore) GetPublicKey(kid string) (*rsa.PublicKey, error) {
k.mu.RLock()
defer k.mu.RUnlock()
if !k.internalHasPublicKey(kid) {
return nil, ErrMissingPublicKey
}
return k.store[kid].public, nil
}
// ClearKeys clears the internal map and makes a new map to release used memory
func (k *KeyStore) ClearKeys() {
k.mu.Lock()
clear(k.store)
k.store = make(map[string]*keyPair)
k.mu.Unlock()
}
// HasPrivateKey outputs true if the KID is found in the KeyStore
func (k *KeyStore) HasPrivateKey(kid string) bool {
k.mu.RLock()
defer k.mu.RUnlock()
return k.internalHasPrivateKey(kid)
}
func (k *KeyStore) internalHasPrivateKey(kid string) bool {
v := k.store[kid]
return v != nil && v.private != nil
}
// HasPublicKey outputs true if the KID is found in the KeyStore
func (k *KeyStore) HasPublicKey(kid string) bool {
k.mu.RLock()
defer k.mu.RUnlock()
return k.internalHasPublicKey(kid)
}
func (k *KeyStore) internalHasPublicKey(kid string) bool {
v := k.store[kid]
return v != nil && v.public != nil
}
// VerifyJwt parses the provided token string and validates it against the KID
// using the KeyStore. An error is returned if the token fails to parse or if
// there is no matching KID in the KeyStore.
func (k *KeyStore) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, ErrMissingPublicKey
}
return k.GetPublicKey(kid)
})
if err != nil {
return nil, err
}
return withClaims, claims.Valid()
}
// SaveSingleKey writes the rsa.PrivateKey/rsa.PublicKey for the requested KID to
// the underlying afero.Fs.
func (k *KeyStore) SaveSingleKey(kid string) error {
if k.dir == nil {
return nil
}
k.mu.RLock()
pair := k.store[kid]
k.mu.RUnlock()
if pair == nil {
return ErrMissingKeyPair
}
return writeSingleKey(k.dir, kid, pair)
}
// SaveKeys writes the rsa.PrivateKey/rsa.PublicKey for the requested KID to the
// underlying afero.Fs.
func (k *KeyStore) SaveKeys() error {
k.mu.RLock()
defer k.mu.RUnlock()
workers := new(errgroup.Group)
workers.SetLimit(runtime.NumCPU())
for kid, pair := range k.store {
workers.Go(func() error {
return writeSingleKey(k.dir, kid, pair)
})
}
return workers.Wait()
}
func writeSingleKey(dir afero.Fs, kid string, pair *keyPair) error {
var errs []error
if pair.private != nil {
errs = append(errs, afero.WriteFile(dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
}
if pair.public != nil {
errs = append(errs, afero.WriteFile(dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
}
return errors.Join(errs...)
}

View File

@ -1,175 +0,0 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"sort"
"testing"
)
const kst_prvExt = "prv"
const kst_pubExt = "pub"
func setupTestDirKeyStore(t *testing.T, genKeys bool) afero.Fs {
tempDir := afero.NewMemMapFs()
if genKeys {
key1, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key1.private.pem", rsaprivate.Encode(key1), 0600)
assert.NoError(t, err)
key2, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key2.private.pem", rsaprivate.Encode(key2), 0600)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key2.public.pem", rsapublic.Encode(&key2.PublicKey), 0600)
assert.NoError(t, err)
key3, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key3.public.pem", rsapublic.Encode(&key3.PublicKey), 0600)
assert.NoError(t, err)
}
return tempDir
}
func commonSubTestsKeyStore(t *testing.T, kStore *KeyStore) {
key4, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
key5, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
const extraKID1 = "key4"
const extraKID2 = "key5"
t.Run("TestSetKey", func(t *testing.T) {
kStore.LoadPrivateKey(extraKID1, key4)
assert.Contains(t, kStore.ListKeys(), extraKID1)
})
t.Run("TestSetKeyPublic", func(t *testing.T) {
kStore.LoadPublicKey(extraKID2, &key5.PublicKey)
assert.Contains(t, kStore.ListKeys(), extraKID2)
})
t.Run("TestGetPrivateKey", func(t *testing.T) {
oKey, err := kStore.GetPrivateKey(extraKID1)
assert.NoError(t, err)
assert.Same(t, key4, oKey)
pKey, err := kStore.GetPrivateKey(extraKID2)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPrivateKey)
assert.Nil(t, pKey)
aKey, err := kStore.GetPrivateKey("key1")
assert.NoError(t, err)
assert.NotNil(t, aKey)
bKey, err := kStore.GetPrivateKey("key2")
assert.NoError(t, err)
assert.NotNil(t, bKey)
cKey, err := kStore.GetPrivateKey("key3")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPrivateKey)
assert.Nil(t, cKey)
wKey, err := kStore.GetPrivateKey("key1337")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPrivateKey)
assert.Nil(t, wKey)
})
t.Run("TestGetPublicKey", func(t *testing.T) {
oKey, err := kStore.GetPublicKey(extraKID1)
assert.NoError(t, err)
assert.Same(t, &key4.PublicKey, oKey)
pKey, err := kStore.GetPublicKey(extraKID2)
assert.NoError(t, err)
assert.Same(t, &key5.PublicKey, pKey)
aKey, err := kStore.GetPublicKey("key1")
assert.NoError(t, err)
assert.NotNil(t, aKey)
bKey, err := kStore.GetPublicKey("key2")
assert.NoError(t, err)
assert.NotNil(t, bKey)
cKey, err := kStore.GetPublicKey("key3")
assert.NoError(t, err)
assert.NotNil(t, cKey)
wKey, err := kStore.GetPublicKey("key1337")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPublicKey)
assert.Nil(t, wKey)
})
t.Run("TestRemoveKey", func(t *testing.T) {
kStore.RemoveKey(extraKID1)
assert.NotContains(t, kStore.ListKeys(), extraKID1)
oKey1, err := kStore.GetPrivateKey(extraKID1)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPrivateKey)
assert.Nil(t, oKey1)
oKey2, err := kStore.GetPublicKey(extraKID1)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPublicKey)
assert.Nil(t, oKey2)
})
t.Run("TestClearKeys", func(t *testing.T) {
kStore.ClearKeys()
assert.Empty(t, kStore.ListKeys())
})
}
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
t.Parallel()
tempDir := setupTestDirKeyStore(t, true)
kStore, err := NewKeyStoreFromDir(tempDir)
assert.NoError(t, err)
assert.Len(t, kStore.ListKeys(), 3)
kIDsToFind := []string{"key1", "key2", "key3"}
for _, k := range kIDsToFind {
assert.Contains(t, kStore.ListKeys(), k)
}
assert.True(t, kStore.HasPrivateKey("key1"))
assert.True(t, kStore.HasPublicKey("key1")) // loading a private key also loads the public key
assert.True(t, kStore.HasPrivateKey("key2"))
assert.True(t, kStore.HasPublicKey("key2"))
assert.False(t, kStore.HasPrivateKey("key3"))
assert.True(t, kStore.HasPublicKey("key3"))
commonSubTestsKeyStore(t, kStore)
}
func TestExportKeyStore(t *testing.T) {
t.Parallel()
tempDir := setupTestDirKeyStore(t, true)
tempDir2 := setupTestDirKeyStore(t, false)
kStore, err := NewKeyStoreFromDir(tempDir)
assert.NoError(t, err)
// internally swap directory
kStore.dir = tempDir2
err = kStore.SaveKeys()
assert.NoError(t, err)
kStore2, err := NewKeyStoreFromDir(tempDir2)
assert.NoError(t, err)
kidList1 := kStore.ListKeys()
kidList2 := kStore2.ListKeys()
sort.Strings(kidList1)
sort.Strings(kidList2)
assert.Equal(t, kidList1, kidList2)
commonSubTestsKeyStore(t, kStore2)
}

View File

@ -10,11 +10,11 @@ import (
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
func wrapClaims[T Claims](sub, id, issuer string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
now := time.Now()
return (&BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: issuer,
Issuer: p.Issuer(),
Subject: sub,
Audience: aud,
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
@ -28,12 +28,12 @@ func wrapClaims[T Claims](sub, id, issuer string, aud jwt.ClaimStrings, dur time
// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
// token and BaseTypeClaims
func ExtractClaims[T Claims](ks *KeyStore, token string) (*jwt.Token, BaseTypeClaims[T], error) {
func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) {
b := BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{},
Claims: *new(T),
}
tok, err := ks.VerifyJwt(token, &b)
tok, err := p.VerifyJwt(token, &b)
return tok, b, err
}
@ -63,9 +63,6 @@ func (b *BaseTypeClaims[T]) init() *BaseTypeClaims[T] {
// Valid checks the InternalClaimType matches and the type claim type
func (b *BaseTypeClaims[T]) Valid() error {
if err := b.RegisteredClaims.Valid(); err != nil {
return err
}
if b.ClaimType != b.InternalClaimType() {
return ErrClaimTypeMismatch
}

61
mjwt_test.go Normal file
View File

@ -0,0 +1,61 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"fmt"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type testClaims struct{ TestValue string }
func (t testClaims) Valid() error {
if t.TestValue != "hello" && t.TestValue != "world" {
return fmt.Errorf("TestValue should be hello")
}
return nil
}
func (t testClaims) Type() string { return "testClaims" }
type testClaims2 struct{ TestValue2 string }
func (t testClaims2) Valid() error {
if t.TestValue2 != "world" {
return fmt.Errorf("TestValue2 should be world")
}
return nil
}
func (t testClaims2) Type() string { return "testClaims2" }
func TestExtractClaims(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
s := NewMJwtSigner("mjwt.test", key)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
m := NewMJwtVerifier(&key.PublicKey)
_, _, err = ExtractClaims[testClaims](m, token)
assert.NoError(t, err)
}
func TestExtractClaimsFail(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
s := NewMJwtSigner("mjwt.test", key)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err)
m := NewMJwtVerifier(&key.PublicKey)
_, _, err = ExtractClaims[testClaims2](m, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
}

143
signer.go Normal file
View File

@ -0,0 +1,143 @@
package mjwt
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"github.com/golang-jwt/jwt/v4"
"io"
"os"
"time"
)
// defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name
// to generate MJWT tokens
type defaultMJwtSigner struct {
issuer string
key *rsa.PrivateKey
verify *defaultMJwtVerifier
}
var _ Signer = &defaultMJwtSigner{}
var _ Verifier = &defaultMJwtSigner{}
// NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
return &defaultMJwtSigner{
issuer: issuer,
key: key,
verify: newMJwtVerifier(&key.PublicKey),
}
}
// NewMJwtSignerFromFileOrCreate creates a new defaultMJwtSigner using the path
// of a rsa.PrivateKey file. If the file does not exist then the file is created
// and a new private key is generated.
func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits int) (Signer, error) {
privateKey, err := readOrCreatePrivateKey(file, random, bits)
if err != nil {
return nil, err
}
return NewMJwtSigner(issuer, privateKey), nil
}
// NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a
// rsa.PrivateKey file.
func NewMJwtSignerFromFile(issuer, file string) (Signer, error) {
// read file
raw, err := os.ReadFile(file)
if err != nil {
return nil, err
}
// decode pem block
block, _ := pem.Decode(raw)
if block == nil || block.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("invalid rsa private key pem block")
}
// parse private key from pem block
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
// create signer using rsa.PrivateKey
return NewMJwtSigner(issuer, key), nil
}
// Issuer returns the name of the issuer
func (d *defaultMJwtSigner) Issuer() string { return d.issuer }
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims))
}
// SignJwt signs a jwt.Claims compatible struct, this is used internally by
// GenerateJwt but is available for signing custom structs
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
return token.SignedString(d.key)
}
// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt()
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
return d.verify.VerifyJwt(token, claims)
}
func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey { return d.key }
func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey { return d.verify.pub }
// readOrCreatePrivateKey returns the private key it the file already exists,
// generates a new private key and saves it to the file, or returns an error if
// reading or generating failed.
func readOrCreatePrivateKey(file string, random io.Reader, bits int) (*rsa.PrivateKey, error) {
// read the file or return nil
f, err := readOrEmptyFile(file)
if err != nil {
return nil, err
}
if f == nil {
// generate a new key
key, err := rsa.GenerateKey(random, bits)
if err != nil {
return nil, err
}
keyBytes := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
if err != nil {
return nil, err
}
// write the key to the file
err = os.WriteFile(file, keyBytes, 0600)
return key, err
} else {
// decode pem block
block, _ := pem.Decode(f)
if block == nil || block.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("invalid rsa private key pem block")
}
// try to parse the private key
return x509.ParsePKCS1PrivateKey(block.Bytes)
}
}
// readOrEmptyFile returns bytes and errors from os.ReadFile or (nil, nil) if the
// file does not exist.
func readOrEmptyFile(file string) ([]byte, error) {
raw, err := os.ReadFile(file)
if err == nil {
return raw, nil
}
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}

69
signer_test.go Normal file
View File

@ -0,0 +1,69 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"github.com/stretchr/testify/assert"
"os"
"testing"
)
func TestNewMJwtSigner(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
NewMJwtSigner("Test", key)
}
func TestNewMJwtSignerFromFile(t *testing.T) {
t.Parallel()
tempKey, err := os.CreateTemp("", "key-test-*.pem")
assert.NoError(t, err)
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
_, err = tempKey.Write(b)
assert.NoError(t, err)
assert.NoError(t, tempKey.Close())
signer, err := NewMJwtSignerFromFile("Test", tempKey.Name())
assert.NoError(t, err)
assert.NoError(t, os.Remove(tempKey.Name()))
_, err = NewMJwtSignerFromFile("Test", tempKey.Name())
assert.Error(t, err)
assert.True(t, os.IsNotExist(err))
assert.True(t, signer.(*defaultMJwtSigner).key.Equal(key))
}
func TestNewMJwtSignerFromFileOrCreate(t *testing.T) {
t.Parallel()
tempKey, err := os.CreateTemp("", "key-test-*.pem")
assert.NoError(t, err)
assert.NoError(t, tempKey.Close())
assert.NoError(t, os.Remove(tempKey.Name()))
signer, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048)
assert.NoError(t, err)
signer2, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048)
assert.NoError(t, err)
assert.True(t, signer.PrivateKey().Equal(signer2.PrivateKey()))
}
func TestReadOrCreatePrivateKey(t *testing.T) {
t.Parallel()
tempKey, err := os.CreateTemp("", "key-test-*.pem")
assert.NoError(t, err)
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
_, err = tempKey.Write(b)
assert.NoError(t, err)
assert.NoError(t, tempKey.Close())
key2, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048)
assert.NoError(t, err)
assert.True(t, key.Equal(key2))
assert.NoError(t, os.Remove(tempKey.Name()))
key3, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048)
assert.NoError(t, err)
assert.NoError(t, key3.Validate())
}

61
verifier.go Normal file
View File

@ -0,0 +1,61 @@
package mjwt
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"github.com/golang-jwt/jwt/v4"
"os"
)
// defaultMJwtVerifier implements Verifier and uses a rsa.PublicKey to validate
// MJWT tokens
type defaultMJwtVerifier struct {
pub *rsa.PublicKey
}
var _ Verifier = &defaultMJwtVerifier{}
// NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey
func NewMJwtVerifier(key *rsa.PublicKey) Verifier {
return newMJwtVerifier(key)
}
func newMJwtVerifier(key *rsa.PublicKey) *defaultMJwtVerifier {
return &defaultMJwtVerifier{pub: key}
}
// NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a
// rsa.PublicKey file
func NewMJwtVerifierFromFile(file string) (Verifier, error) {
// read file
f, err := os.ReadFile(file)
if err != nil {
return nil, err
}
// decode pem block
block, _ := pem.Decode(f)
// parse public key from pem block
pub, err := x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return nil, err
}
// create verifier using rsa.PublicKey
return NewMJwtVerifier(pub), nil
}
// VerifyJwt validates and parses MJWT tokens and returns the claims
func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return d.pub, nil
})
if err != nil {
return nil, err
}
return withClaims, claims.Valid()
}
func (d *defaultMJwtVerifier) PublicKey() *rsa.PublicKey { return d.pub }

34
verifier_test.go Normal file
View File

@ -0,0 +1,34 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"github.com/stretchr/testify/assert"
"os"
"testing"
"time"
)
func TestNewMJwtVerifierFromFile(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
s := NewMJwtSigner("mjwt.test", key)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"})
assert.NoError(t, err)
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&key.PublicKey)})
temp, err := os.CreateTemp("", "this-is-a-test-file.pem")
assert.NoError(t, err)
_, err = temp.Write(b)
assert.NoError(t, err)
file, err := NewMJwtVerifierFromFile(temp.Name())
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims](file, token)
assert.NoError(t, err)
err = os.Remove(temp.Name())
assert.NoError(t, err)
}