diff --git a/auth/access-token.go b/auth/access-token.go index 8172576..0bd7ac3 100644 --- a/auth/access-token.go +++ b/auth/access-token.go @@ -3,13 +3,13 @@ package auth import ( "github.com/MrMelon54/mjwt" "github.com/MrMelon54/mjwt/claims" + "github.com/golang-jwt/jwt/v4" "time" ) // AccessTokenClaims contains the JWT claims for an access token type AccessTokenClaims struct { - UserId uint64 `json:"uid"` - Perms *claims.PermStorage `json:"per"` + Perms *claims.PermStorage `json:"per"` } func (a AccessTokenClaims) Valid() error { return nil } @@ -17,14 +17,11 @@ func (a AccessTokenClaims) Valid() error { return nil } func (a AccessTokenClaims) Type() string { return "access-token" } // CreateAccessToken creates an access token with the default 15 minute duration -func CreateAccessToken(p mjwt.Signer, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) { - return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, userId, perms) +func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) { + return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms) } // CreateAccessTokenWithDuration creates an access token with a custom duration -func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) { - return p.GenerateJwt(sub, id, dur, &AccessTokenClaims{ - UserId: userId, - Perms: perms, - }) +func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) { + return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms}) } diff --git a/auth/access-token_test.go b/auth/access-token_test.go index dada28b..eb2b3a4 100644 --- a/auth/access-token_test.go +++ b/auth/access-token_test.go @@ -19,14 +19,13 @@ func TestCreateAccessToken(t *testing.T) { s := mjwt.NewMJwtSigner("mjwt.test", key) - accessToken, err := CreateAccessToken(s, "1", "test", 1, ps) + accessToken, err := CreateAccessToken(s, "1", "test", nil, ps) assert.NoError(t, err) _, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken) assert.NoError(t, err) assert.Equal(t, "1", b.Subject) assert.Equal(t, "test", b.ID) - assert.Equal(t, uint64(1), b.Claims.UserId) assert.True(t, b.Claims.Perms.Has("mjwt:test")) assert.True(t, b.Claims.Perms.Has("mjwt:test2")) assert.False(t, b.Claims.Perms.Has("mjwt:test3")) diff --git a/auth/pair.go b/auth/pair.go index d3175b4..b99cd06 100644 --- a/auth/pair.go +++ b/auth/pair.go @@ -3,23 +3,24 @@ package auth import ( "github.com/MrMelon54/mjwt" "github.com/MrMelon54/mjwt/claims" + "github.com/golang-jwt/jwt/v4" "time" ) // CreateTokenPair creates an access and refresh token pair using the default // 15 minute and 7 day durations respectively -func CreateTokenPair(p mjwt.Signer, sub, id string, userId uint64, perms *claims.PermStorage) (string, string, error) { - return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, userId, perms) +func CreateTokenPair(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) { + return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms) } // CreateTokenPairWithDuration creates an access and refresh token pair using // custom durations for the access and refresh tokens -func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id string, userId uint64, perms *claims.PermStorage) (string, string, error) { - accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, userId, perms) +func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) { + accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms) if err != nil { return "", "", err } - refreshToken, err := CreateRefreshTokenWithDuration(p, refreshDur, sub, id, userId, perms) + refreshToken, err := CreateRefreshTokenWithDuration(p, refreshDur, sub, rId, id, rAud) if err != nil { return "", "", err } diff --git a/auth/pair_test.go b/auth/pair_test.go index 733b524..f5cc542 100644 --- a/auth/pair_test.go +++ b/auth/pair_test.go @@ -19,14 +19,13 @@ func TestCreateTokenPair(t *testing.T) { s := mjwt.NewMJwtSigner("mjwt.test", key) - accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", 1, ps) + accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps) assert.NoError(t, err) _, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken) assert.NoError(t, err) assert.Equal(t, "1", b.Subject) assert.Equal(t, "test", b.ID) - assert.Equal(t, uint64(1), b.Claims.UserId) assert.True(t, b.Claims.Perms.Has("mjwt:test")) assert.True(t, b.Claims.Perms.Has("mjwt:test2")) assert.False(t, b.Claims.Perms.Has("mjwt:test3")) @@ -34,9 +33,5 @@ func TestCreateTokenPair(t *testing.T) { _, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken) assert.NoError(t, err) assert.Equal(t, "1", b2.Subject) - assert.Equal(t, "test", b2.ID) - assert.Equal(t, uint64(1), b2.Claims.UserId) - assert.True(t, b2.Claims.Perms.Has("mjwt:test")) - assert.True(t, b2.Claims.Perms.Has("mjwt:test2")) - assert.False(t, b2.Claims.Perms.Has("mjwt:test3")) + assert.Equal(t, "test2", b2.ID) } diff --git a/auth/refresh-token.go b/auth/refresh-token.go index 969525f..f87fd40 100644 --- a/auth/refresh-token.go +++ b/auth/refresh-token.go @@ -2,28 +2,25 @@ package auth import ( "github.com/MrMelon54/mjwt" - "github.com/MrMelon54/mjwt/claims" + "github.com/golang-jwt/jwt/v4" "time" ) // RefreshTokenClaims contains the JWT claims for a refresh token +// AccessTokenId (ati) must match the similar JWT ID (jti) claim type RefreshTokenClaims struct { - UserId uint64 `json:"uid"` - Perms *claims.PermStorage `json:"per"` + AccessTokenId string `json:"ati"` } func (r RefreshTokenClaims) Valid() error { return nil } func (r RefreshTokenClaims) Type() string { return "refresh-token" } // CreateRefreshToken creates a refresh token with the default 7 day duration -func CreateRefreshToken(p mjwt.Signer, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) { - return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, userId, perms) +func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) { + return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud) } // CreateRefreshTokenWithDuration creates a refresh token with a custom duration -func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) { - return p.GenerateJwt(sub, id, dur, RefreshTokenClaims{ - UserId: userId, - Perms: perms, - }) +func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) { + return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati}) } diff --git a/auth/refresh-token_test.go b/auth/refresh-token_test.go index fc74fcb..479e727 100644 --- a/auth/refresh-token_test.go +++ b/auth/refresh-token_test.go @@ -4,7 +4,6 @@ import ( "crypto/rand" "crypto/rsa" "github.com/MrMelon54/mjwt" - "github.com/MrMelon54/mjwt/claims" "github.com/stretchr/testify/assert" "testing" ) @@ -13,21 +12,14 @@ func TestCreateRefreshToken(t *testing.T) { key, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) - ps := claims.NewPermStorage() - ps.Set("mjwt:test") - ps.Set("mjwt:test2") - s := mjwt.NewMJwtSigner("mjwt.test", key) - refreshToken, err := CreateRefreshToken(s, "1", "test", 1, ps) + refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil) assert.NoError(t, err) _, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken) assert.NoError(t, err) assert.Equal(t, "1", b.Subject) assert.Equal(t, "test", b.ID) - assert.Equal(t, uint64(1), b.Claims.UserId) - assert.True(t, b.Claims.Perms.Has("mjwt:test")) - assert.True(t, b.Claims.Perms.Has("mjwt:test2")) - assert.False(t, b.Claims.Perms.Has("mjwt:test3")) + assert.Equal(t, "test2", b.Claims.AccessTokenId) } diff --git a/cmd/mjwt/access.go b/cmd/mjwt/access.go new file mode 100644 index 0000000..e00bb96 --- /dev/null +++ b/cmd/mjwt/access.go @@ -0,0 +1,89 @@ +package main + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "flag" + "fmt" + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/mjwt/auth" + "github.com/MrMelon54/mjwt/claims" + "github.com/golang-jwt/jwt/v4" + "github.com/google/subcommands" + "os" + "strings" + "time" +) + +type accessCmd struct { + issuer, subject, id, audience, duration string +} + +func (s *accessCmd) Name() string { return "access" } +func (s *accessCmd) Synopsis() string { + return "Generates an access token with permissions using the private key" +} +func (s *accessCmd) Usage() string { + return `sign [-iss ] [-sub ] [-id ] [-aud ] [-dur ] + Output a signed MJWT token with the specified permissions. +` +} + +func (s *accessCmd) SetFlags(f *flag.FlagSet) { + f.StringVar(&s.issuer, "iss", "MJWT Utility", "The name of the MJWT issuer (default: MJWT Utility)") + f.StringVar(&s.subject, "sub", "", "MJWT Subject") + f.StringVar(&s.id, "id", "", "MJWT ID") + f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT") + f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)") +} + +func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { + if f.NArg() < 1 { + _, _ = fmt.Fprintln(os.Stderr, "Error: Missing private key path argument") + return subcommands.ExitFailure + } + + args := f.Args() + key, err := s.parseKey(args[0]) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Error: Failed to parse private key: ", err) + return subcommands.ExitFailure + } + + ps := claims.NewPermStorage() + for i := 1; i < len(args); i++ { + ps.Set(args[i]) + } + + var aud jwt.ClaimStrings + if s.audience != "" { + aud = strings.Split(s.audience, ",") + } + dur, err := time.ParseDuration(s.duration) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Error: Failed to parse duration: ", err) + return subcommands.ExitFailure + } + + signer := mjwt.NewMJwtSigner(s.issuer, key) + token, err := signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps}) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err) + return subcommands.ExitFailure + } + + fmt.Println(token) + return subcommands.ExitSuccess +} + +func (s *accessCmd) parseKey(privKeyFile string) (*rsa.PrivateKey, error) { + b, err := os.ReadFile(privKeyFile) + if err != nil { + return nil, err + } + + p, _ := pem.Decode(b) + return x509.ParsePKCS1PrivateKey(p.Bytes) +} diff --git a/cmd/mjwt/gen.go b/cmd/mjwt/gen.go new file mode 100644 index 0000000..9eb0b47 --- /dev/null +++ b/cmd/mjwt/gen.go @@ -0,0 +1,77 @@ +package main + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "flag" + "fmt" + "github.com/google/subcommands" + "math/rand" + "os" + "time" +) + +type genCmd struct { + bits int +} + +func (g *genCmd) Name() string { return "gen" } +func (g *genCmd) Synopsis() string { return "Generate RSA private key" } +func (g *genCmd) Usage() string { + return `gen + Output RSA private key to the provided file. +` +} + +func (g *genCmd) SetFlags(f *flag.FlagSet) { + f.IntVar(&g.bits, "bits", 4096, "Number of bits to generate (default: 4096)") +} + +func (g *genCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { + if f.NArg() != 2 { + _, _ = fmt.Fprintln(os.Stderr, "Error: Missing private and public key file") + return subcommands.ExitFailure + } + + // arguments + privPath := f.Arg(0) + pubPath := f.Arg(1) + + if err := g.gen(privPath, pubPath); err != nil { + _, _ = fmt.Fprintln(os.Stderr, "An error occured generating the private and public keys: ", err) + return subcommands.ExitFailure + } + + fmt.Println("Success generating RSA private key") + return subcommands.ExitSuccess +} + +func (g *genCmd) gen(privPath, pubPath string) error { + createPriv, err := os.OpenFile(privPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer createPriv.Close() + + createPub, err := os.OpenFile(pubPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer createPub.Close() + + key, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), g.bits) + if err != nil { + return err + } + + keyBytes := x509.MarshalPKCS1PrivateKey(key) + pubBytes := x509.MarshalPKCS1PublicKey(&key.PublicKey) + err = pem.Encode(createPriv, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes}) + if err != nil { + return err + } + err = pem.Encode(createPub, &pem.Block{Type: "RSA PUBLIC KEY", Bytes: pubBytes}) + return err +} diff --git a/cmd/mjwt/main.go b/cmd/mjwt/main.go new file mode 100644 index 0000000..9aeb15c --- /dev/null +++ b/cmd/mjwt/main.go @@ -0,0 +1,20 @@ +package main + +import ( + "context" + "flag" + "github.com/google/subcommands" + "os" +) + +func main() { + subcommands.Register(subcommands.HelpCommand(), "") + subcommands.Register(subcommands.FlagsCommand(), "") + subcommands.Register(subcommands.CommandsCommand(), "") + subcommands.Register(&genCmd{}, "") + subcommands.Register(&accessCmd{}, "") + + flag.Parse() + ctx := context.Background() + os.Exit(int(subcommands.Execute(ctx))) +} diff --git a/go.mod b/go.mod index 598e75f..1edb8ec 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.19 require ( github.com/golang-jwt/jwt/v4 v4.4.3 + github.com/google/subcommands v1.2.0 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.8.4 gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index 162adfa..22d18c5 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/interfaces.go b/interfaces.go index 2cf3886..f0ca8c5 100644 --- a/interfaces.go +++ b/interfaces.go @@ -9,7 +9,7 @@ import ( // Signer can also be used as a Verifier. type Signer interface { Verifier - GenerateJwt(sub, id string, dur time.Duration, claims Claims) (string, error) + GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) SignJwt(claims jwt.Claims) (string, error) Issuer() string } diff --git a/mjwt.go b/mjwt.go index b147ba8..96b00b0 100644 --- a/mjwt.go +++ b/mjwt.go @@ -10,12 +10,13 @@ import ( var ErrClaimTypeMismatch = errors.New("claim type mismatch") // wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct -func wrapClaims[T Claims](p Signer, sub, id string, dur time.Duration, claims T) *BaseTypeClaims[T] { +func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] { now := time.Now() return (&BaseTypeClaims[T]{ RegisteredClaims: jwt.RegisteredClaims{ Issuer: p.Issuer(), Subject: sub, + Audience: aud, ExpiresAt: jwt.NewNumericDate(now.Add(dur)), NotBefore: jwt.NewNumericDate(now), IssuedAt: jwt.NewNumericDate(now), diff --git a/mjwt_test.go b/mjwt_test.go index e6dff92..d9a957f 100644 --- a/mjwt_test.go +++ b/mjwt_test.go @@ -36,7 +36,7 @@ func TestExtractClaims(t *testing.T) { assert.NoError(t, err) s := NewMJwtSigner("mjwt.test", key) - token, err := s.GenerateJwt("1", "test", 10*time.Minute, testClaims{TestValue: "hello"}) + token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}) assert.NoError(t, err) m := NewMJwtVerifier(&key.PublicKey) @@ -49,7 +49,7 @@ func TestExtractClaimsFail(t *testing.T) { assert.NoError(t, err) s := NewMJwtSigner("mjwt.test", key) - token, err := s.GenerateJwt("1", "test", 10*time.Minute, testClaims{TestValue: "test"}) + token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}) assert.NoError(t, err) m := NewMJwtVerifier(&key.PublicKey) diff --git a/signer.go b/signer.go index bb50311..636c3ff 100644 --- a/signer.go +++ b/signer.go @@ -29,8 +29,8 @@ func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer { func (d *defaultMJwtSigner) Issuer() string { return d.issuer } // GenerateJwt generates and returns a JWT string using the sub, id, duration and claims -func (d *defaultMJwtSigner) GenerateJwt(sub, id string, dur time.Duration, claims Claims) (string, error) { - return d.SignJwt(wrapClaims[Claims](d, sub, id, dur, claims)) +func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) { + return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims)) } // SignJwt signs a jwt.Claims compatible struct, this is used internally by diff --git a/verifier_test.go b/verifier_test.go index 87916f7..98cc4bc 100644 --- a/verifier_test.go +++ b/verifier_test.go @@ -16,7 +16,7 @@ func TestNewMJwtVerifierFromFile(t *testing.T) { assert.NoError(t, err) s := NewMJwtSigner("mjwt.test", key) - token, err := s.GenerateJwt("1", "test", 10*time.Minute, testClaims{TestValue: "world"}) + token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}) assert.NoError(t, err) b := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&key.PublicKey)})