From 9a859eb2d7137a40d58992f0a74a97be1459f1de Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Sun, 29 Oct 2023 12:28:21 +0000 Subject: [PATCH] Major refactor --- auth/access-token.go | 4 ++-- auth/access-token_test.go | 5 +++-- auth/pair.go | 4 ++-- auth/pair_test.go | 5 +++-- auth/refresh-token.go | 2 +- auth/refresh-token_test.go | 3 ++- claims/perms_test.go | 6 ++++++ cmd/mjwt/access.go | 6 +++--- go.mod | 2 +- mjwt_test.go | 2 ++ signer.go | 33 ++++++++++++++++++++++++++++++--- signer_test.go | 13 ++++++++++--- verifier_test.go | 1 + 13 files changed, 66 insertions(+), 20 deletions(-) diff --git a/auth/access-token.go b/auth/access-token.go index 0bd7ac3..7b54e4d 100644 --- a/auth/access-token.go +++ b/auth/access-token.go @@ -1,8 +1,8 @@ package auth import ( - "github.com/MrMelon54/mjwt" - "github.com/MrMelon54/mjwt/claims" + "github.com/1f349/mjwt" + "github.com/1f349/mjwt/claims" "github.com/golang-jwt/jwt/v4" "time" ) diff --git a/auth/access-token_test.go b/auth/access-token_test.go index eb2b3a4..acf53aa 100644 --- a/auth/access-token_test.go +++ b/auth/access-token_test.go @@ -3,13 +3,14 @@ package auth import ( "crypto/rand" "crypto/rsa" - "github.com/MrMelon54/mjwt" - "github.com/MrMelon54/mjwt/claims" + "github.com/1f349/mjwt" + "github.com/1f349/mjwt/claims" "github.com/stretchr/testify/assert" "testing" ) func TestCreateAccessToken(t *testing.T) { + t.Parallel() key, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) diff --git a/auth/pair.go b/auth/pair.go index b99cd06..b2ce969 100644 --- a/auth/pair.go +++ b/auth/pair.go @@ -1,8 +1,8 @@ package auth import ( - "github.com/MrMelon54/mjwt" - "github.com/MrMelon54/mjwt/claims" + "github.com/1f349/mjwt" + "github.com/1f349/mjwt/claims" "github.com/golang-jwt/jwt/v4" "time" ) diff --git a/auth/pair_test.go b/auth/pair_test.go index f5cc542..5a7dd77 100644 --- a/auth/pair_test.go +++ b/auth/pair_test.go @@ -3,13 +3,14 @@ package auth import ( "crypto/rand" "crypto/rsa" - "github.com/MrMelon54/mjwt" - "github.com/MrMelon54/mjwt/claims" + "github.com/1f349/mjwt" + "github.com/1f349/mjwt/claims" "github.com/stretchr/testify/assert" "testing" ) func TestCreateTokenPair(t *testing.T) { + t.Parallel() key, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) diff --git a/auth/refresh-token.go b/auth/refresh-token.go index f87fd40..5667885 100644 --- a/auth/refresh-token.go +++ b/auth/refresh-token.go @@ -1,7 +1,7 @@ package auth import ( - "github.com/MrMelon54/mjwt" + "github.com/1f349/mjwt" "github.com/golang-jwt/jwt/v4" "time" ) diff --git a/auth/refresh-token_test.go b/auth/refresh-token_test.go index 479e727..4765c35 100644 --- a/auth/refresh-token_test.go +++ b/auth/refresh-token_test.go @@ -3,12 +3,13 @@ package auth import ( "crypto/rand" "crypto/rsa" - "github.com/MrMelon54/mjwt" + "github.com/1f349/mjwt" "github.com/stretchr/testify/assert" "testing" ) func TestCreateRefreshToken(t *testing.T) { + t.Parallel() key, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) diff --git a/claims/perms_test.go b/claims/perms_test.go index 108ba73..a046177 100644 --- a/claims/perms_test.go +++ b/claims/perms_test.go @@ -7,6 +7,7 @@ import ( ) func TestPermStorage_Set(t *testing.T) { + t.Parallel() ps := NewPermStorage() ps.Set("mjwt:test") if _, ok := ps.values["mjwt:test"]; !ok { @@ -15,6 +16,7 @@ func TestPermStorage_Set(t *testing.T) { } func TestPermStorage_Clear(t *testing.T) { + t.Parallel() ps := NewPermStorage() ps.values["mjwt:test"] = struct{}{} ps.Clear("mjwt:test") @@ -24,6 +26,7 @@ func TestPermStorage_Clear(t *testing.T) { } func TestPermStorage_Has(t *testing.T) { + t.Parallel() ps := NewPermStorage() assert.False(t, ps.Has("mjwt:test")) ps.values["mjwt:test"] = struct{}{} @@ -31,6 +34,7 @@ func TestPermStorage_Has(t *testing.T) { } func TestPermStorage_OneOf(t *testing.T) { + t.Parallel() o := NewPermStorage() o.Set("mjwt:test") o.Set("mjwt:test2") @@ -48,6 +52,7 @@ func TestPermStorage_OneOf(t *testing.T) { } func TestPermStorage_MarshalJSON(t *testing.T) { + t.Parallel() ps := NewPermStorage() ps.Set("mjwt:test") ps.Set("mjwt:test2") @@ -57,6 +62,7 @@ func TestPermStorage_MarshalJSON(t *testing.T) { } func TestPermStorage_MarshalYAML(t *testing.T) { + t.Parallel() ps := NewPermStorage() ps.Set("mjwt:test") ps.Set("mjwt:test2") diff --git a/cmd/mjwt/access.go b/cmd/mjwt/access.go index e00bb96..32a3dd3 100644 --- a/cmd/mjwt/access.go +++ b/cmd/mjwt/access.go @@ -7,9 +7,9 @@ import ( "encoding/pem" "flag" "fmt" - "github.com/MrMelon54/mjwt" - "github.com/MrMelon54/mjwt/auth" - "github.com/MrMelon54/mjwt/claims" + "github.com/1f349/mjwt" + "github.com/1f349/mjwt/auth" + "github.com/1f349/mjwt/claims" "github.com/golang-jwt/jwt/v4" "github.com/google/subcommands" "os" diff --git a/go.mod b/go.mod index 0d3c7fb..150ebd1 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/MrMelon54/mjwt +module github.com/1f349/mjwt go 1.19 diff --git a/mjwt_test.go b/mjwt_test.go index d9a957f..3a83fe8 100644 --- a/mjwt_test.go +++ b/mjwt_test.go @@ -32,6 +32,7 @@ func (t testClaims2) Valid() error { func (t testClaims2) Type() string { return "testClaims2" } func TestExtractClaims(t *testing.T) { + t.Parallel() key, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) @@ -45,6 +46,7 @@ func TestExtractClaims(t *testing.T) { } func TestExtractClaimsFail(t *testing.T) { + t.Parallel() key, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) diff --git a/signer.go b/signer.go index b44ca79..fa7947b 100644 --- a/signer.go +++ b/signer.go @@ -3,6 +3,8 @@ package mjwt import ( "crypto/rsa" "crypto/x509" + "encoding/pem" + "fmt" "github.com/golang-jwt/jwt/v4" "io" "os" @@ -43,14 +45,25 @@ func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits i // NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a // rsa.PrivateKey file. func NewMJwtSignerFromFile(issuer, file string) (Signer, error) { + // read file raw, err := os.ReadFile(file) if err != nil { return nil, err } - key, err := x509.ParsePKCS1PrivateKey(raw) + + // decode pem block + block, _ := pem.Decode(raw) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return nil, fmt.Errorf("invalid rsa private key pem block") + } + + // parse private key from pem block + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return nil, err } + + // create signer using rsa.PrivateKey return NewMJwtSigner(issuer, key), nil } @@ -93,12 +106,26 @@ func readOrCreatePrivateKey(file string, random io.Reader, bits int) (*rsa.Priva return nil, err } + keyBytes := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + if err != nil { + return nil, err + } + // write the key to the file - err = os.WriteFile(file, x509.MarshalPKCS1PrivateKey(key), 0600) + err = os.WriteFile(file, keyBytes, 0600) return key, err } else { + // decode pem block + block, _ := pem.Decode(f) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return nil, fmt.Errorf("invalid rsa private key pem block") + } + // try to parse the private key - return x509.ParsePKCS1PrivateKey(f) + return x509.ParsePKCS1PrivateKey(block.Bytes) } } diff --git a/signer_test.go b/signer_test.go index 8691da3..64d8fda 100644 --- a/signer_test.go +++ b/signer_test.go @@ -4,23 +4,27 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" + "encoding/pem" "github.com/stretchr/testify/assert" "os" "testing" ) func TestNewMJwtSigner(t *testing.T) { + t.Parallel() key, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) NewMJwtSigner("Test", key) } func TestNewMJwtSignerFromFile(t *testing.T) { + t.Parallel() tempKey, err := os.CreateTemp("", "key-test-*.pem") assert.NoError(t, err) key, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) - _, err = tempKey.Write(x509.MarshalPKCS1PrivateKey(key)) + b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + _, err = tempKey.Write(b) assert.NoError(t, err) assert.NoError(t, tempKey.Close()) signer, err := NewMJwtSignerFromFile("Test", tempKey.Name()) @@ -33,6 +37,7 @@ func TestNewMJwtSignerFromFile(t *testing.T) { } func TestNewMJwtSignerFromFileOrCreate(t *testing.T) { + t.Parallel() tempKey, err := os.CreateTemp("", "key-test-*.pem") assert.NoError(t, err) assert.NoError(t, tempKey.Close()) @@ -41,15 +46,17 @@ func TestNewMJwtSignerFromFileOrCreate(t *testing.T) { assert.NoError(t, err) signer2, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048) assert.NoError(t, err) - assert.True(t, signer.(*defaultMJwtSigner).key.Equal(signer2.(*defaultMJwtSigner).key)) + assert.True(t, signer.PrivateKey().Equal(signer2.PrivateKey())) } func TestReadOrCreatePrivateKey(t *testing.T) { + t.Parallel() tempKey, err := os.CreateTemp("", "key-test-*.pem") assert.NoError(t, err) key, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) - _, err = tempKey.Write(x509.MarshalPKCS1PrivateKey(key)) + b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + _, err = tempKey.Write(b) assert.NoError(t, err) assert.NoError(t, tempKey.Close()) key2, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048) diff --git a/verifier_test.go b/verifier_test.go index 98cc4bc..8e448dd 100644 --- a/verifier_test.go +++ b/verifier_test.go @@ -12,6 +12,7 @@ import ( ) func TestNewMJwtVerifierFromFile(t *testing.T) { + t.Parallel() key, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err)