mjwt/signer.go
Captain ALM 5d1bd6f8fd
Update rsa-helper
Add read limit for key loader in signer
2024-06-10 17:51:11 +01:00

205 lines
6.3 KiB
Go

package mjwt
import (
"bytes"
"crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/golang-jwt/jwt/v4"
"io"
"os"
"time"
)
const readLimit = 10240 // 10 KiB
var ErrNoPrivateKeyFound = errors.New("no private key found")
// defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name
// to generate MJWT tokens
type defaultMJwtSigner struct {
issuer string
key *rsa.PrivateKey
verify *defaultMJwtVerifier
}
var _ Signer = &defaultMJwtSigner{}
var _ Verifier = &defaultMJwtSigner{}
// NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
return NewMJwtSignerWithKeyStore(issuer, key, NewMJwtKeyStore())
}
// NewMJwtSignerWithKeyStore creates a new defaultMJwtSigner using the issuer name, a rsa.PrivateKey
// for no kID and a KeyStore for kID based keys
func NewMJwtSignerWithKeyStore(issuer string, key *rsa.PrivateKey, kStore KeyStore) Signer {
var pKey *rsa.PublicKey = nil
if key != nil {
pKey = &key.PublicKey
}
return &defaultMJwtSigner{
issuer: issuer,
key: key,
verify: NewMJwtVerifierWithKeyStore(pKey, kStore).(*defaultMJwtVerifier),
}
}
// NewMJwtSignerFromFileOrCreate creates a new defaultMJwtSigner using the path
// of a rsa.PrivateKey file. If the file does not exist then the file is created
// and a new private key is generated.
func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits int) (Signer, error) {
privateKey, err := readOrCreatePrivateKey(file, random, bits)
if err != nil {
return nil, err
}
return NewMJwtSigner(issuer, privateKey), nil
}
// NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a
// rsa.PrivateKey file.
func NewMJwtSignerFromFile(issuer, file string) (Signer, error) {
return NewMJwtSignerFromFileAndDirectory(issuer, file, "", "", "")
}
// NewMJwtSignerFromDirectory creates a new defaultMJwtSigner using the path of a directory to
// load the keys into a KeyStore; there is no default rsa.PrivateKey
func NewMJwtSignerFromDirectory(issuer, directory, prvExt, pubExt string) (Signer, error) {
return NewMJwtSignerFromFileAndDirectory(issuer, "", directory, prvExt, pubExt)
}
// NewMJwtSignerFromFileAndDirectory creates a new defaultMJwtSigner using the path of a rsa.PrivateKey
// file as the non kID key and the path of a directory to load the keys into a KeyStore
func NewMJwtSignerFromFileAndDirectory(issuer, file, directory, prvExt, pubExt string) (Signer, error) {
var err error
// read key
var prv *rsa.PrivateKey = nil
if file != "" {
prv, err = rsaprivate.Read(file)
if err != nil {
return nil, err
}
}
// read KeyStore
var kStore KeyStore = nil
if directory != "" {
kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt)
if err != nil {
return nil, err
}
}
return NewMJwtSignerWithKeyStore(issuer, prv, kStore), nil
}
// Issuer returns the name of the issuer
func (d *defaultMJwtSigner) Issuer() string {
return d.issuer
}
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims; uses the default key
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims))
}
// SignJwt signs a jwt.Claims compatible struct, this is used internally by
// GenerateJwt but is available for signing custom structs; uses the default key
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
if d.key == nil {
return "", ErrNoPrivateKeyFound
}
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
return token.SignedString(d.key)
}
// GenerateJwtWithKID generates and returns a JWT string using the sub, id, duration and claims; this gets signed with the specified kID
func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) {
return d.SignJwtWithKID(wrapClaims[Claims](d, sub, id, aud, dur, claims), kID)
}
// SignJwtWithKID signs a jwt.Claims compatible struct, this is used internally by
// GenerateJwt but is available for signing custom structs; this gets signed with the specified kID
func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (string, error) {
pKey := d.verify.GetKeyStore().GetKey(kID)
if pKey == nil {
return "", ErrNoPrivateKeyFound
}
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
token.Header["kid"] = kID
return token.SignedString(pKey)
}
// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt()
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
return d.verify.VerifyJwt(token, claims)
}
func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey {
return d.key
}
func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey {
return d.verify.pub
}
func (d *defaultMJwtSigner) PublicKeyOf(kID string) *rsa.PublicKey {
return d.verify.kStore.GetKeyPublic(kID)
}
func (d *defaultMJwtSigner) GetKeyStore() KeyStore {
return d.verify.GetKeyStore()
}
func (d *defaultMJwtSigner) PrivateKeyOf(kID string) *rsa.PrivateKey {
return d.verify.kStore.GetKey(kID)
}
// readOrCreatePrivateKey returns the private key it the file already exists,
// generates a new private key and saves it to the file, or returns an error if
// reading or generating failed.
func readOrCreatePrivateKey(file string, random io.Reader, bits int) (*rsa.PrivateKey, error) {
// read the file or return nil
f, err := readOrEmptyFile(file)
if err != nil {
return nil, err
}
if f == nil {
// generate a new key
key, err := rsa.GenerateKey(random, bits)
if err != nil {
return nil, err
}
// save key to file
err = rsaprivate.Write(file, key)
if err != nil {
return nil, err
}
return key, err
} else {
// return key
return rsaprivate.Decode(bytes.NewReader(f))
}
}
// readOrEmptyFile returns bytes and errors from os.OpenFile or (nil, nil) if the
// file does not exist.
func readOrEmptyFile(file string) ([]byte, error) {
fp, err := os.Open(file)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
defer func() { _ = fp.Close() }()
// add hard limit
limitReader := io.LimitReader(fp, readLimit)
raw, err := io.ReadAll(limitReader)
if err != nil {
return nil, err
}
return raw, nil
}