Add loads of test cases

This commit is contained in:
Melon 2023-10-08 15:24:59 +01:00
parent 01d03fef9d
commit bf39f4421b
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
13 changed files with 131 additions and 61 deletions

View File

@ -1,15 +0,0 @@
package main
import (
loginServiceManager "github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/utils"
)
type startUpConfig struct {
Listen string `json:"listen"`
BaseUrl string `json:"base_url"`
ServiceName string `json:"service_name"`
Issuer string `json:"issuer"`
SsoServices []loginServiceManager.SsoConfig `json:"sso_services"`
AllowedClients []utils.JsonUrl `json:"allowed_clients"`
}

View File

@ -5,7 +5,6 @@ import (
"crypto/rand" "crypto/rand"
"encoding/json" "encoding/json"
"flag" "flag"
"github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/server" "github.com/1f349/lavender/server"
"github.com/1f349/violet/utils" "github.com/1f349/violet/utils"
exit_reload "github.com/MrMelon54/exit-reload" exit_reload "github.com/MrMelon54/exit-reload"
@ -50,7 +49,7 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
return subcommands.ExitFailure return subcommands.ExitFailure
} }
var config startUpConfig var config server.Conf
err = json.NewDecoder(openConf).Decode(&config) err = json.NewDecoder(openConf).Decode(&config)
if err != nil { if err != nil {
log.Println("[Lavender] Error: invalid config file: ", err) log.Println("[Lavender] Error: invalid config file: ", err)
@ -66,18 +65,13 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
return subcommands.ExitSuccess return subcommands.ExitSuccess
} }
func normalLoad(startUp startUpConfig, wd string) { func normalLoad(startUp server.Conf, wd string) {
mSign, err := mjwt.NewMJwtSignerFromFileOrCreate(startUp.Issuer, filepath.Join(wd, "lavender.private.key"), rand.Reader, 4096) mSign, err := mjwt.NewMJwtSignerFromFileOrCreate(startUp.Issuer, filepath.Join(wd, "lavender.private.key"), rand.Reader, 4096)
if err != nil { if err != nil {
log.Fatal("[Lavender] Failed to load or create MJWT signer:", err) log.Fatal("[Lavender] Failed to load or create MJWT signer:", err)
} }
manager, err := issuer.NewManager(startUp.SsoServices) srv := server.NewHttpServer(startUp, mSign)
if err != nil {
log.Fatal("[Lavender] Failed to create SSO service manager: ", err)
}
srv := server.NewHttpServer(startUp.Listen, startUp.BaseUrl, startUp.ServiceName, startUp.AllowedClients, manager, mSign)
log.Printf("[Lavender] Starting HTTP server on '%s'\n", srv.Addr) log.Printf("[Lavender] Starting HTTP server on '%s'\n", srv.Addr)
go utils.RunBackgroundHttp("HTTP", srv) go utils.RunBackgroundHttp("HTTP", srv)

View File

@ -22,7 +22,7 @@ func testBody() io.ReadCloser {
return io.NopCloser(strings.NewReader("{}")) return io.NopCloser(strings.NewReader("{}"))
} }
func TestManager_CheckIssuer(t *testing.T) { func TestManager_CheckNamespace(t *testing.T) {
httpGet = func(url string) (resp *http.Response, err error) { httpGet = func(url string) (resp *http.Response, err error) {
return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil
} }

1
issuer/sso_test.go Normal file
View File

@ -0,0 +1 @@
package issuer

15
server/conf.go Normal file
View File

@ -0,0 +1,15 @@
package server
import (
"github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/utils"
)
type Conf struct {
Listen string `json:"listen"`
BaseUrl string `json:"base_url"`
ServiceName string `json:"service_name"`
Issuer string `json:"issuer"`
SsoServices []issuer.SsoConfig `json:"sso_services"`
AllowedClients []utils.JsonUrl `json:"allowed_clients"`
}

View File

@ -18,16 +18,6 @@ import (
"time" "time"
) )
var (
//go:embed flow-popup.go.html
flowPopupHtml string
flowPopupTemplate *template.Template
//go:embed flow-callback.go.html
flowCallbackHtml string
flowCallbackTemplate *template.Template
)
func init() { func init() {
pageParse, err := template.New("pages").Parse(flowPopupHtml) pageParse, err := template.New("pages").Parse(flowPopupHtml)
if err != nil { if err != nil {
@ -44,7 +34,7 @@ func init() {
func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
err := flowPopupTemplate.Execute(rw, map[string]any{ err := flowPopupTemplate.Execute(rw, map[string]any{
"ServiceName": h.serviceName, "ServiceName": h.conf.ServiceName,
"Origin": req.URL.Query().Get("origin"), "Origin": req.URL.Query().Get("origin"),
}) })
if err != nil { if err != nil {
@ -75,7 +65,7 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _
// generate oauth2 config and redirect to authorize URL // generate oauth2 config and redirect to authorize URL
oa2conf := login.OAuth2Config oa2conf := login.OAuth2Config
oa2conf.RedirectURL = h.baseUrl + "/callback" oa2conf.RedirectURL = h.conf.BaseUrl + "/callback"
nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", loginName)) nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", loginName))
http.Redirect(rw, req, nextUrl, http.StatusFound) http.Redirect(rw, req, nextUrl, http.StatusFound)
} }
@ -101,7 +91,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
} }
oa2conf := v.sso.OAuth2Config oa2conf := v.sso.OAuth2Config
oa2conf.RedirectURL = h.baseUrl + "/callback" oa2conf.RedirectURL = h.conf.BaseUrl + "/callback"
exchange, err := oa2conf.Exchange(context.Background(), q.Get("code")) exchange, err := oa2conf.Exchange(context.Background(), q.Get("code"))
if err != nil { if err != nil {
fmt.Println("Failed exchange:", err) fmt.Println("Failed exchange:", err)
@ -157,7 +147,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
} }
_ = flowCallbackTemplate.Execute(rw, map[string]any{ _ = flowCallbackTemplate.Execute(rw, map[string]any{
"ServiceName": h.serviceName, "ServiceName": h.conf.ServiceName,
"TargetOrigin": v.targetOrigin, "TargetOrigin": v.targetOrigin,
"TargetMessage": v3, "TargetMessage": v3,
"AccessToken": accessToken, "AccessToken": accessToken,

1
server/flow_test.go Normal file
View File

@ -0,0 +1 @@
package server

20
server/pages/pages.go Normal file
View File

@ -0,0 +1,20 @@
package pages
import (
"embed"
_ "embed"
"html/template"
"os"
)
var (
//go:embed pages/*
flowPages embed.FS
flowTemplates *template.Template
)
func LoadPages(wd string) {
wdFs := os.DirFS(wd)
flowPages.Open()
}

View File

@ -4,21 +4,20 @@ import (
"fmt" "fmt"
"github.com/1f349/cache" "github.com/1f349/cache"
"github.com/1f349/lavender/issuer" "github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/utils"
"github.com/MrMelon54/mjwt" "github.com/MrMelon54/mjwt"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"log"
"net/http" "net/http"
"time" "time"
) )
type HttpServer struct { type HttpServer struct {
r *httprouter.Router r *httprouter.Router
baseUrl string conf Conf
serviceName string manager *issuer.Manager
manager *issuer.Manager signer mjwt.Signer
signer mjwt.Signer flowState *cache.Cache[string, flowStateData]
flowState *cache.Cache[string, flowStateData] services map[string]struct{}
services map[string]struct{}
} }
type flowStateData struct { type flowStateData struct {
@ -26,30 +25,34 @@ type flowStateData struct {
targetOrigin string targetOrigin string
} }
func NewHttpServer(listen, baseUrl, serviceName string, clients []utils.JsonUrl, manager *issuer.Manager, signer mjwt.Signer) *http.Server { func NewHttpServer(conf Conf, signer mjwt.Signer) *http.Server {
r := httprouter.New() r := httprouter.New()
// remove last slash from baseUrl // remove last slash from baseUrl
{ {
l := len(baseUrl) l := len(conf.BaseUrl)
if baseUrl[l-1] == '/' { if conf.BaseUrl[l-1] == '/' {
baseUrl = baseUrl[:l-1] conf.BaseUrl = conf.BaseUrl[:l-1]
} }
} }
manager, err := issuer.NewManager(conf.SsoServices)
if err != nil {
log.Fatal("[Lavender] Failed to create SSO service manager: ", err)
}
services := make(map[string]struct{}) services := make(map[string]struct{})
for _, i := range clients { for _, i := range conf.AllowedClients {
services[i.String()] = struct{}{} services[i.String()] = struct{}{}
} }
hs := &HttpServer{ hs := &HttpServer{
r: r, r: r,
baseUrl: baseUrl, conf: conf,
serviceName: serviceName, manager: manager,
manager: manager, signer: signer,
signer: signer, flowState: cache.New[string, flowStateData](),
flowState: cache.New[string, flowStateData](), services: services,
services: services,
} }
r.GET("/", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { r.GET("/", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
@ -62,7 +65,7 @@ func NewHttpServer(listen, baseUrl, serviceName string, clients []utils.JsonUrl,
r.GET("/callback", hs.flowCallback) r.GET("/callback", hs.flowCallback)
return &http.Server{ return &http.Server{
Addr: listen, Addr: conf.Listen,
Handler: r, Handler: r,
ReadTimeout: time.Minute, ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute, ReadHeaderTimeout: time.Minute,

View File

@ -24,7 +24,7 @@ func (h *HttpServer) verifyHandler(rw http.ResponseWriter, req *http.Request, _
} }
// check issuer against config // check issuer against config
if b.Issuer != h.baseUrl { if b.Issuer != h.conf.Issuer {
http.Error(rw, "Invalid issuer", http.StatusBadRequest) http.Error(rw, "Invalid issuer", http.StatusBadRequest)
return return
} }

61
server/verify_test.go Normal file
View File

@ -0,0 +1,61 @@
package server
import (
"crypto/rand"
"crypto/rsa"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/auth"
"github.com/MrMelon54/mjwt/claims"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestVerifyHandler(t *testing.T) {
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
invalidSigner := mjwt.NewMJwtSigner("Invalid Issuer", privKey)
h := HttpServer{
conf: Conf{Issuer: "Test Issuer"},
signer: mjwt.NewMJwtSigner("Test Issuer", privKey),
}
// test for missing bearer response
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "https://example.localhost", nil)
h.verifyHandler(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusForbidden, rec.Code)
assert.Equal(t, "Missing bearer\n", rec.Body.String())
// test for invalid token response
rec = httptest.NewRecorder()
req.Header.Set("Authorization", "Bearer abcd")
h.verifyHandler(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusForbidden, rec.Code)
assert.Equal(t, "Invalid token\n", rec.Body.String())
// test for invalid issuer response
rec = httptest.NewRecorder()
accessToken, err := invalidSigner.GenerateJwt("a", "a", nil, 15*time.Minute, auth.AccessTokenClaims{
Perms: claims.NewPermStorage(),
})
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+accessToken)
h.verifyHandler(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, "Invalid issuer\n", rec.Body.String())
// test for invalid issuer response
rec = httptest.NewRecorder()
accessToken, err = h.signer.GenerateJwt("a", "a", nil, 15*time.Minute, auth.AccessTokenClaims{
Perms: claims.NewPermStorage(),
})
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+accessToken)
h.verifyHandler(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusOK, rec.Code)
}