Allow specific cors origin value

This commit is contained in:
Melon 2024-07-14 14:01:08 +01:00
parent 852fc9b872
commit a5ffa91443
Signed by: melon
GPG Key ID: 6C9D970C50D26A25

View File

@ -9,14 +9,13 @@ import (
"time" "time"
) )
var listenFlag, targetFlag, hostFlag string var listenFlag, targetFlag, hostFlag, allowCors string
var allowCors bool
func main() { func main() {
flag.StringVar(&listenFlag, "listen", ":8080", "Listen address") flag.StringVar(&listenFlag, "listen", ":8080", "Listen address")
flag.StringVar(&targetFlag, "target", "", "Set target URL to forward to") flag.StringVar(&targetFlag, "target", "", "Set target URL to forward to")
flag.StringVar(&hostFlag, "host", "", "Set the host value") flag.StringVar(&hostFlag, "host", "", "Set the host value")
flag.BoolVar(&allowCors, "cors", false, "Allow all cross origin requests") flag.StringVar(&allowCors, "cors", "", "Set cross origin domain or '*' for all origins")
flag.Parse() flag.Parse()
target, err := url.Parse(targetFlag) target, err := url.Parse(targetFlag)
@ -33,14 +32,11 @@ func main() {
r.Out.Header = r.In.Header.Clone() r.Out.Header = r.In.Header.Clone()
}, },
ModifyResponse: func(resp *http.Response) error { ModifyResponse: func(resp *http.Response) error {
if !allowCors { if allowCors != "" {
return nil resp.Header.Set("Access-Control-Allow-Origin", allowCors)
}
resp.Header.Set("Access-Control-Allow-Origin", "*")
resp.Header.Set("Access-Control-Allow-Credentials", "true") resp.Header.Set("Access-Control-Allow-Credentials", "true")
resp.Header.Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") resp.Header.Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
resp.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") resp.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, HEAD, DELETE")
if resp.Request.Method == http.MethodOptions { if resp.Request.Method == http.MethodOptions {
resp.Header.Set("Content-Type", "text/plain; charset=utf-8") resp.Header.Set("Content-Type", "text/plain; charset=utf-8")
@ -48,6 +44,7 @@ func main() {
resp.StatusCode = http.StatusNoContent resp.StatusCode = http.StatusNoContent
resp.Status = http.StatusText(http.StatusNoContent) resp.Status = http.StatusText(http.StatusNoContent)
} }
}
return nil return nil
}, },
} }
@ -66,7 +63,7 @@ func main() {
if hostFlag != "" { if hostFlag != "" {
log.Printf("- Rewriting host to '%s'\n", hostFlag) log.Printf("- Rewriting host to '%s'\n", hostFlag)
} }
if allowCors { if allowCors != "" {
log.Println("- Allowing all cross-origin requests") log.Println("- Allowing all cross-origin requests")
} }
if err := s.ListenAndServe(); err != nil { if err := s.ListenAndServe(); err != nil {