mirror of
https://github.com/1f349/lavender.git
synced 2025-04-15 15:27:55 +01:00
Rewrite OAuth manager, OAuth source and URL wrapper
This commit is contained in:
parent
d9b0074133
commit
50df217b66
@ -5,10 +5,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/1f349/lavender/auth"
|
"github.com/1f349/lavender/auth"
|
||||||
"github.com/1f349/lavender/auth/authContext"
|
"github.com/1f349/lavender/auth/authContext"
|
||||||
process "github.com/1f349/lavender/auth/process"
|
"github.com/1f349/lavender/auth/process"
|
||||||
"github.com/1f349/lavender/database"
|
"github.com/1f349/lavender/database"
|
||||||
"github.com/1f349/lavender/issuer"
|
"github.com/1f349/lavender/issuer"
|
||||||
"github.com/1f349/lavender/logger"
|
"github.com/1f349/lavender/logger"
|
||||||
|
"github.com/1f349/lavender/utils"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -25,6 +26,7 @@ type InitialLogin struct {
|
|||||||
DB *database.Queries
|
DB *database.Queries
|
||||||
MyNamespace string
|
MyNamespace string
|
||||||
Manager *issuer.Manager
|
Manager *issuer.Manager
|
||||||
|
OAuth *OAuthLogin
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *InitialLogin) AccessState() process.State { return process.StateUnauthorized }
|
func (m *InitialLogin) AccessState() process.State { return process.StateUnauthorized }
|
||||||
@ -101,36 +103,39 @@ func (m *InitialLogin) AttemptLogin(ctx authContext.FormContext) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// append local namespace if @ is missing
|
// append local namespace if @ is missing
|
||||||
n := strings.IndexByte(loginName, '@')
|
if !strings.ContainsRune(loginName, '@') {
|
||||||
if n < 0 {
|
|
||||||
// correct the @ index
|
|
||||||
n = len(loginName)
|
|
||||||
loginName += "@" + m.MyNamespace
|
loginName += "@" + m.MyNamespace
|
||||||
}
|
}
|
||||||
|
|
||||||
login := m.Manager.FindServiceFromLogin(loginName)
|
user, namespace, err := utils.ParseLoginName(loginName)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, "Invalid login name", http.StatusBadRequest)
|
||||||
|
return fmt.Errorf("invalid login name %s", loginName)
|
||||||
|
}
|
||||||
|
|
||||||
|
login := m.Manager.GetService(namespace)
|
||||||
if login == nil {
|
if login == nil {
|
||||||
http.Error(rw, "No login service defined for this username", http.StatusBadRequest)
|
http.Error(rw, "No login service defined for this username", http.StatusBadRequest)
|
||||||
return errors.New("no login service defined for this username")
|
return errors.New("no login service defined for this username")
|
||||||
}
|
}
|
||||||
|
|
||||||
// the @ must exist if the service is defined
|
|
||||||
loginUn := loginName[:n]
|
|
||||||
|
|
||||||
// TODO: finish migrating this shit
|
|
||||||
|
|
||||||
// the login is not for this namespace
|
// the login is not for this namespace
|
||||||
if login != issuer.MeWellKnown {
|
if login != issuer.MeWellKnown {
|
||||||
|
// TODO: this is oauth request code start oauth request
|
||||||
|
// TODO: this could all be 1 function call into oauth
|
||||||
|
|
||||||
// save state for use later
|
// save state for use later
|
||||||
state := login.Config.Namespace + ":" + uuid.NewString()
|
state := login.Config.Namespace + ":" + uuid.NewString()
|
||||||
h.flowState.Set(state, flowStateData{loginName, login, req.PostFormValue("redirect")}, time.Now().Add(15*time.Minute))
|
m.OAuth.flow.Set(state, flowStateData{loginName, login, req.PostFormValue("redirect")}, time.Now().Add(15*time.Minute))
|
||||||
|
|
||||||
// generate oauth2 config and redirect to authorize URL
|
// generate oauth2 config and redirect to authorize URL
|
||||||
oa2conf := login.OAuth2Config
|
oa2conf := login.OAuth2Config
|
||||||
oa2conf.RedirectURL = h.conf.BaseUrl.JoinPath("callback").String()
|
oa2conf.RedirectURL = m.OAuth.BaseUrl.JoinPath("callback").String()
|
||||||
nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", loginUn))
|
nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", user))
|
||||||
http.Redirect(rw, req, nextUrl, http.StatusFound)
|
return auth.RedirectError{
|
||||||
return
|
Target: nextUrl,
|
||||||
|
Code: http.StatusFound,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.UpdateSession(process.LoginProcessData{
|
ctx.UpdateSession(process.LoginProcessData{
|
||||||
|
@ -13,17 +13,18 @@ import (
|
|||||||
"github.com/1f349/lavender/database"
|
"github.com/1f349/lavender/database"
|
||||||
"github.com/1f349/lavender/database/types"
|
"github.com/1f349/lavender/database/types"
|
||||||
"github.com/1f349/lavender/issuer"
|
"github.com/1f349/lavender/issuer"
|
||||||
"github.com/1f349/lavender/url"
|
"github.com/1f349/lavender/utils"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/mrmelon54/pronouns"
|
"github.com/mrmelon54/pronouns"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/text/language"
|
"golang.org/x/text/language"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OauthCallback interface {
|
type OauthCallback interface {
|
||||||
OAuthCallback(rw http.ResponseWriter, req *http.Request, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request))
|
OAuthCallback(rw http.ResponseWriter, req *http.Request, namespace string, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request))
|
||||||
}
|
}
|
||||||
|
|
||||||
type flowStateData struct {
|
type flowStateData struct {
|
||||||
@ -40,9 +41,9 @@ var (
|
|||||||
type OAuthLogin struct {
|
type OAuthLogin struct {
|
||||||
DB *database.Queries
|
DB *database.Queries
|
||||||
|
|
||||||
BaseUrl *url.URL
|
BaseUrl *utils.URL
|
||||||
|
flow *cache.Cache[string, flowStateData]
|
||||||
flow *cache.Cache[string, flowStateData]
|
Manager *issuer.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o OAuthLogin) Init() {
|
func (o OAuthLogin) Init() {
|
||||||
@ -50,7 +51,7 @@ func (o OAuthLogin) Init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o OAuthLogin) authUrlBase(ref string) *url.URL {
|
func (o OAuthLogin) authUrlBase(ref string) *url.URL {
|
||||||
return o.BaseUrl.Resolve("oauth", o.Name(), ref)
|
return o.BaseUrl.Resolve("oauth", o.Name(), ref).URL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o OAuthLogin) AccessState() process.State { return process.StateUnauthorized }
|
func (o OAuthLogin) AccessState() process.State { return process.StateUnauthorized }
|
||||||
@ -83,12 +84,16 @@ func (o OAuthLogin) AttemptLogin(ctx authContext.FormContext) error {
|
|||||||
return auth.RedirectError{Target: nextUrl, Code: http.StatusFound}
|
return auth.RedirectError{Target: nextUrl, Code: http.StatusFound}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request)) {
|
func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, namespace string, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request)) {
|
||||||
flowState, ok := o.flow.Get(req.FormValue("state"))
|
flowState, ok := o.flow.Get(req.FormValue("state"))
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(rw, "Invalid flow state", http.StatusBadRequest)
|
http.Error(rw, "Invalid flow state", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if flowState.sso.Namespace != namespace {
|
||||||
|
http.Error(rw, "OAuth source mismatch", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
token, err := flowState.sso.OAuth2Config.Exchange(context.Background(), req.FormValue("code"), oauth2.SetAuthURLParam("redirect_uri", o.authUrlBase("callback").String()))
|
token, err := flowState.sso.OAuth2Config.Exchange(context.Background(), req.FormValue("code"), oauth2.SetAuthURLParam("redirect_uri", o.authUrlBase("callback").String()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(rw, "Failed to exchange code for token", http.StatusInternalServerError)
|
http.Error(rw, "Failed to exchange code for token", http.StatusInternalServerError)
|
||||||
@ -112,17 +117,32 @@ func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, inf
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o OAuthLogin) RenderButtonTemplate(ctx authContext.TemplateContext) {
|
func (o OAuthLogin) RenderButtonTemplate(ctx authContext.TemplateContext) {
|
||||||
|
// TODO: idk what this is
|
||||||
// o.authUrlBase("button")
|
// o.authUrlBase("button")
|
||||||
// provide something non-nil
|
// provide something non-nil
|
||||||
ctx.Render(struct {
|
ctx.Render(struct {
|
||||||
Href string
|
Href string
|
||||||
ButtonName string
|
ButtonName string
|
||||||
}{
|
}{
|
||||||
Href: o.authUrlBase("button").String(),
|
Href: o.authUrlBase("start").String(),
|
||||||
ButtonName: "Login with Unknown OAuth Button", // TODO: actually get the service name
|
ButtonName: "Login with Unknown OAuth Button", // TODO: actually get the service name
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RedirectToAuthorize returns true when redirecting to the authorize endpoint
|
||||||
|
// for the requested namespace.
|
||||||
|
func (o OAuthLogin) RedirectToAuthorize(rw http.ResponseWriter, req *http.Request, namespace string) bool {
|
||||||
|
oidc := o.Manager.GetService(namespace)
|
||||||
|
if oidc == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: come back to fix this
|
||||||
|
oidc.OAuth2Config.AuthCodeURL("state")
|
||||||
|
// http.Redirect(rw, req, oidc.AuthorizationEndpoint, http.StatusOK)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
type oauthServiceLogin int
|
type oauthServiceLogin int
|
||||||
|
|
||||||
func WithWellKnown(ctx context.Context, login *issuer.WellKnownOIDC) context.Context {
|
func WithWellKnown(ctx context.Context, login *issuer.WellKnownOIDC) context.Context {
|
||||||
@ -256,7 +276,7 @@ func (o OAuthLogin) updateOAuth2UserProfile(ctx context.Context, tx *database.Qu
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o OAuthLogin) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
|
func (o OAuthLogin) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
|
||||||
res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint)
|
res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint.String())
|
||||||
if err != nil || res.StatusCode != http.StatusOK {
|
if err != nil || res.StatusCode != http.StatusOK {
|
||||||
return auth.UserAuth{}, fmt.Errorf("request failed")
|
return auth.UserAuth{}, fmt.Errorf("request failed")
|
||||||
}
|
}
|
||||||
|
20
conf/conf.go
20
conf/conf.go
@ -1,19 +1,17 @@
|
|||||||
package conf
|
package conf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/1f349/lavender/issuer"
|
"github.com/1f349/lavender/utils"
|
||||||
"github.com/1f349/lavender/url"
|
|
||||||
"github.com/1f349/simplemail"
|
"github.com/1f349/simplemail"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conf struct {
|
type Conf struct {
|
||||||
Listen string `yaml:"listen"`
|
Listen string `yaml:"listen"`
|
||||||
BaseUrl url.URL `yaml:"baseUrl"`
|
BaseUrl *utils.URL `yaml:"baseUrl"`
|
||||||
ServiceName string `yaml:"serviceName"`
|
ServiceName string `yaml:"serviceName"`
|
||||||
Issuer string `yaml:"issuer"`
|
Issuer string `yaml:"issuer"`
|
||||||
Kid string `yaml:"kid"`
|
Kid string `yaml:"kid"`
|
||||||
Namespace string `yaml:"namespace"`
|
Namespace string `yaml:"namespace"`
|
||||||
OtpIssuer string `yaml:"otpIssuer"`
|
OtpIssuer string `yaml:"otpIssuer"`
|
||||||
Mail simplemail.Mail `yaml:"mail"`
|
Mail simplemail.Mail `yaml:"mail"`
|
||||||
SsoServices []issuer.SsoConfig `yaml:"ssoServices"`
|
|
||||||
}
|
}
|
||||||
|
45
database/manage-oauth-source.sql.go
Normal file
45
database/manage-oauth-source.sql.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.28.0
|
||||||
|
// source: manage-oauth-source.sql
|
||||||
|
|
||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
const getOAuthSources = `-- name: GetOAuthSources :many
|
||||||
|
SELECT namespace, address, registration, button, client_id, client_secret, client_scopes FROM oauth_sources
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOAuthSources(ctx context.Context) ([]OauthSource, error) {
|
||||||
|
rows, err := q.db.QueryContext(ctx, getOAuthSources)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []OauthSource
|
||||||
|
for rows.Next() {
|
||||||
|
var i OauthSource
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.Namespace,
|
||||||
|
&i.Address,
|
||||||
|
&i.Registration,
|
||||||
|
&i.Button,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ClientSecret,
|
||||||
|
&i.ClientScopes,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
10
database/migrations/20250303230832_oauth_source.up.sql
Normal file
10
database/migrations/20250303230832_oauth_source.up.sql
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
CREATE TABLE oauth_sources
|
||||||
|
(
|
||||||
|
namespace TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||||
|
address TEXT NOT NULL,
|
||||||
|
registration BOOLEAN NOT NULL,
|
||||||
|
button BOOLEAN NOT NULL,
|
||||||
|
client_id TEXT NOT NULL,
|
||||||
|
client_secret TEXT NOT NULL,
|
||||||
|
client_scopes TEXT NOT NULL
|
||||||
|
);
|
@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/1f349/lavender/database/types"
|
"github.com/1f349/lavender/database/types"
|
||||||
"github.com/1f349/lavender/password"
|
"github.com/1f349/lavender/password"
|
||||||
|
"github.com/1f349/lavender/utils"
|
||||||
"github.com/hardfinhq/go-date"
|
"github.com/hardfinhq/go-date"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -25,6 +26,16 @@ type ClientStore struct {
|
|||||||
Active bool `json:"active"`
|
Active bool `json:"active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OauthSource struct {
|
||||||
|
Namespace string `json:"namespace"`
|
||||||
|
Address utils.URL `json:"address"`
|
||||||
|
Registration bool `json:"registration"`
|
||||||
|
Button bool `json:"button"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
ClientSecret string `json:"client_secret"`
|
||||||
|
ClientScopes string `json:"client_scopes"`
|
||||||
|
}
|
||||||
|
|
||||||
type Role struct {
|
type Role struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
|
2
database/queries/manage-oauth-source.sql
Normal file
2
database/queries/manage-oauth-source.sql
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
-- name: GetOAuthSources :many
|
||||||
|
SELECT * FROM oauth_sources;
|
1
go.mod
1
go.mod
@ -19,6 +19,7 @@ require (
|
|||||||
github.com/julienschmidt/httprouter v1.3.0
|
github.com/julienschmidt/httprouter v1.3.0
|
||||||
github.com/mattn/go-sqlite3 v1.14.24
|
github.com/mattn/go-sqlite3 v1.14.24
|
||||||
github.com/mrmelon54/pronouns v1.0.3
|
github.com/mrmelon54/pronouns v1.0.3
|
||||||
|
github.com/robfig/cron v1.2.0
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
github.com/spf13/afero v1.12.0
|
github.com/spf13/afero v1.12.0
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
|
2
go.sum
2
go.sum
@ -138,6 +138,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
|||||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||||
|
github.com/robfig/cron v1.2.0 h1:ZjScXvvxeQ63Dbyxy76Fj3AT3Ut0aKsyd2/tl3DTMuQ=
|
||||||
|
github.com/robfig/cron v1.2.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k=
|
||||||
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||||
github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=
|
github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=
|
||||||
|
@ -1,57 +1,226 @@
|
|||||||
package issuer
|
package issuer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/1f349/lavender/database"
|
||||||
|
"github.com/1f349/lavender/logger"
|
||||||
|
"github.com/robfig/cron"
|
||||||
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
fetchPoolSize = 4
|
||||||
|
fetchTimeout = 2 * time.Minute
|
||||||
|
fetchRetryDelay = 15 * time.Second
|
||||||
|
fetchRetryAfterTimeout = 15 * time.Minute
|
||||||
|
fetchRetryCount = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
var isValidNamespace = regexp.MustCompile("^[0-9a-z.]+$")
|
var isValidNamespace = regexp.MustCompile("^[0-9a-z.]+$")
|
||||||
|
|
||||||
var MeWellKnown = &WellKnownOIDC{}
|
var MeWellKnown = &WellKnownOIDC{}
|
||||||
|
|
||||||
type Manager struct {
|
type managerReloadDB interface {
|
||||||
m map[string]*WellKnownOIDC
|
GetOAuthSources(ctx context.Context) ([]database.OauthSource, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(myNamespace string, services []SsoConfig) (*Manager, error) {
|
type Manager struct {
|
||||||
l := &Manager{m: make(map[string]*WellKnownOIDC)}
|
db managerReloadDB
|
||||||
l.m[myNamespace] = MeWellKnown
|
|
||||||
for _, ssoService := range services {
|
|
||||||
if !isValidNamespace.MatchString(ssoService.Namespace) {
|
|
||||||
return nil, fmt.Errorf("invalid namespace: %s", ssoService.Namespace)
|
|
||||||
}
|
|
||||||
|
|
||||||
conf, err := ssoService.FetchConfig()
|
client *http.Client
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// save by namespace
|
sourceMu *sync.RWMutex
|
||||||
l.m[ssoService.Namespace] = conf
|
sourceMap map[string]*WellKnownOIDC
|
||||||
|
|
||||||
|
errMu *sync.RWMutex
|
||||||
|
errMap map[string]FetchErrorDetails
|
||||||
|
|
||||||
|
shutdownCall context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
type FetchErrorDetails struct {
|
||||||
|
Retry time.Time
|
||||||
|
Errors []error
|
||||||
|
}
|
||||||
|
|
||||||
|
type managerRoundTripper struct {
|
||||||
|
Base http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *managerRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
userAgent := strings.Join(request.Header.Values("User-Agent"), " ")
|
||||||
|
if strings.TrimSpace(userAgent) == "" {
|
||||||
|
request.Header.Set("User-Agent", "Lavender/1.0 (OAuth Manager)")
|
||||||
}
|
}
|
||||||
|
return a.Base.RoundTrip(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(db managerReloadDB, myNamespace string) (*Manager, error) {
|
||||||
|
return NewManagerWithClient(db, myNamespace, &http.Client{
|
||||||
|
Transport: &managerRoundTripper{
|
||||||
|
Base: http.DefaultTransport,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManagerWithClient(db managerReloadDB, myNamespace string, client *http.Client) (*Manager, error) {
|
||||||
|
l := &Manager{
|
||||||
|
db: db,
|
||||||
|
client: client,
|
||||||
|
sourceMu: new(sync.RWMutex),
|
||||||
|
sourceMap: make(map[string]*WellKnownOIDC),
|
||||||
|
errMap: make(map[string]FetchErrorDetails),
|
||||||
|
}
|
||||||
|
l.sourceMap[myNamespace] = MeWellKnown
|
||||||
|
|
||||||
|
// reload should run blocking on start up
|
||||||
|
l.internalReload(context.Background())
|
||||||
|
|
||||||
|
reloadDaily := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
tab := cron.New()
|
||||||
|
err := tab.AddFunc("0 2 * * *", func() {
|
||||||
|
reloadDaily <- struct{}{}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go l.reloadWorker(ctx, reloadDaily)
|
||||||
|
|
||||||
|
l.shutdownCall = cancel
|
||||||
|
|
||||||
return l, nil
|
return l, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Shutdown() { m.shutdownCall() }
|
||||||
|
|
||||||
|
func (m *Manager) reloadWorker(ctx context.Context, reload <-chan struct{}) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.Logger.Info("Shutting down oauth manager reload worker")
|
||||||
|
return
|
||||||
|
case <-reload:
|
||||||
|
m.internalReload(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) internalReload(ctx context.Context) {
|
||||||
|
logger.Logger.Info("Reloading oauth sources")
|
||||||
|
dbCtx, cancel := context.WithTimeout(ctx, fetchTimeout)
|
||||||
|
sources, err := m.db.GetOAuthSources(dbCtx)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
// TODO: send email to admin
|
||||||
|
logger.Logger.Warn("Failed to load OAuth sources from the database", "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// construct map of db oauth sources
|
||||||
|
toAdd := make(map[string]database.OauthSource)
|
||||||
|
for _, source := range sources {
|
||||||
|
toAdd[source.Namespace] = source
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove sources from the active map
|
||||||
|
m.sourceMu.Lock()
|
||||||
|
for k, v := range m.sourceMap {
|
||||||
|
// ignore my own well-known
|
||||||
|
if v == MeWellKnown {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, shouldAdd := toAdd[k]; !shouldAdd {
|
||||||
|
delete(m.sourceMap, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.sourceMu.Unlock()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(fetchPoolSize)
|
||||||
|
|
||||||
|
fetchChan := make(chan database.OauthSource, fetchPoolSize)
|
||||||
|
|
||||||
|
for range fetchPoolSize {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
for source := range fetchChan {
|
||||||
|
itemCtx, cancel := context.WithTimeout(ctx, fetchTimeout)
|
||||||
|
errs := m.reloadSource(itemCtx, source)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// this also resets the error if it was set previously
|
||||||
|
m.sourceMu.Lock()
|
||||||
|
if len(errs) > 0 {
|
||||||
|
m.errMap[source.Namespace] = FetchErrorDetails{
|
||||||
|
Retry: time.Now().Add(fetchRetryAfterTimeout),
|
||||||
|
Errors: errs,
|
||||||
|
}
|
||||||
|
// TODO: send email to admin
|
||||||
|
} else {
|
||||||
|
delete(m.errMap, source.Namespace)
|
||||||
|
// TODO: send resolved email to admin
|
||||||
|
}
|
||||||
|
m.sourceMu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// reload sources
|
||||||
|
for _, source := range toAdd {
|
||||||
|
fetchChan <- source
|
||||||
|
}
|
||||||
|
|
||||||
|
close(fetchChan)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) reloadSource(ctx context.Context, source database.OauthSource) []error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
for range fetchRetryCount {
|
||||||
|
oidc, err := fetchConfig(ctx, m.client, source)
|
||||||
|
fmt.Println(oidc, err)
|
||||||
|
if err == nil {
|
||||||
|
// fetch was successful
|
||||||
|
m.sourceMu.Lock()
|
||||||
|
m.sourceMap[source.Namespace] = oidc
|
||||||
|
m.sourceMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
errs = append(errs, fmt.Errorf("failed to fetch OIDC well-known configuration: %w", err))
|
||||||
|
|
||||||
|
// wait before retrying
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return errs
|
||||||
|
case <-time.After(fetchRetryDelay):
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) CheckNamespace(namespace string) bool {
|
func (m *Manager) CheckNamespace(namespace string) bool {
|
||||||
_, ok := m.m[namespace]
|
m.sourceMu.RLock()
|
||||||
|
defer m.sourceMu.RUnlock()
|
||||||
|
_, ok := m.sourceMap[namespace]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) GetService(namespace string) *WellKnownOIDC {
|
func (m *Manager) GetService(namespace string) *WellKnownOIDC {
|
||||||
return m.m[namespace]
|
m.sourceMu.RLock()
|
||||||
}
|
defer m.sourceMu.RUnlock()
|
||||||
|
return m.sourceMap[namespace]
|
||||||
func (m *Manager) FindServiceFromLogin(login string) *WellKnownOIDC {
|
|
||||||
// @ should have at least one byte before it
|
|
||||||
n := strings.IndexByte(login, '@')
|
|
||||||
if n < 1 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// there should not be a second @
|
|
||||||
n2 := strings.IndexByte(login[n+1:], '@')
|
|
||||||
if n2 != -1 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m.GetService(login[n+1:])
|
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
package issuer
|
package issuer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"github.com/1f349/lavender/database"
|
||||||
"github.com/1f349/lavender/utils"
|
"github.com/1f349/lavender/utils"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"io"
|
"io"
|
||||||
@ -8,29 +11,54 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testAddrUrl = func() utils.JsonUrl {
|
var testAddrUrl = func() utils.URL {
|
||||||
a, err := url.Parse("https://example.com")
|
a, err := url.Parse("https://example.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return utils.JsonUrl{URL: a}
|
return utils.URL{URL: a}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
func testBody() io.ReadCloser {
|
func testBody() io.ReadCloser {
|
||||||
return io.NopCloser(strings.NewReader("{}"))
|
return io.NopCloser(strings.NewReader(`{
|
||||||
|
"issuer": "Example.com Issuer",
|
||||||
|
"authorization_endpoint": "https://example.com/oauth/authorize",
|
||||||
|
"token_endpoint": "https://example.com/oauth/token",
|
||||||
|
"userinfo_endpoint": "https://example.com/oauth/userinfo",
|
||||||
|
"revocation_endpoint": "https://example.com/oauth/revoke",
|
||||||
|
"response_types_supported": [
|
||||||
|
"code"
|
||||||
|
],
|
||||||
|
"scopes_supported": [
|
||||||
|
"openid"
|
||||||
|
],
|
||||||
|
"claims_supported": [
|
||||||
|
"sub"
|
||||||
|
],
|
||||||
|
"grant_types_supported": [
|
||||||
|
"authorization_code",
|
||||||
|
"refresh_token"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
`))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_CheckNamespace(t *testing.T) {
|
func TestManager_CheckNamespace(t *testing.T) {
|
||||||
httpGet = func(url string) (resp *http.Response, err error) {
|
manager, err := NewManagerWithClient(&testDB{
|
||||||
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
|
Rows: []database.OauthSource{
|
||||||
}
|
{Address: testAddrUrl, Namespace: "example.com"},
|
||||||
manager, err := NewManager("example.org", []SsoConfig{
|
|
||||||
{
|
|
||||||
Addr: testAddrUrl,
|
|
||||||
Namespace: "example.com",
|
|
||||||
},
|
},
|
||||||
|
}, "example.org", &http.Client{
|
||||||
|
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
fmt.Println("Request URL:", req.URL)
|
||||||
|
if req.URL.String() == "https://example.com/.well-known/openid-configuration" {
|
||||||
|
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("request failed")
|
||||||
|
}),
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, manager.CheckNamespace("example.org"))
|
assert.True(t, manager.CheckNamespace("example.org"))
|
||||||
@ -39,17 +67,36 @@ func TestManager_CheckNamespace(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_FindServiceFromLogin(t *testing.T) {
|
func TestManager_FindServiceFromLogin(t *testing.T) {
|
||||||
httpGet = func(url string) (resp *http.Response, err error) {
|
manager, err := NewManagerWithClient(&testDB{
|
||||||
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
|
Rows: []database.OauthSource{
|
||||||
}
|
{Address: testAddrUrl, Namespace: "example.com"},
|
||||||
manager, err := NewManager("example.org", []SsoConfig{
|
|
||||||
{
|
|
||||||
Addr: testAddrUrl,
|
|
||||||
Namespace: "example.com",
|
|
||||||
},
|
},
|
||||||
|
}, "example.org", &http.Client{
|
||||||
|
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
fmt.Println("Request URL:", req.URL)
|
||||||
|
if req.URL.String() != "https://example.com/.well-known/openid-configuration" {
|
||||||
|
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("request failed")
|
||||||
|
}),
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, manager.FindServiceFromLogin("jane@example.org"), MeWellKnown)
|
assert.Equal(t, manager.GetService("example.org"), MeWellKnown)
|
||||||
assert.Equal(t, manager.FindServiceFromLogin("jane@example.com"), manager.m["example.com"])
|
assert.Equal(t, manager.GetService("example.com"), manager.sourceMap["example.com"])
|
||||||
assert.Nil(t, manager.FindServiceFromLogin("jane@missing.example.com"))
|
assert.Nil(t, manager.GetService("missing.example.com"))
|
||||||
|
}
|
||||||
|
|
||||||
|
type testDB struct {
|
||||||
|
Rows []database.OauthSource
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testDB) GetOAuthSources(ctx context.Context) ([]database.OauthSource, error) {
|
||||||
|
return t.Rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (r roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
time.Sleep(time.Minute)
|
||||||
|
return r(req)
|
||||||
}
|
}
|
||||||
|
@ -1,94 +1,82 @@
|
|||||||
package issuer
|
package issuer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/1f349/lavender/database"
|
||||||
"github.com/1f349/lavender/utils"
|
"github.com/1f349/lavender/utils"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var httpGet = http.Get
|
type WellKnownOIDC struct {
|
||||||
|
Namespace string `json:"-"`
|
||||||
// SsoConfig is the base URL for an OAUTH/OPENID/SSO login service
|
Config database.OauthSource `json:"-"`
|
||||||
// The path `/.well-known/openid-configuration` should be available
|
Available atomic.Bool `json:"-"`
|
||||||
type SsoConfig struct {
|
Issuer string `json:"issuer"`
|
||||||
Addr utils.JsonUrl `json:"addr" yaml:"addr"` // https://login.example.com
|
AuthorizationEndpoint *utils.URL `json:"authorization_endpoint"`
|
||||||
Namespace string `json:"namespace" yaml:"namespace"` // example.com
|
TokenEndpoint *utils.URL `json:"token_endpoint"`
|
||||||
Registration bool `json:"registration" yaml:"registration"`
|
UserInfoEndpoint *utils.URL `json:"userinfo_endpoint"`
|
||||||
LoginWithButton bool `json:"login_with_button" yaml:"loginWithButton"`
|
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||||
Client SsoConfigClient `json:"client" yaml:"client"`
|
ScopesSupported []string `json:"scopes_supported"`
|
||||||
|
ClaimsSupported []string `json:"claims_supported"`
|
||||||
|
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||||
|
JwksUri *utils.URL `json:"jwks_uri"`
|
||||||
|
OAuth2Config oauth2.Config `json:"-"`
|
||||||
|
LastFetch time.Time `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SsoConfigClient struct {
|
func fetchConfig(ctx context.Context, httpClient *http.Client, s database.OauthSource) (*WellKnownOIDC, error) {
|
||||||
ID string `json:"id"`
|
|
||||||
Secret string `json:"secret"`
|
|
||||||
Scopes []string `json:"scopes"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s SsoConfig) FetchConfig() (*WellKnownOIDC, error) {
|
|
||||||
// generate openid config url
|
// generate openid config url
|
||||||
u := s.Addr.JoinPath(".well-known/openid-configuration")
|
u := s.Address.JoinPath(".well-known/openid-configuration")
|
||||||
|
|
||||||
// fetch metadata
|
// fetch metadata
|
||||||
get, err := httpGet(u.String())
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer get.Body.Close()
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
var c WellKnownOIDC
|
var c WellKnownOIDC
|
||||||
err = json.NewDecoder(get.Body).Decode(&c)
|
err = json.NewDecoder(resp.Body).Decode(&c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.Config = s
|
c.Config = s
|
||||||
c.OAuth2Config = oauth2.Config{
|
c.OAuth2Config = oauth2.Config{
|
||||||
ClientID: c.Config.Client.ID,
|
ClientID: c.Config.ClientID,
|
||||||
ClientSecret: c.Config.Client.Secret,
|
ClientSecret: c.Config.ClientSecret,
|
||||||
Endpoint: oauth2.Endpoint{
|
Endpoint: oauth2.Endpoint{
|
||||||
AuthURL: c.AuthorizationEndpoint,
|
AuthURL: c.AuthorizationEndpoint.String(),
|
||||||
TokenURL: c.TokenEndpoint,
|
TokenURL: c.TokenEndpoint.String(),
|
||||||
AuthStyle: oauth2.AuthStyleInHeader,
|
AuthStyle: oauth2.AuthStyleInHeader,
|
||||||
},
|
},
|
||||||
Scopes: c.Config.Client.Scopes,
|
Scopes: strings.Fields(c.Config.ClientScopes),
|
||||||
}
|
}
|
||||||
|
c.Available.Store(true)
|
||||||
return &c, nil
|
return &c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type WellKnownOIDC struct {
|
func (o *WellKnownOIDC) Validate() error {
|
||||||
Namespace string `json:"-"`
|
|
||||||
Config SsoConfig `json:"-"`
|
|
||||||
Issuer string `json:"issuer"`
|
|
||||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
|
||||||
TokenEndpoint string `json:"token_endpoint"`
|
|
||||||
UserInfoEndpoint string `json:"userinfo_endpoint"`
|
|
||||||
ResponseTypesSupported []string `json:"response_types_supported"`
|
|
||||||
ScopesSupported []string `json:"scopes_supported"`
|
|
||||||
ClaimsSupported []string `json:"claims_supported"`
|
|
||||||
GrantTypesSupported []string `json:"grant_types_supported"`
|
|
||||||
OAuth2Config oauth2.Config `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o WellKnownOIDC) Validate() error {
|
|
||||||
if o.Issuer == "" {
|
if o.Issuer == "" {
|
||||||
return errors.New("missing issuer")
|
return errors.New("missing issuer")
|
||||||
}
|
}
|
||||||
|
|
||||||
// check URLs are valid
|
|
||||||
if _, err := url.Parse(o.AuthorizationEndpoint); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := url.Parse(o.TokenEndpoint); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := url.Parse(o.UserInfoEndpoint); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// check oidc supported values
|
// check oidc supported values
|
||||||
if !slices.Contains(o.ResponseTypesSupported, "code") {
|
if !slices.Contains(o.ResponseTypesSupported, "code") {
|
||||||
return errors.New("missing required response type 'code'")
|
return errors.New("missing required response type 'code'")
|
||||||
@ -107,6 +95,6 @@ func (o WellKnownOIDC) Validate() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o WellKnownOIDC) ValidReturnUrl(u *url.URL) bool {
|
func (o *WellKnownOIDC) ValidReturnUrl(u *url.URL) bool {
|
||||||
return o.Config.Addr.Scheme == u.Scheme && o.Config.Addr.Host == u.Host
|
return o.Config.Address.Scheme == u.Scheme && o.Config.Address.Host == u.Host
|
||||||
}
|
}
|
||||||
|
@ -1,30 +1,29 @@
|
|||||||
package openid
|
package openid
|
||||||
|
|
||||||
import "github.com/1f349/lavender/url"
|
import "github.com/1f349/lavender/utils"
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Issuer string `json:"issuer"`
|
Issuer string `json:"issuer"`
|
||||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
AuthorizationEndpoint *utils.URL `json:"authorization_endpoint"`
|
||||||
TokenEndpoint string `json:"token_endpoint"`
|
TokenEndpoint *utils.URL `json:"token_endpoint"`
|
||||||
UserInfoEndpoint string `json:"userinfo_endpoint"`
|
UserInfoEndpoint *utils.URL `json:"userinfo_endpoint"`
|
||||||
ResponseTypesSupported []string `json:"response_types_supported"`
|
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||||
ScopesSupported []string `json:"scopes_supported"`
|
ScopesSupported []string `json:"scopes_supported"`
|
||||||
ClaimsSupported []string `json:"claims_supported"`
|
ClaimsSupported []string `json:"claims_supported"`
|
||||||
GrantTypesSupported []string `json:"grant_types_supported"`
|
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||||
JwksUri string `json:"jwks_uri"`
|
JwksUri *utils.URL `json:"jwks_uri"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenConfig(baseUrl *url.URL, scopes, claims []string) Config {
|
func GenConfig(baseUrl *utils.URL, scopes, claims []string) Config {
|
||||||
|
|
||||||
return Config{
|
return Config{
|
||||||
Issuer: baseUrl.String(),
|
Issuer: baseUrl.String(),
|
||||||
AuthorizationEndpoint: baseUrl.Resolve("authorize").String(),
|
AuthorizationEndpoint: baseUrl.Resolve("authorize"),
|
||||||
TokenEndpoint: baseUrl.Resolve("token").String(),
|
TokenEndpoint: baseUrl.Resolve("token"),
|
||||||
UserInfoEndpoint: baseUrl.Resolve("userinfo").String(),
|
UserInfoEndpoint: baseUrl.Resolve("userinfo"),
|
||||||
ResponseTypesSupported: []string{"code"},
|
ResponseTypesSupported: []string{"code"},
|
||||||
ScopesSupported: scopes,
|
ScopesSupported: scopes,
|
||||||
ClaimsSupported: claims,
|
ClaimsSupported: claims,
|
||||||
GrantTypesSupported: []string{"authorization_code", "refresh_token"},
|
GrantTypesSupported: []string{"authorization_code", "refresh_token"},
|
||||||
JwksUri: baseUrl.Resolve(".well-known/jwks.json").String(),
|
JwksUri: baseUrl.Resolve(".well-known/jwks.json"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package openid
|
package openid
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/1f349/lavender/url"
|
"github.com/1f349/lavender/utils"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -9,13 +9,13 @@ import (
|
|||||||
func TestGenConfig(t *testing.T) {
|
func TestGenConfig(t *testing.T) {
|
||||||
assert.Equal(t, Config{
|
assert.Equal(t, Config{
|
||||||
Issuer: "https://example.com",
|
Issuer: "https://example.com",
|
||||||
AuthorizationEndpoint: "https://example.com/authorize",
|
AuthorizationEndpoint: utils.MustParse("https://example.com/authorize"),
|
||||||
TokenEndpoint: "https://example.com/token",
|
TokenEndpoint: utils.MustParse("https://example.com/token"),
|
||||||
UserInfoEndpoint: "https://example.com/userinfo",
|
UserInfoEndpoint: utils.MustParse("https://example.com/userinfo"),
|
||||||
ResponseTypesSupported: []string{"code"},
|
ResponseTypesSupported: []string{"code"},
|
||||||
ScopesSupported: []string{"openid", "email"},
|
ScopesSupported: []string{"openid", "email"},
|
||||||
ClaimsSupported: []string{"name", "email", "preferred_username"},
|
ClaimsSupported: []string{"name", "email", "preferred_username"},
|
||||||
GrantTypesSupported: []string{"authorization_code", "refresh_token"},
|
GrantTypesSupported: []string{"authorization_code", "refresh_token"},
|
||||||
JwksUri: "https://example.com/.well-known/jwks.json",
|
JwksUri: utils.MustParse("https://example.com/.well-known/jwks.json"),
|
||||||
}, GenConfig(url.MustParse("https://example.com"), []string{"openid", "email"}, []string{"name", "email", "preferred_username"}))
|
}, GenConfig(utils.MustParse("https://example.com"), []string{"openid", "email"}, []string{"name", "email", "preferred_username"}))
|
||||||
}
|
}
|
||||||
|
@ -309,7 +309,15 @@ func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
sso := h.manager.FindServiceFromLogin(refreshData.Claims.Login)
|
_, namespace, err := utils.ParseLoginName(refreshData.Claims.Login)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sso := h.manager.GetService(namespace)
|
||||||
|
if sso == nil {
|
||||||
|
return fmt.Errorf("invalid namespace: %s", namespace)
|
||||||
|
}
|
||||||
|
|
||||||
var oauthToken *oauth2.Token
|
var oauthToken *oauth2.Token
|
||||||
|
|
||||||
|
@ -5,13 +5,13 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/1f349/lavender/logger"
|
"github.com/1f349/lavender/logger"
|
||||||
"github.com/1f349/lavender/openid"
|
"github.com/1f349/lavender/openid"
|
||||||
"github.com/1f349/lavender/url"
|
"github.com/1f349/lavender/utils"
|
||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetupOpenId(r *httprouter.Router, baseUrl *url.URL, signingKey *mjwt.Issuer) {
|
func SetupOpenId(r *httprouter.Router, baseUrl *utils.URL, signingKey *mjwt.Issuer) {
|
||||||
openIdConf := openid.GenConfig(baseUrl, []string{
|
openIdConf := openid.GenConfig(baseUrl, []string{
|
||||||
"openid", "name", "username", "profile", "email", "birthdate", "age", "zoneinfo", "locale",
|
"openid", "name", "username", "profile", "email", "birthdate", "age", "zoneinfo", "locale",
|
||||||
}, []string{
|
}, []string{
|
||||||
|
@ -1,20 +1,25 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"github.com/1f349/cache"
|
"github.com/1f349/cache"
|
||||||
"github.com/1f349/lavender/auth"
|
"github.com/1f349/lavender/auth"
|
||||||
|
"github.com/1f349/lavender/auth/process"
|
||||||
"github.com/1f349/lavender/auth/providers"
|
"github.com/1f349/lavender/auth/providers"
|
||||||
"github.com/1f349/lavender/conf"
|
"github.com/1f349/lavender/conf"
|
||||||
"github.com/1f349/lavender/database"
|
"github.com/1f349/lavender/database"
|
||||||
"github.com/1f349/lavender/issuer"
|
"github.com/1f349/lavender/issuer"
|
||||||
"github.com/1f349/lavender/logger"
|
"github.com/1f349/lavender/logger"
|
||||||
"github.com/1f349/lavender/mail"
|
"github.com/1f349/lavender/mail"
|
||||||
|
"github.com/1f349/lavender/utils"
|
||||||
"github.com/1f349/lavender/web"
|
"github.com/1f349/lavender/web"
|
||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
"github.com/go-oauth2/oauth2/v4/manage"
|
"github.com/go-oauth2/oauth2/v4/manage"
|
||||||
"github.com/go-oauth2/oauth2/v4/server"
|
"github.com/go-oauth2/oauth2/v4/server"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
@ -65,12 +70,7 @@ type mailLinkKey struct {
|
|||||||
func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail, db *database.Queries, signingKey *mjwt.Issuer) {
|
func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail, db *database.Queries, signingKey *mjwt.Issuer) {
|
||||||
// TODO: move auth provider init to main function
|
// TODO: move auth provider init to main function
|
||||||
// TODO: allow dynamically changing the providers based on database information
|
// TODO: allow dynamically changing the providers based on database information
|
||||||
authInitial := &providers.InitialLogin{}
|
// TODO: move oauth setup into oauth provider
|
||||||
authPassword := &providers.PasswordLogin{DB: db}
|
|
||||||
authOtp := &providers.OtpLogin{DB: db}
|
|
||||||
authOAuth := &providers.OAuthLogin{DB: db, BaseUrl: &config.BaseUrl}
|
|
||||||
authOAuth.Init()
|
|
||||||
authPasskey := &providers.PasskeyLogin{DB: db}
|
|
||||||
|
|
||||||
hs := &httpServer{
|
hs := &httpServer{
|
||||||
r: r,
|
r: r,
|
||||||
@ -80,19 +80,71 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail,
|
|||||||
mailSender: mailSender,
|
mailSender: mailSender,
|
||||||
|
|
||||||
mailLinkCache: cache.New[mailLinkKey, string](),
|
mailLinkCache: cache.New[mailLinkKey, string](),
|
||||||
|
|
||||||
authSources: []auth.Provider{
|
|
||||||
authInitial,
|
|
||||||
authPassword,
|
|
||||||
authOtp,
|
|
||||||
authOAuth,
|
|
||||||
authPasskey,
|
|
||||||
},
|
|
||||||
authButtons: make([]auth.Button, 0),
|
|
||||||
formProviderLookup: make(map[string]auth.Form),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
hs.manager, err = issuer.NewManager(db, config.Namespace)
|
||||||
|
if err != nil {
|
||||||
|
logger.Logger.Fatal("Failed to load SSO services", "err", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
authPassword := &providers.PasswordLogin{DB: db}
|
||||||
|
authOtp := &providers.OtpLogin{DB: db}
|
||||||
|
authOAuth := &providers.OAuthLogin{DB: db, BaseUrl: config.BaseUrl, Manager: hs.manager}
|
||||||
|
authOAuth.Init()
|
||||||
|
authPasskey := &providers.PasskeyLogin{DB: db}
|
||||||
|
authInitial := &providers.InitialLogin{DB: db, MyNamespace: config.Namespace, Manager: hs.manager, OAuth: authOAuth}
|
||||||
|
|
||||||
|
hs.authSources = []auth.Provider{
|
||||||
|
authInitial,
|
||||||
|
authPassword,
|
||||||
|
authOtp,
|
||||||
|
authOAuth,
|
||||||
|
authPasskey,
|
||||||
|
}
|
||||||
|
|
||||||
|
r.GET("/oauth/:namespace/start", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||||
|
namespace := params.ByName("namespace")
|
||||||
|
if !authOAuth.RedirectToAuthorize(rw, req, namespace) {
|
||||||
|
http.Error(rw, "Invalid OAuth namespace", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
r.GET("/oauth/:namespace/callback", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||||
|
namespace := params.ByName("namespace")
|
||||||
|
authOAuth.OAuthCallback(rw, req, namespace, func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
|
||||||
|
resp, err := sso.OAuth2Config.Client(req.Context(), token).Get(sso.UserInfoEndpoint.String())
|
||||||
|
if err != nil {
|
||||||
|
return auth.UserAuth{}, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var userInfoJson auth.UserInfoFields
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&userInfoJson); err != nil {
|
||||||
|
return auth.UserAuth{}, err
|
||||||
|
}
|
||||||
|
subject, ok := userInfoJson.GetString("sub")
|
||||||
|
if !ok {
|
||||||
|
return auth.UserAuth{}, fmt.Errorf("invalid subject")
|
||||||
|
}
|
||||||
|
subject += "@" + sso.Config.Namespace
|
||||||
|
|
||||||
|
return auth.UserAuth{
|
||||||
|
Subject: subject,
|
||||||
|
Factor: process.StateBasic,
|
||||||
|
UserInfo: userInfoJson,
|
||||||
|
}, nil
|
||||||
|
}, func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool {
|
||||||
|
// TODO: this should be using the existing auth flow calls
|
||||||
|
return hs.setLoginDataCookie(rw, authData, loginName)
|
||||||
|
}, func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// TODO: is this really needed like this?
|
||||||
|
utils.SafeRedirect(rw, req)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
// build slices and maps for quick access to auth interfaces
|
// build slices and maps for quick access to auth interfaces
|
||||||
|
hs.authButtons = make([]auth.Button, 0)
|
||||||
|
hs.formProviderLookup = make(map[string]auth.Form)
|
||||||
for _, source := range hs.authSources {
|
for _, source := range hs.authSources {
|
||||||
if button, isButton := source.(auth.Button); isButton {
|
if button, isButton := source.(auth.Button); isButton {
|
||||||
hs.authButtons = append(hs.authButtons, button)
|
hs.authButtons = append(hs.authButtons, button)
|
||||||
@ -103,13 +155,7 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
SetupOpenId(r, config.BaseUrl, signingKey)
|
||||||
hs.manager, err = issuer.NewManager(config.Namespace, config.SsoServices)
|
|
||||||
if err != nil {
|
|
||||||
logger.Logger.Fatal("Failed to load SSO services", "err", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
SetupOpenId(r, &config.BaseUrl, signingKey)
|
|
||||||
r.GET("/", hs.OptionalAuthentication(false, hs.Home))
|
r.GET("/", hs.OptionalAuthentication(false, hs.Home))
|
||||||
r.POST("/logout", hs.RequireAuthentication(hs.logoutPost))
|
r.POST("/logout", hs.RequireAuthentication(hs.logoutPost))
|
||||||
|
|
||||||
|
@ -29,3 +29,5 @@ sql:
|
|||||||
go_type: "database/sql.NullString"
|
go_type: "database/sql.NullString"
|
||||||
- column: "users.token_expiry"
|
- column: "users.token_expiry"
|
||||||
go_type: "database/sql.NullTime"
|
go_type: "database/sql.NullTime"
|
||||||
|
- column: "oauth_sources.address"
|
||||||
|
go_type: "github.com/1f349/lavender/utils.URL"
|
||||||
|
@ -1,21 +0,0 @@
|
|||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding"
|
|
||||||
"net/url"
|
|
||||||
)
|
|
||||||
|
|
||||||
type JsonUrl struct {
|
|
||||||
*url.URL
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ encoding.TextUnmarshaler = &JsonUrl{}
|
|
||||||
|
|
||||||
func (s *JsonUrl) UnmarshalText(text []byte) error {
|
|
||||||
parse, err := url.Parse(string(text))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.URL = parse
|
|
||||||
return nil
|
|
||||||
}
|
|
51
utils/loginname.go
Normal file
51
utils/loginname.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrInvalidLoginNameFormat = errors.New("invalid login name format")
|
||||||
|
|
||||||
|
func ParseLoginName(loginName string) (user string, namespace string, err error) {
|
||||||
|
if loginName == "" || strings.HasPrefix(loginName, "@") || strings.HasSuffix(loginName, "@") || containsInvalidLoginNameRunes(loginName) {
|
||||||
|
return "", "", ErrInvalidLoginNameFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
// @ should have at least one byte before it
|
||||||
|
n := strings.IndexByte(loginName, '@')
|
||||||
|
if n < 1 {
|
||||||
|
return "", "", ErrInvalidLoginNameFormat
|
||||||
|
}
|
||||||
|
// there should not be a second @
|
||||||
|
n2 := strings.IndexByte(loginName[n+1:], '@')
|
||||||
|
if n2 != -1 {
|
||||||
|
return "", "", ErrInvalidLoginNameFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
return loginName[:n], loginName[n+1:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsInvalidLoginNameRunes(loginName string) bool {
|
||||||
|
// check if the name contains an invalid rune
|
||||||
|
return strings.ContainsFunc(loginName, func(r rune) bool {
|
||||||
|
return !isValidLoginNameRune(r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidLoginNameRune(r rune) bool {
|
||||||
|
switch {
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
return true
|
||||||
|
case r >= 'a' && r <= 'z':
|
||||||
|
return true
|
||||||
|
case r == '.':
|
||||||
|
return true
|
||||||
|
case r == '-':
|
||||||
|
return true
|
||||||
|
case r == '@':
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
174
utils/loginname_test.go
Normal file
174
utils/loginname_test.go
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var parseLoginNameTests = []struct {
|
||||||
|
User string
|
||||||
|
Namespace string
|
||||||
|
HasError bool
|
||||||
|
Input string
|
||||||
|
}{
|
||||||
|
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com", false, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com"},
|
||||||
|
{"", "", true, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com\u1111"},
|
||||||
|
{"", "", true, "\u1111aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com"},
|
||||||
|
{"", "", true, "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com"},
|
||||||
|
{"", "", true, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com@"},
|
||||||
|
{"a", "b.com", false, "a@b.com"},
|
||||||
|
{"aa", "bb.com", false, "aa@bb.com"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var parseLoginNameImplementations = []struct {
|
||||||
|
Name string
|
||||||
|
Func func(string) (string, string, error)
|
||||||
|
}{
|
||||||
|
{"ParseLoginName", ParseLoginName},
|
||||||
|
{"parseLoginNameLoopRunes", parseLoginNameLoopRunes},
|
||||||
|
{"parseLoginNameRegex", parseLoginNameRegex},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseLoginName(t *testing.T) {
|
||||||
|
for _, impl := range parseLoginNameImplementations {
|
||||||
|
t.Run(impl.Name, func(t *testing.T) {
|
||||||
|
for _, i := range parseLoginNameTests {
|
||||||
|
t.Run(i.Input, func(t *testing.T) {
|
||||||
|
user, namespace, err := ParseLoginName(i.Input)
|
||||||
|
if i.HasError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, i.User, user)
|
||||||
|
assert.Equal(t, i.Namespace, namespace)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func FuzzParseLoginName(f *testing.F) {
|
||||||
|
for _, i := range parseLoginNameTests {
|
||||||
|
f.Add(i.Input)
|
||||||
|
}
|
||||||
|
f.Fuzz(func(t *testing.T, s string) {
|
||||||
|
t.Log("Input: ", s, hex.EncodeToString([]byte(s)))
|
||||||
|
hasError := s == "" || strings.HasPrefix(s, "@") || strings.HasSuffix(s, "@") || strings.Count(s, "@") != 1 || strings.ContainsFunc(s, func(r rune) bool {
|
||||||
|
return !isValidLoginNameRune(r)
|
||||||
|
})
|
||||||
|
login, namespace, err := ParseLoginName(s)
|
||||||
|
if hasError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
n := strings.IndexRune(s, '@')
|
||||||
|
assert.Equal(t, s[:n], login)
|
||||||
|
assert.Equal(t, s[n+1:], namespace)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, login, "", "Login should be empty if an error occurred")
|
||||||
|
assert.Equal(t, namespace, "", "Namespace should be empty if an error occurred")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLoginBench(b *testing.B, f func(string) (string, string, error)) {
|
||||||
|
for _, i := range parseLoginNameTests {
|
||||||
|
b.Run(i.Input, func(b *testing.B) {
|
||||||
|
for range b.N {
|
||||||
|
_, _, _ = f(i.Input)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkParseLoginName(b *testing.B) {
|
||||||
|
b.Run("ParseLoginName", func(b *testing.B) {
|
||||||
|
parseLoginBench(b, ParseLoginName)
|
||||||
|
})
|
||||||
|
b.Run("parseLoginNameLoopRunes", func(b *testing.B) {
|
||||||
|
parseLoginBench(b, parseLoginNameLoopRunes)
|
||||||
|
})
|
||||||
|
b.Run("parseLoginNameRegex", func(b *testing.B) {
|
||||||
|
parseLoginBench(b, parseLoginNameRegex)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLoginNameLoopRunes(s string) (user string, namespace string, err error) {
|
||||||
|
if s == "" || strings.HasPrefix(s, "@") || strings.HasSuffix(s, "@") {
|
||||||
|
return "", "", ErrInvalidLoginNameFormat
|
||||||
|
}
|
||||||
|
hasAt := false
|
||||||
|
atIndex := 0
|
||||||
|
for i, r := range []rune(s) {
|
||||||
|
switch {
|
||||||
|
case r == '@':
|
||||||
|
if hasAt {
|
||||||
|
return "", "", ErrInvalidLoginNameFormat
|
||||||
|
}
|
||||||
|
hasAt = true
|
||||||
|
atIndex = i
|
||||||
|
case r == '.':
|
||||||
|
continue
|
||||||
|
case r == '-':
|
||||||
|
continue
|
||||||
|
case r >= 'a' && r <= 'z':
|
||||||
|
continue
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
return "", "", ErrInvalidLoginNameFormat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s[:atIndex], s[atIndex+1:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// regexLoginName is a regex based implementation of ParseLoginName
|
||||||
|
//
|
||||||
|
// This implementation prevents using - or . at the start and end of the user and namespace
|
||||||
|
var regexLoginName = regexp.MustCompile(`^([a-z0-9]([a-z0-9-.]*[a-z0-9]|))@([a-z0-9]([a-z0-9-.]*[a-z0-9]|))$`)
|
||||||
|
|
||||||
|
func parseLoginNameRegex(s string) (user string, namespace string, err error) {
|
||||||
|
if s == "" {
|
||||||
|
return "", "", ErrInvalidLoginNameFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := regexLoginName.FindStringSubmatch(s)
|
||||||
|
if matches == nil {
|
||||||
|
return "", "", ErrInvalidLoginNameFormat
|
||||||
|
}
|
||||||
|
return matches[1], matches[3], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkA(b *testing.B) {
|
||||||
|
b.Run("A", func(b *testing.B) {
|
||||||
|
for range b.N {
|
||||||
|
_ = isValidLoginNameRune('b')
|
||||||
|
_ = isValidLoginNameRune('4')
|
||||||
|
_ = isValidLoginNameRune('.')
|
||||||
|
_ = isValidLoginNameRune('-')
|
||||||
|
_ = isValidLoginNameRune('@')
|
||||||
|
_ = isValidLoginNameRune(' ')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
b.Run("B", func(b *testing.B) {
|
||||||
|
for range b.N {
|
||||||
|
_ = isValidLoginNameRune2('b')
|
||||||
|
_ = isValidLoginNameRune2('4')
|
||||||
|
_ = isValidLoginNameRune2('.')
|
||||||
|
_ = isValidLoginNameRune2('-')
|
||||||
|
_ = isValidLoginNameRune2('@')
|
||||||
|
_ = isValidLoginNameRune2(' ')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidLoginNameRune2(r rune) bool {
|
||||||
|
return (r >= '0' && r <= '9') || (r >= 'a' && r <= 'z') || r == '.' || r == '-' || r == '@'
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package url
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding"
|
"encoding"
|
||||||
@ -6,35 +6,33 @@ import (
|
|||||||
"path"
|
"path"
|
||||||
)
|
)
|
||||||
|
|
||||||
type URL struct {
|
|
||||||
url.URL
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *URL) Resolve(paths ...string) *URL {
|
|
||||||
return &URL{URL: *u.URL.ResolveReference(&url.URL{Path: path.Join(paths...)})}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u URL) MarshalText() (text []byte, err error) {
|
|
||||||
return []byte(u.String()), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *URL) UnmarshalText(text []byte) error {
|
|
||||||
parse, err := u.Parse(string(text))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
u.URL = *parse
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ encoding.TextMarshaler = (*URL)(nil)
|
var _ encoding.TextMarshaler = (*URL)(nil)
|
||||||
var _ encoding.TextUnmarshaler = (*URL)(nil)
|
var _ encoding.TextUnmarshaler = (*URL)(nil)
|
||||||
|
|
||||||
|
type URL struct{ *url.URL }
|
||||||
|
|
||||||
func MustParse(rawURL string) *URL {
|
func MustParse(rawURL string) *URL {
|
||||||
u, err := url.Parse(rawURL)
|
u, err := url.Parse(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return &URL{*u}
|
return &URL{u}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *URL) Resolve(paths ...string) *URL {
|
||||||
|
return &URL{URL: u.URL.ResolveReference(&url.URL{Path: path.Join(paths...)})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *URL) MarshalText() (text []byte, err error) {
|
||||||
|
return []byte(u.URL.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *URL) UnmarshalText(text []byte) error {
|
||||||
|
parse, err := url.Parse(string(text))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
u.URL = parse
|
||||||
|
return nil
|
||||||
}
|
}
|
Loading…
x
Reference in New Issue
Block a user