mirror of
https://github.com/1f349/mjwt.git
synced 2024-12-22 15:34:08 +00:00
Add comments to all of this
This commit is contained in:
parent
fc6e076f24
commit
8806e30591
49
mjwt.go
49
mjwt.go
@ -1,7 +1,6 @@
|
|||||||
package mjwt
|
package mjwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@ -48,6 +47,8 @@ type baseTypeClaim interface {
|
|||||||
InternalClaimType() string
|
InternalClaimType() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BaseTypeClaims is a wrapper for combining the jwt.RegisteredClaims with a ClaimType
|
||||||
|
// and generic Claims data
|
||||||
type BaseTypeClaims[T Claims] struct {
|
type BaseTypeClaims[T Claims] struct {
|
||||||
jwt.RegisteredClaims
|
jwt.RegisteredClaims
|
||||||
ClaimType string
|
ClaimType string
|
||||||
@ -59,6 +60,7 @@ func (b *BaseTypeClaims[T]) init() *BaseTypeClaims[T] {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Valid checks the InternalClaimType matches and the type claim type
|
||||||
func (b *BaseTypeClaims[T]) Valid() error {
|
func (b *BaseTypeClaims[T]) Valid() error {
|
||||||
if b.ClaimType != b.InternalClaimType() {
|
if b.ClaimType != b.InternalClaimType() {
|
||||||
return ErrClaimTypeMismatch
|
return ErrClaimTypeMismatch
|
||||||
@ -66,61 +68,60 @@ func (b *BaseTypeClaims[T]) Valid() error {
|
|||||||
return b.Claims.Valid()
|
return b.Claims.Valid()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BaseTypeClaims[T]) InternalClaimType() string {
|
// InternalClaimType returns the Type of the generic claim struct
|
||||||
return b.Claims.Type()
|
func (b *BaseTypeClaims[T]) InternalClaimType() string { return b.Claims.Type() }
|
||||||
}
|
|
||||||
|
|
||||||
|
// MarshalJSON converts the internalBaseTypeClaims and generic claim struct into
|
||||||
|
// a serialized JSON byte array
|
||||||
func (b *BaseTypeClaims[T]) MarshalJSON() ([]byte, error) {
|
func (b *BaseTypeClaims[T]) MarshalJSON() ([]byte, error) {
|
||||||
// setup buffers
|
// encode the internalBaseTypeClaims
|
||||||
buf := new(bytes.Buffer)
|
b1, err := json.Marshal(internalBaseTypeClaims{
|
||||||
buf2 := new(bytes.Buffer)
|
|
||||||
|
|
||||||
// encode into both buffers
|
|
||||||
err := json.NewEncoder(buf).Encode(internalBaseTypeClaims{
|
|
||||||
RegisteredClaims: b.RegisteredClaims,
|
RegisteredClaims: b.RegisteredClaims,
|
||||||
ClaimType: b.InternalClaimType(),
|
ClaimType: b.InternalClaimType(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = json.NewEncoder(buf2).Encode(b.Claims)
|
|
||||||
|
// encode the generic claims struct
|
||||||
|
b2, err := json.Marshal(b.Claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// decode into a single map
|
// replace starting '{' with ','
|
||||||
var a map[string]any
|
b2[0] = ','
|
||||||
err = json.NewDecoder(buf).Decode(&a)
|
// join the two json strings and remove the last char '}' from the first string
|
||||||
if err != nil {
|
return append(b1[:len(b1)-1], b2...), nil
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = json.NewDecoder(buf2).Decode(&a)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// encode to output
|
|
||||||
return json.Marshal(a)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON reads the internalBaseTypeClaims and generic claim struct from
|
||||||
|
// a serialized JSON byte array
|
||||||
func (b *BaseTypeClaims[T]) UnmarshalJSON(bytes []byte) error {
|
func (b *BaseTypeClaims[T]) UnmarshalJSON(bytes []byte) error {
|
||||||
a := internalBaseTypeClaims{}
|
a := internalBaseTypeClaims{}
|
||||||
var t T
|
var t T
|
||||||
|
|
||||||
|
// convert JSON to internalBaseTypeClaims
|
||||||
err := json.Unmarshal(bytes, &a)
|
err := json.Unmarshal(bytes, &a)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// convert JSON to the generic claim struct
|
||||||
err = json.Unmarshal(bytes, &t)
|
err = json.Unmarshal(bytes, &t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// assign the fields in BaseTypeClaims
|
||||||
b.RegisteredClaims = a.RegisteredClaims
|
b.RegisteredClaims = a.RegisteredClaims
|
||||||
b.ClaimType = a.ClaimType
|
b.ClaimType = a.ClaimType
|
||||||
b.Claims = t
|
b.Claims = t
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// internalBaseTypeClaims is a wrapper for jwt.RegisteredClaims which adds a
|
||||||
|
// ClaimType field containing the type of the generic claim struct
|
||||||
type internalBaseTypeClaims struct {
|
type internalBaseTypeClaims struct {
|
||||||
jwt.RegisteredClaims
|
jwt.RegisteredClaims
|
||||||
ClaimType string `json:"mct"`
|
ClaimType string `json:"mct"`
|
||||||
|
@ -6,6 +6,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name
|
||||||
|
// to generate MJWT tokens
|
||||||
type defaultMJwtSigner struct {
|
type defaultMJwtSigner struct {
|
||||||
issuer string
|
issuer string
|
||||||
key *rsa.PrivateKey
|
key *rsa.PrivateKey
|
||||||
@ -14,6 +16,7 @@ type defaultMJwtSigner struct {
|
|||||||
|
|
||||||
var _ Signer = &defaultMJwtSigner{}
|
var _ Signer = &defaultMJwtSigner{}
|
||||||
|
|
||||||
|
// NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey
|
||||||
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
|
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
|
||||||
return &defaultMJwtSigner{
|
return &defaultMJwtSigner{
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
@ -22,17 +25,22 @@ func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Issuer returns the name of the issuer
|
||||||
func (d *defaultMJwtSigner) Issuer() string { return d.issuer }
|
func (d *defaultMJwtSigner) Issuer() string { return d.issuer }
|
||||||
|
|
||||||
|
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims
|
||||||
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, dur time.Duration, claims Claims) (string, error) {
|
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, dur time.Duration, claims Claims) (string, error) {
|
||||||
return d.SignJwt(wrapClaims[Claims](d, sub, id, dur, claims))
|
return d.SignJwt(wrapClaims[Claims](d, sub, id, dur, claims))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SignJwt signs a jwt.Claims compatible struct, this is used internally by
|
||||||
|
// GenerateJwt but is available for signing custom structs
|
||||||
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
|
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
|
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
|
||||||
return token.SignedString(d.key)
|
return token.SignedString(d.key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt()
|
||||||
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
||||||
return d.verify.VerifyJwt(token, claims)
|
return d.verify.VerifyJwt(token, claims)
|
||||||
}
|
}
|
||||||
|
24
verifier.go
24
verifier.go
@ -7,17 +7,19 @@ import (
|
|||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrCannotGenerateMJwtToken = errors.New("cannot generate mjwt token with verifier")
|
var ErrCannotGenerateMJwtToken = errors.New("cannot generate mjwt token with verifier")
|
||||||
|
|
||||||
|
// defaultMJwtVerifier implements Verifier and uses an rsa.PublicKey to validate
|
||||||
|
// MJWT tokens
|
||||||
type defaultMJwtVerifier struct {
|
type defaultMJwtVerifier struct {
|
||||||
pub *rsa.PublicKey
|
pub *rsa.PublicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Verifier = &defaultMJwtVerifier{}
|
var _ Verifier = &defaultMJwtVerifier{}
|
||||||
|
|
||||||
|
// NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey
|
||||||
func NewMJwtVerifier(key *rsa.PublicKey) Verifier {
|
func NewMJwtVerifier(key *rsa.PublicKey) Verifier {
|
||||||
return newMJwtVerifier(key)
|
return newMJwtVerifier(key)
|
||||||
}
|
}
|
||||||
@ -26,29 +28,29 @@ func newMJwtVerifier(key *rsa.PublicKey) *defaultMJwtVerifier {
|
|||||||
return &defaultMJwtVerifier{pub: key}
|
return &defaultMJwtVerifier{pub: key}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a
|
||||||
|
// rsa.PublicKey file
|
||||||
func NewMJwtVerifierFromFile(file string) (Verifier, error) {
|
func NewMJwtVerifierFromFile(file string) (Verifier, error) {
|
||||||
|
// read file
|
||||||
f, err := os.ReadFile(file)
|
f, err := os.ReadFile(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// decode pem block
|
||||||
block, _ := pem.Decode(f)
|
block, _ := pem.Decode(f)
|
||||||
|
|
||||||
|
// parse public key from pem block
|
||||||
pub, err := x509.ParsePKCS1PublicKey(block.Bytes)
|
pub, err := x509.ParsePKCS1PublicKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create verifier using rsa.PublicKey
|
||||||
return NewMJwtVerifier(pub), nil
|
return NewMJwtVerifier(pub), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *defaultMJwtVerifier) Issuer() string { return "" }
|
// VerifyJwt validates and parses MJWT tokens and returns the claims
|
||||||
|
|
||||||
func (d *defaultMJwtVerifier) GenerateJwt(_, _ string, _ time.Duration, _ Claims) (string, error) {
|
|
||||||
return "", ErrCannotGenerateMJwtToken
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *defaultMJwtVerifier) SignJwt(_ jwt.Claims) (string, error) {
|
|
||||||
return "", ErrCannotGenerateMJwtToken
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
||||||
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||||
return d.pub, nil
|
return d.pub, nil
|
||||||
|
Loading…
Reference in New Issue
Block a user