mirror of
https://github.com/1f349/lavender.git
synced 2024-12-22 07:34:06 +00:00
Add loads of test cases
This commit is contained in:
parent
01d03fef9d
commit
bf39f4421b
@ -1,15 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
loginServiceManager "github.com/1f349/lavender/issuer"
|
|
||||||
"github.com/1f349/lavender/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
type startUpConfig struct {
|
|
||||||
Listen string `json:"listen"`
|
|
||||||
BaseUrl string `json:"base_url"`
|
|
||||||
ServiceName string `json:"service_name"`
|
|
||||||
Issuer string `json:"issuer"`
|
|
||||||
SsoServices []loginServiceManager.SsoConfig `json:"sso_services"`
|
|
||||||
AllowedClients []utils.JsonUrl `json:"allowed_clients"`
|
|
||||||
}
|
|
@ -5,7 +5,6 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
"github.com/1f349/lavender/issuer"
|
|
||||||
"github.com/1f349/lavender/server"
|
"github.com/1f349/lavender/server"
|
||||||
"github.com/1f349/violet/utils"
|
"github.com/1f349/violet/utils"
|
||||||
exit_reload "github.com/MrMelon54/exit-reload"
|
exit_reload "github.com/MrMelon54/exit-reload"
|
||||||
@ -50,7 +49,7 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
|
|||||||
return subcommands.ExitFailure
|
return subcommands.ExitFailure
|
||||||
}
|
}
|
||||||
|
|
||||||
var config startUpConfig
|
var config server.Conf
|
||||||
err = json.NewDecoder(openConf).Decode(&config)
|
err = json.NewDecoder(openConf).Decode(&config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("[Lavender] Error: invalid config file: ", err)
|
log.Println("[Lavender] Error: invalid config file: ", err)
|
||||||
@ -66,18 +65,13 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
|
|||||||
return subcommands.ExitSuccess
|
return subcommands.ExitSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalLoad(startUp startUpConfig, wd string) {
|
func normalLoad(startUp server.Conf, wd string) {
|
||||||
mSign, err := mjwt.NewMJwtSignerFromFileOrCreate(startUp.Issuer, filepath.Join(wd, "lavender.private.key"), rand.Reader, 4096)
|
mSign, err := mjwt.NewMJwtSignerFromFileOrCreate(startUp.Issuer, filepath.Join(wd, "lavender.private.key"), rand.Reader, 4096)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("[Lavender] Failed to load or create MJWT signer:", err)
|
log.Fatal("[Lavender] Failed to load or create MJWT signer:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := issuer.NewManager(startUp.SsoServices)
|
srv := server.NewHttpServer(startUp, mSign)
|
||||||
if err != nil {
|
|
||||||
log.Fatal("[Lavender] Failed to create SSO service manager: ", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := server.NewHttpServer(startUp.Listen, startUp.BaseUrl, startUp.ServiceName, startUp.AllowedClients, manager, mSign)
|
|
||||||
log.Printf("[Lavender] Starting HTTP server on '%s'\n", srv.Addr)
|
log.Printf("[Lavender] Starting HTTP server on '%s'\n", srv.Addr)
|
||||||
go utils.RunBackgroundHttp("HTTP", srv)
|
go utils.RunBackgroundHttp("HTTP", srv)
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ func testBody() io.ReadCloser {
|
|||||||
return io.NopCloser(strings.NewReader("{}"))
|
return io.NopCloser(strings.NewReader("{}"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_CheckIssuer(t *testing.T) {
|
func TestManager_CheckNamespace(t *testing.T) {
|
||||||
httpGet = func(url string) (resp *http.Response, err error) {
|
httpGet = func(url string) (resp *http.Response, err error) {
|
||||||
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
|
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
|
||||||
}
|
}
|
||||||
|
1
issuer/sso_test.go
Normal file
1
issuer/sso_test.go
Normal file
@ -0,0 +1 @@
|
|||||||
|
package issuer
|
15
server/conf.go
Normal file
15
server/conf.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/1f349/lavender/issuer"
|
||||||
|
"github.com/1f349/lavender/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Conf struct {
|
||||||
|
Listen string `json:"listen"`
|
||||||
|
BaseUrl string `json:"base_url"`
|
||||||
|
ServiceName string `json:"service_name"`
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
SsoServices []issuer.SsoConfig `json:"sso_services"`
|
||||||
|
AllowedClients []utils.JsonUrl `json:"allowed_clients"`
|
||||||
|
}
|
@ -18,16 +18,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
//go:embed flow-popup.go.html
|
|
||||||
flowPopupHtml string
|
|
||||||
flowPopupTemplate *template.Template
|
|
||||||
|
|
||||||
//go:embed flow-callback.go.html
|
|
||||||
flowCallbackHtml string
|
|
||||||
flowCallbackTemplate *template.Template
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
pageParse, err := template.New("pages").Parse(flowPopupHtml)
|
pageParse, err := template.New("pages").Parse(flowPopupHtml)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -44,7 +34,7 @@ func init() {
|
|||||||
|
|
||||||
func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
||||||
err := flowPopupTemplate.Execute(rw, map[string]any{
|
err := flowPopupTemplate.Execute(rw, map[string]any{
|
||||||
"ServiceName": h.serviceName,
|
"ServiceName": h.conf.ServiceName,
|
||||||
"Origin": req.URL.Query().Get("origin"),
|
"Origin": req.URL.Query().Get("origin"),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -75,7 +65,7 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _
|
|||||||
|
|
||||||
// generate oauth2 config and redirect to authorize URL
|
// generate oauth2 config and redirect to authorize URL
|
||||||
oa2conf := login.OAuth2Config
|
oa2conf := login.OAuth2Config
|
||||||
oa2conf.RedirectURL = h.baseUrl + "/callback"
|
oa2conf.RedirectURL = h.conf.BaseUrl + "/callback"
|
||||||
nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", loginName))
|
nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", loginName))
|
||||||
http.Redirect(rw, req, nextUrl, http.StatusFound)
|
http.Redirect(rw, req, nextUrl, http.StatusFound)
|
||||||
}
|
}
|
||||||
@ -101,7 +91,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
|
|||||||
}
|
}
|
||||||
|
|
||||||
oa2conf := v.sso.OAuth2Config
|
oa2conf := v.sso.OAuth2Config
|
||||||
oa2conf.RedirectURL = h.baseUrl + "/callback"
|
oa2conf.RedirectURL = h.conf.BaseUrl + "/callback"
|
||||||
exchange, err := oa2conf.Exchange(context.Background(), q.Get("code"))
|
exchange, err := oa2conf.Exchange(context.Background(), q.Get("code"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Failed exchange:", err)
|
fmt.Println("Failed exchange:", err)
|
||||||
@ -157,7 +147,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
|
|||||||
}
|
}
|
||||||
|
|
||||||
_ = flowCallbackTemplate.Execute(rw, map[string]any{
|
_ = flowCallbackTemplate.Execute(rw, map[string]any{
|
||||||
"ServiceName": h.serviceName,
|
"ServiceName": h.conf.ServiceName,
|
||||||
"TargetOrigin": v.targetOrigin,
|
"TargetOrigin": v.targetOrigin,
|
||||||
"TargetMessage": v3,
|
"TargetMessage": v3,
|
||||||
"AccessToken": accessToken,
|
"AccessToken": accessToken,
|
||||||
|
1
server/flow_test.go
Normal file
1
server/flow_test.go
Normal file
@ -0,0 +1 @@
|
|||||||
|
package server
|
20
server/pages/pages.go
Normal file
20
server/pages/pages.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package pages
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
_ "embed"
|
||||||
|
"html/template"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
//go:embed pages/*
|
||||||
|
flowPages embed.FS
|
||||||
|
flowTemplates *template.Template
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadPages(wd string) {
|
||||||
|
wdFs := os.DirFS(wd)
|
||||||
|
|
||||||
|
flowPages.Open()
|
||||||
|
}
|
@ -4,21 +4,20 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/1f349/cache"
|
"github.com/1f349/cache"
|
||||||
"github.com/1f349/lavender/issuer"
|
"github.com/1f349/lavender/issuer"
|
||||||
"github.com/1f349/lavender/utils"
|
|
||||||
"github.com/MrMelon54/mjwt"
|
"github.com/MrMelon54/mjwt"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type HttpServer struct {
|
type HttpServer struct {
|
||||||
r *httprouter.Router
|
r *httprouter.Router
|
||||||
baseUrl string
|
conf Conf
|
||||||
serviceName string
|
manager *issuer.Manager
|
||||||
manager *issuer.Manager
|
signer mjwt.Signer
|
||||||
signer mjwt.Signer
|
flowState *cache.Cache[string, flowStateData]
|
||||||
flowState *cache.Cache[string, flowStateData]
|
services map[string]struct{}
|
||||||
services map[string]struct{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type flowStateData struct {
|
type flowStateData struct {
|
||||||
@ -26,30 +25,34 @@ type flowStateData struct {
|
|||||||
targetOrigin string
|
targetOrigin string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHttpServer(listen, baseUrl, serviceName string, clients []utils.JsonUrl, manager *issuer.Manager, signer mjwt.Signer) *http.Server {
|
func NewHttpServer(conf Conf, signer mjwt.Signer) *http.Server {
|
||||||
r := httprouter.New()
|
r := httprouter.New()
|
||||||
|
|
||||||
// remove last slash from baseUrl
|
// remove last slash from baseUrl
|
||||||
{
|
{
|
||||||
l := len(baseUrl)
|
l := len(conf.BaseUrl)
|
||||||
if baseUrl[l-1] == '/' {
|
if conf.BaseUrl[l-1] == '/' {
|
||||||
baseUrl = baseUrl[:l-1]
|
conf.BaseUrl = conf.BaseUrl[:l-1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
manager, err := issuer.NewManager(conf.SsoServices)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("[Lavender] Failed to create SSO service manager: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
services := make(map[string]struct{})
|
services := make(map[string]struct{})
|
||||||
for _, i := range clients {
|
for _, i := range conf.AllowedClients {
|
||||||
services[i.String()] = struct{}{}
|
services[i.String()] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
hs := &HttpServer{
|
hs := &HttpServer{
|
||||||
r: r,
|
r: r,
|
||||||
baseUrl: baseUrl,
|
conf: conf,
|
||||||
serviceName: serviceName,
|
manager: manager,
|
||||||
manager: manager,
|
signer: signer,
|
||||||
signer: signer,
|
flowState: cache.New[string, flowStateData](),
|
||||||
flowState: cache.New[string, flowStateData](),
|
services: services,
|
||||||
services: services,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r.GET("/", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
r.GET("/", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
||||||
@ -62,7 +65,7 @@ func NewHttpServer(listen, baseUrl, serviceName string, clients []utils.JsonUrl,
|
|||||||
r.GET("/callback", hs.flowCallback)
|
r.GET("/callback", hs.flowCallback)
|
||||||
|
|
||||||
return &http.Server{
|
return &http.Server{
|
||||||
Addr: listen,
|
Addr: conf.Listen,
|
||||||
Handler: r,
|
Handler: r,
|
||||||
ReadTimeout: time.Minute,
|
ReadTimeout: time.Minute,
|
||||||
ReadHeaderTimeout: time.Minute,
|
ReadHeaderTimeout: time.Minute,
|
||||||
|
@ -24,7 +24,7 @@ func (h *HttpServer) verifyHandler(rw http.ResponseWriter, req *http.Request, _
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check issuer against config
|
// check issuer against config
|
||||||
if b.Issuer != h.baseUrl {
|
if b.Issuer != h.conf.Issuer {
|
||||||
http.Error(rw, "Invalid issuer", http.StatusBadRequest)
|
http.Error(rw, "Invalid issuer", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
61
server/verify_test.go
Normal file
61
server/verify_test.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/mjwt/auth"
|
||||||
|
"github.com/MrMelon54/mjwt/claims"
|
||||||
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestVerifyHandler(t *testing.T) {
|
||||||
|
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
invalidSigner := mjwt.NewMJwtSigner("Invalid Issuer", privKey)
|
||||||
|
h := HttpServer{
|
||||||
|
conf: Conf{Issuer: "Test Issuer"},
|
||||||
|
signer: mjwt.NewMJwtSigner("Test Issuer", privKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
// test for missing bearer response
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "https://example.localhost", nil)
|
||||||
|
h.verifyHandler(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||||
|
assert.Equal(t, "Missing bearer\n", rec.Body.String())
|
||||||
|
|
||||||
|
// test for invalid token response
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
req.Header.Set("Authorization", "Bearer abcd")
|
||||||
|
h.verifyHandler(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||||
|
assert.Equal(t, "Invalid token\n", rec.Body.String())
|
||||||
|
|
||||||
|
// test for invalid issuer response
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
accessToken, err := invalidSigner.GenerateJwt("a", "a", nil, 15*time.Minute, auth.AccessTokenClaims{
|
||||||
|
Perms: claims.NewPermStorage(),
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
h.verifyHandler(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
assert.Equal(t, "Invalid issuer\n", rec.Body.String())
|
||||||
|
|
||||||
|
// test for invalid issuer response
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
accessToken, err = h.signer.GenerateJwt("a", "a", nil, 15*time.Minute, auth.AccessTokenClaims{
|
||||||
|
Perms: claims.NewPermStorage(),
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
h.verifyHandler(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user