mirror of
https://github.com/1f349/mjwt.git
synced 2024-12-22 15:34:08 +00:00
Thunder and lightning caused a powercut just committing my progress before continuing
This commit is contained in:
parent
b859e3a63a
commit
fc6e076f24
30
auth/access-token.go
Normal file
30
auth/access-token.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/mjwt/claims"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AccessTokenClaims contains the JWT claims for an access token
|
||||||
|
type AccessTokenClaims struct {
|
||||||
|
UserId uint64 `json:"uid"`
|
||||||
|
Perms *claims.PermStorage `json:"per"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a AccessTokenClaims) Valid() error { return nil }
|
||||||
|
|
||||||
|
func (a AccessTokenClaims) Type() string { return "access-token" }
|
||||||
|
|
||||||
|
// CreateAccessToken creates an access token with the default 15 minute duration
|
||||||
|
func CreateAccessToken(p mjwt.Signer, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) {
|
||||||
|
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, userId, perms)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAccessTokenWithDuration creates an access token with a custom duration
|
||||||
|
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) {
|
||||||
|
return p.GenerateJwt(sub, id, dur, &AccessTokenClaims{
|
||||||
|
UserId: userId,
|
||||||
|
Perms: perms,
|
||||||
|
})
|
||||||
|
}
|
33
auth/access-token_test.go
Normal file
33
auth/access-token_test.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/mjwt/claims"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateAccessToken(t *testing.T) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
ps := claims.NewPermStorage()
|
||||||
|
ps.Set("mjwt:test")
|
||||||
|
ps.Set("mjwt:test2")
|
||||||
|
|
||||||
|
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||||
|
|
||||||
|
accessToken, err := CreateAccessToken(s, "1", "test", 1, ps)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "1", b.Subject)
|
||||||
|
assert.Equal(t, "test", b.ID)
|
||||||
|
assert.Equal(t, uint64(1), b.Claims.UserId)
|
||||||
|
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
||||||
|
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
||||||
|
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
||||||
|
}
|
27
auth/pair.go
Normal file
27
auth/pair.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/mjwt/claims"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateTokenPair creates an access and refresh token pair using the default
|
||||||
|
// 15 minute and 7 day durations respectively
|
||||||
|
func CreateTokenPair(p mjwt.Signer, sub, id string, userId uint64, perms *claims.PermStorage) (string, string, error) {
|
||||||
|
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, userId, perms)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTokenPairWithDuration creates an access and refresh token pair using
|
||||||
|
// custom durations for the access and refresh tokens
|
||||||
|
func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id string, userId uint64, perms *claims.PermStorage) (string, string, error) {
|
||||||
|
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, userId, perms)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
refreshToken, err := CreateRefreshTokenWithDuration(p, refreshDur, sub, id, userId, perms)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
return accessToken, refreshToken, nil
|
||||||
|
}
|
42
auth/pair_test.go
Normal file
42
auth/pair_test.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/mjwt/claims"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateTokenPair(t *testing.T) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
ps := claims.NewPermStorage()
|
||||||
|
ps.Set("mjwt:test")
|
||||||
|
ps.Set("mjwt:test2")
|
||||||
|
|
||||||
|
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||||
|
|
||||||
|
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", 1, ps)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "1", b.Subject)
|
||||||
|
assert.Equal(t, "test", b.ID)
|
||||||
|
assert.Equal(t, uint64(1), b.Claims.UserId)
|
||||||
|
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
||||||
|
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
||||||
|
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
||||||
|
|
||||||
|
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "1", b2.Subject)
|
||||||
|
assert.Equal(t, "test", b2.ID)
|
||||||
|
assert.Equal(t, uint64(1), b2.Claims.UserId)
|
||||||
|
assert.True(t, b2.Claims.Perms.Has("mjwt:test"))
|
||||||
|
assert.True(t, b2.Claims.Perms.Has("mjwt:test2"))
|
||||||
|
assert.False(t, b2.Claims.Perms.Has("mjwt:test3"))
|
||||||
|
}
|
29
auth/refresh-token.go
Normal file
29
auth/refresh-token.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/mjwt/claims"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RefreshTokenClaims contains the JWT claims for a refresh token
|
||||||
|
type RefreshTokenClaims struct {
|
||||||
|
UserId uint64 `json:"uid"`
|
||||||
|
Perms *claims.PermStorage `json:"per"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RefreshTokenClaims) Valid() error { return nil }
|
||||||
|
func (r RefreshTokenClaims) Type() string { return "refresh-token" }
|
||||||
|
|
||||||
|
// CreateRefreshToken creates a refresh token with the default 7 day duration
|
||||||
|
func CreateRefreshToken(p mjwt.Signer, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) {
|
||||||
|
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, userId, perms)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
|
||||||
|
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) {
|
||||||
|
return p.GenerateJwt(sub, id, dur, RefreshTokenClaims{
|
||||||
|
UserId: userId,
|
||||||
|
Perms: perms,
|
||||||
|
})
|
||||||
|
}
|
33
auth/refresh-token_test.go
Normal file
33
auth/refresh-token_test.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/mjwt/claims"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateRefreshToken(t *testing.T) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
ps := claims.NewPermStorage()
|
||||||
|
ps.Set("mjwt:test")
|
||||||
|
ps.Set("mjwt:test2")
|
||||||
|
|
||||||
|
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||||
|
|
||||||
|
refreshToken, err := CreateRefreshToken(s, "1", "test", 1, ps)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "1", b.Subject)
|
||||||
|
assert.Equal(t, "test", b.ID)
|
||||||
|
assert.Equal(t, uint64(1), b.Claims.UserId)
|
||||||
|
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
||||||
|
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
||||||
|
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
||||||
|
}
|
85
claims/perms.go
Normal file
85
claims/perms.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package claims
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PermStorage struct {
|
||||||
|
values map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPermStorage() *PermStorage {
|
||||||
|
return new(PermStorage).setup()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PermStorage) setup() *PermStorage {
|
||||||
|
if p.values == nil {
|
||||||
|
p.values = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PermStorage) Set(perm string) {
|
||||||
|
p.values[perm] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PermStorage) Clear(perm string) {
|
||||||
|
delete(p.values, perm)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PermStorage) Has(perm string) bool {
|
||||||
|
_, ok := p.values[perm]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PermStorage) OneOf(o *PermStorage) bool {
|
||||||
|
for i := range o.values {
|
||||||
|
if p.Has(i) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PermStorage) dump() []string {
|
||||||
|
var a []string
|
||||||
|
for i := range p.values {
|
||||||
|
a = append(a, i)
|
||||||
|
}
|
||||||
|
sort.Strings(a)
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PermStorage) prepare(a []string) {
|
||||||
|
for _, i := range a {
|
||||||
|
p.Set(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PermStorage) MarshalJSON() ([]byte, error) { return json.Marshal(p.dump()) }
|
||||||
|
|
||||||
|
func (p *PermStorage) UnmarshalJSON(bytes []byte) error {
|
||||||
|
p.setup()
|
||||||
|
var a []string
|
||||||
|
err := json.Unmarshal(bytes, &a)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.prepare(a)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PermStorage) MarshalYAML() (interface{}, error) { return yaml.Marshal(p.dump()) }
|
||||||
|
|
||||||
|
func (p *PermStorage) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
p.setup()
|
||||||
|
var a []string
|
||||||
|
err := value.Decode(&a)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.prepare(a)
|
||||||
|
return nil
|
||||||
|
}
|
66
claims/perms_test.go
Normal file
66
claims/perms_test.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
package claims
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPermStorage_Set(t *testing.T) {
|
||||||
|
ps := NewPermStorage()
|
||||||
|
ps.Set("mjwt:test")
|
||||||
|
if _, ok := ps.values["mjwt:test"]; !ok {
|
||||||
|
assert.Fail(t, "perm not set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPermStorage_Clear(t *testing.T) {
|
||||||
|
ps := NewPermStorage()
|
||||||
|
ps.values["mjwt:test"] = struct{}{}
|
||||||
|
ps.Clear("mjwt:test")
|
||||||
|
if _, ok := ps.values["mjwt:test"]; ok {
|
||||||
|
assert.Fail(t, "perm not cleared")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPermStorage_Has(t *testing.T) {
|
||||||
|
ps := NewPermStorage()
|
||||||
|
assert.False(t, ps.Has("mjwt:test"))
|
||||||
|
ps.values["mjwt:test"] = struct{}{}
|
||||||
|
assert.True(t, ps.Has("mjwt:test"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPermStorage_OneOf(t *testing.T) {
|
||||||
|
o := NewPermStorage()
|
||||||
|
o.Set("mjwt:test")
|
||||||
|
o.Set("mjwt:test2")
|
||||||
|
|
||||||
|
ps := NewPermStorage()
|
||||||
|
assert.False(t, ps.OneOf(o))
|
||||||
|
ps.values["mjwt:test"] = struct{}{}
|
||||||
|
assert.True(t, ps.OneOf(o))
|
||||||
|
ps.values["mjwt:test2"] = struct{}{}
|
||||||
|
assert.True(t, ps.OneOf(o))
|
||||||
|
delete(ps.values, "mjwt:test")
|
||||||
|
assert.True(t, ps.OneOf(o))
|
||||||
|
delete(ps.values, "mjwt:test2")
|
||||||
|
assert.False(t, ps.OneOf(o))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPermStorage_MarshalJSON(t *testing.T) {
|
||||||
|
ps := NewPermStorage()
|
||||||
|
ps.Set("mjwt:test")
|
||||||
|
ps.Set("mjwt:test2")
|
||||||
|
b, err := ps.MarshalJSON()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, bytes.Compare([]byte(`["mjwt:test","mjwt:test2"]`), b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPermStorage_MarshalYAML(t *testing.T) {
|
||||||
|
ps := NewPermStorage()
|
||||||
|
ps.Set("mjwt:test")
|
||||||
|
ps.Set("mjwt:test2")
|
||||||
|
b, err := ps.MarshalYAML()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, bytes.Compare([]byte("- mjwt:test\n- mjwt:test2\n"), b.([]byte)))
|
||||||
|
}
|
8
go.mod
8
go.mod
@ -1,9 +1,15 @@
|
|||||||
module github.com/mrmelon54/mjwt
|
module github.com/MrMelon54/mjwt
|
||||||
|
|
||||||
go 1.19
|
go 1.19
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/golang-jwt/jwt/v4 v4.4.3
|
github.com/golang-jwt/jwt/v4 v4.4.3
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
|
github.com/stretchr/testify v1.8.4
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
)
|
||||||
|
6
go.sum
6
go.sum
@ -1,7 +1,13 @@
|
|||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU=
|
github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU=
|
||||||
github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
20
interfaces.go
Normal file
20
interfaces.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package mjwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Signer is used to generate MJWT tokens.
|
||||||
|
// Signer can also be used as a Verifier.
|
||||||
|
type Signer interface {
|
||||||
|
Verifier
|
||||||
|
GenerateJwt(sub, id string, dur time.Duration, claims Claims) (string, error)
|
||||||
|
SignJwt(claims jwt.Claims) (string, error)
|
||||||
|
Issuer() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verifier is used to verify the validity MJWT tokens and extract the claim values.
|
||||||
|
type Verifier interface {
|
||||||
|
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
|
||||||
|
}
|
15
mjwt.go
15
mjwt.go
@ -10,14 +10,8 @@ import (
|
|||||||
|
|
||||||
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
|
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
|
||||||
|
|
||||||
type Provider interface {
|
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
|
||||||
GenerateJwt(sub, id string, dur time.Duration, claims Claims) (string, error)
|
func wrapClaims[T Claims](p Signer, sub, id string, dur time.Duration, claims T) *BaseTypeClaims[T] {
|
||||||
SignJwt(claims jwt.Claims) (string, error)
|
|
||||||
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
|
|
||||||
Issuer() string
|
|
||||||
}
|
|
||||||
|
|
||||||
func wrapClaims[T Claims](p Provider, sub, id string, dur time.Duration, claims T) *BaseTypeClaims[T] {
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return (&BaseTypeClaims[T]{
|
return (&BaseTypeClaims[T]{
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
@ -32,7 +26,9 @@ func wrapClaims[T Claims](p Provider, sub, id string, dur time.Duration, claims
|
|||||||
}).init()
|
}).init()
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExtractClaims[T Claims](p Provider, token string) (*jwt.Token, BaseTypeClaims[T], error) {
|
// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
|
||||||
|
// token and BaseTypeClaims
|
||||||
|
func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) {
|
||||||
b := BaseTypeClaims[T]{
|
b := BaseTypeClaims[T]{
|
||||||
RegisteredClaims: jwt.RegisteredClaims{},
|
RegisteredClaims: jwt.RegisteredClaims{},
|
||||||
Claims: *new(T),
|
Claims: *new(T),
|
||||||
@ -41,6 +37,7 @@ func ExtractClaims[T Claims](p Provider, token string) (*jwt.Token, BaseTypeClai
|
|||||||
return tok, b, err
|
return tok, b, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Claims is a wrapper for jwt.Claims and adds a Type method to name internal claim structs
|
||||||
type Claims interface {
|
type Claims interface {
|
||||||
jwt.Claims
|
jwt.Claims
|
||||||
Type() string
|
Type() string
|
||||||
|
59
mjwt_test.go
Normal file
59
mjwt_test.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package mjwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"fmt"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testClaims struct{ TestValue string }
|
||||||
|
|
||||||
|
func (t testClaims) Valid() error {
|
||||||
|
if t.TestValue != "hello" {
|
||||||
|
return fmt.Errorf("TestValue should be hello")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t testClaims) Type() string { return "testClaims" }
|
||||||
|
|
||||||
|
type testClaims2 struct{ TestValue2 string }
|
||||||
|
|
||||||
|
func (t testClaims2) Valid() error {
|
||||||
|
if t.TestValue2 != "world" {
|
||||||
|
return fmt.Errorf("TestValue2 should be world")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t testClaims2) Type() string { return "testClaims2" }
|
||||||
|
|
||||||
|
func TestExtractClaims(t *testing.T) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
s := NewMJwtSigner("mjwt.test", key)
|
||||||
|
token, err := s.GenerateJwt("1", "test", 10*time.Minute, testClaims{TestValue: "hello"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
m := NewMJwtVerifier(&key.PublicKey)
|
||||||
|
_, _, err = ExtractClaims[testClaims](m, token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractClaimsFail(t *testing.T) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
s := NewMJwtSigner("mjwt.test", key)
|
||||||
|
token, err := s.GenerateJwt("1", "test", 10*time.Minute, testClaims{TestValue: "test"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
m := NewMJwtVerifier(&key.PublicKey)
|
||||||
|
_, _, err = ExtractClaims[testClaims2](m, token)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
|
||||||
|
}
|
@ -12,9 +12,9 @@ type defaultMJwtSigner struct {
|
|||||||
verify *defaultMJwtVerifier
|
verify *defaultMJwtVerifier
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Provider = &defaultMJwtSigner{}
|
var _ Signer = &defaultMJwtSigner{}
|
||||||
|
|
||||||
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Provider {
|
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
|
||||||
return &defaultMJwtSigner{
|
return &defaultMJwtSigner{
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
key: key,
|
key: key,
|
||||||
|
@ -16,9 +16,9 @@ type defaultMJwtVerifier struct {
|
|||||||
pub *rsa.PublicKey
|
pub *rsa.PublicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Provider = &defaultMJwtVerifier{}
|
var _ Verifier = &defaultMJwtVerifier{}
|
||||||
|
|
||||||
func NewMJwtVerifier(key *rsa.PublicKey) Provider {
|
func NewMJwtVerifier(key *rsa.PublicKey) Verifier {
|
||||||
return newMJwtVerifier(key)
|
return newMJwtVerifier(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ func newMJwtVerifier(key *rsa.PublicKey) *defaultMJwtVerifier {
|
|||||||
return &defaultMJwtVerifier{pub: key}
|
return &defaultMJwtVerifier{pub: key}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMJwtVerifierFromFile(file string) (Provider, error) {
|
func NewMJwtVerifierFromFile(file string) (Verifier, error) {
|
||||||
f, err := os.ReadFile(file)
|
f, err := os.ReadFile(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
Loading…
Reference in New Issue
Block a user