Rewrite OAuth manager, OAuth source and URL wrapper

This commit is contained in:
Melon 2025-03-13 13:29:57 +00:00
parent d9b0074133
commit 50df217b66
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
24 changed files with 789 additions and 234 deletions

View File

@ -5,10 +5,11 @@ import (
"fmt"
"github.com/1f349/lavender/auth"
"github.com/1f349/lavender/auth/authContext"
process "github.com/1f349/lavender/auth/process"
"github.com/1f349/lavender/auth/process"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/logger"
"github.com/1f349/lavender/utils"
"github.com/google/uuid"
"golang.org/x/oauth2"
"net/http"
@ -25,6 +26,7 @@ type InitialLogin struct {
DB *database.Queries
MyNamespace string
Manager *issuer.Manager
OAuth *OAuthLogin
}
func (m *InitialLogin) AccessState() process.State { return process.StateUnauthorized }
@ -101,36 +103,39 @@ func (m *InitialLogin) AttemptLogin(ctx authContext.FormContext) error {
}
// append local namespace if @ is missing
n := strings.IndexByte(loginName, '@')
if n < 0 {
// correct the @ index
n = len(loginName)
if !strings.ContainsRune(loginName, '@') {
loginName += "@" + m.MyNamespace
}
login := m.Manager.FindServiceFromLogin(loginName)
user, namespace, err := utils.ParseLoginName(loginName)
if err != nil {
http.Error(rw, "Invalid login name", http.StatusBadRequest)
return fmt.Errorf("invalid login name %s", loginName)
}
login := m.Manager.GetService(namespace)
if login == nil {
http.Error(rw, "No login service defined for this username", http.StatusBadRequest)
return errors.New("no login service defined for this username")
}
// the @ must exist if the service is defined
loginUn := loginName[:n]
// TODO: finish migrating this shit
// the login is not for this namespace
if login != issuer.MeWellKnown {
// TODO: this is oauth request code start oauth request
// TODO: this could all be 1 function call into oauth
// save state for use later
state := login.Config.Namespace + ":" + uuid.NewString()
h.flowState.Set(state, flowStateData{loginName, login, req.PostFormValue("redirect")}, time.Now().Add(15*time.Minute))
m.OAuth.flow.Set(state, flowStateData{loginName, login, req.PostFormValue("redirect")}, time.Now().Add(15*time.Minute))
// generate oauth2 config and redirect to authorize URL
oa2conf := login.OAuth2Config
oa2conf.RedirectURL = h.conf.BaseUrl.JoinPath("callback").String()
nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", loginUn))
http.Redirect(rw, req, nextUrl, http.StatusFound)
return
oa2conf.RedirectURL = m.OAuth.BaseUrl.JoinPath("callback").String()
nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", user))
return auth.RedirectError{
Target: nextUrl,
Code: http.StatusFound,
}
}
ctx.UpdateSession(process.LoginProcessData{

View File

@ -13,17 +13,18 @@ import (
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/database/types"
"github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/url"
"github.com/1f349/lavender/utils"
"github.com/google/uuid"
"github.com/mrmelon54/pronouns"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"net/http"
"net/url"
"time"
)
type OauthCallback interface {
OAuthCallback(rw http.ResponseWriter, req *http.Request, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request))
OAuthCallback(rw http.ResponseWriter, req *http.Request, namespace string, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request))
}
type flowStateData struct {
@ -40,9 +41,9 @@ var (
type OAuthLogin struct {
DB *database.Queries
BaseUrl *url.URL
flow *cache.Cache[string, flowStateData]
BaseUrl *utils.URL
flow *cache.Cache[string, flowStateData]
Manager *issuer.Manager
}
func (o OAuthLogin) Init() {
@ -50,7 +51,7 @@ func (o OAuthLogin) Init() {
}
func (o OAuthLogin) authUrlBase(ref string) *url.URL {
return o.BaseUrl.Resolve("oauth", o.Name(), ref)
return o.BaseUrl.Resolve("oauth", o.Name(), ref).URL
}
func (o OAuthLogin) AccessState() process.State { return process.StateUnauthorized }
@ -83,12 +84,16 @@ func (o OAuthLogin) AttemptLogin(ctx authContext.FormContext) error {
return auth.RedirectError{Target: nextUrl, Code: http.StatusFound}
}
func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request)) {
func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, namespace string, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request)) {
flowState, ok := o.flow.Get(req.FormValue("state"))
if !ok {
http.Error(rw, "Invalid flow state", http.StatusBadRequest)
return
}
if flowState.sso.Namespace != namespace {
http.Error(rw, "OAuth source mismatch", http.StatusBadRequest)
return
}
token, err := flowState.sso.OAuth2Config.Exchange(context.Background(), req.FormValue("code"), oauth2.SetAuthURLParam("redirect_uri", o.authUrlBase("callback").String()))
if err != nil {
http.Error(rw, "Failed to exchange code for token", http.StatusInternalServerError)
@ -112,17 +117,32 @@ func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, inf
}
func (o OAuthLogin) RenderButtonTemplate(ctx authContext.TemplateContext) {
// TODO: idk what this is
// o.authUrlBase("button")
// provide something non-nil
ctx.Render(struct {
Href string
ButtonName string
}{
Href: o.authUrlBase("button").String(),
Href: o.authUrlBase("start").String(),
ButtonName: "Login with Unknown OAuth Button", // TODO: actually get the service name
})
}
// RedirectToAuthorize returns true when redirecting to the authorize endpoint
// for the requested namespace.
func (o OAuthLogin) RedirectToAuthorize(rw http.ResponseWriter, req *http.Request, namespace string) bool {
oidc := o.Manager.GetService(namespace)
if oidc == nil {
return false
}
// TODO: come back to fix this
oidc.OAuth2Config.AuthCodeURL("state")
// http.Redirect(rw, req, oidc.AuthorizationEndpoint, http.StatusOK)
return true
}
type oauthServiceLogin int
func WithWellKnown(ctx context.Context, login *issuer.WellKnownOIDC) context.Context {
@ -256,7 +276,7 @@ func (o OAuthLogin) updateOAuth2UserProfile(ctx context.Context, tx *database.Qu
}
func (o OAuthLogin) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint)
res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint.String())
if err != nil || res.StatusCode != http.StatusOK {
return auth.UserAuth{}, fmt.Errorf("request failed")
}

View File

@ -1,19 +1,17 @@
package conf
import (
"github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/url"
"github.com/1f349/lavender/utils"
"github.com/1f349/simplemail"
)
type Conf struct {
Listen string `yaml:"listen"`
BaseUrl url.URL `yaml:"baseUrl"`
ServiceName string `yaml:"serviceName"`
Issuer string `yaml:"issuer"`
Kid string `yaml:"kid"`
Namespace string `yaml:"namespace"`
OtpIssuer string `yaml:"otpIssuer"`
Mail simplemail.Mail `yaml:"mail"`
SsoServices []issuer.SsoConfig `yaml:"ssoServices"`
Listen string `yaml:"listen"`
BaseUrl *utils.URL `yaml:"baseUrl"`
ServiceName string `yaml:"serviceName"`
Issuer string `yaml:"issuer"`
Kid string `yaml:"kid"`
Namespace string `yaml:"namespace"`
OtpIssuer string `yaml:"otpIssuer"`
Mail simplemail.Mail `yaml:"mail"`
}

View File

@ -0,0 +1,45 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.28.0
// source: manage-oauth-source.sql
package database
import (
"context"
)
const getOAuthSources = `-- name: GetOAuthSources :many
SELECT namespace, address, registration, button, client_id, client_secret, client_scopes FROM oauth_sources
`
func (q *Queries) GetOAuthSources(ctx context.Context) ([]OauthSource, error) {
rows, err := q.db.QueryContext(ctx, getOAuthSources)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OauthSource
for rows.Next() {
var i OauthSource
if err := rows.Scan(
&i.Namespace,
&i.Address,
&i.Registration,
&i.Button,
&i.ClientID,
&i.ClientSecret,
&i.ClientScopes,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -0,0 +1,10 @@
CREATE TABLE oauth_sources
(
namespace TEXT NOT NULL UNIQUE PRIMARY KEY,
address TEXT NOT NULL,
registration BOOLEAN NOT NULL,
button BOOLEAN NOT NULL,
client_id TEXT NOT NULL,
client_secret TEXT NOT NULL,
client_scopes TEXT NOT NULL
);

View File

@ -10,6 +10,7 @@ import (
"github.com/1f349/lavender/database/types"
"github.com/1f349/lavender/password"
"github.com/1f349/lavender/utils"
"github.com/hardfinhq/go-date"
)
@ -25,6 +26,16 @@ type ClientStore struct {
Active bool `json:"active"`
}
type OauthSource struct {
Namespace string `json:"namespace"`
Address utils.URL `json:"address"`
Registration bool `json:"registration"`
Button bool `json:"button"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
ClientScopes string `json:"client_scopes"`
}
type Role struct {
ID int64 `json:"id"`
Role string `json:"role"`

View File

@ -0,0 +1,2 @@
-- name: GetOAuthSources :many
SELECT * FROM oauth_sources;

1
go.mod
View File

@ -19,6 +19,7 @@ require (
github.com/julienschmidt/httprouter v1.3.0
github.com/mattn/go-sqlite3 v1.14.24
github.com/mrmelon54/pronouns v1.0.3
github.com/robfig/cron v1.2.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/spf13/afero v1.12.0
github.com/stretchr/testify v1.10.0

2
go.sum
View File

@ -138,6 +138,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/robfig/cron v1.2.0 h1:ZjScXvvxeQ63Dbyxy76Fj3AT3Ut0aKsyd2/tl3DTMuQ=
github.com/robfig/cron v1.2.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=

View File

@ -1,57 +1,226 @@
package issuer
import (
"context"
"fmt"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/logger"
"github.com/robfig/cron"
"net/http"
"regexp"
"strings"
"sync"
"time"
)
const (
fetchPoolSize = 4
fetchTimeout = 2 * time.Minute
fetchRetryDelay = 15 * time.Second
fetchRetryAfterTimeout = 15 * time.Minute
fetchRetryCount = 3
)
var isValidNamespace = regexp.MustCompile("^[0-9a-z.]+$")
var MeWellKnown = &WellKnownOIDC{}
type Manager struct {
m map[string]*WellKnownOIDC
type managerReloadDB interface {
GetOAuthSources(ctx context.Context) ([]database.OauthSource, error)
}
func NewManager(myNamespace string, services []SsoConfig) (*Manager, error) {
l := &Manager{m: make(map[string]*WellKnownOIDC)}
l.m[myNamespace] = MeWellKnown
for _, ssoService := range services {
if !isValidNamespace.MatchString(ssoService.Namespace) {
return nil, fmt.Errorf("invalid namespace: %s", ssoService.Namespace)
}
type Manager struct {
db managerReloadDB
conf, err := ssoService.FetchConfig()
if err != nil {
return nil, err
}
client *http.Client
// save by namespace
l.m[ssoService.Namespace] = conf
sourceMu *sync.RWMutex
sourceMap map[string]*WellKnownOIDC
errMu *sync.RWMutex
errMap map[string]FetchErrorDetails
shutdownCall context.CancelFunc
}
type FetchErrorDetails struct {
Retry time.Time
Errors []error
}
type managerRoundTripper struct {
Base http.RoundTripper
}
func (a *managerRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
userAgent := strings.Join(request.Header.Values("User-Agent"), " ")
if strings.TrimSpace(userAgent) == "" {
request.Header.Set("User-Agent", "Lavender/1.0 (OAuth Manager)")
}
return a.Base.RoundTrip(request)
}
func NewManager(db managerReloadDB, myNamespace string) (*Manager, error) {
return NewManagerWithClient(db, myNamespace, &http.Client{
Transport: &managerRoundTripper{
Base: http.DefaultTransport,
},
})
}
func NewManagerWithClient(db managerReloadDB, myNamespace string, client *http.Client) (*Manager, error) {
l := &Manager{
db: db,
client: client,
sourceMu: new(sync.RWMutex),
sourceMap: make(map[string]*WellKnownOIDC),
errMap: make(map[string]FetchErrorDetails),
}
l.sourceMap[myNamespace] = MeWellKnown
// reload should run blocking on start up
l.internalReload(context.Background())
reloadDaily := make(chan struct{}, 1)
tab := cron.New()
err := tab.AddFunc("0 2 * * *", func() {
reloadDaily <- struct{}{}
})
if err != nil {
panic(err)
}
ctx, cancel := context.WithCancel(context.Background())
go l.reloadWorker(ctx, reloadDaily)
l.shutdownCall = cancel
return l, nil
}
func (m *Manager) Shutdown() { m.shutdownCall() }
func (m *Manager) reloadWorker(ctx context.Context, reload <-chan struct{}) {
for {
select {
case <-ctx.Done():
logger.Logger.Info("Shutting down oauth manager reload worker")
return
case <-reload:
m.internalReload(ctx)
}
}
}
func (m *Manager) internalReload(ctx context.Context) {
logger.Logger.Info("Reloading oauth sources")
dbCtx, cancel := context.WithTimeout(ctx, fetchTimeout)
sources, err := m.db.GetOAuthSources(dbCtx)
cancel()
if err != nil {
// TODO: send email to admin
logger.Logger.Warn("Failed to load OAuth sources from the database", "err", err)
return
}
// construct map of db oauth sources
toAdd := make(map[string]database.OauthSource)
for _, source := range sources {
toAdd[source.Namespace] = source
}
// remove sources from the active map
m.sourceMu.Lock()
for k, v := range m.sourceMap {
// ignore my own well-known
if v == MeWellKnown {
continue
}
if _, shouldAdd := toAdd[k]; !shouldAdd {
delete(m.sourceMap, k)
}
}
m.sourceMu.Unlock()
var wg sync.WaitGroup
wg.Add(fetchPoolSize)
fetchChan := make(chan database.OauthSource, fetchPoolSize)
for range fetchPoolSize {
go func() {
defer wg.Done()
for source := range fetchChan {
itemCtx, cancel := context.WithTimeout(ctx, fetchTimeout)
errs := m.reloadSource(itemCtx, source)
cancel()
// this also resets the error if it was set previously
m.sourceMu.Lock()
if len(errs) > 0 {
m.errMap[source.Namespace] = FetchErrorDetails{
Retry: time.Now().Add(fetchRetryAfterTimeout),
Errors: errs,
}
// TODO: send email to admin
} else {
delete(m.errMap, source.Namespace)
// TODO: send resolved email to admin
}
m.sourceMu.Unlock()
}
}()
}
// reload sources
for _, source := range toAdd {
fetchChan <- source
}
close(fetchChan)
wg.Wait()
}
func (m *Manager) reloadSource(ctx context.Context, source database.OauthSource) []error {
var errs []error
for range fetchRetryCount {
oidc, err := fetchConfig(ctx, m.client, source)
fmt.Println(oidc, err)
if err == nil {
// fetch was successful
m.sourceMu.Lock()
m.sourceMap[source.Namespace] = oidc
m.sourceMu.Unlock()
return nil
}
errs = append(errs, fmt.Errorf("failed to fetch OIDC well-known configuration: %w", err))
// wait before retrying
select {
case <-ctx.Done():
return errs
case <-time.After(fetchRetryDelay):
continue
}
}
return errs
}
func (m *Manager) CheckNamespace(namespace string) bool {
_, ok := m.m[namespace]
m.sourceMu.RLock()
defer m.sourceMu.RUnlock()
_, ok := m.sourceMap[namespace]
return ok
}
func (m *Manager) GetService(namespace string) *WellKnownOIDC {
return m.m[namespace]
}
func (m *Manager) FindServiceFromLogin(login string) *WellKnownOIDC {
// @ should have at least one byte before it
n := strings.IndexByte(login, '@')
if n < 1 {
return nil
}
// there should not be a second @
n2 := strings.IndexByte(login[n+1:], '@')
if n2 != -1 {
return nil
}
return m.GetService(login[n+1:])
m.sourceMu.RLock()
defer m.sourceMu.RUnlock()
return m.sourceMap[namespace]
}

View File

@ -1,6 +1,9 @@
package issuer
import (
"context"
"fmt"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/utils"
"github.com/stretchr/testify/assert"
"io"
@ -8,29 +11,54 @@ import (
"net/url"
"strings"
"testing"
"time"
)
var testAddrUrl = func() utils.JsonUrl {
var testAddrUrl = func() utils.URL {
a, err := url.Parse("https://example.com")
if err != nil {
panic(err)
}
return utils.JsonUrl{URL: a}
return utils.URL{URL: a}
}()
func testBody() io.ReadCloser {
return io.NopCloser(strings.NewReader("{}"))
return io.NopCloser(strings.NewReader(`{
"issuer": "Example.com Issuer",
"authorization_endpoint": "https://example.com/oauth/authorize",
"token_endpoint": "https://example.com/oauth/token",
"userinfo_endpoint": "https://example.com/oauth/userinfo",
"revocation_endpoint": "https://example.com/oauth/revoke",
"response_types_supported": [
"code"
],
"scopes_supported": [
"openid"
],
"claims_supported": [
"sub"
],
"grant_types_supported": [
"authorization_code",
"refresh_token"
]
}
`))
}
func TestManager_CheckNamespace(t *testing.T) {
httpGet = func(url string) (resp *http.Response, err error) {
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
}
manager, err := NewManager("example.org", []SsoConfig{
{
Addr: testAddrUrl,
Namespace: "example.com",
manager, err := NewManagerWithClient(&testDB{
Rows: []database.OauthSource{
{Address: testAddrUrl, Namespace: "example.com"},
},
}, "example.org", &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
fmt.Println("Request URL:", req.URL)
if req.URL.String() == "https://example.com/.well-known/openid-configuration" {
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
}
return nil, fmt.Errorf("request failed")
}),
})
assert.NoError(t, err)
assert.True(t, manager.CheckNamespace("example.org"))
@ -39,17 +67,36 @@ func TestManager_CheckNamespace(t *testing.T) {
}
func TestManager_FindServiceFromLogin(t *testing.T) {
httpGet = func(url string) (resp *http.Response, err error) {
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
}
manager, err := NewManager("example.org", []SsoConfig{
{
Addr: testAddrUrl,
Namespace: "example.com",
manager, err := NewManagerWithClient(&testDB{
Rows: []database.OauthSource{
{Address: testAddrUrl, Namespace: "example.com"},
},
}, "example.org", &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
fmt.Println("Request URL:", req.URL)
if req.URL.String() != "https://example.com/.well-known/openid-configuration" {
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
}
return nil, fmt.Errorf("request failed")
}),
})
assert.NoError(t, err)
assert.Equal(t, manager.FindServiceFromLogin("jane@example.org"), MeWellKnown)
assert.Equal(t, manager.FindServiceFromLogin("jane@example.com"), manager.m["example.com"])
assert.Nil(t, manager.FindServiceFromLogin("jane@missing.example.com"))
assert.Equal(t, manager.GetService("example.org"), MeWellKnown)
assert.Equal(t, manager.GetService("example.com"), manager.sourceMap["example.com"])
assert.Nil(t, manager.GetService("missing.example.com"))
}
type testDB struct {
Rows []database.OauthSource
}
func (t *testDB) GetOAuthSources(ctx context.Context) ([]database.OauthSource, error) {
return t.Rows, nil
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (r roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
time.Sleep(time.Minute)
return r(req)
}

View File

@ -1,94 +1,82 @@
package issuer
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/utils"
"golang.org/x/oauth2"
"net/http"
"net/url"
"slices"
"strings"
"sync/atomic"
"time"
)
var httpGet = http.Get
// SsoConfig is the base URL for an OAUTH/OPENID/SSO login service
// The path `/.well-known/openid-configuration` should be available
type SsoConfig struct {
Addr utils.JsonUrl `json:"addr" yaml:"addr"` // https://login.example.com
Namespace string `json:"namespace" yaml:"namespace"` // example.com
Registration bool `json:"registration" yaml:"registration"`
LoginWithButton bool `json:"login_with_button" yaml:"loginWithButton"`
Client SsoConfigClient `json:"client" yaml:"client"`
type WellKnownOIDC struct {
Namespace string `json:"-"`
Config database.OauthSource `json:"-"`
Available atomic.Bool `json:"-"`
Issuer string `json:"issuer"`
AuthorizationEndpoint *utils.URL `json:"authorization_endpoint"`
TokenEndpoint *utils.URL `json:"token_endpoint"`
UserInfoEndpoint *utils.URL `json:"userinfo_endpoint"`
ResponseTypesSupported []string `json:"response_types_supported"`
ScopesSupported []string `json:"scopes_supported"`
ClaimsSupported []string `json:"claims_supported"`
GrantTypesSupported []string `json:"grant_types_supported"`
JwksUri *utils.URL `json:"jwks_uri"`
OAuth2Config oauth2.Config `json:"-"`
LastFetch time.Time `json:"-"`
}
type SsoConfigClient struct {
ID string `json:"id"`
Secret string `json:"secret"`
Scopes []string `json:"scopes"`
}
func (s SsoConfig) FetchConfig() (*WellKnownOIDC, error) {
func fetchConfig(ctx context.Context, httpClient *http.Client, s database.OauthSource) (*WellKnownOIDC, error) {
// generate openid config url
u := s.Addr.JoinPath(".well-known/openid-configuration")
u := s.Address.JoinPath(".well-known/openid-configuration")
// fetch metadata
get, err := httpGet(u.String())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, err
}
defer get.Body.Close()
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var c WellKnownOIDC
err = json.NewDecoder(get.Body).Decode(&c)
err = json.NewDecoder(resp.Body).Decode(&c)
if err != nil {
return nil, err
}
c.Config = s
c.OAuth2Config = oauth2.Config{
ClientID: c.Config.Client.ID,
ClientSecret: c.Config.Client.Secret,
ClientID: c.Config.ClientID,
ClientSecret: c.Config.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: c.AuthorizationEndpoint,
TokenURL: c.TokenEndpoint,
AuthURL: c.AuthorizationEndpoint.String(),
TokenURL: c.TokenEndpoint.String(),
AuthStyle: oauth2.AuthStyleInHeader,
},
Scopes: c.Config.Client.Scopes,
Scopes: strings.Fields(c.Config.ClientScopes),
}
c.Available.Store(true)
return &c, nil
}
type WellKnownOIDC struct {
Namespace string `json:"-"`
Config SsoConfig `json:"-"`
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"userinfo_endpoint"`
ResponseTypesSupported []string `json:"response_types_supported"`
ScopesSupported []string `json:"scopes_supported"`
ClaimsSupported []string `json:"claims_supported"`
GrantTypesSupported []string `json:"grant_types_supported"`
OAuth2Config oauth2.Config `json:"-"`
}
func (o WellKnownOIDC) Validate() error {
func (o *WellKnownOIDC) Validate() error {
if o.Issuer == "" {
return errors.New("missing issuer")
}
// check URLs are valid
if _, err := url.Parse(o.AuthorizationEndpoint); err != nil {
return err
}
if _, err := url.Parse(o.TokenEndpoint); err != nil {
return err
}
if _, err := url.Parse(o.UserInfoEndpoint); err != nil {
return err
}
// check oidc supported values
if !slices.Contains(o.ResponseTypesSupported, "code") {
return errors.New("missing required response type 'code'")
@ -107,6 +95,6 @@ func (o WellKnownOIDC) Validate() error {
return nil
}
func (o WellKnownOIDC) ValidReturnUrl(u *url.URL) bool {
return o.Config.Addr.Scheme == u.Scheme && o.Config.Addr.Host == u.Host
func (o *WellKnownOIDC) ValidReturnUrl(u *url.URL) bool {
return o.Config.Address.Scheme == u.Scheme && o.Config.Address.Host == u.Host
}

View File

@ -1,30 +1,29 @@
package openid
import "github.com/1f349/lavender/url"
import "github.com/1f349/lavender/utils"
type Config struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"userinfo_endpoint"`
ResponseTypesSupported []string `json:"response_types_supported"`
ScopesSupported []string `json:"scopes_supported"`
ClaimsSupported []string `json:"claims_supported"`
GrantTypesSupported []string `json:"grant_types_supported"`
JwksUri string `json:"jwks_uri"`
Issuer string `json:"issuer"`
AuthorizationEndpoint *utils.URL `json:"authorization_endpoint"`
TokenEndpoint *utils.URL `json:"token_endpoint"`
UserInfoEndpoint *utils.URL `json:"userinfo_endpoint"`
ResponseTypesSupported []string `json:"response_types_supported"`
ScopesSupported []string `json:"scopes_supported"`
ClaimsSupported []string `json:"claims_supported"`
GrantTypesSupported []string `json:"grant_types_supported"`
JwksUri *utils.URL `json:"jwks_uri"`
}
func GenConfig(baseUrl *url.URL, scopes, claims []string) Config {
func GenConfig(baseUrl *utils.URL, scopes, claims []string) Config {
return Config{
Issuer: baseUrl.String(),
AuthorizationEndpoint: baseUrl.Resolve("authorize").String(),
TokenEndpoint: baseUrl.Resolve("token").String(),
UserInfoEndpoint: baseUrl.Resolve("userinfo").String(),
AuthorizationEndpoint: baseUrl.Resolve("authorize"),
TokenEndpoint: baseUrl.Resolve("token"),
UserInfoEndpoint: baseUrl.Resolve("userinfo"),
ResponseTypesSupported: []string{"code"},
ScopesSupported: scopes,
ClaimsSupported: claims,
GrantTypesSupported: []string{"authorization_code", "refresh_token"},
JwksUri: baseUrl.Resolve(".well-known/jwks.json").String(),
JwksUri: baseUrl.Resolve(".well-known/jwks.json"),
}
}

View File

@ -1,7 +1,7 @@
package openid
import (
"github.com/1f349/lavender/url"
"github.com/1f349/lavender/utils"
"github.com/stretchr/testify/assert"
"testing"
)
@ -9,13 +9,13 @@ import (
func TestGenConfig(t *testing.T) {
assert.Equal(t, Config{
Issuer: "https://example.com",
AuthorizationEndpoint: "https://example.com/authorize",
TokenEndpoint: "https://example.com/token",
UserInfoEndpoint: "https://example.com/userinfo",
AuthorizationEndpoint: utils.MustParse("https://example.com/authorize"),
TokenEndpoint: utils.MustParse("https://example.com/token"),
UserInfoEndpoint: utils.MustParse("https://example.com/userinfo"),
ResponseTypesSupported: []string{"code"},
ScopesSupported: []string{"openid", "email"},
ClaimsSupported: []string{"name", "email", "preferred_username"},
GrantTypesSupported: []string{"authorization_code", "refresh_token"},
JwksUri: "https://example.com/.well-known/jwks.json",
}, GenConfig(url.MustParse("https://example.com"), []string{"openid", "email"}, []string{"name", "email", "preferred_username"}))
JwksUri: utils.MustParse("https://example.com/.well-known/jwks.json"),
}, GenConfig(utils.MustParse("https://example.com"), []string{"openid", "email"}, []string{"name", "email", "preferred_username"}))
}

View File

@ -309,7 +309,15 @@ func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re
return err
}
sso := h.manager.FindServiceFromLogin(refreshData.Claims.Login)
_, namespace, err := utils.ParseLoginName(refreshData.Claims.Login)
if err != nil {
return err
}
sso := h.manager.GetService(namespace)
if sso == nil {
return fmt.Errorf("invalid namespace: %s", namespace)
}
var oauthToken *oauth2.Token

View File

@ -5,13 +5,13 @@ import (
"encoding/json"
"github.com/1f349/lavender/logger"
"github.com/1f349/lavender/openid"
"github.com/1f349/lavender/url"
"github.com/1f349/lavender/utils"
"github.com/1f349/mjwt"
"github.com/julienschmidt/httprouter"
"net/http"
)
func SetupOpenId(r *httprouter.Router, baseUrl *url.URL, signingKey *mjwt.Issuer) {
func SetupOpenId(r *httprouter.Router, baseUrl *utils.URL, signingKey *mjwt.Issuer) {
openIdConf := openid.GenConfig(baseUrl, []string{
"openid", "name", "username", "profile", "email", "birthdate", "age", "zoneinfo", "locale",
}, []string{

View File

@ -1,20 +1,25 @@
package server
import (
"encoding/json"
"errors"
"fmt"
"github.com/1f349/cache"
"github.com/1f349/lavender/auth"
"github.com/1f349/lavender/auth/process"
"github.com/1f349/lavender/auth/providers"
"github.com/1f349/lavender/conf"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/logger"
"github.com/1f349/lavender/mail"
"github.com/1f349/lavender/utils"
"github.com/1f349/lavender/web"
"github.com/1f349/mjwt"
"github.com/go-oauth2/oauth2/v4/manage"
"github.com/go-oauth2/oauth2/v4/server"
"github.com/julienschmidt/httprouter"
"golang.org/x/oauth2"
"net/http"
"path"
"strings"
@ -65,12 +70,7 @@ type mailLinkKey struct {
func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail, db *database.Queries, signingKey *mjwt.Issuer) {
// TODO: move auth provider init to main function
// TODO: allow dynamically changing the providers based on database information
authInitial := &providers.InitialLogin{}
authPassword := &providers.PasswordLogin{DB: db}
authOtp := &providers.OtpLogin{DB: db}
authOAuth := &providers.OAuthLogin{DB: db, BaseUrl: &config.BaseUrl}
authOAuth.Init()
authPasskey := &providers.PasskeyLogin{DB: db}
// TODO: move oauth setup into oauth provider
hs := &httpServer{
r: r,
@ -80,19 +80,71 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail,
mailSender: mailSender,
mailLinkCache: cache.New[mailLinkKey, string](),
authSources: []auth.Provider{
authInitial,
authPassword,
authOtp,
authOAuth,
authPasskey,
},
authButtons: make([]auth.Button, 0),
formProviderLookup: make(map[string]auth.Form),
}
var err error
hs.manager, err = issuer.NewManager(db, config.Namespace)
if err != nil {
logger.Logger.Fatal("Failed to load SSO services", "err", err)
}
authPassword := &providers.PasswordLogin{DB: db}
authOtp := &providers.OtpLogin{DB: db}
authOAuth := &providers.OAuthLogin{DB: db, BaseUrl: config.BaseUrl, Manager: hs.manager}
authOAuth.Init()
authPasskey := &providers.PasskeyLogin{DB: db}
authInitial := &providers.InitialLogin{DB: db, MyNamespace: config.Namespace, Manager: hs.manager, OAuth: authOAuth}
hs.authSources = []auth.Provider{
authInitial,
authPassword,
authOtp,
authOAuth,
authPasskey,
}
r.GET("/oauth/:namespace/start", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
namespace := params.ByName("namespace")
if !authOAuth.RedirectToAuthorize(rw, req, namespace) {
http.Error(rw, "Invalid OAuth namespace", http.StatusBadRequest)
}
})
r.GET("/oauth/:namespace/callback", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
namespace := params.ByName("namespace")
authOAuth.OAuthCallback(rw, req, namespace, func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
resp, err := sso.OAuth2Config.Client(req.Context(), token).Get(sso.UserInfoEndpoint.String())
if err != nil {
return auth.UserAuth{}, err
}
defer resp.Body.Close()
var userInfoJson auth.UserInfoFields
if err := json.NewDecoder(resp.Body).Decode(&userInfoJson); err != nil {
return auth.UserAuth{}, err
}
subject, ok := userInfoJson.GetString("sub")
if !ok {
return auth.UserAuth{}, fmt.Errorf("invalid subject")
}
subject += "@" + sso.Config.Namespace
return auth.UserAuth{
Subject: subject,
Factor: process.StateBasic,
UserInfo: userInfoJson,
}, nil
}, func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool {
// TODO: this should be using the existing auth flow calls
return hs.setLoginDataCookie(rw, authData, loginName)
}, func(rw http.ResponseWriter, req *http.Request) {
// TODO: is this really needed like this?
utils.SafeRedirect(rw, req)
})
})
// build slices and maps for quick access to auth interfaces
hs.authButtons = make([]auth.Button, 0)
hs.formProviderLookup = make(map[string]auth.Form)
for _, source := range hs.authSources {
if button, isButton := source.(auth.Button); isButton {
hs.authButtons = append(hs.authButtons, button)
@ -103,13 +155,7 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail,
}
}
var err error
hs.manager, err = issuer.NewManager(config.Namespace, config.SsoServices)
if err != nil {
logger.Logger.Fatal("Failed to load SSO services", "err", err)
}
SetupOpenId(r, &config.BaseUrl, signingKey)
SetupOpenId(r, config.BaseUrl, signingKey)
r.GET("/", hs.OptionalAuthentication(false, hs.Home))
r.POST("/logout", hs.RequireAuthentication(hs.logoutPost))

View File

@ -29,3 +29,5 @@ sql:
go_type: "database/sql.NullString"
- column: "users.token_expiry"
go_type: "database/sql.NullTime"
- column: "oauth_sources.address"
go_type: "github.com/1f349/lavender/utils.URL"

BIN
tmp/main

Binary file not shown.

View File

@ -1,21 +0,0 @@
package utils
import (
"encoding"
"net/url"
)
type JsonUrl struct {
*url.URL
}
var _ encoding.TextUnmarshaler = &JsonUrl{}
func (s *JsonUrl) UnmarshalText(text []byte) error {
parse, err := url.Parse(string(text))
if err != nil {
return err
}
s.URL = parse
return nil
}

51
utils/loginname.go Normal file
View File

@ -0,0 +1,51 @@
package utils
import (
"errors"
"strings"
)
var ErrInvalidLoginNameFormat = errors.New("invalid login name format")
func ParseLoginName(loginName string) (user string, namespace string, err error) {
if loginName == "" || strings.HasPrefix(loginName, "@") || strings.HasSuffix(loginName, "@") || containsInvalidLoginNameRunes(loginName) {
return "", "", ErrInvalidLoginNameFormat
}
// @ should have at least one byte before it
n := strings.IndexByte(loginName, '@')
if n < 1 {
return "", "", ErrInvalidLoginNameFormat
}
// there should not be a second @
n2 := strings.IndexByte(loginName[n+1:], '@')
if n2 != -1 {
return "", "", ErrInvalidLoginNameFormat
}
return loginName[:n], loginName[n+1:], nil
}
func containsInvalidLoginNameRunes(loginName string) bool {
// check if the name contains an invalid rune
return strings.ContainsFunc(loginName, func(r rune) bool {
return !isValidLoginNameRune(r)
})
}
func isValidLoginNameRune(r rune) bool {
switch {
case r >= '0' && r <= '9':
return true
case r >= 'a' && r <= 'z':
return true
case r == '.':
return true
case r == '-':
return true
case r == '@':
return true
default:
return false
}
}

174
utils/loginname_test.go Normal file
View File

@ -0,0 +1,174 @@
package utils
import (
"encoding/hex"
"github.com/stretchr/testify/assert"
"regexp"
"strings"
"testing"
)
var parseLoginNameTests = []struct {
User string
Namespace string
HasError bool
Input string
}{
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com", false, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com"},
{"", "", true, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com\u1111"},
{"", "", true, "\u1111aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com"},
{"", "", true, "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com"},
{"", "", true, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa@bbbbbbbbbbbbbbbbbbbbbbbbbbbb.com@"},
{"a", "b.com", false, "a@b.com"},
{"aa", "bb.com", false, "aa@bb.com"},
}
var parseLoginNameImplementations = []struct {
Name string
Func func(string) (string, string, error)
}{
{"ParseLoginName", ParseLoginName},
{"parseLoginNameLoopRunes", parseLoginNameLoopRunes},
{"parseLoginNameRegex", parseLoginNameRegex},
}
func TestParseLoginName(t *testing.T) {
for _, impl := range parseLoginNameImplementations {
t.Run(impl.Name, func(t *testing.T) {
for _, i := range parseLoginNameTests {
t.Run(i.Input, func(t *testing.T) {
user, namespace, err := ParseLoginName(i.Input)
if i.HasError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, i.User, user)
assert.Equal(t, i.Namespace, namespace)
})
}
})
}
}
func FuzzParseLoginName(f *testing.F) {
for _, i := range parseLoginNameTests {
f.Add(i.Input)
}
f.Fuzz(func(t *testing.T, s string) {
t.Log("Input: ", s, hex.EncodeToString([]byte(s)))
hasError := s == "" || strings.HasPrefix(s, "@") || strings.HasSuffix(s, "@") || strings.Count(s, "@") != 1 || strings.ContainsFunc(s, func(r rune) bool {
return !isValidLoginNameRune(r)
})
login, namespace, err := ParseLoginName(s)
if hasError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
if err == nil {
n := strings.IndexRune(s, '@')
assert.Equal(t, s[:n], login)
assert.Equal(t, s[n+1:], namespace)
} else {
assert.Equal(t, login, "", "Login should be empty if an error occurred")
assert.Equal(t, namespace, "", "Namespace should be empty if an error occurred")
}
})
}
func parseLoginBench(b *testing.B, f func(string) (string, string, error)) {
for _, i := range parseLoginNameTests {
b.Run(i.Input, func(b *testing.B) {
for range b.N {
_, _, _ = f(i.Input)
}
})
}
}
func BenchmarkParseLoginName(b *testing.B) {
b.Run("ParseLoginName", func(b *testing.B) {
parseLoginBench(b, ParseLoginName)
})
b.Run("parseLoginNameLoopRunes", func(b *testing.B) {
parseLoginBench(b, parseLoginNameLoopRunes)
})
b.Run("parseLoginNameRegex", func(b *testing.B) {
parseLoginBench(b, parseLoginNameRegex)
})
}
func parseLoginNameLoopRunes(s string) (user string, namespace string, err error) {
if s == "" || strings.HasPrefix(s, "@") || strings.HasSuffix(s, "@") {
return "", "", ErrInvalidLoginNameFormat
}
hasAt := false
atIndex := 0
for i, r := range []rune(s) {
switch {
case r == '@':
if hasAt {
return "", "", ErrInvalidLoginNameFormat
}
hasAt = true
atIndex = i
case r == '.':
continue
case r == '-':
continue
case r >= 'a' && r <= 'z':
continue
case r >= '0' && r <= '9':
continue
default:
return "", "", ErrInvalidLoginNameFormat
}
}
return s[:atIndex], s[atIndex+1:], nil
}
// regexLoginName is a regex based implementation of ParseLoginName
//
// This implementation prevents using - or . at the start and end of the user and namespace
var regexLoginName = regexp.MustCompile(`^([a-z0-9]([a-z0-9-.]*[a-z0-9]|))@([a-z0-9]([a-z0-9-.]*[a-z0-9]|))$`)
func parseLoginNameRegex(s string) (user string, namespace string, err error) {
if s == "" {
return "", "", ErrInvalidLoginNameFormat
}
matches := regexLoginName.FindStringSubmatch(s)
if matches == nil {
return "", "", ErrInvalidLoginNameFormat
}
return matches[1], matches[3], nil
}
func BenchmarkA(b *testing.B) {
b.Run("A", func(b *testing.B) {
for range b.N {
_ = isValidLoginNameRune('b')
_ = isValidLoginNameRune('4')
_ = isValidLoginNameRune('.')
_ = isValidLoginNameRune('-')
_ = isValidLoginNameRune('@')
_ = isValidLoginNameRune(' ')
}
})
b.Run("B", func(b *testing.B) {
for range b.N {
_ = isValidLoginNameRune2('b')
_ = isValidLoginNameRune2('4')
_ = isValidLoginNameRune2('.')
_ = isValidLoginNameRune2('-')
_ = isValidLoginNameRune2('@')
_ = isValidLoginNameRune2(' ')
}
})
}
func isValidLoginNameRune2(r rune) bool {
return (r >= '0' && r <= '9') || (r >= 'a' && r <= 'z') || r == '.' || r == '-' || r == '@'
}

View File

@ -1,4 +1,4 @@
package url
package utils
import (
"encoding"
@ -6,35 +6,33 @@ import (
"path"
)
type URL struct {
url.URL
}
func (u *URL) Resolve(paths ...string) *URL {
return &URL{URL: *u.URL.ResolveReference(&url.URL{Path: path.Join(paths...)})}
}
func (u URL) MarshalText() (text []byte, err error) {
return []byte(u.String()), nil
}
func (u *URL) UnmarshalText(text []byte) error {
parse, err := u.Parse(string(text))
if err != nil {
return err
}
u.URL = *parse
return nil
}
var _ encoding.TextMarshaler = (*URL)(nil)
var _ encoding.TextUnmarshaler = (*URL)(nil)
type URL struct{ *url.URL }
func MustParse(rawURL string) *URL {
u, err := url.Parse(rawURL)
if err != nil {
panic(err)
}
return &URL{*u}
return &URL{u}
}
func (u *URL) Resolve(paths ...string) *URL {
return &URL{URL: u.URL.ResolveReference(&url.URL{Path: path.Join(paths...)})}
}
func (u *URL) MarshalText() (text []byte, err error) {
return []byte(u.URL.String()), nil
}
func (u *URL) UnmarshalText(text []byte) error {
parse, err := url.Parse(string(text))
if err != nil {
return err
}
u.URL = parse
return nil
}