Continued working on OAuth flow

This commit is contained in:
Melon 2023-10-04 14:51:38 +01:00
parent 1efd68b0eb
commit d772d14041
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
9 changed files with 193 additions and 61 deletions

View File

@ -8,7 +8,6 @@ import (
type startUpConfig struct { type startUpConfig struct {
Listen string `json:"listen"` Listen string `json:"listen"`
BaseUrl string `json:"base_url"` BaseUrl string `json:"base_url"`
PrivateKey string `json:"private_key"`
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
SsoServices []loginServiceManager.SsoConfig `json:"sso_services"` SsoServices []loginServiceManager.SsoConfig `json:"sso_services"`
AllowedClients []utils.JsonUrl `json:"allowed_clients"` AllowedClients []utils.JsonUrl `json:"allowed_clients"`

View File

@ -74,10 +74,10 @@ func normalLoad(startUp startUpConfig, wd string) {
manager, err := issuer.NewManager(startUp.SsoServices) manager, err := issuer.NewManager(startUp.SsoServices)
if err != nil { if err != nil {
log.Fatal("[Lavender] Failed to create SSO service manager") log.Fatal("[Lavender] Failed to create SSO service manager: ", err)
} }
srv := server.NewHttpServer(startUp.Listen, startUp.BaseUrl, manager, mSign) srv := server.NewHttpServer(startUp.Listen, startUp.BaseUrl, startUp.AllowedClients, manager, mSign)
log.Printf("[Lavender] Starting HTTP server on '%s'\n", srv.Addr) log.Printf("[Lavender] Starting HTTP server on '%s'\n", srv.Addr)
go utils.RunBackgroundHttp("HTTP", srv) go utils.RunBackgroundHttp("HTTP", srv)

3
go.mod
View File

@ -10,15 +10,18 @@ require (
github.com/google/subcommands v1.2.0 github.com/google/subcommands v1.2.0
github.com/google/uuid v1.3.1 github.com/google/uuid v1.3.1
github.com/julienschmidt/httprouter v1.3.0 github.com/julienschmidt/httprouter v1.3.0
github.com/stretchr/testify v1.8.4
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
) )
require ( require (
github.com/MrMelon54/rescheduler v0.0.2 // indirect github.com/MrMelon54/rescheduler v0.0.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/golang-jwt/jwt/v4 v4.5.0 // indirect
github.com/golang/protobuf v1.4.2 // indirect github.com/golang/protobuf v1.4.2 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/net v0.9.0 // indirect golang.org/x/net v0.9.0 // indirect
google.golang.org/appengine v1.6.6 // indirect google.golang.org/appengine v1.6.6 // indirect
google.golang.org/protobuf v1.23.0 // indirect google.golang.org/protobuf v1.23.0 // indirect

View File

@ -1,5 +1,13 @@
package issuer package issuer
import (
"fmt"
"regexp"
"strings"
)
var isValidNamespace = regexp.MustCompile("^[0-9a-z.]+$")
type Manager struct { type Manager struct {
m map[string]*WellKnownOIDC m map[string]*WellKnownOIDC
} }
@ -7,23 +15,36 @@ type Manager struct {
func NewManager(services []SsoConfig) (*Manager, error) { func NewManager(services []SsoConfig) (*Manager, error) {
l := &Manager{m: make(map[string]*WellKnownOIDC)} l := &Manager{m: make(map[string]*WellKnownOIDC)}
for _, i := range services { for _, i := range services {
if !isValidNamespace.MatchString(i.Namespace) {
return nil, fmt.Errorf("invalid namespace: %s", i.Namespace)
}
conf, err := i.FetchConfig() conf, err := i.FetchConfig()
if err != nil { if err != nil {
return nil, err return nil, err
} }
// save by issuer // save by namespace
l.m[conf.Issuer] = conf l.m[i.Namespace] = conf
} }
return l, nil return l, nil
} }
func (l *Manager) CheckIssuer(issuer string) bool { func (l *Manager) CheckNamespace(namespace string) bool {
_, ok := l.m[issuer] _, ok := l.m[namespace]
return ok return ok
} }
func (l *Manager) FindServiceFromLogin(login string) *WellKnownOIDC { func (l *Manager) FindServiceFromLogin(login string) *WellKnownOIDC {
// @ should have at least one byte before it
return l.m[namespace] n := strings.IndexByte(login, '@')
if n < 1 {
return nil
}
// there should not be a second @
n2 := strings.IndexByte(login[n+1:], '@')
if n2 != -1 {
return nil
}
return l.m[login[n+1:]]
} }

53
issuer/manager_test.go Normal file
View File

@ -0,0 +1,53 @@
package issuer
import (
"github.com/1f349/lavender/utils"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/url"
"strings"
"testing"
)
var testAddrUrl = func() utils.JsonUrl {
a, err := url.Parse("https://example.com")
if err != nil {
panic(err)
}
return utils.JsonUrl{URL: a}
}()
func testBody() io.ReadCloser {
return io.NopCloser(strings.NewReader("{}"))
}
func TestManager_CheckIssuer(t *testing.T) {
httpGet = func(url string) (resp *http.Response, err error) {
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
}
manager, err := NewManager([]SsoConfig{
{
Addr: testAddrUrl,
Namespace: "example.com",
},
})
assert.NoError(t, err)
assert.True(t, manager.CheckNamespace("example.com"))
assert.False(t, manager.CheckNamespace("missing.example.com"))
}
func TestManager_FindServiceFromLogin(t *testing.T) {
httpGet = func(url string) (resp *http.Response, err error) {
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
}
manager, err := NewManager([]SsoConfig{
{
Addr: testAddrUrl,
Namespace: "example.com",
},
})
assert.NoError(t, err)
assert.Equal(t, manager.FindServiceFromLogin("jane@example.com"), manager.m["example.com"])
assert.Nil(t, manager.FindServiceFromLogin("jane@missing.example.com"))
}

View File

@ -12,6 +12,8 @@ import (
"strings" "strings"
) )
var httpGet = http.Get
// SsoConfig is the base URL for an OAUTH/OPENID/SSO login service // SsoConfig is the base URL for an OAUTH/OPENID/SSO login service
// The path `/.well-known/openid-configuration` should be available // The path `/.well-known/openid-configuration` should be available
type SsoConfig struct { type SsoConfig struct {
@ -33,7 +35,7 @@ func (s SsoConfig) FetchConfig() (*WellKnownOIDC, error) {
u += ".well-known/openid-configuration" u += ".well-known/openid-configuration"
// fetch metadata // fetch metadata
get, err := http.Get(u) get, err := httpGet(u)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -41,7 +43,20 @@ func (s SsoConfig) FetchConfig() (*WellKnownOIDC, error) {
var c WellKnownOIDC var c WellKnownOIDC
err = json.NewDecoder(get.Body).Decode(&c) err = json.NewDecoder(get.Body).Decode(&c)
return &c, err if err != nil {
return nil, err
}
c.OAuth2Config = oauth2.Config{
ClientID: c.Config.Client.ID,
ClientSecret: c.Config.Client.Secret,
Endpoint: oauth2.Endpoint{
AuthURL: c.AuthorizationEndpoint,
TokenURL: c.TokenEndpoint,
AuthStyle: oauth2.AuthStyleInHeader,
},
Scopes: c.Config.Client.Scopes,
}
return &c, nil
} }
type WellKnownOIDC struct { type WellKnownOIDC struct {
@ -54,6 +69,7 @@ type WellKnownOIDC struct {
ScopesSupported []string `json:"scopes_supported"` ScopesSupported []string `json:"scopes_supported"`
ClaimsSupported []string `json:"claims_supported"` ClaimsSupported []string `json:"claims_supported"`
GrantTypesSupported []string `json:"grant_types_supported"` GrantTypesSupported []string `json:"grant_types_supported"`
OAuth2Config oauth2.Config `json:"-"`
} }
func (o WellKnownOIDC) Validate() error { func (o WellKnownOIDC) Validate() error {
@ -90,19 +106,6 @@ func (o WellKnownOIDC) Validate() error {
return nil return nil
} }
func (o WellKnownOIDC) Oauth2Config() oauth2.Config {
return oauth2.Config{
ClientID: o.Config.Client.ID,
ClientSecret: o.Config.Client.Secret,
Endpoint: oauth2.Endpoint{
AuthURL: o.AuthorizationEndpoint,
TokenURL: o.TokenEndpoint,
AuthStyle: oauth2.AuthStyleInHeader,
},
Scopes: o.Config.Client.Scopes,
}
}
func (o WellKnownOIDC) ValidReturnUrl(u *url.URL) bool { func (o WellKnownOIDC) ValidReturnUrl(u *url.URL) bool {
o.Config.Addr return o.Config.Addr.Scheme == u.Scheme && o.Config.Addr.Host == u.Host
} }

View File

@ -0,0 +1,21 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>{{.ServiceName}}</title>
</head>
<script>
let loginData = {target:{{.TargetOrigin}}, message:{{.LoginData}}};
document.addEventListener("load", function () {
postMessage(loginData.message, loginData.target);
setTimeout(function () {
window.close();
}, 2000);
});
</script>
<body>
<header>
<h1>{{.ServiceName}}</h1>
</header>
<main>Loading...</main>
</body>
</html>

View File

@ -2,13 +2,13 @@ package server
import ( import (
_ "embed" _ "embed"
"encoding/json"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"html/template" "html/template"
"log" "log"
"net/http" "net/http"
"net/url" "strings"
"regexp"
"time" "time"
) )
@ -17,7 +17,9 @@ var (
flowPopupHtml string flowPopupHtml string
flowPopupTemplate *template.Template flowPopupTemplate *template.Template
isValidState = regexp.MustCompile("^[a-z.]+%[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$") //go:embed flow-callback.go.html
flowCallbackHtml string
flowCallbackTemplate *template.Template
) )
func init() { func init() {
@ -26,12 +28,17 @@ func init() {
log.Fatal("flow.go: Failed to parse flow popup HTML:", err) log.Fatal("flow.go: Failed to parse flow popup HTML:", err)
} }
flowPopupTemplate = pageParse flowPopupTemplate = pageParse
pageParse, err = template.New("pages").Parse(flowCallbackHtml)
if err != nil {
log.Fatal("flow.go: Failed to parse flow callback HTML:", err)
}
flowCallbackTemplate = pageParse
} }
func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
err := flowPopupTemplate.Execute(rw, map[string]any{ err := flowPopupTemplate.Execute(rw, map[string]any{
"ServiceName": flowPopupTemplate, "ServiceName": flowPopupTemplate,
"Return": req.URL.Query().Get("return"), "Origin": req.URL.Query().Get("origin"),
}) })
if err != nil { if err != nil {
log.Printf("Failed to render page: %s\n", err) log.Printf("Failed to render page: %s\n", err)
@ -45,13 +52,9 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _
return return
} }
returnUrl, err := url.Parse(req.PostFormValue("return")) targetOrigin := req.PostFormValue("origin")
if err != nil { if _, found := h.services[targetOrigin]; !found {
http.Error(rw, "Invalid return URL", http.StatusBadRequest) http.Error(rw, "Invalid target origin", http.StatusBadRequest)
return
}
if !login.ValidReturnUrl(returnUrl) {
http.Error(rw, "Invalid return URL for this application", http.StatusBadRequest)
return return
} }
@ -59,11 +62,11 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _
state := login.Config.Namespace + "%" + uuid.NewString() state := login.Config.Namespace + "%" + uuid.NewString()
h.flowState.Set(state, flowStateData{ h.flowState.Set(state, flowStateData{
login, login,
returnUrl, targetOrigin,
}, time.Now().Add(15*time.Minute)) }, time.Now().Add(15*time.Minute))
// generate oauth2 config and redirect to authorize URL // generate oauth2 config and redirect to authorize URL
oa2conf := login.Oauth2Config() oa2conf := login.OAuth2Config
oa2conf.RedirectURL = h.baseUrl + "/callback" oa2conf.RedirectURL = h.baseUrl + "/callback"
nextUrl := oa2conf.AuthCodeURL(state) nextUrl := oa2conf.AuthCodeURL(state)
http.Redirect(rw, req, nextUrl, http.StatusFound) http.Redirect(rw, req, nextUrl, http.StatusFound)
@ -78,11 +81,8 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
q := req.URL.Query() q := req.URL.Query()
state := q.Get("state") state := q.Get("state")
if !isValidState.MatchString(state) { n := strings.IndexByte(state, '%')
http.Error(rw, "Invalid state", http.StatusBadRequest) if !h.manager.CheckNamespace(state[:n]) {
return
}
if !h.manager.CheckIssuer(state) {
http.Error(rw, "Invalid state", http.StatusBadRequest) http.Error(rw, "Invalid state", http.StatusBadRequest)
return return
} }
@ -92,5 +92,30 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
return return
} }
// TODO: process flow callback exchange, err := v.sso.OAuth2Config.Exchange(req.Context(), q.Get("code"))
if err != nil {
http.Error(rw, "Failed to exchange code", http.StatusInternalServerError)
return
}
client := v.sso.OAuth2Config.Client(req.Context(), exchange)
v2, err := client.Get(v.sso.UserInfoEndpoint)
if err != nil {
http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError)
return
}
defer v2.Body.Close()
if v2.StatusCode != http.StatusOK {
http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError)
return
}
var v3 any
if json.NewDecoder(v2.Body).Decode(&v3) != nil {
http.Error(rw, "Failed to decode userinfo JSON", http.StatusInternalServerError)
return
}
_ = flowCallbackTemplate.Execute(rw, map[string]any{
"TargetOrigin": v.targetOrigin,
"TargetMessage": v3,
})
} }

View File

@ -4,10 +4,10 @@ import (
"fmt" "fmt"
"github.com/1f349/cache" "github.com/1f349/cache"
"github.com/1f349/lavender/issuer" "github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/utils"
"github.com/MrMelon54/mjwt" "github.com/MrMelon54/mjwt"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"net/http" "net/http"
"net/url"
"time" "time"
) )
@ -17,14 +17,15 @@ type HttpServer struct {
manager *issuer.Manager manager *issuer.Manager
signer mjwt.Signer signer mjwt.Signer
flowState *cache.Cache[string, flowStateData] flowState *cache.Cache[string, flowStateData]
services map[string]struct{}
} }
type flowStateData struct { type flowStateData struct {
sso *issuer.WellKnownOIDC sso *issuer.WellKnownOIDC
returnUrl *url.URL targetOrigin string
} }
func NewHttpServer(listen, baseUrl string, manager *issuer.Manager, signer mjwt.Signer) *http.Server { func NewHttpServer(listen, baseUrl string, clients []utils.JsonUrl, manager *issuer.Manager, signer mjwt.Signer) *http.Server {
r := httprouter.New() r := httprouter.New()
// remove last slash from baseUrl // remove last slash from baseUrl
@ -35,11 +36,17 @@ func NewHttpServer(listen, baseUrl string, manager *issuer.Manager, signer mjwt.
} }
} }
services := make(map[string]struct{})
for _, i := range clients {
services[i.Host] = struct{}{}
}
hs := &HttpServer{ hs := &HttpServer{
r: r, r: r,
baseUrl: baseUrl, baseUrl: baseUrl,
manager: manager, manager: manager,
signer: signer, signer: signer,
services: services,
} }
r.GET("/", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { r.GET("/", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {