commit 3312a37264b85af5c4aea5dea5f51f3c434f9529 Author: MrMelon54 Date: Sun Dec 4 13:42:35 2022 +0000 First commit diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/discord.xml b/.idea/discord.xml new file mode 100644 index 0000000..d8e9561 --- /dev/null +++ b/.idea/discord.xml @@ -0,0 +1,7 @@ + + + + + \ No newline at end of file diff --git a/.idea/mjwt.iml b/.idea/mjwt.iml new file mode 100644 index 0000000..5e764c4 --- /dev/null +++ b/.idea/mjwt.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..cc778c5 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..99a5366 --- /dev/null +++ b/go.mod @@ -0,0 +1,9 @@ +module github.com/mrmelon54/mjwt + +go 1.19 + +require ( + github.com/golang-jwt/jwt/v4 v4.4.3 + github.com/pkg/errors v0.9.1 + gopkg.in/yaml.v3 v3.0.1 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..0b96d8e --- /dev/null +++ b/go.sum @@ -0,0 +1,8 @@ +github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/mjwt.go b/mjwt.go new file mode 100644 index 0000000..e203dfc --- /dev/null +++ b/mjwt.go @@ -0,0 +1,129 @@ +package mjwt + +import ( + "bytes" + "encoding/json" + "github.com/golang-jwt/jwt/v4" + "github.com/pkg/errors" + "time" +) + +var ErrClaimTypeMismatch = errors.New("claim type mismatch") + +type Provider interface { + GenerateJwt(sub, id string, dur time.Duration, claims Claims) (string, error) + VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) + Issuer() string +} + +func wrapClaims[T Claims](p Provider, sub, id string, dur time.Duration, claims T) *BaseTypeClaims[T] { + now := time.Now() + return (&BaseTypeClaims[T]{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: p.Issuer(), + Subject: sub, + ExpiresAt: jwt.NewNumericDate(now.Add(dur)), + NotBefore: jwt.NewNumericDate(now), + IssuedAt: jwt.NewNumericDate(now), + ID: id, + }, + Claims: claims, + }).init() +} + +func ExtractClaims[T Claims](p Provider, token string) (*jwt.Token, BaseTypeClaims[T], error) { + b := BaseTypeClaims[T]{ + RegisteredClaims: jwt.RegisteredClaims{}, + Claims: *new(T), + } + tok, err := p.VerifyJwt(token, &b) + return tok, b, err +} + +type Claims interface { + jwt.Claims + Type() string +} + +type baseTypeClaim interface { + jwt.Claims + InternalClaimType() string +} + +type BaseTypeClaims[T Claims] struct { + jwt.RegisteredClaims + ClaimType string + Claims T +} + +func (b *BaseTypeClaims[T]) init() *BaseTypeClaims[T] { + b.ClaimType = b.InternalClaimType() + return b +} + +func (b *BaseTypeClaims[T]) Valid() error { + if b.ClaimType != b.InternalClaimType() { + return ErrClaimTypeMismatch + } + return b.Claims.Valid() +} + +func (b *BaseTypeClaims[T]) InternalClaimType() string { + return b.Claims.Type() +} + +func (b *BaseTypeClaims[T]) MarshalJSON() ([]byte, error) { + // setup buffers + buf := new(bytes.Buffer) + buf2 := new(bytes.Buffer) + + // encode into both buffers + err := json.NewEncoder(buf).Encode(internalBaseTypeClaims{ + RegisteredClaims: b.RegisteredClaims, + ClaimType: b.InternalClaimType(), + }) + if err != nil { + return nil, err + } + err = json.NewEncoder(buf2).Encode(b.Claims) + if err != nil { + return nil, err + } + + // decode into a single map + var a map[string]any + err = json.NewDecoder(buf).Decode(&a) + if err != nil { + return nil, err + } + err = json.NewDecoder(buf2).Decode(&a) + if err != nil { + return nil, err + } + + // encode to output + return json.Marshal(a) +} + +func (b *BaseTypeClaims[T]) UnmarshalJSON(bytes []byte) error { + a := internalBaseTypeClaims{} + var t T + err := json.Unmarshal(bytes, &a) + if err != nil { + return err + } + err = json.Unmarshal(bytes, &t) + if err != nil { + return err + } + + b.RegisteredClaims = a.RegisteredClaims + b.ClaimType = a.ClaimType + b.Claims = t + return err +} + +type internalBaseTypeClaims struct { + jwt.RegisteredClaims + ClaimType string `json:"mct"` +} diff --git a/signer.go b/signer.go new file mode 100644 index 0000000..b8a6448 --- /dev/null +++ b/signer.go @@ -0,0 +1,35 @@ +package mjwt + +import ( + "crypto/rsa" + "github.com/golang-jwt/jwt/v4" + "time" +) + +type defaultMJwtSigner struct { + issuer string + key *rsa.PrivateKey + verify *defaultMJwtVerifier +} + +var _ Provider = &defaultMJwtSigner{} + +func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Provider { + return &defaultMJwtSigner{ + issuer: issuer, + key: key, + verify: newMJwtVerifier(&key.PublicKey), + } +} + +func (d *defaultMJwtSigner) Issuer() string { return d.issuer } + +func (d *defaultMJwtSigner) GenerateJwt(sub, id string, dur time.Duration, claims Claims) (string, error) { + wrapped := wrapClaims[Claims](d, sub, id, dur, claims) + token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped) + return token.SignedString(d.key) +} + +func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { + return d.verify.VerifyJwt(token, claims) +} diff --git a/verifier.go b/verifier.go new file mode 100644 index 0000000..739c73d --- /dev/null +++ b/verifier.go @@ -0,0 +1,56 @@ +package mjwt + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "github.com/golang-jwt/jwt/v4" + "github.com/pkg/errors" + "os" + "time" +) + +var ErrCannotGenerateMJwtToken = errors.New("cannot generate mjwt token with verifier") + +type defaultMJwtVerifier struct { + pub *rsa.PublicKey +} + +var _ Provider = &defaultMJwtVerifier{} + +func NewMJwtVerifier(key *rsa.PublicKey) Provider { + return newMJwtVerifier(key) +} + +func newMJwtVerifier(key *rsa.PublicKey) *defaultMJwtVerifier { + return &defaultMJwtVerifier{pub: key} +} + +func NewMJwtVerifierFromFile(file string) (Provider, error) { + f, err := os.ReadFile(file) + if err != nil { + return nil, err + } + block, _ := pem.Decode(f) + pub, err := x509.ParsePKCS1PublicKey(block.Bytes) + if err != nil { + return nil, err + } + return NewMJwtVerifier(pub), nil +} + +func (d *defaultMJwtVerifier) Issuer() string { return "" } + +func (d *defaultMJwtVerifier) GenerateJwt(_, _ string, _ time.Duration, _ Claims) (string, error) { + return "", ErrCannotGenerateMJwtToken +} + +func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { + withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { + return d.pub, nil + }) + if err != nil { + return nil, err + } + return withClaims, claims.Valid() +}