Rewrite OAuth manager, OAuth source and URL wrapper

This commit is contained in:
Melon 2025-03-13 13:29:57 +00:00
parent d9b0074133
commit 50df217b66
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
24 changed files with 789 additions and 234 deletions

View File

@ -5,10 +5,11 @@ import (
"fmt" "fmt"
"github.com/1f349/lavender/auth" "github.com/1f349/lavender/auth"
"github.com/1f349/lavender/auth/authContext" "github.com/1f349/lavender/auth/authContext"
process "github.com/1f349/lavender/auth/process" "github.com/1f349/lavender/auth/process"
"github.com/1f349/lavender/database" "github.com/1f349/lavender/database"
"github.com/1f349/lavender/issuer" "github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/logger" "github.com/1f349/lavender/logger"
"github.com/1f349/lavender/utils"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"net/http" "net/http"
@ -25,6 +26,7 @@ type InitialLogin struct {
DB *database.Queries DB *database.Queries
MyNamespace string MyNamespace string
Manager *issuer.Manager Manager *issuer.Manager
OAuth *OAuthLogin
} }
func (m *InitialLogin) AccessState() process.State { return process.StateUnauthorized } func (m *InitialLogin) AccessState() process.State { return process.StateUnauthorized }
@ -101,36 +103,39 @@ func (m *InitialLogin) AttemptLogin(ctx authContext.FormContext) error {
} }
// append local namespace if @ is missing // append local namespace if @ is missing
n := strings.IndexByte(loginName, '@') if !strings.ContainsRune(loginName, '@') {
if n < 0 {
// correct the @ index
n = len(loginName)
loginName += "@" + m.MyNamespace loginName += "@" + m.MyNamespace
} }
login := m.Manager.FindServiceFromLogin(loginName) user, namespace, err := utils.ParseLoginName(loginName)
if err != nil {
http.Error(rw, "Invalid login name", http.StatusBadRequest)
return fmt.Errorf("invalid login name %s", loginName)
}
login := m.Manager.GetService(namespace)
if login == nil { if login == nil {
http.Error(rw, "No login service defined for this username", http.StatusBadRequest) http.Error(rw, "No login service defined for this username", http.StatusBadRequest)
return errors.New("no login service defined for this username") return errors.New("no login service defined for this username")
} }
// the @ must exist if the service is defined
loginUn := loginName[:n]
// TODO: finish migrating this shit
// the login is not for this namespace // the login is not for this namespace
if login != issuer.MeWellKnown { if login != issuer.MeWellKnown {
// TODO: this is oauth request code start oauth request
// TODO: this could all be 1 function call into oauth
// save state for use later // save state for use later
state := login.Config.Namespace + ":" + uuid.NewString() state := login.Config.Namespace + ":" + uuid.NewString()
h.flowState.Set(state, flowStateData{loginName, login, req.PostFormValue("redirect")}, time.Now().Add(15*time.Minute)) m.OAuth.flow.Set(state, flowStateData{loginName, login, req.PostFormValue("redirect")}, time.Now().Add(15*time.Minute))
// generate oauth2 config and redirect to authorize URL // generate oauth2 config and redirect to authorize URL
oa2conf := login.OAuth2Config oa2conf := login.OAuth2Config
oa2conf.RedirectURL = h.conf.BaseUrl.JoinPath("callback").String() oa2conf.RedirectURL = m.OAuth.BaseUrl.JoinPath("callback").String()
nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", loginUn)) nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", user))
http.Redirect(rw, req, nextUrl, http.StatusFound) return auth.RedirectError{
return Target: nextUrl,
Code: http.StatusFound,
}
} }
ctx.UpdateSession(process.LoginProcessData{ ctx.UpdateSession(process.LoginProcessData{

View File

@ -13,17 +13,18 @@ import (
"github.com/1f349/lavender/database" "github.com/1f349/lavender/database"
"github.com/1f349/lavender/database/types" "github.com/1f349/lavender/database/types"
"github.com/1f349/lavender/issuer" "github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/url" "github.com/1f349/lavender/utils"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mrmelon54/pronouns" "github.com/mrmelon54/pronouns"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/text/language" "golang.org/x/text/language"
"net/http" "net/http"
"net/url"
"time" "time"
) )
type OauthCallback interface { type OauthCallback interface {
OAuthCallback(rw http.ResponseWriter, req *http.Request, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request)) OAuthCallback(rw http.ResponseWriter, req *http.Request, namespace string, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request))
} }
type flowStateData struct { type flowStateData struct {
@ -40,9 +41,9 @@ var (
type OAuthLogin struct { type OAuthLogin struct {
DB *database.Queries DB *database.Queries
BaseUrl *url.URL BaseUrl *utils.URL
flow *cache.Cache[string, flowStateData]
flow *cache.Cache[string, flowStateData] Manager *issuer.Manager
} }
func (o OAuthLogin) Init() { func (o OAuthLogin) Init() {
@ -50,7 +51,7 @@ func (o OAuthLogin) Init() {
} }
func (o OAuthLogin) authUrlBase(ref string) *url.URL { func (o OAuthLogin) authUrlBase(ref string) *url.URL {
return o.BaseUrl.Resolve("oauth", o.Name(), ref) return o.BaseUrl.Resolve("oauth", o.Name(), ref).URL
} }
func (o OAuthLogin) AccessState() process.State { return process.StateUnauthorized } func (o OAuthLogin) AccessState() process.State { return process.StateUnauthorized }
@ -83,12 +84,16 @@ func (o OAuthLogin) AttemptLogin(ctx authContext.FormContext) error {
return auth.RedirectError{Target: nextUrl, Code: http.StatusFound} return auth.RedirectError{Target: nextUrl, Code: http.StatusFound}
} }
func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request)) { func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, namespace string, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request)) {
flowState, ok := o.flow.Get(req.FormValue("state")) flowState, ok := o.flow.Get(req.FormValue("state"))
if !ok { if !ok {
http.Error(rw, "Invalid flow state", http.StatusBadRequest) http.Error(rw, "Invalid flow state", http.StatusBadRequest)
return return
} }
if flowState.sso.Namespace != namespace {
http.Error(rw, "OAuth source mismatch", http.StatusBadRequest)
return
}
token, err := flowState.sso.OAuth2Config.Exchange(context.Background(), req.FormValue("code"), oauth2.SetAuthURLParam("redirect_uri", o.authUrlBase("callback").String())) token, err := flowState.sso.OAuth2Config.Exchange(context.Background(), req.FormValue("code"), oauth2.SetAuthURLParam("redirect_uri", o.authUrlBase("callback").String()))
if err != nil { if err != nil {
http.Error(rw, "Failed to exchange code for token", http.StatusInternalServerError) http.Error(rw, "Failed to exchange code for token", http.StatusInternalServerError)
@ -112,17 +117,32 @@ func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, inf
} }
func (o OAuthLogin) RenderButtonTemplate(ctx authContext.TemplateContext) { func (o OAuthLogin) RenderButtonTemplate(ctx authContext.TemplateContext) {
// TODO: idk what this is
// o.authUrlBase("button") // o.authUrlBase("button")
// provide something non-nil // provide something non-nil
ctx.Render(struct { ctx.Render(struct {
Href string Href string
ButtonName string ButtonName string
}{ }{
Href: o.authUrlBase("button").String(), Href: o.authUrlBase("start").String(),
ButtonName: "Login with Unknown OAuth Button", // TODO: actually get the service name ButtonName: "Login with Unknown OAuth Button", // TODO: actually get the service name
}) })
} }
// RedirectToAuthorize returns true when redirecting to the authorize endpoint
// for the requested namespace.
func (o OAuthLogin) RedirectToAuthorize(rw http.ResponseWriter, req *http.Request, namespace string) bool {
oidc := o.Manager.GetService(namespace)
if oidc == nil {
return false
}
// TODO: come back to fix this
oidc.OAuth2Config.AuthCodeURL("state")
// http.Redirect(rw, req, oidc.AuthorizationEndpoint, http.StatusOK)
return true
}
type oauthServiceLogin int type oauthServiceLogin int
func WithWellKnown(ctx context.Context, login *issuer.WellKnownOIDC) context.Context { func WithWellKnown(ctx context.Context, login *issuer.WellKnownOIDC) context.Context {
@ -256,7 +276,7 @@ func (o OAuthLogin) updateOAuth2UserProfile(ctx context.Context, tx *database.Qu
} }
func (o OAuthLogin) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) { func (o OAuthLogin) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint) res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint.String())
if err != nil || res.StatusCode != http.StatusOK { if err != nil || res.StatusCode != http.StatusOK {
return auth.UserAuth{}, fmt.Errorf("request failed") return auth.UserAuth{}, fmt.Errorf("request failed")
} }

View File

@ -1,19 +1,17 @@
package conf package conf
import ( import (
"github.com/1f349/lavender/issuer" "github.com/1f349/lavender/utils"
"github.com/1f349/lavender/url"
"github.com/1f349/simplemail" "github.com/1f349/simplemail"
) )
type Conf struct { type Conf struct {
Listen string `yaml:"listen"` Listen string `yaml:"listen"`
BaseUrl url.URL `yaml:"baseUrl"` BaseUrl *utils.URL `yaml:"baseUrl"`
ServiceName string `yaml:"serviceName"` ServiceName string `yaml:"serviceName"`
Issuer string `yaml:"issuer"` Issuer string `yaml:"issuer"`
Kid string `yaml:"kid"` Kid string `yaml:"kid"`
Namespace string `yaml:"namespace"` Namespace string `yaml:"namespace"`
OtpIssuer string `yaml:"otpIssuer"` OtpIssuer string `yaml:"otpIssuer"`
Mail simplemail.Mail `yaml:"mail"` Mail simplemail.Mail `yaml:"mail"`
SsoServices []issuer.SsoConfig `yaml:"ssoServices"`
} }

View File

@ -0,0 +1,45 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.28.0
// source: manage-oauth-source.sql
package database
import (
"context"
)
const getOAuthSources = `-- name: GetOAuthSources :many
SELECT namespace, address, registration, button, client_id, client_secret, client_scopes FROM oauth_sources
`
func (q *Queries) GetOAuthSources(ctx context.Context) ([]OauthSource, error) {
rows, err := q.db.QueryContext(ctx, getOAuthSources)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OauthSource
for rows.Next() {
var i OauthSource
if err := rows.Scan(
&i.Namespace,
&i.Address,
&i.Registration,
&i.Button,
&i.ClientID,
&i.ClientSecret,
&i.ClientScopes,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -0,0 +1,10 @@
CREATE TABLE oauth_sources
(
namespace TEXT NOT NULL UNIQUE PRIMARY KEY,
address TEXT NOT NULL,
registration BOOLEAN NOT NULL,
button BOOLEAN NOT NULL,
client_id TEXT NOT NULL,
client_secret TEXT NOT NULL,
client_scopes TEXT NOT NULL
);

View File

@ -10,6 +10,7 @@ import (
"github.com/1f349/lavender/database/types" "github.com/1f349/lavender/database/types"
"github.com/1f349/lavender/password" "github.com/1f349/lavender/password"
"github.com/1f349/lavender/utils"
"github.com/hardfinhq/go-date" "github.com/hardfinhq/go-date"
) )
@ -25,6 +26,16 @@ type ClientStore struct {
Active bool `json:"active"` Active bool `json:"active"`
} }
type OauthSource struct {
Namespace string `json:"namespace"`
Address utils.URL `json:"address"`
Registration bool `json:"registration"`
Button bool `json:"button"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
ClientScopes string `json:"client_scopes"`
}
type Role struct { type Role struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Role string `json:"role"` Role string `json:"role"`

View File

@ -0,0 +1,2 @@
-- name: GetOAuthSources :many
SELECT * FROM oauth_sources;

1
go.mod
View File

@ -19,6 +19,7 @@ require (
github.com/julienschmidt/httprouter v1.3.0 github.com/julienschmidt/httprouter v1.3.0
github.com/mattn/go-sqlite3 v1.14.24 github.com/mattn/go-sqlite3 v1.14.24
github.com/mrmelon54/pronouns v1.0.3 github.com/mrmelon54/pronouns v1.0.3
github.com/robfig/cron v1.2.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/spf13/afero v1.12.0 github.com/spf13/afero v1.12.0
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0

2
go.sum
View File

@ -138,6 +138,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/robfig/cron v1.2.0 h1:ZjScXvvxeQ63Dbyxy76Fj3AT3Ut0aKsyd2/tl3DTMuQ=
github.com/robfig/cron v1.2.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=

View File

@ -1,57 +1,226 @@
package issuer package issuer
import ( import (
"context"
"fmt" "fmt"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/logger"
"github.com/robfig/cron"
"net/http"
"regexp" "regexp"
"strings" "strings"
"sync"
"time"
)
const (
fetchPoolSize = 4
fetchTimeout = 2 * time.Minute
fetchRetryDelay = 15 * time.Second
fetchRetryAfterTimeout = 15 * time.Minute
fetchRetryCount = 3
) )
var isValidNamespace = regexp.MustCompile("^[0-9a-z.]+$") var isValidNamespace = regexp.MustCompile("^[0-9a-z.]+$")
var MeWellKnown = &WellKnownOIDC{} var MeWellKnown = &WellKnownOIDC{}
type Manager struct { type managerReloadDB interface {
m map[string]*WellKnownOIDC GetOAuthSources(ctx context.Context) ([]database.OauthSource, error)
} }
func NewManager(myNamespace string, services []SsoConfig) (*Manager, error) { type Manager struct {
l := &Manager{m: make(map[string]*WellKnownOIDC)} db managerReloadDB
l.m[myNamespace] = MeWellKnown
for _, ssoService := range services {
if !isValidNamespace.MatchString(ssoService.Namespace) {
return nil, fmt.Errorf("invalid namespace: %s", ssoService.Namespace)
}
conf, err := ssoService.FetchConfig() client *http.Client
if err != nil {
return nil, err
}
// save by namespace sourceMu *sync.RWMutex
l.m[ssoService.Namespace] = conf sourceMap map[string]*WellKnownOIDC
errMu *sync.RWMutex
errMap map[string]FetchErrorDetails
shutdownCall context.CancelFunc
}
type FetchErrorDetails struct {
Retry time.Time
Errors []error
}
type managerRoundTripper struct {
Base http.RoundTripper
}
func (a *managerRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
userAgent := strings.Join(request.Header.Values("User-Agent"), " ")
if strings.TrimSpace(userAgent) == "" {
request.Header.Set("User-Agent", "Lavender/1.0 (OAuth Manager)")
} }
return a.Base.RoundTrip(request)
}
func NewManager(db managerReloadDB, myNamespace string) (*Manager, error) {
return NewManagerWithClient(db, myNamespace, &http.Client{
Transport: &managerRoundTripper{
Base: http.DefaultTransport,
},
})
}
func NewManagerWithClient(db managerReloadDB, myNamespace string, client *http.Client) (*Manager, error) {
l := &Manager{
db: db,
client: client,
sourceMu: new(sync.RWMutex),
sourceMap: make(map[string]*WellKnownOIDC),
errMap: make(map[string]FetchErrorDetails),
}
l.sourceMap[myNamespace] = MeWellKnown
// reload should run blocking on start up
l.internalReload(context.Background())
reloadDaily := make(chan struct{}, 1)
tab := cron.New()
err := tab.AddFunc("0 2 * * *", func() {
reloadDaily <- struct{}{}
})
if err != nil {
panic(err)
}
ctx, cancel := context.WithCancel(context.Background())
go l.reloadWorker(ctx, reloadDaily)
l.shutdownCall = cancel
return l, nil return l, nil
} }
func (m *Manager) Shutdown() { m.shutdownCall() }
func (m *Manager) reloadWorker(ctx context.Context, reload <-chan struct{}) {
for {
select {
case <-ctx.Done():
logger.Logger.Info("Shutting down oauth manager reload worker")
return
case <-reload:
m.internalReload(ctx)
}
}
}
func (m *Manager) internalReload(ctx context.Context) {
logger.Logger.Info("Reloading oauth sources")
dbCtx, cancel := context.WithTimeout(ctx, fetchTimeout)
sources, err := m.db.GetOAuthSources(dbCtx)
cancel()
if err != nil {
// TODO: send email to admin
logger.Logger.Warn("Failed to load OAuth sources from the database", "err", err)
return
}
// construct map of db oauth sources
toAdd := make(map[string]database.OauthSource)
for _, source := range sources {
toAdd[source.Namespace] = source
}
// remove sources from the active map
m.sourceMu.Lock()
for k, v := range m.sourceMap {
// ignore my own well-known
if v == MeWellKnown {
continue
}
if _, shouldAdd := toAdd[k]; !shouldAdd {
delete(m.sourceMap, k)
}
}
m.sourceMu.Unlock()
var wg sync.WaitGroup
wg.Add(fetchPoolSize)
fetchChan := make(chan database.OauthSource, fetchPoolSize)
for range fetchPoolSize {
go func() {
defer wg.Done()
for source := range fetchChan {
itemCtx, cancel := context.WithTimeout(ctx, fetchTimeout)
errs := m.reloadSource(itemCtx, source)
cancel()
// this also resets the error if it was set previously
m.sourceMu.Lock()
if len(errs) > 0 {
m.errMap[source.Namespace] = FetchErrorDetails{
Retry: time.Now().Add(fetchRetryAfterTimeout),
Errors: errs,
}
// TODO: send email to admin
} else {
delete(m.errMap, source.Namespace)
// TODO: send resolved email to admin
}
m.sourceMu.Unlock()
}
}()
}
// reload sources
for _, source := range toAdd {
fetchChan <- source
}
close(fetchChan)
wg.Wait()
}
func (m *Manager) reloadSource(ctx context.Context, source database.OauthSource) []error {
var errs []error
for range fetchRetryCount {
oidc, err := fetchConfig(ctx, m.client, source)
fmt.Println(oidc, err)
if err == nil {
// fetch was successful
m.sourceMu.Lock()
m.sourceMap[source.Namespace] = oidc
m.sourceMu.Unlock()
return nil
}
errs = append(errs, fmt.Errorf("failed to fetch OIDC well-known configuration: %w", err))
// wait before retrying
select {
case <-ctx.Done():
return errs
case <-time.After(fetchRetryDelay):
continue
}
}
return errs
}
func (m *Manager) CheckNamespace(namespace string) bool { func (m *Manager) CheckNamespace(namespace string) bool {
_, ok := m.m[namespace] m.sourceMu.RLock()
defer m.sourceMu.RUnlock()
_, ok := m.sourceMap[namespace]
return ok return ok
} }
func (m *Manager) GetService(namespace string) *WellKnownOIDC { func (m *Manager) GetService(namespace string) *WellKnownOIDC {
return m.m[namespace] m.sourceMu.RLock()
} defer m.sourceMu.RUnlock()
return m.sourceMap[namespace]
func (m *Manager) FindServiceFromLogin(login string) *WellKnownOIDC {
// @ should have at least one byte before it
n := strings.IndexByte(login, '@')
if n < 1 {
return nil
}
// there should not be a second @
n2 := strings.IndexByte(login[n+1:], '@')
if n2 != -1 {
return nil
}
return m.GetService(login[n+1:])
} }

View File

@ -1,6 +1,9 @@
package issuer package issuer
import ( import (
"context"
"fmt"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/utils" "github.com/1f349/lavender/utils"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"io" "io"
@ -8,29 +11,54 @@ import (
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time"
) )
var testAddrUrl = func() utils.JsonUrl { var testAddrUrl = func() utils.URL {
a, err := url.Parse("https://example.com") a, err := url.Parse("https://example.com")
if err != nil { if err != nil {
panic(err) panic(err)
} }
return utils.JsonUrl{URL: a} return utils.URL{URL: a}
}() }()
func testBody() io.ReadCloser { func testBody() io.ReadCloser {
return io.NopCloser(strings.NewReader("{}")) return io.NopCloser(strings.NewReader(`{
"issuer": "Example.com Issuer",
"authorization_endpoint": "https://example.com/oauth/authorize",
"token_endpoint": "https://example.com/oauth/token",
"userinfo_endpoint": "https://example.com/oauth/userinfo",
"revocation_endpoint": "https://example.com/oauth/revoke",
"response_types_supported": [
"code"
],
"scopes_supported": [
"openid"
],
"claims_supported": [
"sub"
],
"grant_types_supported": [
"authorization_code",
"refresh_token"
]
}
`))
} }
func TestManager_CheckNamespace(t *testing.T) { func TestManager_CheckNamespace(t *testing.T) {
httpGet = func(url string) (resp *http.Response, err error) { manager, err := NewManagerWithClient(&testDB{
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil Rows: []database.OauthSource{
} {Address: testAddrUrl, Namespace: "example.com"},
manager, err := NewManager("example.org", []SsoConfig{
{
Addr: testAddrUrl,
Namespace: "example.com",
}, },
}, "example.org", &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
fmt.Println("Request URL:", req.URL)
if req.URL.String() == "https://example.com/.well-known/openid-configuration" {
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
}
return nil, fmt.Errorf("request failed")
}),
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, manager.CheckNamespace("example.org")) assert.True(t, manager.CheckNamespace("example.org"))
@ -39,17 +67,36 @@ func TestManager_CheckNamespace(t *testing.T) {
} }
func TestManager_FindServiceFromLogin(t *testing.T) { func TestManager_FindServiceFromLogin(t *testing.T) {
httpGet = func(url string) (resp *http.Response, err error) { manager, err := NewManagerWithClient(&testDB{
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil Rows: []database.OauthSource{
} {Address: testAddrUrl, Namespace: "example.com"},
manager, err := NewManager("example.org", []SsoConfig{
{
Addr: testAddrUrl,
Namespace: "example.com",
}, },
}, "example.org", &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
fmt.Println("Request URL:", req.URL)
if req.URL.String() != "https://example.com/.well-known/openid-configuration" {
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
}
return nil, fmt.Errorf("request failed")
}),
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, manager.FindServiceFromLogin("jane@example.org"), MeWellKnown) assert.Equal(t, manager.GetService("example.org"), MeWellKnown)
assert.Equal(t, manager.FindServiceFromLogin("jane@example.com"), manager.m["example.com"]) assert.Equal(t, manager.GetService("example.com"), manager.sourceMap["example.com"])
assert.Nil(t, manager.FindServiceFromLogin("jane@missing.example.com")) assert.Nil(t, manager.GetService("missing.example.com"))
}
type testDB struct {
Rows []database.OauthSource
}
func (t *testDB) GetOAuthSources(ctx context.Context) ([]database.OauthSource, error) {
return t.Rows, nil
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (r roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
time.Sleep(time.Minute)
return r(req)
} }

View File

@ -1,94 +1,82 @@
package issuer package issuer
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/utils" "github.com/1f349/lavender/utils"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"net/http" "net/http"
"net/url" "net/url"
"slices" "slices"
"strings"
"sync/atomic"
"time"
) )
var httpGet = http.Get type WellKnownOIDC struct {
Namespace string `json:"-"`
// SsoConfig is the base URL for an OAUTH/OPENID/SSO login service Config database.OauthSource `json:"-"`
// The path `/.well-known/openid-configuration` should be available Available atomic.Bool `json:"-"`
type SsoConfig struct { Issuer string `json:"issuer"`
Addr utils.JsonUrl `json:"addr" yaml:"addr"` // https://login.example.com AuthorizationEndpoint *utils.URL `json:"authorization_endpoint"`
Namespace string `json:"namespace" yaml:"namespace"` // example.com TokenEndpoint *utils.URL `json:"token_endpoint"`
Registration bool `json:"registration" yaml:"registration"` UserInfoEndpoint *utils.URL `json:"userinfo_endpoint"`
LoginWithButton bool `json:"login_with_button" yaml:"loginWithButton"` ResponseTypesSupported []string `json:"response_types_supported"`
Client SsoConfigClient `json:"client" yaml:"client"` ScopesSupported []string `json:"scopes_supported"`
ClaimsSupported []string `json:"claims_supported"`
GrantTypesSupported []string `json:"grant_types_supported"`
JwksUri *utils.URL `json:"jwks_uri"`
OAuth2Config oauth2.Config `json:"-"`
LastFetch time.Time `json:"-"`
} }
type SsoConfigClient struct { func fetchConfig(ctx context.Context, httpClient *http.Client, s database.OauthSource) (*WellKnownOIDC, error) {
ID string `json:"id"`
Secret string `json:"secret"`
Scopes []string `json:"scopes"`
}
func (s SsoConfig) FetchConfig() (*WellKnownOIDC, error) {
// generate openid config url // generate openid config url
u := s.Addr.JoinPath(".well-known/openid-configuration") u := s.Address.JoinPath(".well-known/openid-configuration")
// fetch metadata // fetch metadata
get, err := httpGet(u.String()) req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer get.Body.Close() resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var c WellKnownOIDC var c WellKnownOIDC
err = json.NewDecoder(get.Body).Decode(&c) err = json.NewDecoder(resp.Body).Decode(&c)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.Config = s c.Config = s
c.OAuth2Config = oauth2.Config{ c.OAuth2Config = oauth2.Config{
ClientID: c.Config.Client.ID, ClientID: c.Config.ClientID,
ClientSecret: c.Config.Client.Secret, ClientSecret: c.Config.ClientSecret,
Endpoint: oauth2.Endpoint{ Endpoint: oauth2.Endpoint{
AuthURL: c.AuthorizationEndpoint, AuthURL: c.AuthorizationEndpoint.String(),
TokenURL: c.TokenEndpoint, TokenURL: c.TokenEndpoint.String(),
AuthStyle: oauth2.AuthStyleInHeader, AuthStyle: oauth2.AuthStyleInHeader,
}, },
Scopes: c.Config.Client.Scopes, Scopes: strings.Fields(c.Config.ClientScopes),
} }
c.Available.Store(true)
return &c, nil return &c, nil
} }
type WellKnownOIDC struct { func (o *WellKnownOIDC) Validate() error {
Namespace string `json:"-"`
Config SsoConfig `json:"-"`
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"userinfo_endpoint"`
ResponseTypesSupported []string `json:"response_types_supported"`
ScopesSupported []string `json:"scopes_supported"`
ClaimsSupported []string `json:"claims_supported"`
GrantTypesSupported []string `json:"grant_types_supported"`
OAuth2Config oauth2.Config `json:"-"`
}
func (o WellKnownOIDC) Validate() error {
if o.Issuer == "" { if o.Issuer == "" {
return errors.New("missing issuer") return errors.New("missing issuer")
} }
// check URLs are valid
if _, err := url.Parse(o.AuthorizationEndpoint); err != nil {
return err
}
if _, err := url.Parse(o.TokenEndpoint); err != nil {
return err
}
if _, err := url.Parse(o.UserInfoEndpoint); err != nil {
return err
}
// check oidc supported values // check oidc supported values
if !slices.Contains(o.ResponseTypesSupported, "code") { if !slices.Contains(o.ResponseTypesSupported, "code") {
return errors.New("missing required response type 'code'") return errors.New("missing required response type 'code'")
@ -107,6 +95,6 @@ func (o WellKnownOIDC) Validate() error {
return nil return nil
} }
func (o WellKnownOIDC) ValidReturnUrl(u *url.URL) bool { func (o *WellKnownOIDC) ValidReturnUrl(u *url.URL) bool {
return o.Config.Addr.Scheme == u.Scheme && o.Config.Addr.Host == u.Host return o.Config.Address.Scheme == u.Scheme && o.Config.Address.Host == u.Host
} }

View File

@ -1,30 +1,29 @@
package openid package openid
import "github.com/1f349/lavender/url" import "github.com/1f349/lavender/utils"
type Config struct { type Config struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"` AuthorizationEndpoint *utils.URL `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"` TokenEndpoint *utils.URL `json:"token_endpoint"`
UserInfoEndpoint string `json:"userinfo_endpoint"` UserInfoEndpoint *utils.URL `json:"userinfo_endpoint"`
ResponseTypesSupported []string `json:"response_types_supported"` ResponseTypesSupported []string `json:"response_types_supported"`
ScopesSupported []string `json:"scopes_supported"` ScopesSupported []string `json:"scopes_supported"`
ClaimsSupported []string `json:"claims_supported"` ClaimsSupported []string `json:"claims_supported"`
GrantTypesSupported []string `json:"grant_types_supported"` GrantTypesSupported []string `json:"grant_types_supported"`
JwksUri string `json:"jwks_uri"` JwksUri *utils.URL `json:"jwks_uri"`
} }
func GenConfig(baseUrl *url.URL, scopes, claims []string) Config { func GenConfig(baseUrl *utils.URL, scopes, claims []string) Config {
return Config{ return Config{
Issuer: baseUrl.String(), Issuer: baseUrl.String(),
AuthorizationEndpoint: baseUrl.Resolve("authorize").String(), AuthorizationEndpoint: baseUrl.Resolve("authorize"),
TokenEndpoint: baseUrl.Resolve("token").String(), TokenEndpoint: baseUrl.Resolve("token"),
UserInfoEndpoint: baseUrl.Resolve("userinfo").String(), UserInfoEndpoint: baseUrl.Resolve("userinfo"),
ResponseTypesSupported: []string{"code"}, ResponseTypesSupported: []string{"code"},
ScopesSupported: scopes, ScopesSupported: scopes,
ClaimsSupported: claims, ClaimsSupported: claims,
GrantTypesSupported: []string{"authorization_code", "refresh_token"}, GrantTypesSupported: []string{"authorization_code", "refresh_token"},
JwksUri: baseUrl.Resolve(".well-known/jwks.json").String(), JwksUri: baseUrl.Resolve(".well-known/jwks.json"),
} }
} }

View File

@ -1,7 +1,7 @@
package openid package openid
import ( import (
"github.com/1f349/lavender/url" "github.com/1f349/lavender/utils"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )
@ -9,13 +9,13 @@ import (
func TestGenConfig(t *testing.T) { func TestGenConfig(t *testing.T) {
assert.Equal(t, Config{ assert.Equal(t, Config{
Issuer: "https://example.com", Issuer: "https://example.com",
AuthorizationEndpoint: "https://example.com/authorize", AuthorizationEndpoint: utils.MustParse("https://example.com/authorize"),
TokenEndpoint: "https://example.com/token", TokenEndpoint: utils.MustParse("https://example.com/token"),
UserInfoEndpoint: "https://example.com/userinfo", UserInfoEndpoint: utils.MustParse("https://example.com/userinfo"),
ResponseTypesSupported: []string{"code"}, ResponseTypesSupported: []string{"code"},
ScopesSupported: []string{"openid", "email"}, ScopesSupported: []string{"openid", "email"},
ClaimsSupported: []string{"name", "email", "preferred_username"}, ClaimsSupported: []string{"name", "email", "preferred_username"},
GrantTypesSupported: []string{"authorization_code", "refresh_token"}, GrantTypesSupported: []string{"authorization_code", "refresh_token"},
JwksUri: "https://example.com/.well-known/jwks.json", JwksUri: utils.MustParse("https://example.com/.well-known/jwks.json"),
}, GenConfig(url.MustParse("https://example.com"), []string{"openid", "email"}, []string{"name", "email", "preferred_username"})) }, GenConfig(utils.MustParse("https://example.com"), []string{"openid", "email"}, []string{"name", "email", "preferred_username"}))
} }

View File

@ -309,7 +309,15 @@ func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re
return err return err
} }
sso := h.manager.FindServiceFromLogin(refreshData.Claims.Login) _, namespace, err := utils.ParseLoginName(refreshData.Claims.Login)
if err != nil {
return err
}
sso := h.manager.GetService(namespace)
if sso == nil {
return fmt.Errorf("invalid namespace: %s", namespace)
}
var oauthToken *oauth2.Token var oauthToken *oauth2.Token

View File

@ -5,13 +5,13 @@ import (
"encoding/json" "encoding/json"
"github.com/1f349/lavender/logger" "github.com/1f349/lavender/logger"
"github.com/1f349/lavender/openid" "github.com/1f349/lavender/openid"
"github.com/1f349/lavender/url" "github.com/1f349/lavender/utils"
"github.com/1f349/mjwt" "github.com/1f349/mjwt"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"net/http" "net/http"
) )
func SetupOpenId(r *httprouter.Router, baseUrl *url.URL, signingKey *mjwt.Issuer) { func SetupOpenId(r *httprouter.Router, baseUrl *utils.URL, signingKey *mjwt.Issuer) {
openIdConf := openid.GenConfig(baseUrl, []string{ openIdConf := openid.GenConfig(baseUrl, []string{
"openid", "name", "username", "profile", "email", "birthdate", "age", "zoneinfo", "locale", "openid", "name", "username", "profile", "email", "birthdate", "age", "zoneinfo", "locale",
}, []string{ }, []string{

View File

@ -1,20 +1,25 @@
package server package server
import ( import (
"encoding/json"
"errors" "errors"
"fmt"
"github.com/1f349/cache" "github.com/1f349/cache"
"github.com/1f349/lavender/auth" "github.com/1f349/lavender/auth"
"github.com/1f349/lavender/auth/process"
"github.com/1f349/lavender/auth/providers" "github.com/1f349/lavender/auth/providers"
"github.com/1f349/lavender/conf" "github.com/1f349/lavender/conf"
"github.com/1f349/lavender/database" "github.com/1f349/lavender/database"
"github.com/1f349/lavender/issuer" "github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/logger" "github.com/1f349/lavender/logger"
"github.com/1f349/lavender/mail" "github.com/1f349/lavender/mail"
"github.com/1f349/lavender/utils"
"github.com/1f349/lavender/web" "github.com/1f349/lavender/web"
"github.com/1f349/mjwt" "github.com/1f349/mjwt"
"github.com/go-oauth2/oauth2/v4/manage" "github.com/go-oauth2/oauth2/v4/manage"
"github.com/go-oauth2/oauth2/v4/server" "github.com/go-oauth2/oauth2/v4/server"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"golang.org/x/oauth2"
"net/http" "net/http"
"path" "path"
"strings" "strings"
@ -65,12 +70,7 @@ type mailLinkKey struct {
func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail, db *database.Queries, signingKey *mjwt.Issuer) { func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail, db *database.Queries, signingKey *mjwt.Issuer) {
// TODO: move auth provider init to main function // TODO: move auth provider init to main function
// TODO: allow dynamically changing the providers based on database information // TODO: allow dynamically changing the providers based on database information
authInitial := &providers.InitialLogin{} // TODO: move oauth setup into oauth provider
authPassword := &providers.PasswordLogin{DB: db}
authOtp := &providers.OtpLogin{DB: db}
authOAuth := &providers.OAuthLogin{DB: db, BaseUrl: &config.BaseUrl}
authOAuth.Init()
authPasskey := &providers.PasskeyLogin{DB: db}
hs := &httpServer{ hs := &httpServer{
r: r, r: r,
@ -80,19 +80,71 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail,
mailSender: mailSender, mailSender: mailSender,
mailLinkCache: cache.New[mailLinkKey, string](), mailLinkCache: cache.New[mailLinkKey, string](),
authSources: []auth.Provider{
authInitial,
authPassword,
authOtp,
authOAuth,
authPasskey,
},
authButtons: make([]auth.Button, 0),
formProviderLookup: make(map[string]auth.Form),
} }
var err error
hs.manager, err = issuer.NewManager(db, config.Namespace)
if err != nil {
logger.Logger.Fatal("Failed to load SSO services", "err", err)
}
authPassword := &providers.PasswordLogin{DB: db}
authOtp := &providers.OtpLogin{DB: db}
authOAuth := &providers.OAuthLogin{DB: db, BaseUrl: config.BaseUrl, Manager: hs.manager}
authOAuth.Init()
authPasskey := &providers.PasskeyLogin{DB: db}
authInitial := &providers.InitialLogin{DB: db, MyNamespace: config.Namespace, Manager: hs.manager, OAuth: authOAuth}
hs.authSources = []auth.Provider{
authInitial,
authPassword,
authOtp,
authOAuth,
authPasskey,
}
r.GET("/oauth/:namespace/start", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
namespace := params.ByName("namespace")
if !authOAuth.RedirectToAuthorize(rw, req, namespace) {
http.Error(rw, "Invalid OAuth namespace", http.StatusBadRequest)
}
})
r.GET("/oauth/:namespace/callback", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
namespace := params.ByName("namespace")
authOAuth.OAuthCallback(rw, req, namespace, func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
resp, err := sso.OAuth2Config.Client(req.Context(), token).Get(sso.UserInfoEndpoint.String())
if err != nil {
return auth.UserAuth{}, err
}
defer resp.Body.Close()
var userInfoJson auth.UserInfoFields
if err := json.NewDecoder(resp.Body).Decode(&userInfoJson); err != nil {
return auth.UserAuth{}, err
}
subject, ok := userInfoJson.GetString("sub")
if !ok {
return auth.UserAuth{}, fmt.Errorf("invalid subject")
}
subject += "@" + sso.Config.Namespace
return auth.UserAuth{
Subject: subject,
Factor: process.StateBasic,
UserInfo: userInfoJson,
}, nil
}, func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool {
// TODO: this should be using the existing auth flow calls
return hs.setLoginDataCookie(rw, authData, loginName)
}, func(rw http.ResponseWriter, req *http.Request) {
// TODO: is this really needed like this?
utils.SafeRedirect(rw, req)
})
})
// build slices and maps for quick access to auth interfaces // build slices and maps for quick access to auth interfaces
hs.authButtons = make([]auth.Button, 0)
hs.formProviderLookup = make(map[string]auth.Form)
for _, source := range hs.authSources { for _, source := range hs.authSources {
if button, isButton := source.(auth.Button); isButton { if button, isButton := source.(auth.Button); isButton {
hs.authButtons = append(hs.authButtons, button) hs.authButtons = append(hs.authButtons, button)
@ -103,13 +155,7 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail,
} }
} }
var err error SetupOpenId(r, config.BaseUrl, signingKey)
hs.manager, err = issuer.NewManager(config.Namespace, config.SsoServices)
if err != nil {
logger.Logger.Fatal("Failed to load SSO services", "err", err)
}
SetupOpenId(r, &config.BaseUrl, signingKey)
r.GET("/", hs.OptionalAuthentication(false, hs.Home)) r.GET("/", hs.OptionalAuthentication(false, hs.Home))
r.POST("/logout", hs.RequireAuthentication(hs.logoutPost)) r.POST("/logout", hs.RequireAuthentication(hs.logoutPost))

View File

@ -29,3 +29,5 @@ sql:
go_type: "database/sql.NullString" go_type: "database/sql.NullString"
- column: "users.token_expiry" - column: "users.token_expiry"
go_type: "database/sql.NullTime" go_type: "database/sql.NullTime"
- column: "oauth_sources.address"
go_type: "github.com/1f349/lavender/utils.URL"

BIN
tmp/main

Binary file not shown.

View File

@ -1,21 +0,0 @@
package utils
import (
"encoding"
"net/url"
)
type JsonUrl struct {
*url.URL
}
var _ encoding.TextUnmarshaler = &JsonUrl{}
func (s *JsonUrl) UnmarshalText(text []byte) error {
parse, err := url.Parse(string(text))
if err != nil {
return err
}
s.URL = parse
return nil
}

51
utils/loginname.go Normal file
View File

@ -0,0 +1,51 @@
package utils
import (
"errors"
"strings"
)
var ErrInvalidLoginNameFormat = errors.New("invalid login name format")
func ParseLoginName(loginName string) (user string, namespace string, err error) {
if loginName == "" || strings.HasPrefix(loginName, "@") || strings.HasSuffix(loginName, "@") || containsInvalidLoginNameRunes(loginName) {
return "", "", ErrInvalidLoginNameFormat
}
// @ should have at least one byte before it
n := strings.IndexByte(loginName, '@')
if n < 1 {
return "", "", ErrInvalidLoginNameFormat
}
// there should not be a second @
n2 := strings.IndexByte(loginName[n+1:], '@')
if n2 != -1 {
return "", "", ErrInvalidLoginNameFormat
}
return loginName[:n], loginName[n+1:], nil
}
func containsInvalidLoginNameRunes(loginName string) bool {
// check if the name contains an invalid rune
return strings.ContainsFunc(loginName, func(r rune) bool {
return !isValidLoginNameRune(r)
})
}
func isValidLoginNameRune(r rune) bool {
switch {
case r >= '0' && r <= '9':
return true
case r >= 'a' && r <= 'z':
return true
case r == '.':
return true
case r == '-':
return true
case r == '@':
return true
default:
return false
}
}

174
utils/loginname_test.go Normal file
View File

@ -0,0 +1,174 @@
package utils
import (
"encoding/hex"
"github.com/stretchr/testify/assert"
"regexp"
"strings"
"testing"
)
var parseLoginNameTests = []struct {
User string
Namespace string
HasError bool
Input string
}{
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com", false, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com"},
{"", "", true, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com\u1111"},
{"", "", true, "\u1111aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com"},
{"", "", true, "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com"},
{"", "", true, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com@"},
{"a", "b.com", false, "a@b.com"},
{"aa", "bb.com", false, "aa@bb.com"},
}
var parseLoginNameImplementations = []struct {
Name string
Func func(string) (string, string, error)
}{
{"ParseLoginName", ParseLoginName},
{"parseLoginNameLoopRunes", parseLoginNameLoopRunes},
{"parseLoginNameRegex", parseLoginNameRegex},
}
func TestParseLoginName(t *testing.T) {
for _, impl := range parseLoginNameImplementations {
t.Run(impl.Name, func(t *testing.T) {
for _, i := range parseLoginNameTests {
t.Run(i.Input, func(t *testing.T) {
user, namespace, err := ParseLoginName(i.Input)
if i.HasError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, i.User, user)
assert.Equal(t, i.Namespace, namespace)
})
}
})
}
}
func FuzzParseLoginName(f *testing.F) {
for _, i := range parseLoginNameTests {
f.Add(i.Input)
}
f.Fuzz(func(t *testing.T, s string) {
t.Log("Input: ", s, hex.EncodeToString([]byte(s)))
hasError := s == "" || strings.HasPrefix(s, "@") || strings.HasSuffix(s, "@") || strings.Count(s, "@") != 1 || strings.ContainsFunc(s, func(r rune) bool {
return !isValidLoginNameRune(r)
})
login, namespace, err := ParseLoginName(s)
if hasError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
if err == nil {
n := strings.IndexRune(s, '@')
assert.Equal(t, s[:n], login)
assert.Equal(t, s[n+1:], namespace)
} else {
assert.Equal(t, login, "", "Login should be empty if an error occurred")
assert.Equal(t, namespace, "", "Namespace should be empty if an error occurred")
}
})
}
func parseLoginBench(b *testing.B, f func(string) (string, string, error)) {
for _, i := range parseLoginNameTests {
b.Run(i.Input, func(b *testing.B) {
for range b.N {
_, _, _ = f(i.Input)
}
})
}
}
func BenchmarkParseLoginName(b *testing.B) {
b.Run("ParseLoginName", func(b *testing.B) {
parseLoginBench(b, ParseLoginName)
})
b.Run("parseLoginNameLoopRunes", func(b *testing.B) {
parseLoginBench(b, parseLoginNameLoopRunes)
})
b.Run("parseLoginNameRegex", func(b *testing.B) {
parseLoginBench(b, parseLoginNameRegex)
})
}
func parseLoginNameLoopRunes(s string) (user string, namespace string, err error) {
if s == "" || strings.HasPrefix(s, "@") || strings.HasSuffix(s, "@") {
return "", "", ErrInvalidLoginNameFormat
}
hasAt := false
atIndex := 0
for i, r := range []rune(s) {
switch {
case r == '@':
if hasAt {
return "", "", ErrInvalidLoginNameFormat
}
hasAt = true
atIndex = i
case r == '.':
continue
case r == '-':
continue
case r >= 'a' && r <= 'z':
continue
case r >= '0' && r <= '9':
continue
default:
return "", "", ErrInvalidLoginNameFormat
}
}
return s[:atIndex], s[atIndex+1:], nil
}
// regexLoginName is a regex based implementation of ParseLoginName
//
// This implementation prevents using - or . at the start and end of the user and namespace
var regexLoginName = regexp.MustCompile(`^([a-z0-9]([a-z0-9-.]*[a-z0-9]|))@([a-z0-9]([a-z0-9-.]*[a-z0-9]|))$`)
func parseLoginNameRegex(s string) (user string, namespace string, err error) {
if s == "" {
return "", "", ErrInvalidLoginNameFormat
}
matches := regexLoginName.FindStringSubmatch(s)
if matches == nil {
return "", "", ErrInvalidLoginNameFormat
}
return matches[1], matches[3], nil
}
func BenchmarkA(b *testing.B) {
b.Run("A", func(b *testing.B) {
for range b.N {
_ = isValidLoginNameRune('b')
_ = isValidLoginNameRune('4')
_ = isValidLoginNameRune('.')
_ = isValidLoginNameRune('-')
_ = isValidLoginNameRune('@')
_ = isValidLoginNameRune(' ')
}
})
b.Run("B", func(b *testing.B) {
for range b.N {
_ = isValidLoginNameRune2('b')
_ = isValidLoginNameRune2('4')
_ = isValidLoginNameRune2('.')
_ = isValidLoginNameRune2('-')
_ = isValidLoginNameRune2('@')
_ = isValidLoginNameRune2(' ')
}
})
}
func isValidLoginNameRune2(r rune) bool {
return (r >= '0' && r <= '9') || (r >= 'a' && r <= 'z') || r == '.' || r == '-' || r == '@'
}

View File

@ -1,4 +1,4 @@
package url package utils
import ( import (
"encoding" "encoding"
@ -6,35 +6,33 @@ import (
"path" "path"
) )
type URL struct {
url.URL
}
func (u *URL) Resolve(paths ...string) *URL {
return &URL{URL: *u.URL.ResolveReference(&url.URL{Path: path.Join(paths...)})}
}
func (u URL) MarshalText() (text []byte, err error) {
return []byte(u.String()), nil
}
func (u *URL) UnmarshalText(text []byte) error {
parse, err := u.Parse(string(text))
if err != nil {
return err
}
u.URL = *parse
return nil
}
var _ encoding.TextMarshaler = (*URL)(nil) var _ encoding.TextMarshaler = (*URL)(nil)
var _ encoding.TextUnmarshaler = (*URL)(nil) var _ encoding.TextUnmarshaler = (*URL)(nil)
type URL struct{ *url.URL }
func MustParse(rawURL string) *URL { func MustParse(rawURL string) *URL {
u, err := url.Parse(rawURL) u, err := url.Parse(rawURL)
if err != nil { if err != nil {
panic(err) panic(err)
} }
return &URL{*u} return &URL{u}
}
func (u *URL) Resolve(paths ...string) *URL {
return &URL{URL: u.URL.ResolveReference(&url.URL{Path: path.Join(paths...)})}
}
func (u *URL) MarshalText() (text []byte, err error) {
return []byte(u.URL.String()), nil
}
func (u *URL) UnmarshalText(text []byte) error {
parse, err := url.Parse(string(text))
if err != nil {
return err
}
u.URL = parse
return nil
} }