mirror of
https://github.com/1f349/lavender.git
synced 2025-04-16 07:47:55 +01:00
227 lines
4.9 KiB
Go
227 lines
4.9 KiB
Go
package issuer
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/1f349/lavender/database"
|
|
"github.com/1f349/lavender/logger"
|
|
"github.com/robfig/cron"
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
fetchPoolSize = 4
|
|
fetchTimeout = 2 * time.Minute
|
|
fetchRetryDelay = 15 * time.Second
|
|
fetchRetryAfterTimeout = 15 * time.Minute
|
|
fetchRetryCount = 3
|
|
)
|
|
|
|
var isValidNamespace = regexp.MustCompile("^[0-9a-z.]+$")
|
|
|
|
var MeWellKnown = &WellKnownOIDC{}
|
|
|
|
type managerReloadDB interface {
|
|
GetOAuthSources(ctx context.Context) ([]database.OauthSource, error)
|
|
}
|
|
|
|
type Manager struct {
|
|
db managerReloadDB
|
|
|
|
client *http.Client
|
|
|
|
sourceMu *sync.RWMutex
|
|
sourceMap map[string]*WellKnownOIDC
|
|
|
|
errMu *sync.RWMutex
|
|
errMap map[string]FetchErrorDetails
|
|
|
|
shutdownCall context.CancelFunc
|
|
}
|
|
|
|
type FetchErrorDetails struct {
|
|
Retry time.Time
|
|
Errors []error
|
|
}
|
|
|
|
type managerRoundTripper struct {
|
|
Base http.RoundTripper
|
|
}
|
|
|
|
func (a *managerRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
|
userAgent := strings.Join(request.Header.Values("User-Agent"), " ")
|
|
if strings.TrimSpace(userAgent) == "" {
|
|
request.Header.Set("User-Agent", "Lavender/1.0 (OAuth Manager)")
|
|
}
|
|
return a.Base.RoundTrip(request)
|
|
}
|
|
|
|
func NewManager(db managerReloadDB, myNamespace string) (*Manager, error) {
|
|
return NewManagerWithClient(db, myNamespace, &http.Client{
|
|
Transport: &managerRoundTripper{
|
|
Base: http.DefaultTransport,
|
|
},
|
|
})
|
|
}
|
|
|
|
func NewManagerWithClient(db managerReloadDB, myNamespace string, client *http.Client) (*Manager, error) {
|
|
l := &Manager{
|
|
db: db,
|
|
client: client,
|
|
sourceMu: new(sync.RWMutex),
|
|
sourceMap: make(map[string]*WellKnownOIDC),
|
|
errMap: make(map[string]FetchErrorDetails),
|
|
}
|
|
l.sourceMap[myNamespace] = MeWellKnown
|
|
|
|
// reload should run blocking on start up
|
|
l.internalReload(context.Background())
|
|
|
|
reloadDaily := make(chan struct{}, 1)
|
|
|
|
tab := cron.New()
|
|
err := tab.AddFunc("0 2 * * *", func() {
|
|
reloadDaily <- struct{}{}
|
|
})
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go l.reloadWorker(ctx, reloadDaily)
|
|
|
|
l.shutdownCall = cancel
|
|
|
|
return l, nil
|
|
}
|
|
|
|
func (m *Manager) Shutdown() { m.shutdownCall() }
|
|
|
|
func (m *Manager) reloadWorker(ctx context.Context, reload <-chan struct{}) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
logger.Logger.Info("Shutting down oauth manager reload worker")
|
|
return
|
|
case <-reload:
|
|
m.internalReload(ctx)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Manager) internalReload(ctx context.Context) {
|
|
logger.Logger.Info("Reloading oauth sources")
|
|
dbCtx, cancel := context.WithTimeout(ctx, fetchTimeout)
|
|
sources, err := m.db.GetOAuthSources(dbCtx)
|
|
cancel()
|
|
if err != nil {
|
|
// TODO: send email to admin
|
|
logger.Logger.Warn("Failed to load OAuth sources from the database", "err", err)
|
|
return
|
|
}
|
|
|
|
// construct map of db oauth sources
|
|
toAdd := make(map[string]database.OauthSource)
|
|
for _, source := range sources {
|
|
toAdd[source.Namespace] = source
|
|
}
|
|
|
|
// remove sources from the active map
|
|
m.sourceMu.Lock()
|
|
for k, v := range m.sourceMap {
|
|
// ignore my own well-known
|
|
if v == MeWellKnown {
|
|
continue
|
|
}
|
|
|
|
if _, shouldAdd := toAdd[k]; !shouldAdd {
|
|
delete(m.sourceMap, k)
|
|
}
|
|
}
|
|
m.sourceMu.Unlock()
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(fetchPoolSize)
|
|
|
|
fetchChan := make(chan database.OauthSource, fetchPoolSize)
|
|
|
|
for range fetchPoolSize {
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
for source := range fetchChan {
|
|
itemCtx, cancel := context.WithTimeout(ctx, fetchTimeout)
|
|
errs := m.reloadSource(itemCtx, source)
|
|
cancel()
|
|
|
|
// this also resets the error if it was set previously
|
|
m.sourceMu.Lock()
|
|
if len(errs) > 0 {
|
|
m.errMap[source.Namespace] = FetchErrorDetails{
|
|
Retry: time.Now().Add(fetchRetryAfterTimeout),
|
|
Errors: errs,
|
|
}
|
|
// TODO: send email to admin
|
|
} else {
|
|
delete(m.errMap, source.Namespace)
|
|
// TODO: send resolved email to admin
|
|
}
|
|
m.sourceMu.Unlock()
|
|
}
|
|
}()
|
|
}
|
|
|
|
// reload sources
|
|
for _, source := range toAdd {
|
|
fetchChan <- source
|
|
}
|
|
|
|
close(fetchChan)
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func (m *Manager) reloadSource(ctx context.Context, source database.OauthSource) []error {
|
|
var errs []error
|
|
|
|
for range fetchRetryCount {
|
|
oidc, err := fetchConfig(ctx, m.client, source)
|
|
fmt.Println(oidc, err)
|
|
if err == nil {
|
|
// fetch was successful
|
|
m.sourceMu.Lock()
|
|
m.sourceMap[source.Namespace] = oidc
|
|
m.sourceMu.Unlock()
|
|
return nil
|
|
}
|
|
errs = append(errs, fmt.Errorf("failed to fetch OIDC well-known configuration: %w", err))
|
|
|
|
// wait before retrying
|
|
select {
|
|
case <-ctx.Done():
|
|
return errs
|
|
case <-time.After(fetchRetryDelay):
|
|
continue
|
|
}
|
|
}
|
|
|
|
return errs
|
|
}
|
|
|
|
func (m *Manager) CheckNamespace(namespace string) bool {
|
|
m.sourceMu.RLock()
|
|
defer m.sourceMu.RUnlock()
|
|
_, ok := m.sourceMap[namespace]
|
|
return ok
|
|
}
|
|
|
|
func (m *Manager) GetService(namespace string) *WellKnownOIDC {
|
|
m.sourceMu.RLock()
|
|
defer m.sourceMu.RUnlock()
|
|
return m.sourceMap[namespace]
|
|
}
|