diff --git a/server/id_token.go b/server/id_token.go new file mode 100644 index 0000000..e7a62b0 --- /dev/null +++ b/server/id_token.go @@ -0,0 +1,55 @@ +package server + +import ( + "github.com/1f349/lavender/database" + "github.com/1f349/mjwt" + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/server" + "github.com/golang-jwt/jwt/v4" + "strings" +) + +func addIdTokenSupport(srv *server.Server, db *database.DB, key mjwt.Signer) { + srv.SetExtensionFieldsHandler(func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) { + scope := ti.GetScope() + if containsScope(scope, "openid") { + idToken, err := generateIDToken(ti, db, key) + if err != nil { + return + } + fieldsValue = map[string]interface{}{} + fieldsValue["id_token"] = idToken + } + return + }) +} + +// IdTokenClaims contains the JWT claims for an access token +type IdTokenClaims struct{} + +func (a IdTokenClaims) Valid() error { return nil } +func (a IdTokenClaims) Type() string { return "access-token" } + +func generateIDToken(ti oauth2.TokenInfo, us *database.DB, key mjwt.Signer) (token string, err error) { + tx, err := us.Begin() + if err != nil { + return "", err + } + user, err := tx.GetUser(ti.GetUserID()) + if err != nil { + return "", err + } + + token, err = key.GenerateJwt(user.Sub, "", jwt.ClaimStrings{ti.GetClientID()}, ti.GetAccessExpiresIn(), IdTokenClaims{}) + return +} + +func containsScope(scopes, s string) bool { + _scopes := strings.Split(scopes, " ") + for _, _s := range _scopes { + if _s == s { + return true + } + } + return false +} diff --git a/server/server.go b/server/server.go index c8f6c1c..7142c88 100644 --- a/server/server.go +++ b/server/server.go @@ -113,6 +113,7 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser } return a, nil }) + addIdTokenSupport(oauthSrv, db, signingKey) r.GET("/.well-known/openid-configuration", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { rw.WriteHeader(http.StatusOK) @@ -166,6 +167,8 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser r.GET("/authorize", hs.RequireAuthentication(hs.authorizeEndpoint)) r.POST("/authorize", hs.RequireAuthentication(hs.authorizeEndpoint)) r.POST("/token", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + // TODO: id_token support + // https://code.mrmelon54.com/melon/summer/src/commit/7b8afa8b91c39eba749f60a45965fd8f75c87147/pkg/oauth-server/server.go#L216 if err := oauthSrv.HandleTokenRequest(rw, req); err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) }