mirror of
https://github.com/1f349/mjwt.git
synced 2024-12-22 07:24:05 +00:00
Update api to add more usability and add cli
This commit is contained in:
parent
9ea1842360
commit
37f499440d
@ -3,13 +3,13 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"github.com/MrMelon54/mjwt"
|
"github.com/MrMelon54/mjwt"
|
||||||
"github.com/MrMelon54/mjwt/claims"
|
"github.com/MrMelon54/mjwt/claims"
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccessTokenClaims contains the JWT claims for an access token
|
// AccessTokenClaims contains the JWT claims for an access token
|
||||||
type AccessTokenClaims struct {
|
type AccessTokenClaims struct {
|
||||||
UserId uint64 `json:"uid"`
|
Perms *claims.PermStorage `json:"per"`
|
||||||
Perms *claims.PermStorage `json:"per"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a AccessTokenClaims) Valid() error { return nil }
|
func (a AccessTokenClaims) Valid() error { return nil }
|
||||||
@ -17,14 +17,11 @@ func (a AccessTokenClaims) Valid() error { return nil }
|
|||||||
func (a AccessTokenClaims) Type() string { return "access-token" }
|
func (a AccessTokenClaims) Type() string { return "access-token" }
|
||||||
|
|
||||||
// CreateAccessToken creates an access token with the default 15 minute duration
|
// CreateAccessToken creates an access token with the default 15 minute duration
|
||||||
func CreateAccessToken(p mjwt.Signer, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) {
|
func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
|
||||||
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, userId, perms)
|
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccessTokenWithDuration creates an access token with a custom duration
|
// CreateAccessTokenWithDuration creates an access token with a custom duration
|
||||||
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) {
|
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
|
||||||
return p.GenerateJwt(sub, id, dur, &AccessTokenClaims{
|
return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms})
|
||||||
UserId: userId,
|
|
||||||
Perms: perms,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
@ -19,14 +19,13 @@ func TestCreateAccessToken(t *testing.T) {
|
|||||||
|
|
||||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||||
|
|
||||||
accessToken, err := CreateAccessToken(s, "1", "test", 1, ps)
|
accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "1", b.Subject)
|
assert.Equal(t, "1", b.Subject)
|
||||||
assert.Equal(t, "test", b.ID)
|
assert.Equal(t, "test", b.ID)
|
||||||
assert.Equal(t, uint64(1), b.Claims.UserId)
|
|
||||||
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
||||||
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
||||||
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
||||||
|
11
auth/pair.go
11
auth/pair.go
@ -3,23 +3,24 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"github.com/MrMelon54/mjwt"
|
"github.com/MrMelon54/mjwt"
|
||||||
"github.com/MrMelon54/mjwt/claims"
|
"github.com/MrMelon54/mjwt/claims"
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateTokenPair creates an access and refresh token pair using the default
|
// CreateTokenPair creates an access and refresh token pair using the default
|
||||||
// 15 minute and 7 day durations respectively
|
// 15 minute and 7 day durations respectively
|
||||||
func CreateTokenPair(p mjwt.Signer, sub, id string, userId uint64, perms *claims.PermStorage) (string, string, error) {
|
func CreateTokenPair(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
|
||||||
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, userId, perms)
|
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTokenPairWithDuration creates an access and refresh token pair using
|
// CreateTokenPairWithDuration creates an access and refresh token pair using
|
||||||
// custom durations for the access and refresh tokens
|
// custom durations for the access and refresh tokens
|
||||||
func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id string, userId uint64, perms *claims.PermStorage) (string, string, error) {
|
func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
|
||||||
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, userId, perms)
|
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
refreshToken, err := CreateRefreshTokenWithDuration(p, refreshDur, sub, id, userId, perms)
|
refreshToken, err := CreateRefreshTokenWithDuration(p, refreshDur, sub, rId, id, rAud)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
@ -19,14 +19,13 @@ func TestCreateTokenPair(t *testing.T) {
|
|||||||
|
|
||||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||||
|
|
||||||
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", 1, ps)
|
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "1", b.Subject)
|
assert.Equal(t, "1", b.Subject)
|
||||||
assert.Equal(t, "test", b.ID)
|
assert.Equal(t, "test", b.ID)
|
||||||
assert.Equal(t, uint64(1), b.Claims.UserId)
|
|
||||||
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
||||||
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
||||||
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
||||||
@ -34,9 +33,5 @@ func TestCreateTokenPair(t *testing.T) {
|
|||||||
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "1", b2.Subject)
|
assert.Equal(t, "1", b2.Subject)
|
||||||
assert.Equal(t, "test", b2.ID)
|
assert.Equal(t, "test2", b2.ID)
|
||||||
assert.Equal(t, uint64(1), b2.Claims.UserId)
|
|
||||||
assert.True(t, b2.Claims.Perms.Has("mjwt:test"))
|
|
||||||
assert.True(t, b2.Claims.Perms.Has("mjwt:test2"))
|
|
||||||
assert.False(t, b2.Claims.Perms.Has("mjwt:test3"))
|
|
||||||
}
|
}
|
||||||
|
@ -2,28 +2,25 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/MrMelon54/mjwt"
|
"github.com/MrMelon54/mjwt"
|
||||||
"github.com/MrMelon54/mjwt/claims"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RefreshTokenClaims contains the JWT claims for a refresh token
|
// RefreshTokenClaims contains the JWT claims for a refresh token
|
||||||
|
// AccessTokenId (ati) must match the similar JWT ID (jti) claim
|
||||||
type RefreshTokenClaims struct {
|
type RefreshTokenClaims struct {
|
||||||
UserId uint64 `json:"uid"`
|
AccessTokenId string `json:"ati"`
|
||||||
Perms *claims.PermStorage `json:"per"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r RefreshTokenClaims) Valid() error { return nil }
|
func (r RefreshTokenClaims) Valid() error { return nil }
|
||||||
func (r RefreshTokenClaims) Type() string { return "refresh-token" }
|
func (r RefreshTokenClaims) Type() string { return "refresh-token" }
|
||||||
|
|
||||||
// CreateRefreshToken creates a refresh token with the default 7 day duration
|
// CreateRefreshToken creates a refresh token with the default 7 day duration
|
||||||
func CreateRefreshToken(p mjwt.Signer, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) {
|
func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||||
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, userId, perms)
|
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
|
// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
|
||||||
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, userId uint64, perms *claims.PermStorage) (string, error) {
|
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||||
return p.GenerateJwt(sub, id, dur, RefreshTokenClaims{
|
return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati})
|
||||||
UserId: userId,
|
|
||||||
Perms: perms,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"github.com/MrMelon54/mjwt"
|
"github.com/MrMelon54/mjwt"
|
||||||
"github.com/MrMelon54/mjwt/claims"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -13,21 +12,14 @@ func TestCreateRefreshToken(t *testing.T) {
|
|||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
ps := claims.NewPermStorage()
|
|
||||||
ps.Set("mjwt:test")
|
|
||||||
ps.Set("mjwt:test2")
|
|
||||||
|
|
||||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||||
|
|
||||||
refreshToken, err := CreateRefreshToken(s, "1", "test", 1, ps)
|
refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "1", b.Subject)
|
assert.Equal(t, "1", b.Subject)
|
||||||
assert.Equal(t, "test", b.ID)
|
assert.Equal(t, "test", b.ID)
|
||||||
assert.Equal(t, uint64(1), b.Claims.UserId)
|
assert.Equal(t, "test2", b.Claims.AccessTokenId)
|
||||||
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
|
||||||
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
|
||||||
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
|
||||||
}
|
}
|
||||||
|
89
cmd/mjwt/access.go
Normal file
89
cmd/mjwt/access.go
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/mjwt/auth"
|
||||||
|
"github.com/MrMelon54/mjwt/claims"
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/google/subcommands"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type accessCmd struct {
|
||||||
|
issuer, subject, id, audience, duration string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accessCmd) Name() string { return "access" }
|
||||||
|
func (s *accessCmd) Synopsis() string {
|
||||||
|
return "Generates an access token with permissions using the private key"
|
||||||
|
}
|
||||||
|
func (s *accessCmd) Usage() string {
|
||||||
|
return `sign [-iss <issuer>] [-sub <subject>] [-id <id>] [-aud <audience>] [-dur <duration>] <private key path> <space separated permissions>
|
||||||
|
Output a signed MJWT token with the specified permissions.
|
||||||
|
`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accessCmd) SetFlags(f *flag.FlagSet) {
|
||||||
|
f.StringVar(&s.issuer, "iss", "MJWT Utility", "The name of the MJWT issuer (default: MJWT Utility)")
|
||||||
|
f.StringVar(&s.subject, "sub", "", "MJWT Subject")
|
||||||
|
f.StringVar(&s.id, "id", "", "MJWT ID")
|
||||||
|
f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT")
|
||||||
|
f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
|
||||||
|
if f.NArg() < 1 {
|
||||||
|
_, _ = fmt.Fprintln(os.Stderr, "Error: Missing private key path argument")
|
||||||
|
return subcommands.ExitFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
args := f.Args()
|
||||||
|
key, err := s.parseKey(args[0])
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to parse private key: ", err)
|
||||||
|
return subcommands.ExitFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
ps := claims.NewPermStorage()
|
||||||
|
for i := 1; i < len(args); i++ {
|
||||||
|
ps.Set(args[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
var aud jwt.ClaimStrings
|
||||||
|
if s.audience != "" {
|
||||||
|
aud = strings.Split(s.audience, ",")
|
||||||
|
}
|
||||||
|
dur, err := time.ParseDuration(s.duration)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to parse duration: ", err)
|
||||||
|
return subcommands.ExitFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
signer := mjwt.NewMJwtSigner(s.issuer, key)
|
||||||
|
token, err := signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err)
|
||||||
|
return subcommands.ExitFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(token)
|
||||||
|
return subcommands.ExitSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accessCmd) parseKey(privKeyFile string) (*rsa.PrivateKey, error) {
|
||||||
|
b, err := os.ReadFile(privKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
p, _ := pem.Decode(b)
|
||||||
|
return x509.ParsePKCS1PrivateKey(p.Bytes)
|
||||||
|
}
|
77
cmd/mjwt/gen.go
Normal file
77
cmd/mjwt/gen.go
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"github.com/google/subcommands"
|
||||||
|
"math/rand"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type genCmd struct {
|
||||||
|
bits int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *genCmd) Name() string { return "gen" }
|
||||||
|
func (g *genCmd) Synopsis() string { return "Generate RSA private key" }
|
||||||
|
func (g *genCmd) Usage() string {
|
||||||
|
return `gen <private key path> <public key path>
|
||||||
|
Output RSA private key to the provided file.
|
||||||
|
`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *genCmd) SetFlags(f *flag.FlagSet) {
|
||||||
|
f.IntVar(&g.bits, "bits", 4096, "Number of bits to generate (default: 4096)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *genCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
|
||||||
|
if f.NArg() != 2 {
|
||||||
|
_, _ = fmt.Fprintln(os.Stderr, "Error: Missing private and public key file")
|
||||||
|
return subcommands.ExitFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// arguments
|
||||||
|
privPath := f.Arg(0)
|
||||||
|
pubPath := f.Arg(1)
|
||||||
|
|
||||||
|
if err := g.gen(privPath, pubPath); err != nil {
|
||||||
|
_, _ = fmt.Fprintln(os.Stderr, "An error occured generating the private and public keys: ", err)
|
||||||
|
return subcommands.ExitFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Success generating RSA private key")
|
||||||
|
return subcommands.ExitSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *genCmd) gen(privPath, pubPath string) error {
|
||||||
|
createPriv, err := os.OpenFile(privPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer createPriv.Close()
|
||||||
|
|
||||||
|
createPub, err := os.OpenFile(pubPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer createPub.Close()
|
||||||
|
|
||||||
|
key, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), g.bits)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
keyBytes := x509.MarshalPKCS1PrivateKey(key)
|
||||||
|
pubBytes := x509.MarshalPKCS1PublicKey(&key.PublicKey)
|
||||||
|
err = pem.Encode(createPriv, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = pem.Encode(createPub, &pem.Block{Type: "RSA PUBLIC KEY", Bytes: pubBytes})
|
||||||
|
return err
|
||||||
|
}
|
20
cmd/mjwt/main.go
Normal file
20
cmd/mjwt/main.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"github.com/google/subcommands"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
subcommands.Register(subcommands.HelpCommand(), "")
|
||||||
|
subcommands.Register(subcommands.FlagsCommand(), "")
|
||||||
|
subcommands.Register(subcommands.CommandsCommand(), "")
|
||||||
|
subcommands.Register(&genCmd{}, "")
|
||||||
|
subcommands.Register(&accessCmd{}, "")
|
||||||
|
|
||||||
|
flag.Parse()
|
||||||
|
ctx := context.Background()
|
||||||
|
os.Exit(int(subcommands.Execute(ctx)))
|
||||||
|
}
|
1
go.mod
1
go.mod
@ -4,6 +4,7 @@ go 1.19
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/golang-jwt/jwt/v4 v4.4.3
|
github.com/golang-jwt/jwt/v4 v4.4.3
|
||||||
|
github.com/google/subcommands v1.2.0
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/stretchr/testify v1.8.4
|
github.com/stretchr/testify v1.8.4
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
|
2
go.sum
2
go.sum
@ -2,6 +2,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU=
|
github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU=
|
||||||
github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||||
|
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||||
|
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
// Signer can also be used as a Verifier.
|
// Signer can also be used as a Verifier.
|
||||||
type Signer interface {
|
type Signer interface {
|
||||||
Verifier
|
Verifier
|
||||||
GenerateJwt(sub, id string, dur time.Duration, claims Claims) (string, error)
|
GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error)
|
||||||
SignJwt(claims jwt.Claims) (string, error)
|
SignJwt(claims jwt.Claims) (string, error)
|
||||||
Issuer() string
|
Issuer() string
|
||||||
}
|
}
|
||||||
|
3
mjwt.go
3
mjwt.go
@ -10,12 +10,13 @@ import (
|
|||||||
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
|
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
|
||||||
|
|
||||||
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
|
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
|
||||||
func wrapClaims[T Claims](p Signer, sub, id string, dur time.Duration, claims T) *BaseTypeClaims[T] {
|
func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return (&BaseTypeClaims[T]{
|
return (&BaseTypeClaims[T]{
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
Issuer: p.Issuer(),
|
Issuer: p.Issuer(),
|
||||||
Subject: sub,
|
Subject: sub,
|
||||||
|
Audience: aud,
|
||||||
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
|
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
IssuedAt: jwt.NewNumericDate(now),
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
@ -36,7 +36,7 @@ func TestExtractClaims(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
s := NewMJwtSigner("mjwt.test", key)
|
s := NewMJwtSigner("mjwt.test", key)
|
||||||
token, err := s.GenerateJwt("1", "test", 10*time.Minute, testClaims{TestValue: "hello"})
|
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
m := NewMJwtVerifier(&key.PublicKey)
|
m := NewMJwtVerifier(&key.PublicKey)
|
||||||
@ -49,7 +49,7 @@ func TestExtractClaimsFail(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
s := NewMJwtSigner("mjwt.test", key)
|
s := NewMJwtSigner("mjwt.test", key)
|
||||||
token, err := s.GenerateJwt("1", "test", 10*time.Minute, testClaims{TestValue: "test"})
|
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
m := NewMJwtVerifier(&key.PublicKey)
|
m := NewMJwtVerifier(&key.PublicKey)
|
||||||
|
@ -29,8 +29,8 @@ func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
|
|||||||
func (d *defaultMJwtSigner) Issuer() string { return d.issuer }
|
func (d *defaultMJwtSigner) Issuer() string { return d.issuer }
|
||||||
|
|
||||||
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims
|
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims
|
||||||
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, dur time.Duration, claims Claims) (string, error) {
|
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
|
||||||
return d.SignJwt(wrapClaims[Claims](d, sub, id, dur, claims))
|
return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignJwt signs a jwt.Claims compatible struct, this is used internally by
|
// SignJwt signs a jwt.Claims compatible struct, this is used internally by
|
||||||
|
@ -16,7 +16,7 @@ func TestNewMJwtVerifierFromFile(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
s := NewMJwtSigner("mjwt.test", key)
|
s := NewMJwtSigner("mjwt.test", key)
|
||||||
token, err := s.GenerateJwt("1", "test", 10*time.Minute, testClaims{TestValue: "world"})
|
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&key.PublicKey)})
|
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&key.PublicKey)})
|
||||||
|
Loading…
Reference in New Issue
Block a user