mjwt/key_store.go

201 lines
4.9 KiB
Go
Raw Normal View History

package mjwt
import (
"crypto/rsa"
2024-06-08 23:57:52 +01:00
"errors"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"os"
"path"
"strings"
"sync"
)
// defaultMJwtKeyStore implements KeyStore and stores kIDs against just rsa.PublicKey
// or with rsa.PrivateKey instances as well.
type defaultMJwtKeyStore struct {
rwLocker *sync.RWMutex
store map[string]*rsa.PrivateKey
storePub map[string]*rsa.PublicKey
}
var _ KeyStore = &defaultMJwtKeyStore{}
2024-06-09 00:49:27 +01:00
// NewMJwtKeyStore creates a new defaultMJwtKeyStore.
func NewMJwtKeyStore() KeyStore {
return &defaultMJwtKeyStore{
rwLocker: new(sync.RWMutex),
store: make(map[string]*rsa.PrivateKey),
storePub: make(map[string]*rsa.PublicKey),
}
}
// NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private
// rsa keys; the kID is the filename of the key up to the first .
func NewMJwtKeyStoreFromDirectory(directory, keyPrvExt, keyPubExt string) (KeyStore, error) {
// Create empty KeyStore
2024-06-09 00:49:27 +01:00
ks := NewMJwtKeyStore().(*defaultMJwtKeyStore)
// List directory contents
dirEntries, err := os.ReadDir(directory)
if err != nil {
return nil, err
}
2024-06-09 21:15:28 +01:00
errs := make([]error, 0, len(dirEntries)/2)
// Import keys from files, based on extension
for _, entry := range dirEntries {
2024-06-09 21:00:18 +01:00
if entry.IsDir() {
continue
}
kID, _, _ := strings.Cut(entry.Name(), ".")
if kID == "" {
continue
}
2024-06-09 21:20:46 +01:00
pExt := path.Ext(entry.Name())
if pExt == "."+keyPrvExt {
2024-06-09 21:00:18 +01:00
// Load rsa private key with the file name as the kID (Up to the first .)
key, err2 := rsaprivate.Read(path.Join(directory, entry.Name()))
if err2 == nil {
ks.store[kID] = key
ks.storePub[kID] = &key.PublicKey
}
2024-06-09 21:15:28 +01:00
errs = append(errs, err2)
2024-06-09 21:20:46 +01:00
} else if pExt == "."+keyPubExt {
2024-06-09 21:00:18 +01:00
// Load rsa public key with the file name as the kID (Up to the first .)
key, err2 := rsapublic.Read(path.Join(directory, entry.Name()))
if err2 == nil {
_, exs := ks.store[kID]
if !exs {
ks.store[kID] = nil
}
2024-06-09 21:00:18 +01:00
ks.storePub[kID] = key
}
2024-06-09 21:15:28 +01:00
errs = append(errs, err2)
}
}
2024-06-09 21:20:46 +01:00
return ks, errors.Join(errs...)
2024-06-08 23:57:52 +01:00
}
// ExportKeyStore saves all the keys stored in the specified KeyStore into a directory with the specified
// extensions for public and private keys
func ExportKeyStore(ks KeyStore, directory, keyPrvExt, keyPubExt string) error {
2024-06-08 23:57:52 +01:00
if ks == nil {
return errors.New("ks is nil")
}
// Create directory
err := os.MkdirAll(directory, 0700)
if err != nil {
return err
}
2024-06-09 21:15:28 +01:00
errs := make([]error, 0, len(ks.ListKeys())/2)
2024-06-08 23:57:52 +01:00
// Export all keys
for _, kID := range ks.ListKeys() {
kPrv := ks.GetKey(kID)
if kPrv != nil {
err2 := rsaprivate.Write(path.Join(directory, kID+"."+keyPrvExt), kPrv)
2024-06-09 21:15:28 +01:00
errs = append(errs, err2)
2024-06-08 23:57:52 +01:00
}
kPub := ks.GetKeyPublic(kID)
if kPub != nil {
err2 := rsapublic.Write(path.Join(directory, kID+"."+keyPubExt), kPub)
2024-06-09 21:15:28 +01:00
errs = append(errs, err2)
2024-06-08 23:57:52 +01:00
}
}
2024-06-09 21:20:46 +01:00
return errors.Join(errs...)
}
// SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore.
func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) {
if d == nil || prvKey == nil {
return
}
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
d.store[kID] = prvKey
d.storePub[kID] = &prvKey.PublicKey
return
}
// SetKeyPublic adds a new rsa.PublicKey with the specified kID to the KeyStore.
func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) {
if d == nil || pubKey == nil {
return
}
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
2024-06-08 23:57:52 +01:00
_, exs := d.store[kID]
if !exs {
d.store[kID] = nil
}
d.storePub[kID] = pubKey
return
}
// RemoveKey removes a specified kID from the KeyStore.
2024-06-09 00:49:27 +01:00
func (d *defaultMJwtKeyStore) RemoveKey(kID string) {
if d == nil {
2024-06-09 00:49:27 +01:00
return
}
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
delete(d.store, kID)
delete(d.storePub, kID)
2024-06-09 00:49:27 +01:00
return
}
// ListKeys lists the kIDs of all the keys in the KeyStore.
func (d *defaultMJwtKeyStore) ListKeys() []string {
if d == nil {
return nil
}
d.rwLocker.RLock()
defer d.rwLocker.RUnlock()
lKeys := make([]string, len(d.store))
i := 0
for k := range d.store {
lKeys[i] = k
2024-06-08 23:57:52 +01:00
i++
}
return lKeys
}
// GetKey gets the rsa.PrivateKey given the kID in the KeyStore or null if not found.
func (d *defaultMJwtKeyStore) GetKey(kID string) *rsa.PrivateKey {
if d == nil {
return nil
}
d.rwLocker.RLock()
defer d.rwLocker.RUnlock()
kPrv, ok := d.store[kID]
if ok {
return kPrv
}
return nil
}
// GetKeyPublic gets the rsa.PublicKey given the kID in the KeyStore or null if not found.
func (d *defaultMJwtKeyStore) GetKeyPublic(kID string) *rsa.PublicKey {
if d == nil {
return nil
}
d.rwLocker.RLock()
defer d.rwLocker.RUnlock()
kPub, ok := d.storePub[kID]
if ok {
return kPub
}
return nil
}
// ClearKeys removes all the stored keys in the KeyStore.
func (d *defaultMJwtKeyStore) ClearKeys() {
if d == nil {
return
}
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
2024-06-08 23:57:52 +01:00
clear(d.store)
clear(d.storePub)
}