Compare commits

...

26 Commits
v0.2.2 ... main

Author SHA1 Message Date
87774ec45e
More documentation, and alias function for loading a KeyStore from a filepath path 2024-08-12 21:29:16 +01:00
7eaf420bb9
Add Issuer.KeyStore() method 2024-07-27 20:03:02 +01:00
4e2c18918f
Add JSONWebKeySet generator 2024-07-27 19:27:13 +01:00
1fc34736a2
Add signing method parameter 2024-07-27 19:25:56 +01:00
cd2d80cb09
Rewrite mjwt library to better support keystores 2024-07-27 17:14:21 +01:00
5d1bd6f8fd
Update rsa-helper
Add read limit for key loader in signer
2024-06-10 17:51:11 +01:00
690b9f9512
Pedantic: Remove defensive programming on receivers. 2024-06-09 21:31:01 +01:00
3201964fec
Fix pendantic negation issue. 2024-06-09 21:23:28 +01:00
9a1029861c
Update go mod sum
Fix up KeyStore
2024-06-09 21:20:46 +01:00
5e627ed024
Update version of rsa-helper. 2024-06-09 21:18:23 +01:00
fe2d905236
Fix error joining issues in KeyStore. 2024-06-09 21:15:28 +01:00
a94ed7a2e5
Fix up KeyStore directory read. 2024-06-09 21:00:18 +01:00
a0d03c0dfb
Fix pedantic string check for nul in access token generator command. 2024-06-09 20:46:25 +01:00
d76a534346
Fix up new tests for auth kID support. 2024-06-09 20:40:13 +01:00
dc95ed754c
Add kID support to auth. 2024-06-09 20:31:53 +01:00
ce5eccfb3c
Finish up tests (mjwt_test).
Fix MJwt func naming.
Move seperate errors.New to a global var
2024-06-09 19:31:12 +01:00
407f8510b6
Add extra tests for signer_test and verifier_test 2024-06-09 18:40:43 +01:00
6fbc9e3c1f
Fix key_store tests
Add key_store support to signer and verifier
2024-06-09 16:49:57 +01:00
32cfa7a30d
Fix key store issues + tests. 2024-06-09 01:06:19 +01:00
3a7b3dd250
In progress. 2024-06-09 01:06:19 +01:00
545b688391
Upgrade to using github.com/1f349/rsa-helper
Add key_store implementation.
2024-06-09 01:06:19 +01:00
6a34395d8e
Add test workflow 2024-06-09 01:01:58 +01:00
ca4e4b7cae
Add empty claims type 2024-02-24 16:03:06 +00:00
ab84ded3a1
Add tests for search and filter 2024-02-14 23:58:52 +00:00
1792211ca2
Filter permissions that match multiple wildcard inputs 2024-02-14 20:04:55 +00:00
82d4a4a414
ParsePermStorage from string 2024-02-14 19:55:19 +00:00
28 changed files with 859 additions and 474 deletions

15
.github/workflows/test.yml vendored Normal file
View File

@ -0,0 +1,15 @@
on: [push, pull_request]
name: Test
jobs:
test:
strategy:
matrix:
go-version: [1.22.x]
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
- uses: actions/checkout@v4
- run: go build ./cmd/mjwt/
- run: go test ./...

View File

@ -1,3 +1,3 @@
# MJWT
A simple wrapper for JWT. Contains an AccessToken and RefreshToken model.
A simple wrapper for JWT. Contains an AccessToken and RefreshToken model.

View File

@ -2,14 +2,13 @@ package auth
import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"time"
)
// AccessTokenClaims contains the JWT claims for an access token
type AccessTokenClaims struct {
Perms *claims.PermStorage `json:"per"`
Perms *PermStorage `json:"per"`
}
func (a AccessTokenClaims) Valid() error { return nil }
@ -17,11 +16,11 @@ func (a AccessTokenClaims) Valid() error { return nil }
func (a AccessTokenClaims) Type() string { return "access-token" }
// CreateAccessToken creates an access token with the default 15 minute duration
func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
func CreateAccessToken(p *mjwt.Issuer, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms)
}
// CreateAccessTokenWithDuration creates an access token with a custom duration
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
func CreateAccessTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms})
}

View File

@ -1,29 +1,27 @@
package auth
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"testing"
)
func TestCreateAccessToken(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ps := claims.NewPermStorage()
ps := NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")
s := mjwt.NewMJwtSigner("mjwt.test", key)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)

View File

@ -2,20 +2,19 @@ package auth
import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"time"
)
// CreateTokenPair creates an access and refresh token pair using the default
// 15 minute and 7 day durations respectively
func CreateTokenPair(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
func CreateTokenPair(p *mjwt.Issuer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms)
}
// CreateTokenPairWithDuration creates an access and refresh token pair using
// custom durations for the access and refresh tokens
func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
func CreateTokenPairWithDuration(p *mjwt.Issuer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms)
if err != nil {
return "", "", err

View File

@ -1,29 +1,27 @@
package auth
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"testing"
)
func TestCreateTokenPair(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ps := claims.NewPermStorage()
ps := NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")
s := mjwt.NewMJwtSigner("mjwt.test", key)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key2", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
@ -31,7 +29,7 @@ func TestCreateTokenPair(t *testing.T) {
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b2.Subject)
assert.Equal(t, "test2", b2.ID)

View File

@ -1,10 +1,12 @@
package claims
package auth
import (
"bufio"
"encoding/json"
"github.com/becheran/wildmatch-go"
"gopkg.in/yaml.v3"
"sort"
"strings"
)
type PermStorage struct {
@ -15,6 +17,16 @@ func NewPermStorage() *PermStorage {
return new(PermStorage).setup()
}
func ParsePermStorage(perms string) *PermStorage {
ps := NewPermStorage()
sc := bufio.NewScanner(strings.NewReader(perms))
sc.Split(bufio.ScanWords)
for sc.Scan() {
ps.Set(sc.Text())
}
return ps
}
func (p *PermStorage) setup() *PermStorage {
if p.values == nil {
p.values = make(map[string]struct{})
@ -64,6 +76,16 @@ func (p *PermStorage) Search(v string) []string {
return a
}
func (p *PermStorage) Filter(match []string) *PermStorage {
out := NewPermStorage()
for _, i := range match {
for _, j := range p.Search(i) {
out.Set(j)
}
}
return out
}
func (p *PermStorage) prepare(a []string) {
for _, i := range a {
p.Set(i)

View File

@ -1,11 +1,20 @@
package claims
package auth
import (
"bytes"
"github.com/stretchr/testify/assert"
"sort"
"testing"
)
func TestParsePermStorage(t *testing.T) {
t.Parallel()
ps := ParsePermStorage("mjwt:test mjwt:test2")
if _, ok := ps.values["mjwt:test"]; !ok {
assert.Fail(t, "perm not set")
}
}
func TestPermStorage_Set(t *testing.T) {
t.Parallel()
ps := NewPermStorage()
@ -51,6 +60,22 @@ func TestPermStorage_OneOf(t *testing.T) {
assert.False(t, ps.OneOf(o))
}
func TestPermStorage_Search(t *testing.T) {
ps := ParsePermStorage("mjwt:test mjwt:test2 mjwt:other")
a := ps.Search("mjwt:test*")
sort.Strings(a)
assert.Equal(t, []string{"mjwt:test", "mjwt:test2"}, a)
assert.Equal(t, []string{"mjwt:other"}, ps.Search("mjwt:other"))
}
func TestPermStorage_Filter(t *testing.T) {
ps := ParsePermStorage("mjwt:test mjwt:test2 mjwt:other mjwt:other2 mjwt:another")
a := ps.Filter([]string{"mjwt:test*", "mjwt:other*"}).Dump()
sort.Strings(a)
assert.Equal(t, []string{"mjwt:other", "mjwt:other2", "mjwt:test", "mjwt:test2"}, a)
assert.Equal(t, []string{"mjwt:another"}, ps.Filter([]string{"mjwt:another"}).Dump())
}
func TestPermStorage_MarshalJSON(t *testing.T) {
t.Parallel()
ps := NewPermStorage()

View File

@ -16,11 +16,11 @@ func (r RefreshTokenClaims) Valid() error { return nil }
func (r RefreshTokenClaims) Type() string { return "refresh-token" }
// CreateRefreshToken creates a refresh token with the default 7 day duration
func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
func CreateRefreshToken(p *mjwt.Issuer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud)
}
// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
func CreateRefreshTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati})
}

View File

@ -1,24 +1,23 @@
package auth
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"testing"
)
func TestCreateRefreshToken(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
s := mjwt.NewMJwtSigner("mjwt.test", key)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)

View File

@ -10,11 +10,11 @@ import (
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
func wrapClaims[T Claims](sub, id, issuer string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
now := time.Now()
return (&BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: p.Issuer(),
Issuer: issuer,
Subject: sub,
Audience: aud,
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
@ -28,12 +28,12 @@ func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur ti
// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
// token and BaseTypeClaims
func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) {
func ExtractClaims[T Claims](ks *KeyStore, token string) (*jwt.Token, BaseTypeClaims[T], error) {
b := BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{},
Claims: *new(T),
}
tok, err := p.VerifyJwt(token, &b)
tok, err := ks.VerifyJwt(token, &b)
return tok, b, err
}

101
claims_test.go Normal file
View File

@ -0,0 +1,101 @@
package mjwt
import (
"fmt"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type testClaims struct{ TestValue string }
func (t testClaims) Valid() error {
if t.TestValue != "hello" && t.TestValue != "world" {
return fmt.Errorf("TestValue should be hello")
}
return nil
}
func (t testClaims) Type() string { return "testClaims" }
type testClaims2 struct{ TestValue2 string }
func (t testClaims2) Valid() error {
if t.TestValue2 != "world" {
return fmt.Errorf("TestValue2 should be world")
}
return nil
}
func (t testClaims2) Type() string { return "testClaims2" }
func TestExtractClaims(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
t.Run("TestNoKID", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
a, _, err := ExtractClaims[testClaims](kStore, token)
assert.NoError(t, err)
kid, _ := a.Header["kid"].(string)
assert.Equal(t, "key1", kid)
})
t.Run("TestKID", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
s2, err := NewIssuerWithKeyStore("mjwt.test", "key3", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
token2, err := s2.GenerateJwt("2", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
k1, _, err := ExtractClaims[testClaims](kStore, token1)
assert.NoError(t, err)
k2, _, err := ExtractClaims[testClaims](kStore, token2)
assert.NoError(t, err)
assert.NotEqual(t, k1.Header["kid"], k2.Header["kid"])
})
}
func TestExtractClaimsFail(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
t.Run("TestInvalidClaims", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims2](kStore, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
})
t.Run("TestKIDNonExist", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err)
kStore.RemoveKey("key2")
assert.NotContains(t, kStore.ListKeys(), "key2")
_, _, err = ExtractClaims[testClaims](kStore, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPublicKey)
})
}

View File

@ -2,14 +2,11 @@ package main
import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/auth"
"github.com/1f349/mjwt/claims"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/golang-jwt/jwt/v4"
"github.com/google/subcommands"
"os"
@ -18,7 +15,7 @@ import (
)
type accessCmd struct {
issuer, subject, id, audience, duration string
issuer, subject, id, audience, duration, kID string
}
func (s *accessCmd) Name() string { return "access" }
@ -26,7 +23,7 @@ func (s *accessCmd) Synopsis() string {
return "Generates an access token with permissions using the private key"
}
func (s *accessCmd) Usage() string {
return `sign [-iss <issuer>] [-sub <subject>] [-id <id>] [-aud <audience>] [-dur <duration>] <private key path> <space separated permissions>
return `sign [-iss <issuer>] [-sub <subject>] [-id <id>] [-aud <audience>] [-dur <duration>] [-kid <name>] <private key path> <space separated permissions>
Output a signed MJWT token with the specified permissions.
`
}
@ -37,6 +34,7 @@ func (s *accessCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&s.id, "id", "", "MJWT ID")
f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT")
f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)")
f.StringVar(&s.kID, "kid", "", "The Key ID of the signing key")
}
func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
@ -46,13 +44,13 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
}
args := f.Args()
key, err := s.parseKey(args[0])
key, err := rsaprivate.Read(args[0])
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to parse private key: ", err)
return subcommands.ExitFailure
}
ps := claims.NewPermStorage()
ps := auth.NewPermStorage()
for i := 1; i < len(args); i++ {
ps.Set(args[i])
}
@ -67,8 +65,17 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
return subcommands.ExitFailure
}
signer := mjwt.NewMJwtSigner(s.issuer, key)
token, err := signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
var token string
kStore := mjwt.NewKeyStore()
kStore.LoadPrivateKey(s.kID, key)
issuer, err := mjwt.NewIssuerWithKeyStore(s.issuer, s.kID, jwt.SigningMethodRS512, kStore)
if err != nil {
panic("this should not fail")
}
token, err = issuer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err)
return subcommands.ExitFailure
@ -77,13 +84,3 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
fmt.Println(token)
return subcommands.ExitSuccess
}
func (s *accessCmd) parseKey(privKeyFile string) (*rsa.PrivateKey, error) {
b, err := os.ReadFile(privKeyFile)
if err != nil {
return nil, err
}
p, _ := pem.Decode(b)
return x509.ParsePKCS1PrivateKey(p.Bytes)
}

View File

@ -3,10 +3,10 @@ package main
import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/google/subcommands"
"math/rand"
"os"
@ -49,29 +49,14 @@ func (g *genCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) s
}
func (g *genCmd) gen(privPath, pubPath string) error {
createPriv, err := os.OpenFile(privPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer createPriv.Close()
createPub, err := os.OpenFile(pubPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer createPub.Close()
key, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), g.bits)
if err != nil {
return err
}
keyBytes := x509.MarshalPKCS1PrivateKey(key)
pubBytes := x509.MarshalPKCS1PublicKey(&key.PublicKey)
err = pem.Encode(createPriv, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes})
err = rsaprivate.Write(privPath, key)
if err != nil {
return err
}
err = pem.Encode(createPub, &pem.Block{Type: "RSA PUBLIC KEY", Bytes: pubBytes})
return err
return rsapublic.Write(pubPath, &key.PublicKey)
}

8
empty-claims.go Normal file
View File

@ -0,0 +1,8 @@
package mjwt
// EmptyClaims contains no claims
type EmptyClaims struct{}
func (e EmptyClaims) Valid() error { return nil }
func (e EmptyClaims) Type() string { return "empty-claims" }

15
go.mod
View File

@ -1,17 +1,28 @@
module github.com/1f349/mjwt
go 1.19
go 1.22
toolchain go1.22.3
require (
github.com/1f349/rsa-helper v0.0.2
github.com/becheran/wildmatch-go v1.0.0
github.com/go-jose/go-jose/v4 v4.0.4
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/google/subcommands v1.2.0
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.8.4
github.com/spf13/afero v1.11.0
github.com/stretchr/testify v1.9.0
golang.org/x/sync v0.7.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
golang.org/x/crypto v0.25.0 // indirect
golang.org/x/text v0.16.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
)

33
go.sum
View File

@ -1,18 +1,45 @@
github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9U=
github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg=
github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA=
github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-jose/go-jose/v4 v4.0.4 h1:VsjPI33J0SB9vQM6PLmNjoHqMQNGPiZ0rHL7Ni7Q6/E=
github.com/go-jose/go-jose/v4 v4.0.4/go.mod h1:NKb5HO1EZccyMpiZNbdUw/14tiXNyUJh188dfnMCAfc=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,23 +0,0 @@
package mjwt
import (
"crypto/rsa"
"github.com/golang-jwt/jwt/v4"
"time"
)
// Signer is used to generate MJWT tokens.
// Signer can also be used as a Verifier.
type Signer interface {
Verifier
GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error)
SignJwt(claims jwt.Claims) (string, error)
Issuer() string
PrivateKey() *rsa.PrivateKey
}
// Verifier is used to verify the validity MJWT tokens and extract the claim values.
type Verifier interface {
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
PublicKey() *rsa.PublicKey
}

62
issuer.go Normal file
View File

@ -0,0 +1,62 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/golang-jwt/jwt/v4"
"time"
)
// Issuer provides the signing for a PrivateKey identified by the KID in the
// provided KeyStore
type Issuer struct {
issuer string
kid string
signing jwt.SigningMethod
keystore *KeyStore
}
// NewIssuer creates an Issuer with an empty KeyStore
func NewIssuer(name, kid string, signing jwt.SigningMethod) (*Issuer, error) {
return NewIssuerWithKeyStore(name, kid, signing, NewKeyStore())
}
// NewIssuerWithKeyStore creates an Issuer with a provided KeyStore
func NewIssuerWithKeyStore(name, kid string, signing jwt.SigningMethod, keystore *KeyStore) (*Issuer, error) {
i := &Issuer{name, kid, signing, keystore}
if i.keystore.HasPrivateKey(kid) {
return i, nil
}
key, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, err
}
i.keystore.LoadPrivateKey(kid, key)
return i, i.keystore.SaveSingleKey(kid)
}
// GenerateJwt produces a signed JWT in string form
func (i *Issuer) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
return i.SignJwt(wrapClaims[Claims](sub, id, i.issuer, aud, dur, claims))
}
// SignJwt produces a signed JWT in string form from a raw jwt.Claims structure
func (i *Issuer) SignJwt(wrapped jwt.Claims) (string, error) {
key, err := i.PrivateKey()
if err != nil {
return "", err
}
token := jwt.NewWithClaims(i.signing, wrapped)
token.Header["kid"] = i.kid
return token.SignedString(key)
}
// PrivateKey outputs the rsa.PrivateKey from the KID of the Issuer
func (i *Issuer) PrivateKey() (*rsa.PrivateKey, error) {
return i.keystore.GetPrivateKey(i.kid)
}
// KeyStore outputs the underlying KeyStore used by the Issuer
func (i *Issuer) KeyStore() *KeyStore {
return i.keystore
}

59
issuer_test.go Normal file
View File

@ -0,0 +1,59 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/golang-jwt/jwt/v4"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"testing"
)
func TestNewIssuer(t *testing.T) {
t.Parallel()
t.Run("generate missing key for issuer", func(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
issuer, err := NewIssuerWithKeyStore("Test", "test", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
assert.True(t, kStore.HasPrivateKey("test"))
assert.True(t, kStore.HasPublicKey("test"))
assert.Equal(t, "Test", issuer.issuer)
assert.Equal(t, "test", issuer.kid)
})
t.Run("use existing issuer key", func(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
kStore.LoadPrivateKey("test", key)
issuer, err := NewIssuerWithKeyStore("Test", "test", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
assert.True(t, kStore.HasPrivateKey("test"))
assert.True(t, kStore.HasPublicKey("test"))
assert.Equal(t, "Test", issuer.issuer)
assert.Equal(t, "test", issuer.kid)
privateKey, err := issuer.PrivateKey()
assert.NoError(t, err)
assert.True(t, key.Equal(privateKey))
})
t.Run("generate missing key in filesystem", func(t *testing.T) {
t.Parallel()
dir := afero.NewMemMapFs()
kStore := NewKeyStoreWithDir(dir)
issuer, err := NewIssuerWithKeyStore("Test", "test", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)
assert.True(t, kStore.HasPrivateKey("test"))
assert.True(t, kStore.HasPublicKey("test"))
assert.Equal(t, "Test", issuer.issuer)
assert.Equal(t, "test", issuer.kid)
privKeyFile, err := dir.Open("test.private.pem")
assert.NoError(t, err)
privKey, err := rsaprivate.Decode(privKeyFile)
assert.NoError(t, err)
key, err := issuer.PrivateKey()
assert.NoError(t, err)
assert.True(t, key.Equal(privKey))
})
}

31
jwks.go Normal file
View File

@ -0,0 +1,31 @@
package mjwt
import (
"encoding/json"
"github.com/go-jose/go-jose/v4"
"io"
)
// WriteJwkSetJson outputs the public keys used by the Issuers
func WriteJwkSetJson(w io.Writer, issuers []*Issuer) error {
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
var j jose.JSONWebKeySet
for _, issuer := range issuers {
// get public key from private key
key, err := issuer.PrivateKey()
if err != nil {
return err
}
pubKey := &key.PublicKey
// format as JWK
j.Keys = append(j.Keys, jose.JSONWebKey{
Algorithm: issuer.signing.Alg(),
Use: "sig",
KeyID: issuer.kid,
Key: pubKey,
})
}
return enc.Encode(j)
}

265
keystore.go Normal file
View File

@ -0,0 +1,265 @@
package mjwt
import (
"crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/golang-jwt/jwt/v4"
"github.com/spf13/afero"
"golang.org/x/sync/errgroup"
"io/fs"
"path/filepath"
"runtime"
"strings"
"sync"
)
var ErrMissingPrivateKey = errors.New("missing private key")
var ErrMissingPublicKey = errors.New("missing public key")
var ErrMissingKeyPair = errors.New("missing key pair")
const PrivateStr = ".private"
const PublicStr = ".public"
const PemExt = ".pem"
const PrivatePemExt = PrivateStr + PemExt
const PublicPemExt = PublicStr + PemExt
// KeyStore provides a store for a collection of private/public keypair structs
type KeyStore struct {
mu *sync.RWMutex
store map[string]*keyPair
dir afero.Fs
}
// NewKeyStore creates an empty KeyStore
func NewKeyStore() *KeyStore {
return &KeyStore{
mu: new(sync.RWMutex),
store: make(map[string]*keyPair),
}
}
// NewKeyStoreWithDir creates an empty KeyStore with an underlying afero.Fs
// filesystem for saving the internal store data
func NewKeyStoreWithDir(dir afero.Fs) *KeyStore {
keyStore := NewKeyStore()
keyStore.dir = dir
return keyStore
}
// NewKeyStoreFromPath creates an empty KeyStore. The provided path is walked to
// load the private/public keys. See implementation in NewKeyStoreFromDir.
func NewKeyStoreFromPath(dir string) (*KeyStore, error) {
abs, err := filepath.Abs(dir)
if err != nil {
return nil, err
}
return NewKeyStoreFromDir(afero.NewBasePathFs(afero.NewOsFs(), abs))
}
// NewKeyStoreFromDir creates an empty KeyStore. The provided afero.Fs is walked
// to find all private/public keys in files named `.private.pem` and
// `.public.pem` respectively. The keys are loaded into the KeyStore and any
// errors are returned immediately.
func NewKeyStoreFromDir(dir afero.Fs) (*KeyStore, error) {
keyStore := NewKeyStoreWithDir(dir)
err := afero.Walk(dir, ".", func(path string, d fs.FileInfo, err error) error {
// maybe this is "name.private.pem"
name := filepath.Base(path)
ext := filepath.Ext(name)
if ext != PemExt {
return nil
}
name = strings.TrimSuffix(name, ext)
ext = filepath.Ext(name)
name = strings.TrimSuffix(name, ext)
switch ext {
case PrivateStr:
open, err := dir.Open(path)
if err != nil {
return err
}
decode, err := rsaprivate.Decode(open)
if err != nil {
return err
}
keyStore.LoadPrivateKey(name, decode)
return nil
case PublicStr:
open, err := dir.Open(path)
if err != nil {
return err
}
decode, err := rsapublic.Decode(open)
if err != nil {
return err
}
keyStore.LoadPublicKey(name, decode)
return nil
}
// still invalid
return nil
})
return keyStore, err
}
type keyPair struct {
private *rsa.PrivateKey
public *rsa.PublicKey
}
// LoadPrivateKey sets the rsa.PrivateKey/rsa.PublicKey for the KID
func (k *KeyStore) LoadPrivateKey(kid string, key *rsa.PrivateKey) {
k.mu.Lock()
if k.store[kid] == nil {
k.store[kid] = &keyPair{}
}
k.store[kid].private = key
k.store[kid].public = &key.PublicKey
k.mu.Unlock()
}
// LoadPublicKey sets the rsa.PublicKey for the KID
func (k *KeyStore) LoadPublicKey(kid string, key *rsa.PublicKey) {
k.mu.Lock()
if k.store[kid] == nil {
k.store[kid] = &keyPair{}
}
k.store[kid].public = key
k.mu.Unlock()
}
// RemoveKey deletes the KID keypair from the KeyStore
func (k *KeyStore) RemoveKey(kid string) {
k.mu.Lock()
delete(k.store, kid)
k.mu.Unlock()
}
// ListKeys provides a slice of the KIDs for all keys loaded in the KeyStore
func (k *KeyStore) ListKeys() []string {
k.mu.RLock()
defer k.mu.RUnlock()
keys := make([]string, 0, len(k.store))
for k, _ := range k.store {
keys = append(keys, k)
}
return keys
}
// GetPrivateKey outputs the rsa.PrivateKey for the KID from the KeyStore
func (k *KeyStore) GetPrivateKey(kid string) (*rsa.PrivateKey, error) {
k.mu.RLock()
defer k.mu.RUnlock()
if !k.internalHasPrivateKey(kid) {
return nil, ErrMissingPrivateKey
}
return k.store[kid].private, nil
}
// GetPublicKey outputs the rsa.PublicKey for the KID from the KeyStore
func (k *KeyStore) GetPublicKey(kid string) (*rsa.PublicKey, error) {
k.mu.RLock()
defer k.mu.RUnlock()
if !k.internalHasPublicKey(kid) {
return nil, ErrMissingPublicKey
}
return k.store[kid].public, nil
}
// ClearKeys clears the internal map and makes a new map to release used memory
func (k *KeyStore) ClearKeys() {
k.mu.Lock()
clear(k.store)
k.store = make(map[string]*keyPair)
k.mu.Unlock()
}
// HasPrivateKey outputs true if the KID is found in the KeyStore
func (k *KeyStore) HasPrivateKey(kid string) bool {
k.mu.RLock()
defer k.mu.RUnlock()
return k.internalHasPrivateKey(kid)
}
func (k *KeyStore) internalHasPrivateKey(kid string) bool {
v := k.store[kid]
return v != nil && v.private != nil
}
// HasPublicKey outputs true if the KID is found in the KeyStore
func (k *KeyStore) HasPublicKey(kid string) bool {
k.mu.RLock()
defer k.mu.RUnlock()
return k.internalHasPublicKey(kid)
}
func (k *KeyStore) internalHasPublicKey(kid string) bool {
v := k.store[kid]
return v != nil && v.public != nil
}
// VerifyJwt parses the provided token string and validates it against the KID
// using the KeyStore. An error is returned if the token fails to parse or if
// there is no matching KID in the KeyStore.
func (k *KeyStore) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, ErrMissingPublicKey
}
return k.GetPublicKey(kid)
})
if err != nil {
return nil, err
}
return withClaims, claims.Valid()
}
// SaveSingleKey writes the rsa.PrivateKey/rsa.PublicKey for the requested KID to
// the underlying afero.Fs.
func (k *KeyStore) SaveSingleKey(kid string) error {
if k.dir == nil {
return nil
}
k.mu.RLock()
pair := k.store[kid]
k.mu.RUnlock()
if pair == nil {
return ErrMissingKeyPair
}
return writeSingleKey(k.dir, kid, pair)
}
// SaveKeys writes the rsa.PrivateKey/rsa.PublicKey for the requested KID to the
// underlying afero.Fs.
func (k *KeyStore) SaveKeys() error {
k.mu.RLock()
defer k.mu.RUnlock()
workers := new(errgroup.Group)
workers.SetLimit(runtime.NumCPU())
for kid, pair := range k.store {
workers.Go(func() error {
return writeSingleKey(k.dir, kid, pair)
})
}
return workers.Wait()
}
func writeSingleKey(dir afero.Fs, kid string, pair *keyPair) error {
var errs []error
if pair.private != nil {
errs = append(errs, afero.WriteFile(dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
}
if pair.public != nil {
errs = append(errs, afero.WriteFile(dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
}
return errors.Join(errs...)
}

175
keystore_test.go Normal file
View File

@ -0,0 +1,175 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"sort"
"testing"
)
const kst_prvExt = "prv"
const kst_pubExt = "pub"
func setupTestDirKeyStore(t *testing.T, genKeys bool) afero.Fs {
tempDir := afero.NewMemMapFs()
if genKeys {
key1, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key1.private.pem", rsaprivate.Encode(key1), 0600)
assert.NoError(t, err)
key2, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key2.private.pem", rsaprivate.Encode(key2), 0600)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key2.public.pem", rsapublic.Encode(&key2.PublicKey), 0600)
assert.NoError(t, err)
key3, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key3.public.pem", rsapublic.Encode(&key3.PublicKey), 0600)
assert.NoError(t, err)
}
return tempDir
}
func commonSubTestsKeyStore(t *testing.T, kStore *KeyStore) {
key4, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
key5, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
const extraKID1 = "key4"
const extraKID2 = "key5"
t.Run("TestSetKey", func(t *testing.T) {
kStore.LoadPrivateKey(extraKID1, key4)
assert.Contains(t, kStore.ListKeys(), extraKID1)
})
t.Run("TestSetKeyPublic", func(t *testing.T) {
kStore.LoadPublicKey(extraKID2, &key5.PublicKey)
assert.Contains(t, kStore.ListKeys(), extraKID2)
})
t.Run("TestGetPrivateKey", func(t *testing.T) {
oKey, err := kStore.GetPrivateKey(extraKID1)
assert.NoError(t, err)
assert.Same(t, key4, oKey)
pKey, err := kStore.GetPrivateKey(extraKID2)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPrivateKey)
assert.Nil(t, pKey)
aKey, err := kStore.GetPrivateKey("key1")
assert.NoError(t, err)
assert.NotNil(t, aKey)
bKey, err := kStore.GetPrivateKey("key2")
assert.NoError(t, err)
assert.NotNil(t, bKey)
cKey, err := kStore.GetPrivateKey("key3")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPrivateKey)
assert.Nil(t, cKey)
wKey, err := kStore.GetPrivateKey("key1337")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPrivateKey)
assert.Nil(t, wKey)
})
t.Run("TestGetPublicKey", func(t *testing.T) {
oKey, err := kStore.GetPublicKey(extraKID1)
assert.NoError(t, err)
assert.Same(t, &key4.PublicKey, oKey)
pKey, err := kStore.GetPublicKey(extraKID2)
assert.NoError(t, err)
assert.Same(t, &key5.PublicKey, pKey)
aKey, err := kStore.GetPublicKey("key1")
assert.NoError(t, err)
assert.NotNil(t, aKey)
bKey, err := kStore.GetPublicKey("key2")
assert.NoError(t, err)
assert.NotNil(t, bKey)
cKey, err := kStore.GetPublicKey("key3")
assert.NoError(t, err)
assert.NotNil(t, cKey)
wKey, err := kStore.GetPublicKey("key1337")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPublicKey)
assert.Nil(t, wKey)
})
t.Run("TestRemoveKey", func(t *testing.T) {
kStore.RemoveKey(extraKID1)
assert.NotContains(t, kStore.ListKeys(), extraKID1)
oKey1, err := kStore.GetPrivateKey(extraKID1)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPrivateKey)
assert.Nil(t, oKey1)
oKey2, err := kStore.GetPublicKey(extraKID1)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPublicKey)
assert.Nil(t, oKey2)
})
t.Run("TestClearKeys", func(t *testing.T) {
kStore.ClearKeys()
assert.Empty(t, kStore.ListKeys())
})
}
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
t.Parallel()
tempDir := setupTestDirKeyStore(t, true)
kStore, err := NewKeyStoreFromDir(tempDir)
assert.NoError(t, err)
assert.Len(t, kStore.ListKeys(), 3)
kIDsToFind := []string{"key1", "key2", "key3"}
for _, k := range kIDsToFind {
assert.Contains(t, kStore.ListKeys(), k)
}
assert.True(t, kStore.HasPrivateKey("key1"))
assert.True(t, kStore.HasPublicKey("key1")) // loading a private key also loads the public key
assert.True(t, kStore.HasPrivateKey("key2"))
assert.True(t, kStore.HasPublicKey("key2"))
assert.False(t, kStore.HasPrivateKey("key3"))
assert.True(t, kStore.HasPublicKey("key3"))
commonSubTestsKeyStore(t, kStore)
}
func TestExportKeyStore(t *testing.T) {
t.Parallel()
tempDir := setupTestDirKeyStore(t, true)
tempDir2 := setupTestDirKeyStore(t, false)
kStore, err := NewKeyStoreFromDir(tempDir)
assert.NoError(t, err)
// internally swap directory
kStore.dir = tempDir2
err = kStore.SaveKeys()
assert.NoError(t, err)
kStore2, err := NewKeyStoreFromDir(tempDir2)
assert.NoError(t, err)
kidList1 := kStore.ListKeys()
kidList2 := kStore2.ListKeys()
sort.Strings(kidList1)
sort.Strings(kidList2)
assert.Equal(t, kidList1, kidList2)
commonSubTestsKeyStore(t, kStore2)
}

View File

@ -1,61 +0,0 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"fmt"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type testClaims struct{ TestValue string }
func (t testClaims) Valid() error {
if t.TestValue != "hello" && t.TestValue != "world" {
return fmt.Errorf("TestValue should be hello")
}
return nil
}
func (t testClaims) Type() string { return "testClaims" }
type testClaims2 struct{ TestValue2 string }
func (t testClaims2) Valid() error {
if t.TestValue2 != "world" {
return fmt.Errorf("TestValue2 should be world")
}
return nil
}
func (t testClaims2) Type() string { return "testClaims2" }
func TestExtractClaims(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
s := NewMJwtSigner("mjwt.test", key)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
m := NewMJwtVerifier(&key.PublicKey)
_, _, err = ExtractClaims[testClaims](m, token)
assert.NoError(t, err)
}
func TestExtractClaimsFail(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
s := NewMJwtSigner("mjwt.test", key)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err)
m := NewMJwtVerifier(&key.PublicKey)
_, _, err = ExtractClaims[testClaims2](m, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
}

143
signer.go
View File

@ -1,143 +0,0 @@
package mjwt
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"github.com/golang-jwt/jwt/v4"
"io"
"os"
"time"
)
// defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name
// to generate MJWT tokens
type defaultMJwtSigner struct {
issuer string
key *rsa.PrivateKey
verify *defaultMJwtVerifier
}
var _ Signer = &defaultMJwtSigner{}
var _ Verifier = &defaultMJwtSigner{}
// NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
return &defaultMJwtSigner{
issuer: issuer,
key: key,
verify: newMJwtVerifier(&key.PublicKey),
}
}
// NewMJwtSignerFromFileOrCreate creates a new defaultMJwtSigner using the path
// of a rsa.PrivateKey file. If the file does not exist then the file is created
// and a new private key is generated.
func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits int) (Signer, error) {
privateKey, err := readOrCreatePrivateKey(file, random, bits)
if err != nil {
return nil, err
}
return NewMJwtSigner(issuer, privateKey), nil
}
// NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a
// rsa.PrivateKey file.
func NewMJwtSignerFromFile(issuer, file string) (Signer, error) {
// read file
raw, err := os.ReadFile(file)
if err != nil {
return nil, err
}
// decode pem block
block, _ := pem.Decode(raw)
if block == nil || block.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("invalid rsa private key pem block")
}
// parse private key from pem block
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
// create signer using rsa.PrivateKey
return NewMJwtSigner(issuer, key), nil
}
// Issuer returns the name of the issuer
func (d *defaultMJwtSigner) Issuer() string { return d.issuer }
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims))
}
// SignJwt signs a jwt.Claims compatible struct, this is used internally by
// GenerateJwt but is available for signing custom structs
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
return token.SignedString(d.key)
}
// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt()
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
return d.verify.VerifyJwt(token, claims)
}
func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey { return d.key }
func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey { return d.verify.pub }
// readOrCreatePrivateKey returns the private key it the file already exists,
// generates a new private key and saves it to the file, or returns an error if
// reading or generating failed.
func readOrCreatePrivateKey(file string, random io.Reader, bits int) (*rsa.PrivateKey, error) {
// read the file or return nil
f, err := readOrEmptyFile(file)
if err != nil {
return nil, err
}
if f == nil {
// generate a new key
key, err := rsa.GenerateKey(random, bits)
if err != nil {
return nil, err
}
keyBytes := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
if err != nil {
return nil, err
}
// write the key to the file
err = os.WriteFile(file, keyBytes, 0600)
return key, err
} else {
// decode pem block
block, _ := pem.Decode(f)
if block == nil || block.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("invalid rsa private key pem block")
}
// try to parse the private key
return x509.ParsePKCS1PrivateKey(block.Bytes)
}
}
// readOrEmptyFile returns bytes and errors from os.ReadFile or (nil, nil) if the
// file does not exist.
func readOrEmptyFile(file string) ([]byte, error) {
raw, err := os.ReadFile(file)
if err == nil {
return raw, nil
}
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}

View File

@ -1,69 +0,0 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"github.com/stretchr/testify/assert"
"os"
"testing"
)
func TestNewMJwtSigner(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
NewMJwtSigner("Test", key)
}
func TestNewMJwtSignerFromFile(t *testing.T) {
t.Parallel()
tempKey, err := os.CreateTemp("", "key-test-*.pem")
assert.NoError(t, err)
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
_, err = tempKey.Write(b)
assert.NoError(t, err)
assert.NoError(t, tempKey.Close())
signer, err := NewMJwtSignerFromFile("Test", tempKey.Name())
assert.NoError(t, err)
assert.NoError(t, os.Remove(tempKey.Name()))
_, err = NewMJwtSignerFromFile("Test", tempKey.Name())
assert.Error(t, err)
assert.True(t, os.IsNotExist(err))
assert.True(t, signer.(*defaultMJwtSigner).key.Equal(key))
}
func TestNewMJwtSignerFromFileOrCreate(t *testing.T) {
t.Parallel()
tempKey, err := os.CreateTemp("", "key-test-*.pem")
assert.NoError(t, err)
assert.NoError(t, tempKey.Close())
assert.NoError(t, os.Remove(tempKey.Name()))
signer, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048)
assert.NoError(t, err)
signer2, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048)
assert.NoError(t, err)
assert.True(t, signer.PrivateKey().Equal(signer2.PrivateKey()))
}
func TestReadOrCreatePrivateKey(t *testing.T) {
t.Parallel()
tempKey, err := os.CreateTemp("", "key-test-*.pem")
assert.NoError(t, err)
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
_, err = tempKey.Write(b)
assert.NoError(t, err)
assert.NoError(t, tempKey.Close())
key2, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048)
assert.NoError(t, err)
assert.True(t, key.Equal(key2))
assert.NoError(t, os.Remove(tempKey.Name()))
key3, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048)
assert.NoError(t, err)
assert.NoError(t, key3.Validate())
}

View File

@ -1,61 +0,0 @@
package mjwt
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"github.com/golang-jwt/jwt/v4"
"os"
)
// defaultMJwtVerifier implements Verifier and uses a rsa.PublicKey to validate
// MJWT tokens
type defaultMJwtVerifier struct {
pub *rsa.PublicKey
}
var _ Verifier = &defaultMJwtVerifier{}
// NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey
func NewMJwtVerifier(key *rsa.PublicKey) Verifier {
return newMJwtVerifier(key)
}
func newMJwtVerifier(key *rsa.PublicKey) *defaultMJwtVerifier {
return &defaultMJwtVerifier{pub: key}
}
// NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a
// rsa.PublicKey file
func NewMJwtVerifierFromFile(file string) (Verifier, error) {
// read file
f, err := os.ReadFile(file)
if err != nil {
return nil, err
}
// decode pem block
block, _ := pem.Decode(f)
// parse public key from pem block
pub, err := x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return nil, err
}
// create verifier using rsa.PublicKey
return NewMJwtVerifier(pub), nil
}
// VerifyJwt validates and parses MJWT tokens and returns the claims
func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return d.pub, nil
})
if err != nil {
return nil, err
}
return withClaims, claims.Valid()
}
func (d *defaultMJwtVerifier) PublicKey() *rsa.PublicKey { return d.pub }

View File

@ -1,34 +0,0 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"github.com/stretchr/testify/assert"
"os"
"testing"
"time"
)
func TestNewMJwtVerifierFromFile(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
s := NewMJwtSigner("mjwt.test", key)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"})
assert.NoError(t, err)
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&key.PublicKey)})
temp, err := os.CreateTemp("", "this-is-a-test-file.pem")
assert.NoError(t, err)
_, err = temp.Write(b)
assert.NoError(t, err)
file, err := NewMJwtVerifierFromFile(temp.Name())
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims](file, token)
assert.NoError(t, err)
err = os.Remove(temp.Name())
assert.NoError(t, err)
}