mirror of
https://github.com/1f349/mjwt.git
synced 2024-11-13 23:11:34 +00:00
266 lines
6.7 KiB
Go
266 lines
6.7 KiB
Go
package mjwt
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"errors"
|
|
"github.com/1f349/rsa-helper/rsaprivate"
|
|
"github.com/1f349/rsa-helper/rsapublic"
|
|
"github.com/golang-jwt/jwt/v4"
|
|
"github.com/spf13/afero"
|
|
"golang.org/x/sync/errgroup"
|
|
"io/fs"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
var ErrMissingPrivateKey = errors.New("missing private key")
|
|
var ErrMissingPublicKey = errors.New("missing public key")
|
|
var ErrMissingKeyPair = errors.New("missing key pair")
|
|
|
|
const PrivateStr = ".private"
|
|
const PublicStr = ".public"
|
|
|
|
const PemExt = ".pem"
|
|
const PrivatePemExt = PrivateStr + PemExt
|
|
const PublicPemExt = PublicStr + PemExt
|
|
|
|
// KeyStore provides a store for a collection of private/public keypair structs
|
|
type KeyStore struct {
|
|
mu *sync.RWMutex
|
|
store map[string]*keyPair
|
|
dir afero.Fs
|
|
}
|
|
|
|
// NewKeyStore creates an empty KeyStore
|
|
func NewKeyStore() *KeyStore {
|
|
return &KeyStore{
|
|
mu: new(sync.RWMutex),
|
|
store: make(map[string]*keyPair),
|
|
}
|
|
}
|
|
|
|
// NewKeyStoreWithDir creates an empty KeyStore with an underlying afero.Fs
|
|
// filesystem for saving the internal store data
|
|
func NewKeyStoreWithDir(dir afero.Fs) *KeyStore {
|
|
keyStore := NewKeyStore()
|
|
keyStore.dir = dir
|
|
return keyStore
|
|
}
|
|
|
|
// NewKeyStoreFromPath creates an empty KeyStore. The provided path is walked to
|
|
// load the private/public keys. See implementation in NewKeyStoreFromDir.
|
|
func NewKeyStoreFromPath(dir string) (*KeyStore, error) {
|
|
abs, err := filepath.Abs(dir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewKeyStoreFromDir(afero.NewBasePathFs(afero.NewOsFs(), abs))
|
|
}
|
|
|
|
// NewKeyStoreFromDir creates an empty KeyStore. The provided afero.Fs is walked
|
|
// to find all private/public keys in files named `.private.pem` and
|
|
// `.public.pem` respectively. The keys are loaded into the KeyStore and any
|
|
// errors are returned immediately.
|
|
func NewKeyStoreFromDir(dir afero.Fs) (*KeyStore, error) {
|
|
keyStore := NewKeyStoreWithDir(dir)
|
|
err := afero.Walk(dir, ".", func(path string, d fs.FileInfo, err error) error {
|
|
// maybe this is "name.private.pem"
|
|
name := filepath.Base(path)
|
|
ext := filepath.Ext(name)
|
|
if ext != PemExt {
|
|
return nil
|
|
}
|
|
|
|
name = strings.TrimSuffix(name, ext)
|
|
ext = filepath.Ext(name)
|
|
name = strings.TrimSuffix(name, ext)
|
|
switch ext {
|
|
case PrivateStr:
|
|
open, err := dir.Open(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
decode, err := rsaprivate.Decode(open)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
keyStore.LoadPrivateKey(name, decode)
|
|
return nil
|
|
case PublicStr:
|
|
open, err := dir.Open(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
decode, err := rsapublic.Decode(open)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
keyStore.LoadPublicKey(name, decode)
|
|
return nil
|
|
}
|
|
|
|
// still invalid
|
|
return nil
|
|
})
|
|
return keyStore, err
|
|
}
|
|
|
|
type keyPair struct {
|
|
private *rsa.PrivateKey
|
|
public *rsa.PublicKey
|
|
}
|
|
|
|
// LoadPrivateKey sets the rsa.PrivateKey/rsa.PublicKey for the KID
|
|
func (k *KeyStore) LoadPrivateKey(kid string, key *rsa.PrivateKey) {
|
|
k.mu.Lock()
|
|
if k.store[kid] == nil {
|
|
k.store[kid] = &keyPair{}
|
|
}
|
|
k.store[kid].private = key
|
|
k.store[kid].public = &key.PublicKey
|
|
k.mu.Unlock()
|
|
}
|
|
|
|
// LoadPublicKey sets the rsa.PublicKey for the KID
|
|
func (k *KeyStore) LoadPublicKey(kid string, key *rsa.PublicKey) {
|
|
k.mu.Lock()
|
|
if k.store[kid] == nil {
|
|
k.store[kid] = &keyPair{}
|
|
}
|
|
k.store[kid].public = key
|
|
k.mu.Unlock()
|
|
}
|
|
|
|
// RemoveKey deletes the KID keypair from the KeyStore
|
|
func (k *KeyStore) RemoveKey(kid string) {
|
|
k.mu.Lock()
|
|
delete(k.store, kid)
|
|
k.mu.Unlock()
|
|
}
|
|
|
|
// ListKeys provides a slice of the KIDs for all keys loaded in the KeyStore
|
|
func (k *KeyStore) ListKeys() []string {
|
|
k.mu.RLock()
|
|
defer k.mu.RUnlock()
|
|
keys := make([]string, 0, len(k.store))
|
|
for k, _ := range k.store {
|
|
keys = append(keys, k)
|
|
}
|
|
return keys
|
|
}
|
|
|
|
// GetPrivateKey outputs the rsa.PrivateKey for the KID from the KeyStore
|
|
func (k *KeyStore) GetPrivateKey(kid string) (*rsa.PrivateKey, error) {
|
|
k.mu.RLock()
|
|
defer k.mu.RUnlock()
|
|
if !k.internalHasPrivateKey(kid) {
|
|
return nil, ErrMissingPrivateKey
|
|
}
|
|
return k.store[kid].private, nil
|
|
}
|
|
|
|
// GetPublicKey outputs the rsa.PublicKey for the KID from the KeyStore
|
|
func (k *KeyStore) GetPublicKey(kid string) (*rsa.PublicKey, error) {
|
|
k.mu.RLock()
|
|
defer k.mu.RUnlock()
|
|
if !k.internalHasPublicKey(kid) {
|
|
return nil, ErrMissingPublicKey
|
|
}
|
|
return k.store[kid].public, nil
|
|
}
|
|
|
|
// ClearKeys clears the internal map and makes a new map to release used memory
|
|
func (k *KeyStore) ClearKeys() {
|
|
k.mu.Lock()
|
|
clear(k.store)
|
|
k.store = make(map[string]*keyPair)
|
|
k.mu.Unlock()
|
|
}
|
|
|
|
// HasPrivateKey outputs true if the KID is found in the KeyStore
|
|
func (k *KeyStore) HasPrivateKey(kid string) bool {
|
|
k.mu.RLock()
|
|
defer k.mu.RUnlock()
|
|
return k.internalHasPrivateKey(kid)
|
|
}
|
|
|
|
func (k *KeyStore) internalHasPrivateKey(kid string) bool {
|
|
v := k.store[kid]
|
|
return v != nil && v.private != nil
|
|
}
|
|
|
|
// HasPublicKey outputs true if the KID is found in the KeyStore
|
|
func (k *KeyStore) HasPublicKey(kid string) bool {
|
|
k.mu.RLock()
|
|
defer k.mu.RUnlock()
|
|
return k.internalHasPublicKey(kid)
|
|
}
|
|
|
|
func (k *KeyStore) internalHasPublicKey(kid string) bool {
|
|
v := k.store[kid]
|
|
return v != nil && v.public != nil
|
|
}
|
|
|
|
// VerifyJwt parses the provided token string and validates it against the KID
|
|
// using the KeyStore. An error is returned if the token fails to parse or if
|
|
// there is no matching KID in the KeyStore.
|
|
func (k *KeyStore) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
|
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
|
kid, ok := token.Header["kid"].(string)
|
|
if !ok {
|
|
return nil, ErrMissingPublicKey
|
|
}
|
|
return k.GetPublicKey(kid)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return withClaims, claims.Valid()
|
|
}
|
|
|
|
// SaveSingleKey writes the rsa.PrivateKey/rsa.PublicKey for the requested KID to
|
|
// the underlying afero.Fs.
|
|
func (k *KeyStore) SaveSingleKey(kid string) error {
|
|
if k.dir == nil {
|
|
return nil
|
|
}
|
|
|
|
k.mu.RLock()
|
|
pair := k.store[kid]
|
|
k.mu.RUnlock()
|
|
if pair == nil {
|
|
return ErrMissingKeyPair
|
|
}
|
|
|
|
return writeSingleKey(k.dir, kid, pair)
|
|
}
|
|
|
|
// SaveKeys writes the rsa.PrivateKey/rsa.PublicKey for the requested KID to the
|
|
// underlying afero.Fs.
|
|
func (k *KeyStore) SaveKeys() error {
|
|
k.mu.RLock()
|
|
defer k.mu.RUnlock()
|
|
|
|
workers := new(errgroup.Group)
|
|
workers.SetLimit(runtime.NumCPU())
|
|
for kid, pair := range k.store {
|
|
workers.Go(func() error {
|
|
return writeSingleKey(k.dir, kid, pair)
|
|
})
|
|
}
|
|
return workers.Wait()
|
|
}
|
|
|
|
func writeSingleKey(dir afero.Fs, kid string, pair *keyPair) error {
|
|
var errs []error
|
|
if pair.private != nil {
|
|
errs = append(errs, afero.WriteFile(dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
|
|
}
|
|
if pair.public != nil {
|
|
errs = append(errs, afero.WriteFile(dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
|
|
}
|
|
return errors.Join(errs...)
|
|
}
|