mirror of
https://github.com/1f349/lavender.git
synced 2025-04-14 06:55:55 +01:00
Rewrite OAuth manager, OAuth source and URL wrapper
This commit is contained in:
parent
d9b0074133
commit
50df217b66
@ -5,10 +5,11 @@ import (
|
||||
"fmt"
|
||||
"github.com/1f349/lavender/auth"
|
||||
"github.com/1f349/lavender/auth/authContext"
|
||||
process "github.com/1f349/lavender/auth/process"
|
||||
"github.com/1f349/lavender/auth/process"
|
||||
"github.com/1f349/lavender/database"
|
||||
"github.com/1f349/lavender/issuer"
|
||||
"github.com/1f349/lavender/logger"
|
||||
"github.com/1f349/lavender/utils"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
@ -25,6 +26,7 @@ type InitialLogin struct {
|
||||
DB *database.Queries
|
||||
MyNamespace string
|
||||
Manager *issuer.Manager
|
||||
OAuth *OAuthLogin
|
||||
}
|
||||
|
||||
func (m *InitialLogin) AccessState() process.State { return process.StateUnauthorized }
|
||||
@ -101,36 +103,39 @@ func (m *InitialLogin) AttemptLogin(ctx authContext.FormContext) error {
|
||||
}
|
||||
|
||||
// append local namespace if @ is missing
|
||||
n := strings.IndexByte(loginName, '@')
|
||||
if n < 0 {
|
||||
// correct the @ index
|
||||
n = len(loginName)
|
||||
if !strings.ContainsRune(loginName, '@') {
|
||||
loginName += "@" + m.MyNamespace
|
||||
}
|
||||
|
||||
login := m.Manager.FindServiceFromLogin(loginName)
|
||||
user, namespace, err := utils.ParseLoginName(loginName)
|
||||
if err != nil {
|
||||
http.Error(rw, "Invalid login name", http.StatusBadRequest)
|
||||
return fmt.Errorf("invalid login name %s", loginName)
|
||||
}
|
||||
|
||||
login := m.Manager.GetService(namespace)
|
||||
if login == nil {
|
||||
http.Error(rw, "No login service defined for this username", http.StatusBadRequest)
|
||||
return errors.New("no login service defined for this username")
|
||||
}
|
||||
|
||||
// the @ must exist if the service is defined
|
||||
loginUn := loginName[:n]
|
||||
|
||||
// TODO: finish migrating this shit
|
||||
|
||||
// the login is not for this namespace
|
||||
if login != issuer.MeWellKnown {
|
||||
// TODO: this is oauth request code start oauth request
|
||||
// TODO: this could all be 1 function call into oauth
|
||||
|
||||
// save state for use later
|
||||
state := login.Config.Namespace + ":" + uuid.NewString()
|
||||
h.flowState.Set(state, flowStateData{loginName, login, req.PostFormValue("redirect")}, time.Now().Add(15*time.Minute))
|
||||
m.OAuth.flow.Set(state, flowStateData{loginName, login, req.PostFormValue("redirect")}, time.Now().Add(15*time.Minute))
|
||||
|
||||
// generate oauth2 config and redirect to authorize URL
|
||||
oa2conf := login.OAuth2Config
|
||||
oa2conf.RedirectURL = h.conf.BaseUrl.JoinPath("callback").String()
|
||||
nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", loginUn))
|
||||
http.Redirect(rw, req, nextUrl, http.StatusFound)
|
||||
return
|
||||
oa2conf.RedirectURL = m.OAuth.BaseUrl.JoinPath("callback").String()
|
||||
nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", user))
|
||||
return auth.RedirectError{
|
||||
Target: nextUrl,
|
||||
Code: http.StatusFound,
|
||||
}
|
||||
}
|
||||
|
||||
ctx.UpdateSession(process.LoginProcessData{
|
||||
|
@ -13,17 +13,18 @@ import (
|
||||
"github.com/1f349/lavender/database"
|
||||
"github.com/1f349/lavender/database/types"
|
||||
"github.com/1f349/lavender/issuer"
|
||||
"github.com/1f349/lavender/url"
|
||||
"github.com/1f349/lavender/utils"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mrmelon54/pronouns"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OauthCallback interface {
|
||||
OAuthCallback(rw http.ResponseWriter, req *http.Request, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request))
|
||||
OAuthCallback(rw http.ResponseWriter, req *http.Request, namespace string, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request))
|
||||
}
|
||||
|
||||
type flowStateData struct {
|
||||
@ -40,9 +41,9 @@ var (
|
||||
type OAuthLogin struct {
|
||||
DB *database.Queries
|
||||
|
||||
BaseUrl *url.URL
|
||||
|
||||
flow *cache.Cache[string, flowStateData]
|
||||
BaseUrl *utils.URL
|
||||
flow *cache.Cache[string, flowStateData]
|
||||
Manager *issuer.Manager
|
||||
}
|
||||
|
||||
func (o OAuthLogin) Init() {
|
||||
@ -50,7 +51,7 @@ func (o OAuthLogin) Init() {
|
||||
}
|
||||
|
||||
func (o OAuthLogin) authUrlBase(ref string) *url.URL {
|
||||
return o.BaseUrl.Resolve("oauth", o.Name(), ref)
|
||||
return o.BaseUrl.Resolve("oauth", o.Name(), ref).URL
|
||||
}
|
||||
|
||||
func (o OAuthLogin) AccessState() process.State { return process.StateUnauthorized }
|
||||
@ -83,12 +84,16 @@ func (o OAuthLogin) AttemptLogin(ctx authContext.FormContext) error {
|
||||
return auth.RedirectError{Target: nextUrl, Code: http.StatusFound}
|
||||
}
|
||||
|
||||
func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request)) {
|
||||
func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, namespace string, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request)) {
|
||||
flowState, ok := o.flow.Get(req.FormValue("state"))
|
||||
if !ok {
|
||||
http.Error(rw, "Invalid flow state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if flowState.sso.Namespace != namespace {
|
||||
http.Error(rw, "OAuth source mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
token, err := flowState.sso.OAuth2Config.Exchange(context.Background(), req.FormValue("code"), oauth2.SetAuthURLParam("redirect_uri", o.authUrlBase("callback").String()))
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to exchange code for token", http.StatusInternalServerError)
|
||||
@ -112,17 +117,32 @@ func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, inf
|
||||
}
|
||||
|
||||
func (o OAuthLogin) RenderButtonTemplate(ctx authContext.TemplateContext) {
|
||||
// TODO: idk what this is
|
||||
// o.authUrlBase("button")
|
||||
// provide something non-nil
|
||||
ctx.Render(struct {
|
||||
Href string
|
||||
ButtonName string
|
||||
}{
|
||||
Href: o.authUrlBase("button").String(),
|
||||
Href: o.authUrlBase("start").String(),
|
||||
ButtonName: "Login with Unknown OAuth Button", // TODO: actually get the service name
|
||||
})
|
||||
}
|
||||
|
||||
// RedirectToAuthorize returns true when redirecting to the authorize endpoint
|
||||
// for the requested namespace.
|
||||
func (o OAuthLogin) RedirectToAuthorize(rw http.ResponseWriter, req *http.Request, namespace string) bool {
|
||||
oidc := o.Manager.GetService(namespace)
|
||||
if oidc == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: come back to fix this
|
||||
oidc.OAuth2Config.AuthCodeURL("state")
|
||||
// http.Redirect(rw, req, oidc.AuthorizationEndpoint, http.StatusOK)
|
||||
return true
|
||||
}
|
||||
|
||||
type oauthServiceLogin int
|
||||
|
||||
func WithWellKnown(ctx context.Context, login *issuer.WellKnownOIDC) context.Context {
|
||||
@ -256,7 +276,7 @@ func (o OAuthLogin) updateOAuth2UserProfile(ctx context.Context, tx *database.Qu
|
||||
}
|
||||
|
||||
func (o OAuthLogin) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
|
||||
res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint)
|
||||
res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint.String())
|
||||
if err != nil || res.StatusCode != http.StatusOK {
|
||||
return auth.UserAuth{}, fmt.Errorf("request failed")
|
||||
}
|
||||
|
20
conf/conf.go
20
conf/conf.go
@ -1,19 +1,17 @@
|
||||
package conf
|
||||
|
||||
import (
|
||||
"github.com/1f349/lavender/issuer"
|
||||
"github.com/1f349/lavender/url"
|
||||
"github.com/1f349/lavender/utils"
|
||||
"github.com/1f349/simplemail"
|
||||
)
|
||||
|
||||
type Conf struct {
|
||||
Listen string `yaml:"listen"`
|
||||
BaseUrl url.URL `yaml:"baseUrl"`
|
||||
ServiceName string `yaml:"serviceName"`
|
||||
Issuer string `yaml:"issuer"`
|
||||
Kid string `yaml:"kid"`
|
||||
Namespace string `yaml:"namespace"`
|
||||
OtpIssuer string `yaml:"otpIssuer"`
|
||||
Mail simplemail.Mail `yaml:"mail"`
|
||||
SsoServices []issuer.SsoConfig `yaml:"ssoServices"`
|
||||
Listen string `yaml:"listen"`
|
||||
BaseUrl *utils.URL `yaml:"baseUrl"`
|
||||
ServiceName string `yaml:"serviceName"`
|
||||
Issuer string `yaml:"issuer"`
|
||||
Kid string `yaml:"kid"`
|
||||
Namespace string `yaml:"namespace"`
|
||||
OtpIssuer string `yaml:"otpIssuer"`
|
||||
Mail simplemail.Mail `yaml:"mail"`
|
||||
}
|
||||
|
45
database/manage-oauth-source.sql.go
Normal file
45
database/manage-oauth-source.sql.go
Normal file
@ -0,0 +1,45 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.28.0
|
||||
// source: manage-oauth-source.sql
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const getOAuthSources = `-- name: GetOAuthSources :many
|
||||
SELECT namespace, address, registration, button, client_id, client_secret, client_scopes FROM oauth_sources
|
||||
`
|
||||
|
||||
func (q *Queries) GetOAuthSources(ctx context.Context) ([]OauthSource, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getOAuthSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []OauthSource
|
||||
for rows.Next() {
|
||||
var i OauthSource
|
||||
if err := rows.Scan(
|
||||
&i.Namespace,
|
||||
&i.Address,
|
||||
&i.Registration,
|
||||
&i.Button,
|
||||
&i.ClientID,
|
||||
&i.ClientSecret,
|
||||
&i.ClientScopes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
10
database/migrations/20250303230832_oauth_source.up.sql
Normal file
10
database/migrations/20250303230832_oauth_source.up.sql
Normal file
@ -0,0 +1,10 @@
|
||||
CREATE TABLE oauth_sources
|
||||
(
|
||||
namespace TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||
address TEXT NOT NULL,
|
||||
registration BOOLEAN NOT NULL,
|
||||
button BOOLEAN NOT NULL,
|
||||
client_id TEXT NOT NULL,
|
||||
client_secret TEXT NOT NULL,
|
||||
client_scopes TEXT NOT NULL
|
||||
);
|
@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/1f349/lavender/database/types"
|
||||
"github.com/1f349/lavender/password"
|
||||
"github.com/1f349/lavender/utils"
|
||||
"github.com/hardfinhq/go-date"
|
||||
)
|
||||
|
||||
@ -25,6 +26,16 @@ type ClientStore struct {
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
type OauthSource struct {
|
||||
Namespace string `json:"namespace"`
|
||||
Address utils.URL `json:"address"`
|
||||
Registration bool `json:"registration"`
|
||||
Button bool `json:"button"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
ClientScopes string `json:"client_scopes"`
|
||||
}
|
||||
|
||||
type Role struct {
|
||||
ID int64 `json:"id"`
|
||||
Role string `json:"role"`
|
||||
|
2
database/queries/manage-oauth-source.sql
Normal file
2
database/queries/manage-oauth-source.sql
Normal file
@ -0,0 +1,2 @@
|
||||
-- name: GetOAuthSources :many
|
||||
SELECT * FROM oauth_sources;
|
1
go.mod
1
go.mod
@ -19,6 +19,7 @@ require (
|
||||
github.com/julienschmidt/httprouter v1.3.0
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/mrmelon54/pronouns v1.0.3
|
||||
github.com/robfig/cron v1.2.0
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/spf13/afero v1.12.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
|
2
go.sum
2
go.sum
@ -138,6 +138,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/robfig/cron v1.2.0 h1:ZjScXvvxeQ63Dbyxy76Fj3AT3Ut0aKsyd2/tl3DTMuQ=
|
||||
github.com/robfig/cron v1.2.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k=
|
||||
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||
github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=
|
||||
|
@ -1,57 +1,226 @@
|
||||
package issuer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/1f349/lavender/database"
|
||||
"github.com/1f349/lavender/logger"
|
||||
"github.com/robfig/cron"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
fetchPoolSize = 4
|
||||
fetchTimeout = 2 * time.Minute
|
||||
fetchRetryDelay = 15 * time.Second
|
||||
fetchRetryAfterTimeout = 15 * time.Minute
|
||||
fetchRetryCount = 3
|
||||
)
|
||||
|
||||
var isValidNamespace = regexp.MustCompile("^[0-9a-z.]+$")
|
||||
|
||||
var MeWellKnown = &WellKnownOIDC{}
|
||||
|
||||
type Manager struct {
|
||||
m map[string]*WellKnownOIDC
|
||||
type managerReloadDB interface {
|
||||
GetOAuthSources(ctx context.Context) ([]database.OauthSource, error)
|
||||
}
|
||||
|
||||
func NewManager(myNamespace string, services []SsoConfig) (*Manager, error) {
|
||||
l := &Manager{m: make(map[string]*WellKnownOIDC)}
|
||||
l.m[myNamespace] = MeWellKnown
|
||||
for _, ssoService := range services {
|
||||
if !isValidNamespace.MatchString(ssoService.Namespace) {
|
||||
return nil, fmt.Errorf("invalid namespace: %s", ssoService.Namespace)
|
||||
}
|
||||
type Manager struct {
|
||||
db managerReloadDB
|
||||
|
||||
conf, err := ssoService.FetchConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client *http.Client
|
||||
|
||||
// save by namespace
|
||||
l.m[ssoService.Namespace] = conf
|
||||
sourceMu *sync.RWMutex
|
||||
sourceMap map[string]*WellKnownOIDC
|
||||
|
||||
errMu *sync.RWMutex
|
||||
errMap map[string]FetchErrorDetails
|
||||
|
||||
shutdownCall context.CancelFunc
|
||||
}
|
||||
|
||||
type FetchErrorDetails struct {
|
||||
Retry time.Time
|
||||
Errors []error
|
||||
}
|
||||
|
||||
type managerRoundTripper struct {
|
||||
Base http.RoundTripper
|
||||
}
|
||||
|
||||
func (a *managerRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
userAgent := strings.Join(request.Header.Values("User-Agent"), " ")
|
||||
if strings.TrimSpace(userAgent) == "" {
|
||||
request.Header.Set("User-Agent", "Lavender/1.0 (OAuth Manager)")
|
||||
}
|
||||
return a.Base.RoundTrip(request)
|
||||
}
|
||||
|
||||
func NewManager(db managerReloadDB, myNamespace string) (*Manager, error) {
|
||||
return NewManagerWithClient(db, myNamespace, &http.Client{
|
||||
Transport: &managerRoundTripper{
|
||||
Base: http.DefaultTransport,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func NewManagerWithClient(db managerReloadDB, myNamespace string, client *http.Client) (*Manager, error) {
|
||||
l := &Manager{
|
||||
db: db,
|
||||
client: client,
|
||||
sourceMu: new(sync.RWMutex),
|
||||
sourceMap: make(map[string]*WellKnownOIDC),
|
||||
errMap: make(map[string]FetchErrorDetails),
|
||||
}
|
||||
l.sourceMap[myNamespace] = MeWellKnown
|
||||
|
||||
// reload should run blocking on start up
|
||||
l.internalReload(context.Background())
|
||||
|
||||
reloadDaily := make(chan struct{}, 1)
|
||||
|
||||
tab := cron.New()
|
||||
err := tab.AddFunc("0 2 * * *", func() {
|
||||
reloadDaily <- struct{}{}
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go l.reloadWorker(ctx, reloadDaily)
|
||||
|
||||
l.shutdownCall = cancel
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Shutdown() { m.shutdownCall() }
|
||||
|
||||
func (m *Manager) reloadWorker(ctx context.Context, reload <-chan struct{}) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Logger.Info("Shutting down oauth manager reload worker")
|
||||
return
|
||||
case <-reload:
|
||||
m.internalReload(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) internalReload(ctx context.Context) {
|
||||
logger.Logger.Info("Reloading oauth sources")
|
||||
dbCtx, cancel := context.WithTimeout(ctx, fetchTimeout)
|
||||
sources, err := m.db.GetOAuthSources(dbCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
// TODO: send email to admin
|
||||
logger.Logger.Warn("Failed to load OAuth sources from the database", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
// construct map of db oauth sources
|
||||
toAdd := make(map[string]database.OauthSource)
|
||||
for _, source := range sources {
|
||||
toAdd[source.Namespace] = source
|
||||
}
|
||||
|
||||
// remove sources from the active map
|
||||
m.sourceMu.Lock()
|
||||
for k, v := range m.sourceMap {
|
||||
// ignore my own well-known
|
||||
if v == MeWellKnown {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, shouldAdd := toAdd[k]; !shouldAdd {
|
||||
delete(m.sourceMap, k)
|
||||
}
|
||||
}
|
||||
m.sourceMu.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(fetchPoolSize)
|
||||
|
||||
fetchChan := make(chan database.OauthSource, fetchPoolSize)
|
||||
|
||||
for range fetchPoolSize {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for source := range fetchChan {
|
||||
itemCtx, cancel := context.WithTimeout(ctx, fetchTimeout)
|
||||
errs := m.reloadSource(itemCtx, source)
|
||||
cancel()
|
||||
|
||||
// this also resets the error if it was set previously
|
||||
m.sourceMu.Lock()
|
||||
if len(errs) > 0 {
|
||||
m.errMap[source.Namespace] = FetchErrorDetails{
|
||||
Retry: time.Now().Add(fetchRetryAfterTimeout),
|
||||
Errors: errs,
|
||||
}
|
||||
// TODO: send email to admin
|
||||
} else {
|
||||
delete(m.errMap, source.Namespace)
|
||||
// TODO: send resolved email to admin
|
||||
}
|
||||
m.sourceMu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// reload sources
|
||||
for _, source := range toAdd {
|
||||
fetchChan <- source
|
||||
}
|
||||
|
||||
close(fetchChan)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (m *Manager) reloadSource(ctx context.Context, source database.OauthSource) []error {
|
||||
var errs []error
|
||||
|
||||
for range fetchRetryCount {
|
||||
oidc, err := fetchConfig(ctx, m.client, source)
|
||||
fmt.Println(oidc, err)
|
||||
if err == nil {
|
||||
// fetch was successful
|
||||
m.sourceMu.Lock()
|
||||
m.sourceMap[source.Namespace] = oidc
|
||||
m.sourceMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("failed to fetch OIDC well-known configuration: %w", err))
|
||||
|
||||
// wait before retrying
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errs
|
||||
case <-time.After(fetchRetryDelay):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
func (m *Manager) CheckNamespace(namespace string) bool {
|
||||
_, ok := m.m[namespace]
|
||||
m.sourceMu.RLock()
|
||||
defer m.sourceMu.RUnlock()
|
||||
_, ok := m.sourceMap[namespace]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (m *Manager) GetService(namespace string) *WellKnownOIDC {
|
||||
return m.m[namespace]
|
||||
}
|
||||
|
||||
func (m *Manager) FindServiceFromLogin(login string) *WellKnownOIDC {
|
||||
// @ should have at least one byte before it
|
||||
n := strings.IndexByte(login, '@')
|
||||
if n < 1 {
|
||||
return nil
|
||||
}
|
||||
// there should not be a second @
|
||||
n2 := strings.IndexByte(login[n+1:], '@')
|
||||
if n2 != -1 {
|
||||
return nil
|
||||
}
|
||||
return m.GetService(login[n+1:])
|
||||
m.sourceMu.RLock()
|
||||
defer m.sourceMu.RUnlock()
|
||||
return m.sourceMap[namespace]
|
||||
}
|
||||
|
@ -1,6 +1,9 @@
|
||||
package issuer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/1f349/lavender/database"
|
||||
"github.com/1f349/lavender/utils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
@ -8,29 +11,54 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var testAddrUrl = func() utils.JsonUrl {
|
||||
var testAddrUrl = func() utils.URL {
|
||||
a, err := url.Parse("https://example.com")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return utils.JsonUrl{URL: a}
|
||||
return utils.URL{URL: a}
|
||||
}()
|
||||
|
||||
func testBody() io.ReadCloser {
|
||||
return io.NopCloser(strings.NewReader("{}"))
|
||||
return io.NopCloser(strings.NewReader(`{
|
||||
"issuer": "Example.com Issuer",
|
||||
"authorization_endpoint": "https://example.com/oauth/authorize",
|
||||
"token_endpoint": "https://example.com/oauth/token",
|
||||
"userinfo_endpoint": "https://example.com/oauth/userinfo",
|
||||
"revocation_endpoint": "https://example.com/oauth/revoke",
|
||||
"response_types_supported": [
|
||||
"code"
|
||||
],
|
||||
"scopes_supported": [
|
||||
"openid"
|
||||
],
|
||||
"claims_supported": [
|
||||
"sub"
|
||||
],
|
||||
"grant_types_supported": [
|
||||
"authorization_code",
|
||||
"refresh_token"
|
||||
]
|
||||
}
|
||||
`))
|
||||
}
|
||||
|
||||
func TestManager_CheckNamespace(t *testing.T) {
|
||||
httpGet = func(url string) (resp *http.Response, err error) {
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
|
||||
}
|
||||
manager, err := NewManager("example.org", []SsoConfig{
|
||||
{
|
||||
Addr: testAddrUrl,
|
||||
Namespace: "example.com",
|
||||
manager, err := NewManagerWithClient(&testDB{
|
||||
Rows: []database.OauthSource{
|
||||
{Address: testAddrUrl, Namespace: "example.com"},
|
||||
},
|
||||
}, "example.org", &http.Client{
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
fmt.Println("Request URL:", req.URL)
|
||||
if req.URL.String() == "https://example.com/.well-known/openid-configuration" {
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("request failed")
|
||||
}),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, manager.CheckNamespace("example.org"))
|
||||
@ -39,17 +67,36 @@ func TestManager_CheckNamespace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestManager_FindServiceFromLogin(t *testing.T) {
|
||||
httpGet = func(url string) (resp *http.Response, err error) {
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
|
||||
}
|
||||
manager, err := NewManager("example.org", []SsoConfig{
|
||||
{
|
||||
Addr: testAddrUrl,
|
||||
Namespace: "example.com",
|
||||
manager, err := NewManagerWithClient(&testDB{
|
||||
Rows: []database.OauthSource{
|
||||
{Address: testAddrUrl, Namespace: "example.com"},
|
||||
},
|
||||
}, "example.org", &http.Client{
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
fmt.Println("Request URL:", req.URL)
|
||||
if req.URL.String() != "https://example.com/.well-known/openid-configuration" {
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("request failed")
|
||||
}),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, manager.FindServiceFromLogin("jane@example.org"), MeWellKnown)
|
||||
assert.Equal(t, manager.FindServiceFromLogin("jane@example.com"), manager.m["example.com"])
|
||||
assert.Nil(t, manager.FindServiceFromLogin("jane@missing.example.com"))
|
||||
assert.Equal(t, manager.GetService("example.org"), MeWellKnown)
|
||||
assert.Equal(t, manager.GetService("example.com"), manager.sourceMap["example.com"])
|
||||
assert.Nil(t, manager.GetService("missing.example.com"))
|
||||
}
|
||||
|
||||
type testDB struct {
|
||||
Rows []database.OauthSource
|
||||
}
|
||||
|
||||
func (t *testDB) GetOAuthSources(ctx context.Context) ([]database.OauthSource, error) {
|
||||
return t.Rows, nil
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (r roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
time.Sleep(time.Minute)
|
||||
return r(req)
|
||||
}
|
||||
|
@ -1,94 +1,82 @@
|
||||
package issuer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/1f349/lavender/database"
|
||||
"github.com/1f349/lavender/utils"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var httpGet = http.Get
|
||||
|
||||
// SsoConfig is the base URL for an OAUTH/OPENID/SSO login service
|
||||
// The path `/.well-known/openid-configuration` should be available
|
||||
type SsoConfig struct {
|
||||
Addr utils.JsonUrl `json:"addr" yaml:"addr"` // https://login.example.com
|
||||
Namespace string `json:"namespace" yaml:"namespace"` // example.com
|
||||
Registration bool `json:"registration" yaml:"registration"`
|
||||
LoginWithButton bool `json:"login_with_button" yaml:"loginWithButton"`
|
||||
Client SsoConfigClient `json:"client" yaml:"client"`
|
||||
type WellKnownOIDC struct {
|
||||
Namespace string `json:"-"`
|
||||
Config database.OauthSource `json:"-"`
|
||||
Available atomic.Bool `json:"-"`
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint *utils.URL `json:"authorization_endpoint"`
|
||||
TokenEndpoint *utils.URL `json:"token_endpoint"`
|
||||
UserInfoEndpoint *utils.URL `json:"userinfo_endpoint"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||
ScopesSupported []string `json:"scopes_supported"`
|
||||
ClaimsSupported []string `json:"claims_supported"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||
JwksUri *utils.URL `json:"jwks_uri"`
|
||||
OAuth2Config oauth2.Config `json:"-"`
|
||||
LastFetch time.Time `json:"-"`
|
||||
}
|
||||
|
||||
type SsoConfigClient struct {
|
||||
ID string `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
Scopes []string `json:"scopes"`
|
||||
}
|
||||
|
||||
func (s SsoConfig) FetchConfig() (*WellKnownOIDC, error) {
|
||||
func fetchConfig(ctx context.Context, httpClient *http.Client, s database.OauthSource) (*WellKnownOIDC, error) {
|
||||
// generate openid config url
|
||||
u := s.Addr.JoinPath(".well-known/openid-configuration")
|
||||
u := s.Address.JoinPath(".well-known/openid-configuration")
|
||||
|
||||
// fetch metadata
|
||||
get, err := httpGet(u.String())
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer get.Body.Close()
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var c WellKnownOIDC
|
||||
err = json.NewDecoder(get.Body).Decode(&c)
|
||||
err = json.NewDecoder(resp.Body).Decode(&c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Config = s
|
||||
c.OAuth2Config = oauth2.Config{
|
||||
ClientID: c.Config.Client.ID,
|
||||
ClientSecret: c.Config.Client.Secret,
|
||||
ClientID: c.Config.ClientID,
|
||||
ClientSecret: c.Config.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: c.AuthorizationEndpoint,
|
||||
TokenURL: c.TokenEndpoint,
|
||||
AuthURL: c.AuthorizationEndpoint.String(),
|
||||
TokenURL: c.TokenEndpoint.String(),
|
||||
AuthStyle: oauth2.AuthStyleInHeader,
|
||||
},
|
||||
Scopes: c.Config.Client.Scopes,
|
||||
Scopes: strings.Fields(c.Config.ClientScopes),
|
||||
}
|
||||
c.Available.Store(true)
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
type WellKnownOIDC struct {
|
||||
Namespace string `json:"-"`
|
||||
Config SsoConfig `json:"-"`
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserInfoEndpoint string `json:"userinfo_endpoint"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||
ScopesSupported []string `json:"scopes_supported"`
|
||||
ClaimsSupported []string `json:"claims_supported"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||
OAuth2Config oauth2.Config `json:"-"`
|
||||
}
|
||||
|
||||
func (o WellKnownOIDC) Validate() error {
|
||||
func (o *WellKnownOIDC) Validate() error {
|
||||
if o.Issuer == "" {
|
||||
return errors.New("missing issuer")
|
||||
}
|
||||
|
||||
// check URLs are valid
|
||||
if _, err := url.Parse(o.AuthorizationEndpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := url.Parse(o.TokenEndpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := url.Parse(o.UserInfoEndpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check oidc supported values
|
||||
if !slices.Contains(o.ResponseTypesSupported, "code") {
|
||||
return errors.New("missing required response type 'code'")
|
||||
@ -107,6 +95,6 @@ func (o WellKnownOIDC) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o WellKnownOIDC) ValidReturnUrl(u *url.URL) bool {
|
||||
return o.Config.Addr.Scheme == u.Scheme && o.Config.Addr.Host == u.Host
|
||||
func (o *WellKnownOIDC) ValidReturnUrl(u *url.URL) bool {
|
||||
return o.Config.Address.Scheme == u.Scheme && o.Config.Address.Host == u.Host
|
||||
}
|
||||
|
@ -1,30 +1,29 @@
|
||||
package openid
|
||||
|
||||
import "github.com/1f349/lavender/url"
|
||||
import "github.com/1f349/lavender/utils"
|
||||
|
||||
type Config struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserInfoEndpoint string `json:"userinfo_endpoint"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||
ScopesSupported []string `json:"scopes_supported"`
|
||||
ClaimsSupported []string `json:"claims_supported"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||
JwksUri string `json:"jwks_uri"`
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint *utils.URL `json:"authorization_endpoint"`
|
||||
TokenEndpoint *utils.URL `json:"token_endpoint"`
|
||||
UserInfoEndpoint *utils.URL `json:"userinfo_endpoint"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||
ScopesSupported []string `json:"scopes_supported"`
|
||||
ClaimsSupported []string `json:"claims_supported"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||
JwksUri *utils.URL `json:"jwks_uri"`
|
||||
}
|
||||
|
||||
func GenConfig(baseUrl *url.URL, scopes, claims []string) Config {
|
||||
|
||||
func GenConfig(baseUrl *utils.URL, scopes, claims []string) Config {
|
||||
return Config{
|
||||
Issuer: baseUrl.String(),
|
||||
AuthorizationEndpoint: baseUrl.Resolve("authorize").String(),
|
||||
TokenEndpoint: baseUrl.Resolve("token").String(),
|
||||
UserInfoEndpoint: baseUrl.Resolve("userinfo").String(),
|
||||
AuthorizationEndpoint: baseUrl.Resolve("authorize"),
|
||||
TokenEndpoint: baseUrl.Resolve("token"),
|
||||
UserInfoEndpoint: baseUrl.Resolve("userinfo"),
|
||||
ResponseTypesSupported: []string{"code"},
|
||||
ScopesSupported: scopes,
|
||||
ClaimsSupported: claims,
|
||||
GrantTypesSupported: []string{"authorization_code", "refresh_token"},
|
||||
JwksUri: baseUrl.Resolve(".well-known/jwks.json").String(),
|
||||
JwksUri: baseUrl.Resolve(".well-known/jwks.json"),
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
package openid
|
||||
|
||||
import (
|
||||
"github.com/1f349/lavender/url"
|
||||
"github.com/1f349/lavender/utils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
@ -9,13 +9,13 @@ import (
|
||||
func TestGenConfig(t *testing.T) {
|
||||
assert.Equal(t, Config{
|
||||
Issuer: "https://example.com",
|
||||
AuthorizationEndpoint: "https://example.com/authorize",
|
||||
TokenEndpoint: "https://example.com/token",
|
||||
UserInfoEndpoint: "https://example.com/userinfo",
|
||||
AuthorizationEndpoint: utils.MustParse("https://example.com/authorize"),
|
||||
TokenEndpoint: utils.MustParse("https://example.com/token"),
|
||||
UserInfoEndpoint: utils.MustParse("https://example.com/userinfo"),
|
||||
ResponseTypesSupported: []string{"code"},
|
||||
ScopesSupported: []string{"openid", "email"},
|
||||
ClaimsSupported: []string{"name", "email", "preferred_username"},
|
||||
GrantTypesSupported: []string{"authorization_code", "refresh_token"},
|
||||
JwksUri: "https://example.com/.well-known/jwks.json",
|
||||
}, GenConfig(url.MustParse("https://example.com"), []string{"openid", "email"}, []string{"name", "email", "preferred_username"}))
|
||||
JwksUri: utils.MustParse("https://example.com/.well-known/jwks.json"),
|
||||
}, GenConfig(utils.MustParse("https://example.com"), []string{"openid", "email"}, []string{"name", "email", "preferred_username"}))
|
||||
}
|
||||
|
@ -309,7 +309,15 @@ func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re
|
||||
return err
|
||||
}
|
||||
|
||||
sso := h.manager.FindServiceFromLogin(refreshData.Claims.Login)
|
||||
_, namespace, err := utils.ParseLoginName(refreshData.Claims.Login)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sso := h.manager.GetService(namespace)
|
||||
if sso == nil {
|
||||
return fmt.Errorf("invalid namespace: %s", namespace)
|
||||
}
|
||||
|
||||
var oauthToken *oauth2.Token
|
||||
|
||||
|
@ -5,13 +5,13 @@ import (
|
||||
"encoding/json"
|
||||
"github.com/1f349/lavender/logger"
|
||||
"github.com/1f349/lavender/openid"
|
||||
"github.com/1f349/lavender/url"
|
||||
"github.com/1f349/lavender/utils"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func SetupOpenId(r *httprouter.Router, baseUrl *url.URL, signingKey *mjwt.Issuer) {
|
||||
func SetupOpenId(r *httprouter.Router, baseUrl *utils.URL, signingKey *mjwt.Issuer) {
|
||||
openIdConf := openid.GenConfig(baseUrl, []string{
|
||||
"openid", "name", "username", "profile", "email", "birthdate", "age", "zoneinfo", "locale",
|
||||
}, []string{
|
||||
|
@ -1,20 +1,25 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/1f349/cache"
|
||||
"github.com/1f349/lavender/auth"
|
||||
"github.com/1f349/lavender/auth/process"
|
||||
"github.com/1f349/lavender/auth/providers"
|
||||
"github.com/1f349/lavender/conf"
|
||||
"github.com/1f349/lavender/database"
|
||||
"github.com/1f349/lavender/issuer"
|
||||
"github.com/1f349/lavender/logger"
|
||||
"github.com/1f349/lavender/mail"
|
||||
"github.com/1f349/lavender/utils"
|
||||
"github.com/1f349/lavender/web"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/go-oauth2/oauth2/v4/manage"
|
||||
"github.com/go-oauth2/oauth2/v4/server"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
@ -65,12 +70,7 @@ type mailLinkKey struct {
|
||||
func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail, db *database.Queries, signingKey *mjwt.Issuer) {
|
||||
// TODO: move auth provider init to main function
|
||||
// TODO: allow dynamically changing the providers based on database information
|
||||
authInitial := &providers.InitialLogin{}
|
||||
authPassword := &providers.PasswordLogin{DB: db}
|
||||
authOtp := &providers.OtpLogin{DB: db}
|
||||
authOAuth := &providers.OAuthLogin{DB: db, BaseUrl: &config.BaseUrl}
|
||||
authOAuth.Init()
|
||||
authPasskey := &providers.PasskeyLogin{DB: db}
|
||||
// TODO: move oauth setup into oauth provider
|
||||
|
||||
hs := &httpServer{
|
||||
r: r,
|
||||
@ -80,19 +80,71 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail,
|
||||
mailSender: mailSender,
|
||||
|
||||
mailLinkCache: cache.New[mailLinkKey, string](),
|
||||
|
||||
authSources: []auth.Provider{
|
||||
authInitial,
|
||||
authPassword,
|
||||
authOtp,
|
||||
authOAuth,
|
||||
authPasskey,
|
||||
},
|
||||
authButtons: make([]auth.Button, 0),
|
||||
formProviderLookup: make(map[string]auth.Form),
|
||||
}
|
||||
|
||||
var err error
|
||||
hs.manager, err = issuer.NewManager(db, config.Namespace)
|
||||
if err != nil {
|
||||
logger.Logger.Fatal("Failed to load SSO services", "err", err)
|
||||
}
|
||||
|
||||
authPassword := &providers.PasswordLogin{DB: db}
|
||||
authOtp := &providers.OtpLogin{DB: db}
|
||||
authOAuth := &providers.OAuthLogin{DB: db, BaseUrl: config.BaseUrl, Manager: hs.manager}
|
||||
authOAuth.Init()
|
||||
authPasskey := &providers.PasskeyLogin{DB: db}
|
||||
authInitial := &providers.InitialLogin{DB: db, MyNamespace: config.Namespace, Manager: hs.manager, OAuth: authOAuth}
|
||||
|
||||
hs.authSources = []auth.Provider{
|
||||
authInitial,
|
||||
authPassword,
|
||||
authOtp,
|
||||
authOAuth,
|
||||
authPasskey,
|
||||
}
|
||||
|
||||
r.GET("/oauth/:namespace/start", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
namespace := params.ByName("namespace")
|
||||
if !authOAuth.RedirectToAuthorize(rw, req, namespace) {
|
||||
http.Error(rw, "Invalid OAuth namespace", http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
r.GET("/oauth/:namespace/callback", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
namespace := params.ByName("namespace")
|
||||
authOAuth.OAuthCallback(rw, req, namespace, func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
|
||||
resp, err := sso.OAuth2Config.Client(req.Context(), token).Get(sso.UserInfoEndpoint.String())
|
||||
if err != nil {
|
||||
return auth.UserAuth{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var userInfoJson auth.UserInfoFields
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfoJson); err != nil {
|
||||
return auth.UserAuth{}, err
|
||||
}
|
||||
subject, ok := userInfoJson.GetString("sub")
|
||||
if !ok {
|
||||
return auth.UserAuth{}, fmt.Errorf("invalid subject")
|
||||
}
|
||||
subject += "@" + sso.Config.Namespace
|
||||
|
||||
return auth.UserAuth{
|
||||
Subject: subject,
|
||||
Factor: process.StateBasic,
|
||||
UserInfo: userInfoJson,
|
||||
}, nil
|
||||
}, func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool {
|
||||
// TODO: this should be using the existing auth flow calls
|
||||
return hs.setLoginDataCookie(rw, authData, loginName)
|
||||
}, func(rw http.ResponseWriter, req *http.Request) {
|
||||
// TODO: is this really needed like this?
|
||||
utils.SafeRedirect(rw, req)
|
||||
})
|
||||
})
|
||||
|
||||
// build slices and maps for quick access to auth interfaces
|
||||
hs.authButtons = make([]auth.Button, 0)
|
||||
hs.formProviderLookup = make(map[string]auth.Form)
|
||||
for _, source := range hs.authSources {
|
||||
if button, isButton := source.(auth.Button); isButton {
|
||||
hs.authButtons = append(hs.authButtons, button)
|
||||
@ -103,13 +155,7 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail,
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
hs.manager, err = issuer.NewManager(config.Namespace, config.SsoServices)
|
||||
if err != nil {
|
||||
logger.Logger.Fatal("Failed to load SSO services", "err", err)
|
||||
}
|
||||
|
||||
SetupOpenId(r, &config.BaseUrl, signingKey)
|
||||
SetupOpenId(r, config.BaseUrl, signingKey)
|
||||
r.GET("/", hs.OptionalAuthentication(false, hs.Home))
|
||||
r.POST("/logout", hs.RequireAuthentication(hs.logoutPost))
|
||||
|
||||
|
@ -29,3 +29,5 @@ sql:
|
||||
go_type: "database/sql.NullString"
|
||||
- column: "users.token_expiry"
|
||||
go_type: "database/sql.NullTime"
|
||||
- column: "oauth_sources.address"
|
||||
go_type: "github.com/1f349/lavender/utils.URL"
|
||||
|
@ -1,21 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type JsonUrl struct {
|
||||
*url.URL
|
||||
}
|
||||
|
||||
var _ encoding.TextUnmarshaler = &JsonUrl{}
|
||||
|
||||
func (s *JsonUrl) UnmarshalText(text []byte) error {
|
||||
parse, err := url.Parse(string(text))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.URL = parse
|
||||
return nil
|
||||
}
|
51
utils/loginname.go
Normal file
51
utils/loginname.go
Normal file
@ -0,0 +1,51 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrInvalidLoginNameFormat = errors.New("invalid login name format")
|
||||
|
||||
func ParseLoginName(loginName string) (user string, namespace string, err error) {
|
||||
if loginName == "" || strings.HasPrefix(loginName, "@") || strings.HasSuffix(loginName, "@") || containsInvalidLoginNameRunes(loginName) {
|
||||
return "", "", ErrInvalidLoginNameFormat
|
||||
}
|
||||
|
||||
// @ should have at least one byte before it
|
||||
n := strings.IndexByte(loginName, '@')
|
||||
if n < 1 {
|
||||
return "", "", ErrInvalidLoginNameFormat
|
||||
}
|
||||
// there should not be a second @
|
||||
n2 := strings.IndexByte(loginName[n+1:], '@')
|
||||
if n2 != -1 {
|
||||
return "", "", ErrInvalidLoginNameFormat
|
||||
}
|
||||
|
||||
return loginName[:n], loginName[n+1:], nil
|
||||
}
|
||||
|
||||
func containsInvalidLoginNameRunes(loginName string) bool {
|
||||
// check if the name contains an invalid rune
|
||||
return strings.ContainsFunc(loginName, func(r rune) bool {
|
||||
return !isValidLoginNameRune(r)
|
||||
})
|
||||
}
|
||||
|
||||
func isValidLoginNameRune(r rune) bool {
|
||||
switch {
|
||||
case r >= '0' && r <= '9':
|
||||
return true
|
||||
case r >= 'a' && r <= 'z':
|
||||
return true
|
||||
case r == '.':
|
||||
return true
|
||||
case r == '-':
|
||||
return true
|
||||
case r == '@':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
174
utils/loginname_test.go
Normal file
174
utils/loginname_test.go
Normal file
@ -0,0 +1,174 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var parseLoginNameTests = []struct {
|
||||
User string
|
||||
Namespace string
|
||||
HasError bool
|
||||
Input string
|
||||
}{
|
||||
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com", false, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com"},
|
||||
{"", "", true, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com\u1111"},
|
||||
{"", "", true, "\u1111aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com"},
|
||||
{"", "", true, "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com"},
|
||||
{"", "", true, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com@"},
|
||||
{"a", "b.com", false, "a@b.com"},
|
||||
{"aa", "bb.com", false, "aa@bb.com"},
|
||||
}
|
||||
|
||||
var parseLoginNameImplementations = []struct {
|
||||
Name string
|
||||
Func func(string) (string, string, error)
|
||||
}{
|
||||
{"ParseLoginName", ParseLoginName},
|
||||
{"parseLoginNameLoopRunes", parseLoginNameLoopRunes},
|
||||
{"parseLoginNameRegex", parseLoginNameRegex},
|
||||
}
|
||||
|
||||
func TestParseLoginName(t *testing.T) {
|
||||
for _, impl := range parseLoginNameImplementations {
|
||||
t.Run(impl.Name, func(t *testing.T) {
|
||||
for _, i := range parseLoginNameTests {
|
||||
t.Run(i.Input, func(t *testing.T) {
|
||||
user, namespace, err := ParseLoginName(i.Input)
|
||||
if i.HasError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, i.User, user)
|
||||
assert.Equal(t, i.Namespace, namespace)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzParseLoginName(f *testing.F) {
|
||||
for _, i := range parseLoginNameTests {
|
||||
f.Add(i.Input)
|
||||
}
|
||||
f.Fuzz(func(t *testing.T, s string) {
|
||||
t.Log("Input: ", s, hex.EncodeToString([]byte(s)))
|
||||
hasError := s == "" || strings.HasPrefix(s, "@") || strings.HasSuffix(s, "@") || strings.Count(s, "@") != 1 || strings.ContainsFunc(s, func(r rune) bool {
|
||||
return !isValidLoginNameRune(r)
|
||||
})
|
||||
login, namespace, err := ParseLoginName(s)
|
||||
if hasError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if err == nil {
|
||||
n := strings.IndexRune(s, '@')
|
||||
assert.Equal(t, s[:n], login)
|
||||
assert.Equal(t, s[n+1:], namespace)
|
||||
} else {
|
||||
assert.Equal(t, login, "", "Login should be empty if an error occurred")
|
||||
assert.Equal(t, namespace, "", "Namespace should be empty if an error occurred")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func parseLoginBench(b *testing.B, f func(string) (string, string, error)) {
|
||||
for _, i := range parseLoginNameTests {
|
||||
b.Run(i.Input, func(b *testing.B) {
|
||||
for range b.N {
|
||||
_, _, _ = f(i.Input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseLoginName(b *testing.B) {
|
||||
b.Run("ParseLoginName", func(b *testing.B) {
|
||||
parseLoginBench(b, ParseLoginName)
|
||||
})
|
||||
b.Run("parseLoginNameLoopRunes", func(b *testing.B) {
|
||||
parseLoginBench(b, parseLoginNameLoopRunes)
|
||||
})
|
||||
b.Run("parseLoginNameRegex", func(b *testing.B) {
|
||||
parseLoginBench(b, parseLoginNameRegex)
|
||||
})
|
||||
}
|
||||
|
||||
func parseLoginNameLoopRunes(s string) (user string, namespace string, err error) {
|
||||
if s == "" || strings.HasPrefix(s, "@") || strings.HasSuffix(s, "@") {
|
||||
return "", "", ErrInvalidLoginNameFormat
|
||||
}
|
||||
hasAt := false
|
||||
atIndex := 0
|
||||
for i, r := range []rune(s) {
|
||||
switch {
|
||||
case r == '@':
|
||||
if hasAt {
|
||||
return "", "", ErrInvalidLoginNameFormat
|
||||
}
|
||||
hasAt = true
|
||||
atIndex = i
|
||||
case r == '.':
|
||||
continue
|
||||
case r == '-':
|
||||
continue
|
||||
case r >= 'a' && r <= 'z':
|
||||
continue
|
||||
case r >= '0' && r <= '9':
|
||||
continue
|
||||
default:
|
||||
return "", "", ErrInvalidLoginNameFormat
|
||||
}
|
||||
}
|
||||
|
||||
return s[:atIndex], s[atIndex+1:], nil
|
||||
}
|
||||
|
||||
// regexLoginName is a regex based implementation of ParseLoginName
|
||||
//
|
||||
// This implementation prevents using - or . at the start and end of the user and namespace
|
||||
var regexLoginName = regexp.MustCompile(`^([a-z0-9]([a-z0-9-.]*[a-z0-9]|))@([a-z0-9]([a-z0-9-.]*[a-z0-9]|))$`)
|
||||
|
||||
func parseLoginNameRegex(s string) (user string, namespace string, err error) {
|
||||
if s == "" {
|
||||
return "", "", ErrInvalidLoginNameFormat
|
||||
}
|
||||
|
||||
matches := regexLoginName.FindStringSubmatch(s)
|
||||
if matches == nil {
|
||||
return "", "", ErrInvalidLoginNameFormat
|
||||
}
|
||||
return matches[1], matches[3], nil
|
||||
}
|
||||
|
||||
func BenchmarkA(b *testing.B) {
|
||||
b.Run("A", func(b *testing.B) {
|
||||
for range b.N {
|
||||
_ = isValidLoginNameRune('b')
|
||||
_ = isValidLoginNameRune('4')
|
||||
_ = isValidLoginNameRune('.')
|
||||
_ = isValidLoginNameRune('-')
|
||||
_ = isValidLoginNameRune('@')
|
||||
_ = isValidLoginNameRune(' ')
|
||||
}
|
||||
})
|
||||
b.Run("B", func(b *testing.B) {
|
||||
for range b.N {
|
||||
_ = isValidLoginNameRune2('b')
|
||||
_ = isValidLoginNameRune2('4')
|
||||
_ = isValidLoginNameRune2('.')
|
||||
_ = isValidLoginNameRune2('-')
|
||||
_ = isValidLoginNameRune2('@')
|
||||
_ = isValidLoginNameRune2(' ')
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func isValidLoginNameRune2(r rune) bool {
|
||||
return (r >= '0' && r <= '9') || (r >= 'a' && r <= 'z') || r == '.' || r == '-' || r == '@'
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package url
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
@ -6,35 +6,33 @@ import (
|
||||
"path"
|
||||
)
|
||||
|
||||
type URL struct {
|
||||
url.URL
|
||||
}
|
||||
|
||||
func (u *URL) Resolve(paths ...string) *URL {
|
||||
return &URL{URL: *u.URL.ResolveReference(&url.URL{Path: path.Join(paths...)})}
|
||||
}
|
||||
|
||||
func (u URL) MarshalText() (text []byte, err error) {
|
||||
return []byte(u.String()), nil
|
||||
}
|
||||
|
||||
func (u *URL) UnmarshalText(text []byte) error {
|
||||
parse, err := u.Parse(string(text))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.URL = *parse
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ encoding.TextMarshaler = (*URL)(nil)
|
||||
var _ encoding.TextUnmarshaler = (*URL)(nil)
|
||||
|
||||
type URL struct{ *url.URL }
|
||||
|
||||
func MustParse(rawURL string) *URL {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &URL{*u}
|
||||
return &URL{u}
|
||||
}
|
||||
|
||||
func (u *URL) Resolve(paths ...string) *URL {
|
||||
return &URL{URL: u.URL.ResolveReference(&url.URL{Path: path.Join(paths...)})}
|
||||
}
|
||||
|
||||
func (u *URL) MarshalText() (text []byte, err error) {
|
||||
return []byte(u.URL.String()), nil
|
||||
}
|
||||
|
||||
func (u *URL) UnmarshalText(text []byte) error {
|
||||
parse, err := url.Parse(string(text))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.URL = parse
|
||||
return nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user