mirror of
https://github.com/1f349/mjwt.git
synced 2025-04-15 15:27:56 +01:00
Compare commits
26 Commits
Author | SHA1 | Date | |
---|---|---|---|
87774ec45e | |||
7eaf420bb9 | |||
4e2c18918f | |||
1fc34736a2 | |||
cd2d80cb09 | |||
5d1bd6f8fd | |||
690b9f9512 | |||
3201964fec | |||
9a1029861c | |||
5e627ed024 | |||
fe2d905236 | |||
a94ed7a2e5 | |||
a0d03c0dfb | |||
d76a534346 | |||
dc95ed754c | |||
ce5eccfb3c | |||
407f8510b6 | |||
6fbc9e3c1f | |||
32cfa7a30d | |||
3a7b3dd250 | |||
545b688391 | |||
6a34395d8e | |||
ca4e4b7cae | |||
ab84ded3a1 | |||
1792211ca2 | |||
82d4a4a414 |
15
.github/workflows/test.yml
vendored
Normal file
15
.github/workflows/test.yml
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
on: [push, pull_request]
|
||||
name: Test
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: [1.22.x]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
- uses: actions/checkout@v4
|
||||
- run: go build ./cmd/mjwt/
|
||||
- run: go test ./...
|
@ -1,3 +1,3 @@
|
||||
# MJWT
|
||||
|
||||
A simple wrapper for JWT. Contains an AccessToken and RefreshToken model.
|
||||
A simple wrapper for JWT. Contains an AccessToken and RefreshToken model.
|
||||
|
@ -2,14 +2,13 @@ package auth
|
||||
|
||||
import (
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AccessTokenClaims contains the JWT claims for an access token
|
||||
type AccessTokenClaims struct {
|
||||
Perms *claims.PermStorage `json:"per"`
|
||||
Perms *PermStorage `json:"per"`
|
||||
}
|
||||
|
||||
func (a AccessTokenClaims) Valid() error { return nil }
|
||||
@ -17,11 +16,11 @@ func (a AccessTokenClaims) Valid() error { return nil }
|
||||
func (a AccessTokenClaims) Type() string { return "access-token" }
|
||||
|
||||
// CreateAccessToken creates an access token with the default 15 minute duration
|
||||
func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
|
||||
func CreateAccessToken(p *mjwt.Issuer, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
|
||||
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms)
|
||||
}
|
||||
|
||||
// CreateAccessTokenWithDuration creates an access token with a custom duration
|
||||
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
|
||||
func CreateAccessTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
|
||||
return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms})
|
||||
}
|
||||
|
@ -1,29 +1,27 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCreateAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ps := claims.NewPermStorage()
|
||||
ps := NewPermStorage()
|
||||
ps.Set("mjwt:test")
|
||||
ps.Set("mjwt:test2")
|
||||
|
||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||
kStore := mjwt.NewKeyStore()
|
||||
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
|
@ -2,20 +2,19 @@ package auth
|
||||
|
||||
import (
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CreateTokenPair creates an access and refresh token pair using the default
|
||||
// 15 minute and 7 day durations respectively
|
||||
func CreateTokenPair(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
|
||||
func CreateTokenPair(p *mjwt.Issuer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
|
||||
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms)
|
||||
}
|
||||
|
||||
// CreateTokenPairWithDuration creates an access and refresh token pair using
|
||||
// custom durations for the access and refresh tokens
|
||||
func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
|
||||
func CreateTokenPairWithDuration(p *mjwt.Issuer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
|
||||
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
@ -1,29 +1,27 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCreateTokenPair(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ps := claims.NewPermStorage()
|
||||
ps := NewPermStorage()
|
||||
ps.Set("mjwt:test")
|
||||
ps.Set("mjwt:test2")
|
||||
|
||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||
kStore := mjwt.NewKeyStore()
|
||||
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key2", jwt.SigningMethodRS512, kStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
@ -31,7 +29,7 @@ func TestCreateTokenPair(t *testing.T) {
|
||||
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
||||
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
||||
|
||||
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b2.Subject)
|
||||
assert.Equal(t, "test2", b2.ID)
|
||||
|
@ -1,10 +1,12 @@
|
||||
package claims
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/becheran/wildmatch-go"
|
||||
"gopkg.in/yaml.v3"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type PermStorage struct {
|
||||
@ -15,6 +17,16 @@ func NewPermStorage() *PermStorage {
|
||||
return new(PermStorage).setup()
|
||||
}
|
||||
|
||||
func ParsePermStorage(perms string) *PermStorage {
|
||||
ps := NewPermStorage()
|
||||
sc := bufio.NewScanner(strings.NewReader(perms))
|
||||
sc.Split(bufio.ScanWords)
|
||||
for sc.Scan() {
|
||||
ps.Set(sc.Text())
|
||||
}
|
||||
return ps
|
||||
}
|
||||
|
||||
func (p *PermStorage) setup() *PermStorage {
|
||||
if p.values == nil {
|
||||
p.values = make(map[string]struct{})
|
||||
@ -64,6 +76,16 @@ func (p *PermStorage) Search(v string) []string {
|
||||
return a
|
||||
}
|
||||
|
||||
func (p *PermStorage) Filter(match []string) *PermStorage {
|
||||
out := NewPermStorage()
|
||||
for _, i := range match {
|
||||
for _, j := range p.Search(i) {
|
||||
out.Set(j)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *PermStorage) prepare(a []string) {
|
||||
for _, i := range a {
|
||||
p.Set(i)
|
@ -1,11 +1,20 @@
|
||||
package claims
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParsePermStorage(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := ParsePermStorage("mjwt:test mjwt:test2")
|
||||
if _, ok := ps.values["mjwt:test"]; !ok {
|
||||
assert.Fail(t, "perm not set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermStorage_Set(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := NewPermStorage()
|
||||
@ -51,6 +60,22 @@ func TestPermStorage_OneOf(t *testing.T) {
|
||||
assert.False(t, ps.OneOf(o))
|
||||
}
|
||||
|
||||
func TestPermStorage_Search(t *testing.T) {
|
||||
ps := ParsePermStorage("mjwt:test mjwt:test2 mjwt:other")
|
||||
a := ps.Search("mjwt:test*")
|
||||
sort.Strings(a)
|
||||
assert.Equal(t, []string{"mjwt:test", "mjwt:test2"}, a)
|
||||
assert.Equal(t, []string{"mjwt:other"}, ps.Search("mjwt:other"))
|
||||
}
|
||||
|
||||
func TestPermStorage_Filter(t *testing.T) {
|
||||
ps := ParsePermStorage("mjwt:test mjwt:test2 mjwt:other mjwt:other2 mjwt:another")
|
||||
a := ps.Filter([]string{"mjwt:test*", "mjwt:other*"}).Dump()
|
||||
sort.Strings(a)
|
||||
assert.Equal(t, []string{"mjwt:other", "mjwt:other2", "mjwt:test", "mjwt:test2"}, a)
|
||||
assert.Equal(t, []string{"mjwt:another"}, ps.Filter([]string{"mjwt:another"}).Dump())
|
||||
}
|
||||
|
||||
func TestPermStorage_MarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := NewPermStorage()
|
@ -16,11 +16,11 @@ func (r RefreshTokenClaims) Valid() error { return nil }
|
||||
func (r RefreshTokenClaims) Type() string { return "refresh-token" }
|
||||
|
||||
// CreateRefreshToken creates a refresh token with the default 7 day duration
|
||||
func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||
func CreateRefreshToken(p *mjwt.Issuer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud)
|
||||
}
|
||||
|
||||
// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
|
||||
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||
func CreateRefreshTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||
return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati})
|
||||
}
|
||||
|
@ -1,24 +1,23 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCreateRefreshToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||
kStore := mjwt.NewKeyStore()
|
||||
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
|
@ -10,11 +10,11 @@ import (
|
||||
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
|
||||
|
||||
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
|
||||
func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
|
||||
func wrapClaims[T Claims](sub, id, issuer string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
|
||||
now := time.Now()
|
||||
return (&BaseTypeClaims[T]{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: p.Issuer(),
|
||||
Issuer: issuer,
|
||||
Subject: sub,
|
||||
Audience: aud,
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
|
||||
@ -28,12 +28,12 @@ func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur ti
|
||||
|
||||
// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
|
||||
// token and BaseTypeClaims
|
||||
func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) {
|
||||
func ExtractClaims[T Claims](ks *KeyStore, token string) (*jwt.Token, BaseTypeClaims[T], error) {
|
||||
b := BaseTypeClaims[T]{
|
||||
RegisteredClaims: jwt.RegisteredClaims{},
|
||||
Claims: *new(T),
|
||||
}
|
||||
tok, err := p.VerifyJwt(token, &b)
|
||||
tok, err := ks.VerifyJwt(token, &b)
|
||||
return tok, b, err
|
||||
}
|
||||
|
101
claims_test.go
Normal file
101
claims_test.go
Normal file
@ -0,0 +1,101 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testClaims struct{ TestValue string }
|
||||
|
||||
func (t testClaims) Valid() error {
|
||||
if t.TestValue != "hello" && t.TestValue != "world" {
|
||||
return fmt.Errorf("TestValue should be hello")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t testClaims) Type() string { return "testClaims" }
|
||||
|
||||
type testClaims2 struct{ TestValue2 string }
|
||||
|
||||
func (t testClaims2) Valid() error {
|
||||
if t.TestValue2 != "world" {
|
||||
return fmt.Errorf("TestValue2 should be world")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t testClaims2) Type() string { return "testClaims2" }
|
||||
|
||||
func TestExtractClaims(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore := NewKeyStore()
|
||||
|
||||
t.Run("TestNoKID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
|
||||
assert.NoError(t, err)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
a, _, err := ExtractClaims[testClaims](kStore, token)
|
||||
assert.NoError(t, err)
|
||||
kid, _ := a.Header["kid"].(string)
|
||||
assert.Equal(t, "key1", kid)
|
||||
})
|
||||
|
||||
t.Run("TestKID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", jwt.SigningMethodRS512, kStore)
|
||||
assert.NoError(t, err)
|
||||
s2, err := NewIssuerWithKeyStore("mjwt.test", "key3", jwt.SigningMethodRS512, kStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||
assert.NoError(t, err)
|
||||
token2, err := s2.GenerateJwt("2", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
k1, _, err := ExtractClaims[testClaims](kStore, token1)
|
||||
assert.NoError(t, err)
|
||||
k2, _, err := ExtractClaims[testClaims](kStore, token2)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, k1.Header["kid"], k2.Header["kid"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractClaimsFail(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore := NewKeyStore()
|
||||
|
||||
t.Run("TestInvalidClaims", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
|
||||
assert.NoError(t, err)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _, err = ExtractClaims[testClaims2](kStore, token)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
|
||||
})
|
||||
|
||||
t.Run("TestKIDNonExist", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", jwt.SigningMethodRS512, kStore)
|
||||
assert.NoError(t, err)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore.RemoveKey("key2")
|
||||
assert.NotContains(t, kStore.ListKeys(), "key2")
|
||||
|
||||
_, _, err = ExtractClaims[testClaims](kStore, token)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPublicKey)
|
||||
})
|
||||
}
|
@ -2,14 +2,11 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/auth"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/subcommands"
|
||||
"os"
|
||||
@ -18,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
type accessCmd struct {
|
||||
issuer, subject, id, audience, duration string
|
||||
issuer, subject, id, audience, duration, kID string
|
||||
}
|
||||
|
||||
func (s *accessCmd) Name() string { return "access" }
|
||||
@ -26,7 +23,7 @@ func (s *accessCmd) Synopsis() string {
|
||||
return "Generates an access token with permissions using the private key"
|
||||
}
|
||||
func (s *accessCmd) Usage() string {
|
||||
return `sign [-iss <issuer>] [-sub <subject>] [-id <id>] [-aud <audience>] [-dur <duration>] <private key path> <space separated permissions>
|
||||
return `sign [-iss <issuer>] [-sub <subject>] [-id <id>] [-aud <audience>] [-dur <duration>] [-kid <name>] <private key path> <space separated permissions>
|
||||
Output a signed MJWT token with the specified permissions.
|
||||
`
|
||||
}
|
||||
@ -37,6 +34,7 @@ func (s *accessCmd) SetFlags(f *flag.FlagSet) {
|
||||
f.StringVar(&s.id, "id", "", "MJWT ID")
|
||||
f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT")
|
||||
f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)")
|
||||
f.StringVar(&s.kID, "kid", "", "The Key ID of the signing key")
|
||||
}
|
||||
|
||||
func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
|
||||
@ -46,13 +44,13 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
|
||||
}
|
||||
|
||||
args := f.Args()
|
||||
key, err := s.parseKey(args[0])
|
||||
key, err := rsaprivate.Read(args[0])
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to parse private key: ", err)
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
|
||||
ps := claims.NewPermStorage()
|
||||
ps := auth.NewPermStorage()
|
||||
for i := 1; i < len(args); i++ {
|
||||
ps.Set(args[i])
|
||||
}
|
||||
@ -67,8 +65,17 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
|
||||
signer := mjwt.NewMJwtSigner(s.issuer, key)
|
||||
token, err := signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
|
||||
var token string
|
||||
|
||||
kStore := mjwt.NewKeyStore()
|
||||
kStore.LoadPrivateKey(s.kID, key)
|
||||
|
||||
issuer, err := mjwt.NewIssuerWithKeyStore(s.issuer, s.kID, jwt.SigningMethodRS512, kStore)
|
||||
if err != nil {
|
||||
panic("this should not fail")
|
||||
}
|
||||
|
||||
token, err = issuer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err)
|
||||
return subcommands.ExitFailure
|
||||
@ -77,13 +84,3 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
|
||||
fmt.Println(token)
|
||||
return subcommands.ExitSuccess
|
||||
}
|
||||
|
||||
func (s *accessCmd) parseKey(privKeyFile string) (*rsa.PrivateKey, error) {
|
||||
b, err := os.ReadFile(privKeyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p, _ := pem.Decode(b)
|
||||
return x509.ParsePKCS1PrivateKey(p.Bytes)
|
||||
}
|
||||
|
@ -3,10 +3,10 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"github.com/google/subcommands"
|
||||
"math/rand"
|
||||
"os"
|
||||
@ -49,29 +49,14 @@ func (g *genCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) s
|
||||
}
|
||||
|
||||
func (g *genCmd) gen(privPath, pubPath string) error {
|
||||
createPriv, err := os.OpenFile(privPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer createPriv.Close()
|
||||
|
||||
createPub, err := os.OpenFile(pubPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer createPub.Close()
|
||||
|
||||
key, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), g.bits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyBytes := x509.MarshalPKCS1PrivateKey(key)
|
||||
pubBytes := x509.MarshalPKCS1PublicKey(&key.PublicKey)
|
||||
err = pem.Encode(createPriv, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes})
|
||||
err = rsaprivate.Write(privPath, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = pem.Encode(createPub, &pem.Block{Type: "RSA PUBLIC KEY", Bytes: pubBytes})
|
||||
return err
|
||||
return rsapublic.Write(pubPath, &key.PublicKey)
|
||||
}
|
||||
|
8
empty-claims.go
Normal file
8
empty-claims.go
Normal file
@ -0,0 +1,8 @@
|
||||
package mjwt
|
||||
|
||||
// EmptyClaims contains no claims
|
||||
type EmptyClaims struct{}
|
||||
|
||||
func (e EmptyClaims) Valid() error { return nil }
|
||||
|
||||
func (e EmptyClaims) Type() string { return "empty-claims" }
|
15
go.mod
15
go.mod
@ -1,17 +1,28 @@
|
||||
module github.com/1f349/mjwt
|
||||
|
||||
go 1.19
|
||||
go 1.22
|
||||
|
||||
toolchain go1.22.3
|
||||
|
||||
require (
|
||||
github.com/1f349/rsa-helper v0.0.2
|
||||
github.com/becheran/wildmatch-go v1.0.0
|
||||
github.com/go-jose/go-jose/v4 v4.0.4
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0
|
||||
github.com/google/subcommands v1.2.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/spf13/afero v1.11.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
golang.org/x/sync v0.7.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||
golang.org/x/crypto v0.25.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
)
|
||||
|
33
go.sum
33
go.sum
@ -1,18 +1,45 @@
|
||||
github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9U=
|
||||
github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg=
|
||||
github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA=
|
||||
github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-jose/go-jose/v4 v4.0.4 h1:VsjPI33J0SB9vQM6PLmNjoHqMQNGPiZ0rHL7Ni7Q6/E=
|
||||
github.com/go-jose/go-jose/v4 v4.0.4/go.mod h1:NKb5HO1EZccyMpiZNbdUw/14tiXNyUJh188dfnMCAfc=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
|
||||
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
@ -1,23 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Signer is used to generate MJWT tokens.
|
||||
// Signer can also be used as a Verifier.
|
||||
type Signer interface {
|
||||
Verifier
|
||||
GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error)
|
||||
SignJwt(claims jwt.Claims) (string, error)
|
||||
Issuer() string
|
||||
PrivateKey() *rsa.PrivateKey
|
||||
}
|
||||
|
||||
// Verifier is used to verify the validity MJWT tokens and extract the claim values.
|
||||
type Verifier interface {
|
||||
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
|
||||
PublicKey() *rsa.PublicKey
|
||||
}
|
62
issuer.go
Normal file
62
issuer.go
Normal file
@ -0,0 +1,62 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Issuer provides the signing for a PrivateKey identified by the KID in the
|
||||
// provided KeyStore
|
||||
type Issuer struct {
|
||||
issuer string
|
||||
kid string
|
||||
signing jwt.SigningMethod
|
||||
keystore *KeyStore
|
||||
}
|
||||
|
||||
// NewIssuer creates an Issuer with an empty KeyStore
|
||||
func NewIssuer(name, kid string, signing jwt.SigningMethod) (*Issuer, error) {
|
||||
return NewIssuerWithKeyStore(name, kid, signing, NewKeyStore())
|
||||
}
|
||||
|
||||
// NewIssuerWithKeyStore creates an Issuer with a provided KeyStore
|
||||
func NewIssuerWithKeyStore(name, kid string, signing jwt.SigningMethod, keystore *KeyStore) (*Issuer, error) {
|
||||
i := &Issuer{name, kid, signing, keystore}
|
||||
if i.keystore.HasPrivateKey(kid) {
|
||||
return i, nil
|
||||
}
|
||||
key, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i.keystore.LoadPrivateKey(kid, key)
|
||||
return i, i.keystore.SaveSingleKey(kid)
|
||||
}
|
||||
|
||||
// GenerateJwt produces a signed JWT in string form
|
||||
func (i *Issuer) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
|
||||
return i.SignJwt(wrapClaims[Claims](sub, id, i.issuer, aud, dur, claims))
|
||||
}
|
||||
|
||||
// SignJwt produces a signed JWT in string form from a raw jwt.Claims structure
|
||||
func (i *Issuer) SignJwt(wrapped jwt.Claims) (string, error) {
|
||||
key, err := i.PrivateKey()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
token := jwt.NewWithClaims(i.signing, wrapped)
|
||||
token.Header["kid"] = i.kid
|
||||
return token.SignedString(key)
|
||||
}
|
||||
|
||||
// PrivateKey outputs the rsa.PrivateKey from the KID of the Issuer
|
||||
func (i *Issuer) PrivateKey() (*rsa.PrivateKey, error) {
|
||||
return i.keystore.GetPrivateKey(i.kid)
|
||||
}
|
||||
|
||||
// KeyStore outputs the underlying KeyStore used by the Issuer
|
||||
func (i *Issuer) KeyStore() *KeyStore {
|
||||
return i.keystore
|
||||
}
|
59
issuer_test.go
Normal file
59
issuer_test.go
Normal file
@ -0,0 +1,59 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewIssuer(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("generate missing key for issuer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore := NewKeyStore()
|
||||
issuer, err := NewIssuerWithKeyStore("Test", "test", jwt.SigningMethodRS512, kStore)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, kStore.HasPrivateKey("test"))
|
||||
assert.True(t, kStore.HasPublicKey("test"))
|
||||
assert.Equal(t, "Test", issuer.issuer)
|
||||
assert.Equal(t, "test", issuer.kid)
|
||||
})
|
||||
t.Run("use existing issuer key", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore := NewKeyStore()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
kStore.LoadPrivateKey("test", key)
|
||||
issuer, err := NewIssuerWithKeyStore("Test", "test", jwt.SigningMethodRS512, kStore)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, kStore.HasPrivateKey("test"))
|
||||
assert.True(t, kStore.HasPublicKey("test"))
|
||||
assert.Equal(t, "Test", issuer.issuer)
|
||||
assert.Equal(t, "test", issuer.kid)
|
||||
privateKey, err := issuer.PrivateKey()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, key.Equal(privateKey))
|
||||
})
|
||||
t.Run("generate missing key in filesystem", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := afero.NewMemMapFs()
|
||||
kStore := NewKeyStoreWithDir(dir)
|
||||
issuer, err := NewIssuerWithKeyStore("Test", "test", jwt.SigningMethodRS512, kStore)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, kStore.HasPrivateKey("test"))
|
||||
assert.True(t, kStore.HasPublicKey("test"))
|
||||
assert.Equal(t, "Test", issuer.issuer)
|
||||
assert.Equal(t, "test", issuer.kid)
|
||||
privKeyFile, err := dir.Open("test.private.pem")
|
||||
assert.NoError(t, err)
|
||||
privKey, err := rsaprivate.Decode(privKeyFile)
|
||||
assert.NoError(t, err)
|
||||
key, err := issuer.PrivateKey()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, key.Equal(privKey))
|
||||
})
|
||||
}
|
31
jwks.go
Normal file
31
jwks.go
Normal file
@ -0,0 +1,31 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"io"
|
||||
)
|
||||
|
||||
// WriteJwkSetJson outputs the public keys used by the Issuers
|
||||
func WriteJwkSetJson(w io.Writer, issuers []*Issuer) error {
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
var j jose.JSONWebKeySet
|
||||
for _, issuer := range issuers {
|
||||
// get public key from private key
|
||||
key, err := issuer.PrivateKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pubKey := &key.PublicKey
|
||||
|
||||
// format as JWK
|
||||
j.Keys = append(j.Keys, jose.JSONWebKey{
|
||||
Algorithm: issuer.signing.Alg(),
|
||||
Use: "sig",
|
||||
KeyID: issuer.kid,
|
||||
Key: pubKey,
|
||||
})
|
||||
}
|
||||
return enc.Encode(j)
|
||||
}
|
265
keystore.go
Normal file
265
keystore.go
Normal file
@ -0,0 +1,265 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/spf13/afero"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var ErrMissingPrivateKey = errors.New("missing private key")
|
||||
var ErrMissingPublicKey = errors.New("missing public key")
|
||||
var ErrMissingKeyPair = errors.New("missing key pair")
|
||||
|
||||
const PrivateStr = ".private"
|
||||
const PublicStr = ".public"
|
||||
|
||||
const PemExt = ".pem"
|
||||
const PrivatePemExt = PrivateStr + PemExt
|
||||
const PublicPemExt = PublicStr + PemExt
|
||||
|
||||
// KeyStore provides a store for a collection of private/public keypair structs
|
||||
type KeyStore struct {
|
||||
mu *sync.RWMutex
|
||||
store map[string]*keyPair
|
||||
dir afero.Fs
|
||||
}
|
||||
|
||||
// NewKeyStore creates an empty KeyStore
|
||||
func NewKeyStore() *KeyStore {
|
||||
return &KeyStore{
|
||||
mu: new(sync.RWMutex),
|
||||
store: make(map[string]*keyPair),
|
||||
}
|
||||
}
|
||||
|
||||
// NewKeyStoreWithDir creates an empty KeyStore with an underlying afero.Fs
|
||||
// filesystem for saving the internal store data
|
||||
func NewKeyStoreWithDir(dir afero.Fs) *KeyStore {
|
||||
keyStore := NewKeyStore()
|
||||
keyStore.dir = dir
|
||||
return keyStore
|
||||
}
|
||||
|
||||
// NewKeyStoreFromPath creates an empty KeyStore. The provided path is walked to
|
||||
// load the private/public keys. See implementation in NewKeyStoreFromDir.
|
||||
func NewKeyStoreFromPath(dir string) (*KeyStore, error) {
|
||||
abs, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewKeyStoreFromDir(afero.NewBasePathFs(afero.NewOsFs(), abs))
|
||||
}
|
||||
|
||||
// NewKeyStoreFromDir creates an empty KeyStore. The provided afero.Fs is walked
|
||||
// to find all private/public keys in files named `.private.pem` and
|
||||
// `.public.pem` respectively. The keys are loaded into the KeyStore and any
|
||||
// errors are returned immediately.
|
||||
func NewKeyStoreFromDir(dir afero.Fs) (*KeyStore, error) {
|
||||
keyStore := NewKeyStoreWithDir(dir)
|
||||
err := afero.Walk(dir, ".", func(path string, d fs.FileInfo, err error) error {
|
||||
// maybe this is "name.private.pem"
|
||||
name := filepath.Base(path)
|
||||
ext := filepath.Ext(name)
|
||||
if ext != PemExt {
|
||||
return nil
|
||||
}
|
||||
|
||||
name = strings.TrimSuffix(name, ext)
|
||||
ext = filepath.Ext(name)
|
||||
name = strings.TrimSuffix(name, ext)
|
||||
switch ext {
|
||||
case PrivateStr:
|
||||
open, err := dir.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decode, err := rsaprivate.Decode(open)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyStore.LoadPrivateKey(name, decode)
|
||||
return nil
|
||||
case PublicStr:
|
||||
open, err := dir.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decode, err := rsapublic.Decode(open)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyStore.LoadPublicKey(name, decode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// still invalid
|
||||
return nil
|
||||
})
|
||||
return keyStore, err
|
||||
}
|
||||
|
||||
type keyPair struct {
|
||||
private *rsa.PrivateKey
|
||||
public *rsa.PublicKey
|
||||
}
|
||||
|
||||
// LoadPrivateKey sets the rsa.PrivateKey/rsa.PublicKey for the KID
|
||||
func (k *KeyStore) LoadPrivateKey(kid string, key *rsa.PrivateKey) {
|
||||
k.mu.Lock()
|
||||
if k.store[kid] == nil {
|
||||
k.store[kid] = &keyPair{}
|
||||
}
|
||||
k.store[kid].private = key
|
||||
k.store[kid].public = &key.PublicKey
|
||||
k.mu.Unlock()
|
||||
}
|
||||
|
||||
// LoadPublicKey sets the rsa.PublicKey for the KID
|
||||
func (k *KeyStore) LoadPublicKey(kid string, key *rsa.PublicKey) {
|
||||
k.mu.Lock()
|
||||
if k.store[kid] == nil {
|
||||
k.store[kid] = &keyPair{}
|
||||
}
|
||||
k.store[kid].public = key
|
||||
k.mu.Unlock()
|
||||
}
|
||||
|
||||
// RemoveKey deletes the KID keypair from the KeyStore
|
||||
func (k *KeyStore) RemoveKey(kid string) {
|
||||
k.mu.Lock()
|
||||
delete(k.store, kid)
|
||||
k.mu.Unlock()
|
||||
}
|
||||
|
||||
// ListKeys provides a slice of the KIDs for all keys loaded in the KeyStore
|
||||
func (k *KeyStore) ListKeys() []string {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
keys := make([]string, 0, len(k.store))
|
||||
for k, _ := range k.store {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// GetPrivateKey outputs the rsa.PrivateKey for the KID from the KeyStore
|
||||
func (k *KeyStore) GetPrivateKey(kid string) (*rsa.PrivateKey, error) {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
if !k.internalHasPrivateKey(kid) {
|
||||
return nil, ErrMissingPrivateKey
|
||||
}
|
||||
return k.store[kid].private, nil
|
||||
}
|
||||
|
||||
// GetPublicKey outputs the rsa.PublicKey for the KID from the KeyStore
|
||||
func (k *KeyStore) GetPublicKey(kid string) (*rsa.PublicKey, error) {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
if !k.internalHasPublicKey(kid) {
|
||||
return nil, ErrMissingPublicKey
|
||||
}
|
||||
return k.store[kid].public, nil
|
||||
}
|
||||
|
||||
// ClearKeys clears the internal map and makes a new map to release used memory
|
||||
func (k *KeyStore) ClearKeys() {
|
||||
k.mu.Lock()
|
||||
clear(k.store)
|
||||
k.store = make(map[string]*keyPair)
|
||||
k.mu.Unlock()
|
||||
}
|
||||
|
||||
// HasPrivateKey outputs true if the KID is found in the KeyStore
|
||||
func (k *KeyStore) HasPrivateKey(kid string) bool {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
return k.internalHasPrivateKey(kid)
|
||||
}
|
||||
|
||||
func (k *KeyStore) internalHasPrivateKey(kid string) bool {
|
||||
v := k.store[kid]
|
||||
return v != nil && v.private != nil
|
||||
}
|
||||
|
||||
// HasPublicKey outputs true if the KID is found in the KeyStore
|
||||
func (k *KeyStore) HasPublicKey(kid string) bool {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
return k.internalHasPublicKey(kid)
|
||||
}
|
||||
|
||||
func (k *KeyStore) internalHasPublicKey(kid string) bool {
|
||||
v := k.store[kid]
|
||||
return v != nil && v.public != nil
|
||||
}
|
||||
|
||||
// VerifyJwt parses the provided token string and validates it against the KID
|
||||
// using the KeyStore. An error is returned if the token fails to parse or if
|
||||
// there is no matching KID in the KeyStore.
|
||||
func (k *KeyStore) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
||||
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
return nil, ErrMissingPublicKey
|
||||
}
|
||||
return k.GetPublicKey(kid)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return withClaims, claims.Valid()
|
||||
}
|
||||
|
||||
// SaveSingleKey writes the rsa.PrivateKey/rsa.PublicKey for the requested KID to
|
||||
// the underlying afero.Fs.
|
||||
func (k *KeyStore) SaveSingleKey(kid string) error {
|
||||
if k.dir == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
k.mu.RLock()
|
||||
pair := k.store[kid]
|
||||
k.mu.RUnlock()
|
||||
if pair == nil {
|
||||
return ErrMissingKeyPair
|
||||
}
|
||||
|
||||
return writeSingleKey(k.dir, kid, pair)
|
||||
}
|
||||
|
||||
// SaveKeys writes the rsa.PrivateKey/rsa.PublicKey for the requested KID to the
|
||||
// underlying afero.Fs.
|
||||
func (k *KeyStore) SaveKeys() error {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
|
||||
workers := new(errgroup.Group)
|
||||
workers.SetLimit(runtime.NumCPU())
|
||||
for kid, pair := range k.store {
|
||||
workers.Go(func() error {
|
||||
return writeSingleKey(k.dir, kid, pair)
|
||||
})
|
||||
}
|
||||
return workers.Wait()
|
||||
}
|
||||
|
||||
func writeSingleKey(dir afero.Fs, kid string, pair *keyPair) error {
|
||||
var errs []error
|
||||
if pair.private != nil {
|
||||
errs = append(errs, afero.WriteFile(dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
|
||||
}
|
||||
if pair.public != nil {
|
||||
errs = append(errs, afero.WriteFile(dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
175
keystore_test.go
Normal file
175
keystore_test.go
Normal file
@ -0,0 +1,175 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const kst_prvExt = "prv"
|
||||
const kst_pubExt = "pub"
|
||||
|
||||
func setupTestDirKeyStore(t *testing.T, genKeys bool) afero.Fs {
|
||||
tempDir := afero.NewMemMapFs()
|
||||
|
||||
if genKeys {
|
||||
key1, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = afero.WriteFile(tempDir, "key1.private.pem", rsaprivate.Encode(key1), 0600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key2, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = afero.WriteFile(tempDir, "key2.private.pem", rsaprivate.Encode(key2), 0600)
|
||||
assert.NoError(t, err)
|
||||
err = afero.WriteFile(tempDir, "key2.public.pem", rsapublic.Encode(&key2.PublicKey), 0600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key3, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = afero.WriteFile(tempDir, "key3.public.pem", rsapublic.Encode(&key3.PublicKey), 0600)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
return tempDir
|
||||
}
|
||||
|
||||
func commonSubTestsKeyStore(t *testing.T, kStore *KeyStore) {
|
||||
key4, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key5, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
const extraKID1 = "key4"
|
||||
const extraKID2 = "key5"
|
||||
|
||||
t.Run("TestSetKey", func(t *testing.T) {
|
||||
kStore.LoadPrivateKey(extraKID1, key4)
|
||||
assert.Contains(t, kStore.ListKeys(), extraKID1)
|
||||
})
|
||||
|
||||
t.Run("TestSetKeyPublic", func(t *testing.T) {
|
||||
kStore.LoadPublicKey(extraKID2, &key5.PublicKey)
|
||||
assert.Contains(t, kStore.ListKeys(), extraKID2)
|
||||
})
|
||||
|
||||
t.Run("TestGetPrivateKey", func(t *testing.T) {
|
||||
oKey, err := kStore.GetPrivateKey(extraKID1)
|
||||
assert.NoError(t, err)
|
||||
assert.Same(t, key4, oKey)
|
||||
pKey, err := kStore.GetPrivateKey(extraKID2)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPrivateKey)
|
||||
assert.Nil(t, pKey)
|
||||
aKey, err := kStore.GetPrivateKey("key1")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, aKey)
|
||||
bKey, err := kStore.GetPrivateKey("key2")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, bKey)
|
||||
cKey, err := kStore.GetPrivateKey("key3")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPrivateKey)
|
||||
assert.Nil(t, cKey)
|
||||
wKey, err := kStore.GetPrivateKey("key1337")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPrivateKey)
|
||||
assert.Nil(t, wKey)
|
||||
})
|
||||
|
||||
t.Run("TestGetPublicKey", func(t *testing.T) {
|
||||
oKey, err := kStore.GetPublicKey(extraKID1)
|
||||
assert.NoError(t, err)
|
||||
assert.Same(t, &key4.PublicKey, oKey)
|
||||
pKey, err := kStore.GetPublicKey(extraKID2)
|
||||
assert.NoError(t, err)
|
||||
assert.Same(t, &key5.PublicKey, pKey)
|
||||
aKey, err := kStore.GetPublicKey("key1")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, aKey)
|
||||
bKey, err := kStore.GetPublicKey("key2")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, bKey)
|
||||
cKey, err := kStore.GetPublicKey("key3")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cKey)
|
||||
wKey, err := kStore.GetPublicKey("key1337")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPublicKey)
|
||||
assert.Nil(t, wKey)
|
||||
})
|
||||
|
||||
t.Run("TestRemoveKey", func(t *testing.T) {
|
||||
kStore.RemoveKey(extraKID1)
|
||||
assert.NotContains(t, kStore.ListKeys(), extraKID1)
|
||||
oKey1, err := kStore.GetPrivateKey(extraKID1)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPrivateKey)
|
||||
assert.Nil(t, oKey1)
|
||||
oKey2, err := kStore.GetPublicKey(extraKID1)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPublicKey)
|
||||
assert.Nil(t, oKey2)
|
||||
})
|
||||
|
||||
t.Run("TestClearKeys", func(t *testing.T) {
|
||||
kStore.ClearKeys()
|
||||
assert.Empty(t, kStore.ListKeys())
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := setupTestDirKeyStore(t, true)
|
||||
|
||||
kStore, err := NewKeyStoreFromDir(tempDir)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, kStore.ListKeys(), 3)
|
||||
kIDsToFind := []string{"key1", "key2", "key3"}
|
||||
for _, k := range kIDsToFind {
|
||||
assert.Contains(t, kStore.ListKeys(), k)
|
||||
}
|
||||
assert.True(t, kStore.HasPrivateKey("key1"))
|
||||
assert.True(t, kStore.HasPublicKey("key1")) // loading a private key also loads the public key
|
||||
assert.True(t, kStore.HasPrivateKey("key2"))
|
||||
assert.True(t, kStore.HasPublicKey("key2"))
|
||||
assert.False(t, kStore.HasPrivateKey("key3"))
|
||||
assert.True(t, kStore.HasPublicKey("key3"))
|
||||
|
||||
commonSubTestsKeyStore(t, kStore)
|
||||
}
|
||||
|
||||
func TestExportKeyStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := setupTestDirKeyStore(t, true)
|
||||
tempDir2 := setupTestDirKeyStore(t, false)
|
||||
|
||||
kStore, err := NewKeyStoreFromDir(tempDir)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// internally swap directory
|
||||
kStore.dir = tempDir2
|
||||
|
||||
err = kStore.SaveKeys()
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore2, err := NewKeyStoreFromDir(tempDir2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
kidList1 := kStore.ListKeys()
|
||||
kidList2 := kStore2.ListKeys()
|
||||
sort.Strings(kidList1)
|
||||
sort.Strings(kidList2)
|
||||
assert.Equal(t, kidList1, kidList2)
|
||||
|
||||
commonSubTestsKeyStore(t, kStore2)
|
||||
}
|
61
mjwt_test.go
61
mjwt_test.go
@ -1,61 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testClaims struct{ TestValue string }
|
||||
|
||||
func (t testClaims) Valid() error {
|
||||
if t.TestValue != "hello" && t.TestValue != "world" {
|
||||
return fmt.Errorf("TestValue should be hello")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t testClaims) Type() string { return "testClaims" }
|
||||
|
||||
type testClaims2 struct{ TestValue2 string }
|
||||
|
||||
func (t testClaims2) Valid() error {
|
||||
if t.TestValue2 != "world" {
|
||||
return fmt.Errorf("TestValue2 should be world")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t testClaims2) Type() string { return "testClaims2" }
|
||||
|
||||
func TestExtractClaims(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := NewMJwtSigner("mjwt.test", key)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := NewMJwtVerifier(&key.PublicKey)
|
||||
_, _, err = ExtractClaims[testClaims](m, token)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestExtractClaimsFail(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := NewMJwtSigner("mjwt.test", key)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := NewMJwtVerifier(&key.PublicKey)
|
||||
_, _, err = ExtractClaims[testClaims2](m, token)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
|
||||
}
|
143
signer.go
143
signer.go
@ -1,143 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name
|
||||
// to generate MJWT tokens
|
||||
type defaultMJwtSigner struct {
|
||||
issuer string
|
||||
key *rsa.PrivateKey
|
||||
verify *defaultMJwtVerifier
|
||||
}
|
||||
|
||||
var _ Signer = &defaultMJwtSigner{}
|
||||
var _ Verifier = &defaultMJwtSigner{}
|
||||
|
||||
// NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey
|
||||
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
|
||||
return &defaultMJwtSigner{
|
||||
issuer: issuer,
|
||||
key: key,
|
||||
verify: newMJwtVerifier(&key.PublicKey),
|
||||
}
|
||||
}
|
||||
|
||||
// NewMJwtSignerFromFileOrCreate creates a new defaultMJwtSigner using the path
|
||||
// of a rsa.PrivateKey file. If the file does not exist then the file is created
|
||||
// and a new private key is generated.
|
||||
func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits int) (Signer, error) {
|
||||
privateKey, err := readOrCreatePrivateKey(file, random, bits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewMJwtSigner(issuer, privateKey), nil
|
||||
}
|
||||
|
||||
// NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a
|
||||
// rsa.PrivateKey file.
|
||||
func NewMJwtSignerFromFile(issuer, file string) (Signer, error) {
|
||||
// read file
|
||||
raw, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// decode pem block
|
||||
block, _ := pem.Decode(raw)
|
||||
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||
return nil, fmt.Errorf("invalid rsa private key pem block")
|
||||
}
|
||||
|
||||
// parse private key from pem block
|
||||
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// create signer using rsa.PrivateKey
|
||||
return NewMJwtSigner(issuer, key), nil
|
||||
}
|
||||
|
||||
// Issuer returns the name of the issuer
|
||||
func (d *defaultMJwtSigner) Issuer() string { return d.issuer }
|
||||
|
||||
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims
|
||||
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
|
||||
return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims))
|
||||
}
|
||||
|
||||
// SignJwt signs a jwt.Claims compatible struct, this is used internally by
|
||||
// GenerateJwt but is available for signing custom structs
|
||||
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
|
||||
return token.SignedString(d.key)
|
||||
}
|
||||
|
||||
// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt()
|
||||
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
||||
return d.verify.VerifyJwt(token, claims)
|
||||
}
|
||||
|
||||
func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey { return d.key }
|
||||
func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey { return d.verify.pub }
|
||||
|
||||
// readOrCreatePrivateKey returns the private key it the file already exists,
|
||||
// generates a new private key and saves it to the file, or returns an error if
|
||||
// reading or generating failed.
|
||||
func readOrCreatePrivateKey(file string, random io.Reader, bits int) (*rsa.PrivateKey, error) {
|
||||
// read the file or return nil
|
||||
f, err := readOrEmptyFile(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if f == nil {
|
||||
// generate a new key
|
||||
key, err := rsa.GenerateKey(random, bits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyBytes := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// write the key to the file
|
||||
err = os.WriteFile(file, keyBytes, 0600)
|
||||
return key, err
|
||||
} else {
|
||||
// decode pem block
|
||||
block, _ := pem.Decode(f)
|
||||
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||
return nil, fmt.Errorf("invalid rsa private key pem block")
|
||||
}
|
||||
|
||||
// try to parse the private key
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
}
|
||||
}
|
||||
|
||||
// readOrEmptyFile returns bytes and errors from os.ReadFile or (nil, nil) if the
|
||||
// file does not exist.
|
||||
func readOrEmptyFile(file string) ([]byte, error) {
|
||||
raw, err := os.ReadFile(file)
|
||||
if err == nil {
|
||||
return raw, nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
@ -1,69 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewMJwtSigner(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
NewMJwtSigner("Test", key)
|
||||
}
|
||||
|
||||
func TestNewMJwtSignerFromFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempKey, err := os.CreateTemp("", "key-test-*.pem")
|
||||
assert.NoError(t, err)
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
||||
_, err = tempKey.Write(b)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, tempKey.Close())
|
||||
signer, err := NewMJwtSignerFromFile("Test", tempKey.Name())
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, os.Remove(tempKey.Name()))
|
||||
_, err = NewMJwtSignerFromFile("Test", tempKey.Name())
|
||||
assert.Error(t, err)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
assert.True(t, signer.(*defaultMJwtSigner).key.Equal(key))
|
||||
}
|
||||
|
||||
func TestNewMJwtSignerFromFileOrCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempKey, err := os.CreateTemp("", "key-test-*.pem")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, tempKey.Close())
|
||||
assert.NoError(t, os.Remove(tempKey.Name()))
|
||||
signer, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
signer2, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, signer.PrivateKey().Equal(signer2.PrivateKey()))
|
||||
}
|
||||
|
||||
func TestReadOrCreatePrivateKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempKey, err := os.CreateTemp("", "key-test-*.pem")
|
||||
assert.NoError(t, err)
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
||||
_, err = tempKey.Write(b)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, tempKey.Close())
|
||||
key2, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, key.Equal(key2))
|
||||
assert.NoError(t, os.Remove(tempKey.Name()))
|
||||
key3, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, key3.Validate())
|
||||
}
|
61
verifier.go
61
verifier.go
@ -1,61 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"os"
|
||||
)
|
||||
|
||||
// defaultMJwtVerifier implements Verifier and uses a rsa.PublicKey to validate
|
||||
// MJWT tokens
|
||||
type defaultMJwtVerifier struct {
|
||||
pub *rsa.PublicKey
|
||||
}
|
||||
|
||||
var _ Verifier = &defaultMJwtVerifier{}
|
||||
|
||||
// NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey
|
||||
func NewMJwtVerifier(key *rsa.PublicKey) Verifier {
|
||||
return newMJwtVerifier(key)
|
||||
}
|
||||
|
||||
func newMJwtVerifier(key *rsa.PublicKey) *defaultMJwtVerifier {
|
||||
return &defaultMJwtVerifier{pub: key}
|
||||
}
|
||||
|
||||
// NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a
|
||||
// rsa.PublicKey file
|
||||
func NewMJwtVerifierFromFile(file string) (Verifier, error) {
|
||||
// read file
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// decode pem block
|
||||
block, _ := pem.Decode(f)
|
||||
|
||||
// parse public key from pem block
|
||||
pub, err := x509.ParsePKCS1PublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// create verifier using rsa.PublicKey
|
||||
return NewMJwtVerifier(pub), nil
|
||||
}
|
||||
|
||||
// VerifyJwt validates and parses MJWT tokens and returns the claims
|
||||
func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
||||
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return d.pub, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return withClaims, claims.Valid()
|
||||
}
|
||||
|
||||
func (d *defaultMJwtVerifier) PublicKey() *rsa.PublicKey { return d.pub }
|
@ -1,34 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewMJwtVerifierFromFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := NewMJwtSigner("mjwt.test", key)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&key.PublicKey)})
|
||||
temp, err := os.CreateTemp("", "this-is-a-test-file.pem")
|
||||
assert.NoError(t, err)
|
||||
_, err = temp.Write(b)
|
||||
assert.NoError(t, err)
|
||||
file, err := NewMJwtVerifierFromFile(temp.Name())
|
||||
assert.NoError(t, err)
|
||||
_, _, err = ExtractClaims[testClaims](file, token)
|
||||
assert.NoError(t, err)
|
||||
err = os.Remove(temp.Name())
|
||||
assert.NoError(t, err)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user