mjwt/verifier_test.go

112 lines
3.4 KiB
Go

package mjwt
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/stretchr/testify/assert"
"os"
"path"
"testing"
"time"
)
const vt_prvExt = "prv"
const vt_pubExt = "pub"
func setupTestDirVerifier(t *testing.T, genKeys bool) (string, *rsa.PrivateKey, func(t *testing.T)) {
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
assert.NoError(t, err)
var key3 *rsa.PrivateKey = nil
if genKeys {
key1, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+vt_prvExt), key1)
assert.NoError(t, err)
key2, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+vt_prvExt), key2)
assert.NoError(t, err)
err = rsapublic.Write(path.Join(tempDir, "key2.pem."+vt_pubExt), &key2.PublicKey)
assert.NoError(t, err)
key3, err = rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+vt_pubExt), &key3.PublicKey)
assert.NoError(t, err)
}
return tempDir, key3, func(t *testing.T) {
err := os.RemoveAll(tempDir)
assert.NoError(t, err)
}
}
func TestNewMJwtVerifierFromFile(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
s := NewMJwtSigner("mjwt.test", key)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"})
assert.NoError(t, err)
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&key.PublicKey)})
temp, err := os.CreateTemp("", "this-is-a-test-file.pem")
assert.NoError(t, err)
_, err = temp.Write(b)
assert.NoError(t, err)
file, err := NewMJwtVerifierFromFile(temp.Name())
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims](file, token)
assert.NoError(t, err)
err = os.Remove(temp.Name())
assert.NoError(t, err)
}
func TestNewMJwtVerifierFromDirectory(t *testing.T) {
t.Parallel()
tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true)
defer cleaner(t)
s, err := NewMJwtSignerFromDirectory("mjwt.test", tempDir, vt_prvExt, vt_pubExt)
assert.NoError(t, err)
s.GetKeyStore().SetKey("key3", prvKey3)
token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3")
assert.NoError(t, err)
v, err := NewMJwtVerifierFromDirectory(tempDir, vt_prvExt, vt_pubExt)
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims](v, token)
assert.NoError(t, err)
}
func TestNewMJwtVerifierFromFileAndDirectory(t *testing.T) {
t.Parallel()
tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true)
defer cleaner(t)
s, err := NewMJwtSignerFromFileAndDirectory("mjwt.test", path.Join(tempDir, "key2.pem."+vt_prvExt), tempDir, vt_prvExt, vt_pubExt)
assert.NoError(t, err)
s.GetKeyStore().SetKey("key3", prvKey3)
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"})
assert.NoError(t, err)
token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3")
assert.NoError(t, err)
v, err := NewMJwtVerifierFromFileAndDirectory(path.Join(tempDir, "key2.pem."+vt_pubExt), tempDir, vt_prvExt, vt_pubExt)
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims](v, token1)
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims](v, token2)
assert.NoError(t, err)
}