2022-12-04 13:42:35 +00:00
|
|
|
package mjwt
|
|
|
|
|
|
|
|
import (
|
|
|
|
"encoding/json"
|
|
|
|
"github.com/golang-jwt/jwt/v4"
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
|
|
|
|
|
2023-06-18 13:03:41 +01:00
|
|
|
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
|
2024-07-27 17:05:27 +01:00
|
|
|
func wrapClaims[T Claims](sub, id, issuer string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
|
2022-12-04 13:42:35 +00:00
|
|
|
now := time.Now()
|
|
|
|
return (&BaseTypeClaims[T]{
|
|
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
2024-07-27 17:05:27 +01:00
|
|
|
Issuer: issuer,
|
2022-12-04 13:42:35 +00:00
|
|
|
Subject: sub,
|
2023-06-20 00:32:16 +01:00
|
|
|
Audience: aud,
|
2022-12-04 13:42:35 +00:00
|
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
|
|
|
|
NotBefore: jwt.NewNumericDate(now),
|
|
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
|
|
ID: id,
|
|
|
|
},
|
|
|
|
Claims: claims,
|
|
|
|
}).init()
|
|
|
|
}
|
|
|
|
|
2023-06-18 13:03:41 +01:00
|
|
|
// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
|
|
|
|
// token and BaseTypeClaims
|
2024-07-27 17:05:27 +01:00
|
|
|
func ExtractClaims[T Claims](ks *KeyStore, token string) (*jwt.Token, BaseTypeClaims[T], error) {
|
2022-12-04 13:42:35 +00:00
|
|
|
b := BaseTypeClaims[T]{
|
|
|
|
RegisteredClaims: jwt.RegisteredClaims{},
|
|
|
|
Claims: *new(T),
|
|
|
|
}
|
2024-07-27 17:05:27 +01:00
|
|
|
tok, err := ks.VerifyJwt(token, &b)
|
2022-12-04 13:42:35 +00:00
|
|
|
return tok, b, err
|
|
|
|
}
|
|
|
|
|
2023-06-18 13:03:41 +01:00
|
|
|
// Claims is a wrapper for jwt.Claims and adds a Type method to name internal claim structs
|
2022-12-04 13:42:35 +00:00
|
|
|
type Claims interface {
|
|
|
|
jwt.Claims
|
|
|
|
Type() string
|
|
|
|
}
|
|
|
|
|
|
|
|
type baseTypeClaim interface {
|
|
|
|
jwt.Claims
|
|
|
|
InternalClaimType() string
|
|
|
|
}
|
|
|
|
|
2023-06-18 18:09:49 +01:00
|
|
|
// BaseTypeClaims is a wrapper for combining the jwt.RegisteredClaims with a ClaimType
|
|
|
|
// and generic Claims data
|
2022-12-04 13:42:35 +00:00
|
|
|
type BaseTypeClaims[T Claims] struct {
|
|
|
|
jwt.RegisteredClaims
|
|
|
|
ClaimType string
|
|
|
|
Claims T
|
|
|
|
}
|
|
|
|
|
|
|
|
func (b *BaseTypeClaims[T]) init() *BaseTypeClaims[T] {
|
|
|
|
b.ClaimType = b.InternalClaimType()
|
|
|
|
return b
|
|
|
|
}
|
|
|
|
|
2023-06-18 18:09:49 +01:00
|
|
|
// Valid checks the InternalClaimType matches and the type claim type
|
2022-12-04 13:42:35 +00:00
|
|
|
func (b *BaseTypeClaims[T]) Valid() error {
|
2023-11-28 17:47:00 +00:00
|
|
|
if err := b.RegisteredClaims.Valid(); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2022-12-04 13:42:35 +00:00
|
|
|
if b.ClaimType != b.InternalClaimType() {
|
|
|
|
return ErrClaimTypeMismatch
|
|
|
|
}
|
|
|
|
return b.Claims.Valid()
|
|
|
|
}
|
|
|
|
|
2023-06-18 18:09:49 +01:00
|
|
|
// InternalClaimType returns the Type of the generic claim struct
|
|
|
|
func (b *BaseTypeClaims[T]) InternalClaimType() string { return b.Claims.Type() }
|
2022-12-04 13:42:35 +00:00
|
|
|
|
2023-06-18 18:09:49 +01:00
|
|
|
// MarshalJSON converts the internalBaseTypeClaims and generic claim struct into
|
|
|
|
// a serialized JSON byte array
|
2022-12-04 13:42:35 +00:00
|
|
|
func (b *BaseTypeClaims[T]) MarshalJSON() ([]byte, error) {
|
2023-06-18 18:09:49 +01:00
|
|
|
// encode the internalBaseTypeClaims
|
|
|
|
b1, err := json.Marshal(internalBaseTypeClaims{
|
2022-12-04 13:42:35 +00:00
|
|
|
RegisteredClaims: b.RegisteredClaims,
|
|
|
|
ClaimType: b.InternalClaimType(),
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2023-06-18 18:09:49 +01:00
|
|
|
// encode the generic claims struct
|
|
|
|
b2, err := json.Marshal(b.Claims)
|
2022-12-04 13:42:35 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2023-06-18 18:09:49 +01:00
|
|
|
// replace starting '{' with ','
|
|
|
|
b2[0] = ','
|
|
|
|
// join the two json strings and remove the last char '}' from the first string
|
|
|
|
return append(b1[:len(b1)-1], b2...), nil
|
2022-12-04 13:42:35 +00:00
|
|
|
}
|
|
|
|
|
2023-06-18 18:09:49 +01:00
|
|
|
// UnmarshalJSON reads the internalBaseTypeClaims and generic claim struct from
|
|
|
|
// a serialized JSON byte array
|
2022-12-04 13:42:35 +00:00
|
|
|
func (b *BaseTypeClaims[T]) UnmarshalJSON(bytes []byte) error {
|
|
|
|
a := internalBaseTypeClaims{}
|
|
|
|
var t T
|
2023-06-18 18:09:49 +01:00
|
|
|
|
|
|
|
// convert JSON to internalBaseTypeClaims
|
2022-12-04 13:42:35 +00:00
|
|
|
err := json.Unmarshal(bytes, &a)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2023-06-18 18:09:49 +01:00
|
|
|
|
|
|
|
// convert JSON to the generic claim struct
|
2022-12-04 13:42:35 +00:00
|
|
|
err = json.Unmarshal(bytes, &t)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2023-06-18 18:09:49 +01:00
|
|
|
// assign the fields in BaseTypeClaims
|
2022-12-04 13:42:35 +00:00
|
|
|
b.RegisteredClaims = a.RegisteredClaims
|
|
|
|
b.ClaimType = a.ClaimType
|
|
|
|
b.Claims = t
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2023-06-18 18:09:49 +01:00
|
|
|
// internalBaseTypeClaims is a wrapper for jwt.RegisteredClaims which adds a
|
|
|
|
// ClaimType field containing the type of the generic claim struct
|
2022-12-04 13:42:35 +00:00
|
|
|
type internalBaseTypeClaims struct {
|
|
|
|
jwt.RegisteredClaims
|
|
|
|
ClaimType string `json:"mct"`
|
|
|
|
}
|