Add comments to all of this

This commit is contained in:
Melon 2023-06-18 18:09:49 +01:00
parent fc6e076f24
commit 8806e30591
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
3 changed files with 46 additions and 35 deletions

49
mjwt.go
View File

@ -1,7 +1,6 @@
package mjwt
import (
"bytes"
"encoding/json"
"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
@ -48,6 +47,8 @@ type baseTypeClaim interface {
InternalClaimType() string
}
// BaseTypeClaims is a wrapper for combining the jwt.RegisteredClaims with a ClaimType
// and generic Claims data
type BaseTypeClaims[T Claims] struct {
jwt.RegisteredClaims
ClaimType string
@ -59,6 +60,7 @@ func (b *BaseTypeClaims[T]) init() *BaseTypeClaims[T] {
return b
}
// Valid checks the InternalClaimType matches and the type claim type
func (b *BaseTypeClaims[T]) Valid() error {
if b.ClaimType != b.InternalClaimType() {
return ErrClaimTypeMismatch
@ -66,61 +68,60 @@ func (b *BaseTypeClaims[T]) Valid() error {
return b.Claims.Valid()
}
func (b *BaseTypeClaims[T]) InternalClaimType() string {
return b.Claims.Type()
}
// InternalClaimType returns the Type of the generic claim struct
func (b *BaseTypeClaims[T]) InternalClaimType() string { return b.Claims.Type() }
// MarshalJSON converts the internalBaseTypeClaims and generic claim struct into
// a serialized JSON byte array
func (b *BaseTypeClaims[T]) MarshalJSON() ([]byte, error) {
// setup buffers
buf := new(bytes.Buffer)
buf2 := new(bytes.Buffer)
// encode into both buffers
err := json.NewEncoder(buf).Encode(internalBaseTypeClaims{
// encode the internalBaseTypeClaims
b1, err := json.Marshal(internalBaseTypeClaims{
RegisteredClaims: b.RegisteredClaims,
ClaimType: b.InternalClaimType(),
})
if err != nil {
return nil, err
}
err = json.NewEncoder(buf2).Encode(b.Claims)
// encode the generic claims struct
b2, err := json.Marshal(b.Claims)
if err != nil {
return nil, err
}
// decode into a single map
var a map[string]any
err = json.NewDecoder(buf).Decode(&a)
if err != nil {
return nil, err
}
err = json.NewDecoder(buf2).Decode(&a)
if err != nil {
return nil, err
}
// encode to output
return json.Marshal(a)
// replace starting '{' with ','
b2[0] = ','
// join the two json strings and remove the last char '}' from the first string
return append(b1[:len(b1)-1], b2...), nil
}
// UnmarshalJSON reads the internalBaseTypeClaims and generic claim struct from
// a serialized JSON byte array
func (b *BaseTypeClaims[T]) UnmarshalJSON(bytes []byte) error {
a := internalBaseTypeClaims{}
var t T
// convert JSON to internalBaseTypeClaims
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
// convert JSON to the generic claim struct
err = json.Unmarshal(bytes, &t)
if err != nil {
return err
}
// assign the fields in BaseTypeClaims
b.RegisteredClaims = a.RegisteredClaims
b.ClaimType = a.ClaimType
b.Claims = t
return err
}
// internalBaseTypeClaims is a wrapper for jwt.RegisteredClaims which adds a
// ClaimType field containing the type of the generic claim struct
type internalBaseTypeClaims struct {
jwt.RegisteredClaims
ClaimType string `json:"mct"`

View File

@ -6,6 +6,8 @@ import (
"time"
)
// defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name
// to generate MJWT tokens
type defaultMJwtSigner struct {
issuer string
key *rsa.PrivateKey
@ -14,6 +16,7 @@ type defaultMJwtSigner struct {
var _ Signer = &defaultMJwtSigner{}
// NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
return &defaultMJwtSigner{
issuer: issuer,
@ -22,17 +25,22 @@ func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
}
}
// Issuer returns the name of the issuer
func (d *defaultMJwtSigner) Issuer() string { return d.issuer }
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, dur time.Duration, claims Claims) (string, error) {
return d.SignJwt(wrapClaims[Claims](d, sub, id, dur, claims))
}
// SignJwt signs a jwt.Claims compatible struct, this is used internally by
// GenerateJwt but is available for signing custom structs
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
return token.SignedString(d.key)
}
// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt()
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
return d.verify.VerifyJwt(token, claims)
}

View File

@ -7,17 +7,19 @@ import (
"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
"os"
"time"
)
var ErrCannotGenerateMJwtToken = errors.New("cannot generate mjwt token with verifier")
// defaultMJwtVerifier implements Verifier and uses an rsa.PublicKey to validate
// MJWT tokens
type defaultMJwtVerifier struct {
pub *rsa.PublicKey
}
var _ Verifier = &defaultMJwtVerifier{}
// NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey
func NewMJwtVerifier(key *rsa.PublicKey) Verifier {
return newMJwtVerifier(key)
}
@ -26,29 +28,29 @@ func newMJwtVerifier(key *rsa.PublicKey) *defaultMJwtVerifier {
return &defaultMJwtVerifier{pub: key}
}
// NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a
// rsa.PublicKey file
func NewMJwtVerifierFromFile(file string) (Verifier, error) {
// read file
f, err := os.ReadFile(file)
if err != nil {
return nil, err
}
// decode pem block
block, _ := pem.Decode(f)
// parse public key from pem block
pub, err := x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return nil, err
}
// create verifier using rsa.PublicKey
return NewMJwtVerifier(pub), nil
}
func (d *defaultMJwtVerifier) Issuer() string { return "" }
func (d *defaultMJwtVerifier) GenerateJwt(_, _ string, _ time.Duration, _ Claims) (string, error) {
return "", ErrCannotGenerateMJwtToken
}
func (d *defaultMJwtVerifier) SignJwt(_ jwt.Claims) (string, error) {
return "", ErrCannotGenerateMJwtToken
}
// VerifyJwt validates and parses MJWT tokens and returns the claims
func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return d.pub, nil