diff --git a/cmd/mjwt/access.go b/cmd/mjwt/access.go index 32a3dd3..d4c2330 100644 --- a/cmd/mjwt/access.go +++ b/cmd/mjwt/access.go @@ -2,14 +2,12 @@ package main import ( "context" - "crypto/rsa" - "crypto/x509" - "encoding/pem" "flag" "fmt" "github.com/1f349/mjwt" "github.com/1f349/mjwt/auth" "github.com/1f349/mjwt/claims" + "github.com/1f349/rsa-helper/rsaprivate" "github.com/golang-jwt/jwt/v4" "github.com/google/subcommands" "os" @@ -46,7 +44,7 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{} } args := f.Args() - key, err := s.parseKey(args[0]) + key, err := rsaprivate.Read(args[0]) if err != nil { _, _ = fmt.Fprintln(os.Stderr, "Error: Failed to parse private key: ", err) return subcommands.ExitFailure @@ -77,13 +75,3 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{} fmt.Println(token) return subcommands.ExitSuccess } - -func (s *accessCmd) parseKey(privKeyFile string) (*rsa.PrivateKey, error) { - b, err := os.ReadFile(privKeyFile) - if err != nil { - return nil, err - } - - p, _ := pem.Decode(b) - return x509.ParsePKCS1PrivateKey(p.Bytes) -} diff --git a/cmd/mjwt/gen.go b/cmd/mjwt/gen.go index 9eb0b47..b56e789 100644 --- a/cmd/mjwt/gen.go +++ b/cmd/mjwt/gen.go @@ -3,10 +3,10 @@ package main import ( "context" "crypto/rsa" - "crypto/x509" - "encoding/pem" "flag" "fmt" + "github.com/1f349/rsa-helper/rsaprivate" + "github.com/1f349/rsa-helper/rsapublic" "github.com/google/subcommands" "math/rand" "os" @@ -49,29 +49,14 @@ func (g *genCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) s } func (g *genCmd) gen(privPath, pubPath string) error { - createPriv, err := os.OpenFile(privPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return err - } - defer createPriv.Close() - - createPub, err := os.OpenFile(pubPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return err - } - defer createPub.Close() - key, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), g.bits) if err != nil { return err } - keyBytes := x509.MarshalPKCS1PrivateKey(key) - pubBytes := x509.MarshalPKCS1PublicKey(&key.PublicKey) - err = pem.Encode(createPriv, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes}) + err = rsaprivate.Write(privPath, key) if err != nil { return err } - err = pem.Encode(createPub, &pem.Block{Type: "RSA PUBLIC KEY", Bytes: pubBytes}) - return err + return rsapublic.Write(pubPath, &key.PublicKey) } diff --git a/go.mod b/go.mod index 150ebd1..a81ad33 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,11 @@ module github.com/1f349/mjwt -go 1.19 +go 1.22 + +toolchain go1.22.3 require ( + github.com/1f349/rsa-helper v0.0.0-20240608023351-e4382c728b17 github.com/becheran/wildmatch-go v1.0.0 github.com/golang-jwt/jwt/v4 v4.5.0 github.com/google/subcommands v1.2.0 diff --git a/go.sum b/go.sum index 239b42d..6fe0930 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/1f349/rsa-helper v0.0.0-20240608023351-e4382c728b17 h1:cgJDS14TTM8hvg1qNhXFj3IfFEZ99IXR00D8gcwbY98= +github.com/1f349/rsa-helper v0.0.0-20240608023351-e4382c728b17/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg= github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA= github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/interfaces.go b/interfaces.go index 939a466..ce63f1d 100644 --- a/interfaces.go +++ b/interfaces.go @@ -21,3 +21,14 @@ type Verifier interface { VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) PublicKey() *rsa.PublicKey } + +// KeyStore is used for the kid header support in Signer and Verifier. +type KeyStore interface { + SetKey(kID string, prvKey *rsa.PrivateKey) bool + SetKeyPublic(kID string, pubKey *rsa.PublicKey) bool + RemoveKey(kID string) bool + ListKeys() []string + GetKey(kID string) *rsa.PrivateKey + GetKeyPublic(kID string) *rsa.PublicKey + ClearKeys() +} diff --git a/key_store.go b/key_store.go new file mode 100644 index 0000000..fa39ad2 --- /dev/null +++ b/key_store.go @@ -0,0 +1,159 @@ +package mjwt + +import ( + "crypto/rsa" + "github.com/1f349/rsa-helper/rsaprivate" + "github.com/1f349/rsa-helper/rsapublic" + "os" + "path" + "strings" + "sync" +) + +// defaultMJwtKeyStore implements KeyStore and stores kIDs against just rsa.PublicKey +// or with rsa.PrivateKey instances as well. +type defaultMJwtKeyStore struct { + rwLocker *sync.RWMutex + store map[string]*rsa.PrivateKey + storePub map[string]*rsa.PublicKey +} + +var _ KeyStore = &defaultMJwtKeyStore{} + +// NewMJwtKeyStore creates a new defaultMJwtKeyStore. +func NewMJwtKeyStore() KeyStore { + return &defaultMJwtKeyStore{ + rwLocker: new(sync.RWMutex), + store: make(map[string]*rsa.PrivateKey), + storePub: make(map[string]*rsa.PublicKey), + } +} + +// NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private +// rsa keys; the kID is the filename of the key up to the first . +func NewMJwtKeyStoreFromDirectory(directory string, keyPrvExt string, keyPubExt string) (KeyStore, error) { + // Create empty KeyStore + ks := NewMJwtKeyStore() + // List directory contents + dirEntries, err := os.ReadDir(directory) + if err != nil { + return nil, err + } + // Import keys from files, based on extension + for _, entry := range dirEntries { + if !entry.IsDir() { + firstDotIdx := strings.Index(entry.Name(), ".") + lastDotIdx := strings.LastIndex(entry.Name(), ".") + if firstDotIdx > 0 && lastDotIdx+1 < len(entry.Name()) { + if entry.Name()[lastDotIdx+1:] == keyPrvExt { + // Load rsa private key with the file name as the kID (Up to the first .) + key, err := rsaprivate.Read(path.Join(directory, entry.Name())) + if err == nil { + ks.SetKey(entry.Name()[:firstDotIdx], key) + } + } else if entry.Name()[lastDotIdx+1:] == keyPubExt { + // Load rsa public key with the file name as the kID (Up to the first .) + key, err := rsapublic.Read(path.Join(directory, entry.Name())) + if err == nil { + ks.SetKeyPublic(entry.Name()[:firstDotIdx], key) + } + } + } + } + } + return ks, nil +} + +// SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore. +func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) bool { + if d == nil || prvKey == nil { + return false + } + d.rwLocker.Lock() + defer d.rwLocker.Unlock() + d.store[kID] = prvKey + d.storePub[kID] = &prvKey.PublicKey + return true +} + +// SetKeyPublic adds a new rsa.PublicKey with the specified kID to the KeyStore. +func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) bool { + if d == nil || pubKey == nil { + return false + } + d.rwLocker.Lock() + defer d.rwLocker.Unlock() + delete(d.store, kID) + d.storePub[kID] = pubKey + return true +} + +// RemoveKey removes a specified kID from the KeyStore. +func (d *defaultMJwtKeyStore) RemoveKey(kID string) bool { + if d == nil { + return false + } + d.rwLocker.Lock() + defer d.rwLocker.Unlock() + delete(d.store, kID) + delete(d.storePub, kID) + return true +} + +// ListKeys lists the kIDs of all the keys in the KeyStore. +func (d *defaultMJwtKeyStore) ListKeys() []string { + if d == nil { + return nil + } + d.rwLocker.RLock() + defer d.rwLocker.RUnlock() + lKeys := make([]string, len(d.store)) + i := 0 + for k := range d.store { + lKeys[i] = k + } + return lKeys +} + +// GetKey gets the rsa.PrivateKey given the kID in the KeyStore or null if not found. +func (d *defaultMJwtKeyStore) GetKey(kID string) *rsa.PrivateKey { + if d == nil { + return nil + } + d.rwLocker.RLock() + defer d.rwLocker.RUnlock() + kPrv, ok := d.store[kID] + if ok { + return kPrv + } + return nil +} + +// GetKeyPublic gets the rsa.PublicKey given the kID in the KeyStore or null if not found. +func (d *defaultMJwtKeyStore) GetKeyPublic(kID string) *rsa.PublicKey { + if d == nil { + return nil + } + d.rwLocker.RLock() + defer d.rwLocker.RUnlock() + kPub, ok := d.storePub[kID] + if ok { + return kPub + } + return nil +} + +// ClearKeys removes all the stored keys in the KeyStore. +func (d *defaultMJwtKeyStore) ClearKeys() { + if d == nil { + return + } + d.rwLocker.Lock() + defer d.rwLocker.Unlock() + for k := range d.store { + delete(d.store, k) + } + for k := range d.storePub { + delete(d.storePub, k) + } +} diff --git a/signer.go b/signer.go index fa7947b..5276065 100644 --- a/signer.go +++ b/signer.go @@ -1,10 +1,9 @@ package mjwt import ( + "bytes" "crypto/rsa" - "crypto/x509" - "encoding/pem" - "fmt" + "github.com/1f349/rsa-helper/rsaprivate" "github.com/golang-jwt/jwt/v4" "io" "os" @@ -45,20 +44,8 @@ func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits i // NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a // rsa.PrivateKey file. func NewMJwtSignerFromFile(issuer, file string) (Signer, error) { - // read file - raw, err := os.ReadFile(file) - if err != nil { - return nil, err - } - - // decode pem block - block, _ := pem.Decode(raw) - if block == nil || block.Type != "RSA PRIVATE KEY" { - return nil, fmt.Errorf("invalid rsa private key pem block") - } - - // parse private key from pem block - key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + // read key + key, err := rsaprivate.Read(file) if err != nil { return nil, err } @@ -106,26 +93,15 @@ func readOrCreatePrivateKey(file string, random io.Reader, bits int) (*rsa.Priva return nil, err } - keyBytes := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(key), - }) + // save key to file + err = rsaprivate.Write(file, key) if err != nil { return nil, err } - - // write the key to the file - err = os.WriteFile(file, keyBytes, 0600) return key, err } else { - // decode pem block - block, _ := pem.Decode(f) - if block == nil || block.Type != "RSA PRIVATE KEY" { - return nil, fmt.Errorf("invalid rsa private key pem block") - } - - // try to parse the private key - return x509.ParsePKCS1PrivateKey(block.Bytes) + // return key + return rsaprivate.Decode(bytes.NewReader(f)) } } diff --git a/verifier.go b/verifier.go index 0fe9411..35f5817 100644 --- a/verifier.go +++ b/verifier.go @@ -2,10 +2,8 @@ package mjwt import ( "crypto/rsa" - "crypto/x509" - "encoding/pem" + "github.com/1f349/rsa-helper/rsapublic" "github.com/golang-jwt/jwt/v4" - "os" ) // defaultMJwtVerifier implements Verifier and uses a rsa.PublicKey to validate @@ -28,17 +26,8 @@ func newMJwtVerifier(key *rsa.PublicKey) *defaultMJwtVerifier { // NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a // rsa.PublicKey file func NewMJwtVerifierFromFile(file string) (Verifier, error) { - // read file - f, err := os.ReadFile(file) - if err != nil { - return nil, err - } - - // decode pem block - block, _ := pem.Decode(f) - - // parse public key from pem block - pub, err := x509.ParsePKCS1PublicKey(block.Bytes) + // read key + pub, err := rsapublic.Read(file) if err != nil { return nil, err }