diff --git a/auth/access-token.go b/auth/access-token.go new file mode 100644 index 0000000..8172576 --- /dev/null +++ b/auth/access-token.go @@ -0,0 +1,30 @@ +package auth + +import ( + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/mjwt/claims" + "time" +) + +// AccessTokenClaims contains the JWT claims for an access token +type AccessTokenClaims struct { + UserId uint64 `json:"uid"` + Perms *claims.PermStorage `json:"per"` +} + +func (a AccessTokenClaims) Valid() error { return nil } + +func (a AccessTokenClaims) Type() string { return "access-token" } + +// CreateAccessToken creates an access token with the default 15 minute duration +func CreateAccessToken(p mjwt.Signer, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) { + return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, userId, perms) +} + +// CreateAccessTokenWithDuration creates an access token with a custom duration +func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) { + return p.GenerateJwt(sub, id, dur, &AccessTokenClaims{ + UserId: userId, + Perms: perms, + }) +} diff --git a/auth/access-token_test.go b/auth/access-token_test.go new file mode 100644 index 0000000..dada28b --- /dev/null +++ b/auth/access-token_test.go @@ -0,0 +1,33 @@ +package auth + +import ( + "crypto/rand" + "crypto/rsa" + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/mjwt/claims" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestCreateAccessToken(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + ps := claims.NewPermStorage() + ps.Set("mjwt:test") + ps.Set("mjwt:test2") + + s := mjwt.NewMJwtSigner("mjwt.test", key) + + accessToken, err := CreateAccessToken(s, "1", "test", 1, ps) + assert.NoError(t, err) + + _, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken) + assert.NoError(t, err) + assert.Equal(t, "1", b.Subject) + assert.Equal(t, "test", b.ID) + assert.Equal(t, uint64(1), b.Claims.UserId) + assert.True(t, b.Claims.Perms.Has("mjwt:test")) + assert.True(t, b.Claims.Perms.Has("mjwt:test2")) + assert.False(t, b.Claims.Perms.Has("mjwt:test3")) +} diff --git a/auth/pair.go b/auth/pair.go new file mode 100644 index 0000000..d3175b4 --- /dev/null +++ b/auth/pair.go @@ -0,0 +1,27 @@ +package auth + +import ( + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/mjwt/claims" + "time" +) + +// CreateTokenPair creates an access and refresh token pair using the default +// 15 minute and 7 day durations respectively +func CreateTokenPair(p mjwt.Signer, sub, id string, userId uint64, perms *claims.PermStorage) (string, string, error) { + return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, userId, perms) +} + +// CreateTokenPairWithDuration creates an access and refresh token pair using +// custom durations for the access and refresh tokens +func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id string, userId uint64, perms *claims.PermStorage) (string, string, error) { + accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, userId, perms) + if err != nil { + return "", "", err + } + refreshToken, err := CreateRefreshTokenWithDuration(p, refreshDur, sub, id, userId, perms) + if err != nil { + return "", "", err + } + return accessToken, refreshToken, nil +} diff --git a/auth/pair_test.go b/auth/pair_test.go new file mode 100644 index 0000000..733b524 --- /dev/null +++ b/auth/pair_test.go @@ -0,0 +1,42 @@ +package auth + +import ( + "crypto/rand" + "crypto/rsa" + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/mjwt/claims" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestCreateTokenPair(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + ps := claims.NewPermStorage() + ps.Set("mjwt:test") + ps.Set("mjwt:test2") + + s := mjwt.NewMJwtSigner("mjwt.test", key) + + accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", 1, ps) + assert.NoError(t, err) + + _, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken) + assert.NoError(t, err) + assert.Equal(t, "1", b.Subject) + assert.Equal(t, "test", b.ID) + assert.Equal(t, uint64(1), b.Claims.UserId) + assert.True(t, b.Claims.Perms.Has("mjwt:test")) + assert.True(t, b.Claims.Perms.Has("mjwt:test2")) + assert.False(t, b.Claims.Perms.Has("mjwt:test3")) + + _, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken) + assert.NoError(t, err) + assert.Equal(t, "1", b2.Subject) + assert.Equal(t, "test", b2.ID) + assert.Equal(t, uint64(1), b2.Claims.UserId) + assert.True(t, b2.Claims.Perms.Has("mjwt:test")) + assert.True(t, b2.Claims.Perms.Has("mjwt:test2")) + assert.False(t, b2.Claims.Perms.Has("mjwt:test3")) +} diff --git a/auth/refresh-token.go b/auth/refresh-token.go new file mode 100644 index 0000000..969525f --- /dev/null +++ b/auth/refresh-token.go @@ -0,0 +1,29 @@ +package auth + +import ( + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/mjwt/claims" + "time" +) + +// RefreshTokenClaims contains the JWT claims for a refresh token +type RefreshTokenClaims struct { + UserId uint64 `json:"uid"` + Perms *claims.PermStorage `json:"per"` +} + +func (r RefreshTokenClaims) Valid() error { return nil } +func (r RefreshTokenClaims) Type() string { return "refresh-token" } + +// CreateRefreshToken creates a refresh token with the default 7 day duration +func CreateRefreshToken(p mjwt.Signer, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) { + return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, userId, perms) +} + +// CreateRefreshTokenWithDuration creates a refresh token with a custom duration +func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) { + return p.GenerateJwt(sub, id, dur, RefreshTokenClaims{ + UserId: userId, + Perms: perms, + }) +} diff --git a/auth/refresh-token_test.go b/auth/refresh-token_test.go new file mode 100644 index 0000000..fc74fcb --- /dev/null +++ b/auth/refresh-token_test.go @@ -0,0 +1,33 @@ +package auth + +import ( + "crypto/rand" + "crypto/rsa" + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/mjwt/claims" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestCreateRefreshToken(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + ps := claims.NewPermStorage() + ps.Set("mjwt:test") + ps.Set("mjwt:test2") + + s := mjwt.NewMJwtSigner("mjwt.test", key) + + refreshToken, err := CreateRefreshToken(s, "1", "test", 1, ps) + assert.NoError(t, err) + + _, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken) + assert.NoError(t, err) + assert.Equal(t, "1", b.Subject) + assert.Equal(t, "test", b.ID) + assert.Equal(t, uint64(1), b.Claims.UserId) + assert.True(t, b.Claims.Perms.Has("mjwt:test")) + assert.True(t, b.Claims.Perms.Has("mjwt:test2")) + assert.False(t, b.Claims.Perms.Has("mjwt:test3")) +} diff --git a/claims/perms.go b/claims/perms.go new file mode 100644 index 0000000..b94f56b --- /dev/null +++ b/claims/perms.go @@ -0,0 +1,85 @@ +package claims + +import ( + "encoding/json" + "gopkg.in/yaml.v3" + "sort" +) + +type PermStorage struct { + values map[string]struct{} +} + +func NewPermStorage() *PermStorage { + return new(PermStorage).setup() +} + +func (p *PermStorage) setup() *PermStorage { + if p.values == nil { + p.values = make(map[string]struct{}) + } + return p +} + +func (p *PermStorage) Set(perm string) { + p.values[perm] = struct{}{} +} + +func (p *PermStorage) Clear(perm string) { + delete(p.values, perm) +} + +func (p *PermStorage) Has(perm string) bool { + _, ok := p.values[perm] + return ok +} + +func (p *PermStorage) OneOf(o *PermStorage) bool { + for i := range o.values { + if p.Has(i) { + return true + } + } + return false +} + +func (p *PermStorage) dump() []string { + var a []string + for i := range p.values { + a = append(a, i) + } + sort.Strings(a) + return a +} + +func (p *PermStorage) prepare(a []string) { + for _, i := range a { + p.Set(i) + } +} + +func (p *PermStorage) MarshalJSON() ([]byte, error) { return json.Marshal(p.dump()) } + +func (p *PermStorage) UnmarshalJSON(bytes []byte) error { + p.setup() + var a []string + err := json.Unmarshal(bytes, &a) + if err != nil { + return err + } + p.prepare(a) + return nil +} + +func (p *PermStorage) MarshalYAML() (interface{}, error) { return yaml.Marshal(p.dump()) } + +func (p *PermStorage) UnmarshalYAML(value *yaml.Node) error { + p.setup() + var a []string + err := value.Decode(&a) + if err != nil { + return err + } + p.prepare(a) + return nil +} diff --git a/claims/perms_test.go b/claims/perms_test.go new file mode 100644 index 0000000..108ba73 --- /dev/null +++ b/claims/perms_test.go @@ -0,0 +1,66 @@ +package claims + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestPermStorage_Set(t *testing.T) { + ps := NewPermStorage() + ps.Set("mjwt:test") + if _, ok := ps.values["mjwt:test"]; !ok { + assert.Fail(t, "perm not set") + } +} + +func TestPermStorage_Clear(t *testing.T) { + ps := NewPermStorage() + ps.values["mjwt:test"] = struct{}{} + ps.Clear("mjwt:test") + if _, ok := ps.values["mjwt:test"]; ok { + assert.Fail(t, "perm not cleared") + } +} + +func TestPermStorage_Has(t *testing.T) { + ps := NewPermStorage() + assert.False(t, ps.Has("mjwt:test")) + ps.values["mjwt:test"] = struct{}{} + assert.True(t, ps.Has("mjwt:test")) +} + +func TestPermStorage_OneOf(t *testing.T) { + o := NewPermStorage() + o.Set("mjwt:test") + o.Set("mjwt:test2") + + ps := NewPermStorage() + assert.False(t, ps.OneOf(o)) + ps.values["mjwt:test"] = struct{}{} + assert.True(t, ps.OneOf(o)) + ps.values["mjwt:test2"] = struct{}{} + assert.True(t, ps.OneOf(o)) + delete(ps.values, "mjwt:test") + assert.True(t, ps.OneOf(o)) + delete(ps.values, "mjwt:test2") + assert.False(t, ps.OneOf(o)) +} + +func TestPermStorage_MarshalJSON(t *testing.T) { + ps := NewPermStorage() + ps.Set("mjwt:test") + ps.Set("mjwt:test2") + b, err := ps.MarshalJSON() + assert.NoError(t, err) + assert.Equal(t, 0, bytes.Compare([]byte(`["mjwt:test","mjwt:test2"]`), b)) +} + +func TestPermStorage_MarshalYAML(t *testing.T) { + ps := NewPermStorage() + ps.Set("mjwt:test") + ps.Set("mjwt:test2") + b, err := ps.MarshalYAML() + assert.NoError(t, err) + assert.Equal(t, 0, bytes.Compare([]byte("- mjwt:test\n- mjwt:test2\n"), b.([]byte))) +} diff --git a/go.mod b/go.mod index 99a5366..598e75f 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,15 @@ -module github.com/mrmelon54/mjwt +module github.com/MrMelon54/mjwt go 1.19 require ( github.com/golang-jwt/jwt/v4 v4.4.3 github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.8.4 gopkg.in/yaml.v3 v3.0.1 ) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect +) diff --git a/go.sum b/go.sum index 0b96d8e..162adfa 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,13 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/interfaces.go b/interfaces.go new file mode 100644 index 0000000..2cf3886 --- /dev/null +++ b/interfaces.go @@ -0,0 +1,20 @@ +package mjwt + +import ( + "github.com/golang-jwt/jwt/v4" + "time" +) + +// Signer is used to generate MJWT tokens. +// Signer can also be used as a Verifier. +type Signer interface { + Verifier + GenerateJwt(sub, id string, dur time.Duration, claims Claims) (string, error) + SignJwt(claims jwt.Claims) (string, error) + Issuer() string +} + +// Verifier is used to verify the validity MJWT tokens and extract the claim values. +type Verifier interface { + VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) +} diff --git a/mjwt.go b/mjwt.go index 7a9bd60..6871b8a 100644 --- a/mjwt.go +++ b/mjwt.go @@ -10,14 +10,8 @@ import ( var ErrClaimTypeMismatch = errors.New("claim type mismatch") -type Provider interface { - GenerateJwt(sub, id string, dur time.Duration, claims Claims) (string, error) - SignJwt(claims jwt.Claims) (string, error) - VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) - Issuer() string -} - -func wrapClaims[T Claims](p Provider, sub, id string, dur time.Duration, claims T) *BaseTypeClaims[T] { +// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct +func wrapClaims[T Claims](p Signer, sub, id string, dur time.Duration, claims T) *BaseTypeClaims[T] { now := time.Now() return (&BaseTypeClaims[T]{ RegisteredClaims: jwt.RegisteredClaims{ @@ -32,7 +26,9 @@ func wrapClaims[T Claims](p Provider, sub, id string, dur time.Duration, claims }).init() } -func ExtractClaims[T Claims](p Provider, token string) (*jwt.Token, BaseTypeClaims[T], error) { +// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed +// token and BaseTypeClaims +func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) { b := BaseTypeClaims[T]{ RegisteredClaims: jwt.RegisteredClaims{}, Claims: *new(T), @@ -41,6 +37,7 @@ func ExtractClaims[T Claims](p Provider, token string) (*jwt.Token, BaseTypeClai return tok, b, err } +// Claims is a wrapper for jwt.Claims and adds a Type method to name internal claim structs type Claims interface { jwt.Claims Type() string diff --git a/mjwt_test.go b/mjwt_test.go new file mode 100644 index 0000000..14297e5 --- /dev/null +++ b/mjwt_test.go @@ -0,0 +1,59 @@ +package mjwt + +import ( + "crypto/rand" + "crypto/rsa" + "fmt" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +type testClaims struct{ TestValue string } + +func (t testClaims) Valid() error { + if t.TestValue != "hello" { + return fmt.Errorf("TestValue should be hello") + } + return nil +} + +func (t testClaims) Type() string { return "testClaims" } + +type testClaims2 struct{ TestValue2 string } + +func (t testClaims2) Valid() error { + if t.TestValue2 != "world" { + return fmt.Errorf("TestValue2 should be world") + } + return nil +} + +func (t testClaims2) Type() string { return "testClaims2" } + +func TestExtractClaims(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + s := NewMJwtSigner("mjwt.test", key) + token, err := s.GenerateJwt("1", "test", 10*time.Minute, testClaims{TestValue: "hello"}) + assert.NoError(t, err) + + m := NewMJwtVerifier(&key.PublicKey) + _, _, err = ExtractClaims[testClaims](m, token) + assert.NoError(t, err) +} + +func TestExtractClaimsFail(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + s := NewMJwtSigner("mjwt.test", key) + token, err := s.GenerateJwt("1", "test", 10*time.Minute, testClaims{TestValue: "test"}) + assert.NoError(t, err) + + m := NewMJwtVerifier(&key.PublicKey) + _, _, err = ExtractClaims[testClaims2](m, token) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrClaimTypeMismatch) +} diff --git a/signer.go b/signer.go index b63e99d..22d966f 100644 --- a/signer.go +++ b/signer.go @@ -12,9 +12,9 @@ type defaultMJwtSigner struct { verify *defaultMJwtVerifier } -var _ Provider = &defaultMJwtSigner{} +var _ Signer = &defaultMJwtSigner{} -func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Provider { +func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer { return &defaultMJwtSigner{ issuer: issuer, key: key, diff --git a/verifier.go b/verifier.go index c38af89..da7edf8 100644 --- a/verifier.go +++ b/verifier.go @@ -16,9 +16,9 @@ type defaultMJwtVerifier struct { pub *rsa.PublicKey } -var _ Provider = &defaultMJwtVerifier{} +var _ Verifier = &defaultMJwtVerifier{} -func NewMJwtVerifier(key *rsa.PublicKey) Provider { +func NewMJwtVerifier(key *rsa.PublicKey) Verifier { return newMJwtVerifier(key) } @@ -26,7 +26,7 @@ func newMJwtVerifier(key *rsa.PublicKey) *defaultMJwtVerifier { return &defaultMJwtVerifier{pub: key} } -func NewMJwtVerifierFromFile(file string) (Provider, error) { +func NewMJwtVerifierFromFile(file string) (Verifier, error) { f, err := os.ReadFile(file) if err != nil { return nil, err