mjwt/mjwt.go

130 lines
3.4 KiB
Go

package mjwt
import (
"encoding/json"
"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
"time"
)
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
now := time.Now()
return (&BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: p.Issuer(),
Subject: sub,
Audience: aud,
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
ID: id,
},
Claims: claims,
}).init()
}
// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
// token and BaseTypeClaims
func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) {
b := BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{},
Claims: *new(T),
}
tok, err := p.VerifyJwt(token, &b)
return tok, b, err
}
// Claims is a wrapper for jwt.Claims and adds a Type method to name internal claim structs
type Claims interface {
jwt.Claims
Type() string
}
type baseTypeClaim interface {
jwt.Claims
InternalClaimType() string
}
// BaseTypeClaims is a wrapper for combining the jwt.RegisteredClaims with a ClaimType
// and generic Claims data
type BaseTypeClaims[T Claims] struct {
jwt.RegisteredClaims
ClaimType string
Claims T
}
func (b *BaseTypeClaims[T]) init() *BaseTypeClaims[T] {
b.ClaimType = b.InternalClaimType()
return b
}
// Valid checks the InternalClaimType matches and the type claim type
func (b *BaseTypeClaims[T]) Valid() error {
if b.ClaimType != b.InternalClaimType() {
return ErrClaimTypeMismatch
}
return b.Claims.Valid()
}
// InternalClaimType returns the Type of the generic claim struct
func (b *BaseTypeClaims[T]) InternalClaimType() string { return b.Claims.Type() }
// MarshalJSON converts the internalBaseTypeClaims and generic claim struct into
// a serialized JSON byte array
func (b *BaseTypeClaims[T]) MarshalJSON() ([]byte, error) {
// encode the internalBaseTypeClaims
b1, err := json.Marshal(internalBaseTypeClaims{
RegisteredClaims: b.RegisteredClaims,
ClaimType: b.InternalClaimType(),
})
if err != nil {
return nil, err
}
// encode the generic claims struct
b2, err := json.Marshal(b.Claims)
if err != nil {
return nil, err
}
// replace starting '{' with ','
b2[0] = ','
// join the two json strings and remove the last char '}' from the first string
return append(b1[:len(b1)-1], b2...), nil
}
// UnmarshalJSON reads the internalBaseTypeClaims and generic claim struct from
// a serialized JSON byte array
func (b *BaseTypeClaims[T]) UnmarshalJSON(bytes []byte) error {
a := internalBaseTypeClaims{}
var t T
// convert JSON to internalBaseTypeClaims
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
// convert JSON to the generic claim struct
err = json.Unmarshal(bytes, &t)
if err != nil {
return err
}
// assign the fields in BaseTypeClaims
b.RegisteredClaims = a.RegisteredClaims
b.ClaimType = a.ClaimType
b.Claims = t
return err
}
// internalBaseTypeClaims is a wrapper for jwt.RegisteredClaims which adds a
// ClaimType field containing the type of the generic claim struct
type internalBaseTypeClaims struct {
jwt.RegisteredClaims
ClaimType string `json:"mct"`
}