diff --git a/http-dev.go b/http-dev.go index b679ccc..af34ed4 100644 --- a/http-dev.go +++ b/http-dev.go @@ -9,14 +9,13 @@ import ( "time" ) -var listenFlag, targetFlag, hostFlag string -var allowCors bool +var listenFlag, targetFlag, hostFlag, allowCors string func main() { flag.StringVar(&listenFlag, "listen", ":8080", "Listen address") flag.StringVar(&targetFlag, "target", "", "Set target URL to forward to") flag.StringVar(&hostFlag, "host", "", "Set the host value") - flag.BoolVar(&allowCors, "cors", false, "Allow all cross origin requests") + flag.StringVar(&allowCors, "cors", "", "Set cross origin domain or '*' for all origins") flag.Parse() target, err := url.Parse(targetFlag) @@ -33,20 +32,18 @@ func main() { r.Out.Header = r.In.Header.Clone() }, ModifyResponse: func(resp *http.Response) error { - if !allowCors { - return nil - } + if allowCors != "" { + resp.Header.Set("Access-Control-Allow-Origin", allowCors) + resp.Header.Set("Access-Control-Allow-Credentials", "true") + resp.Header.Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") + resp.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, HEAD, DELETE") - resp.Header.Set("Access-Control-Allow-Origin", "*") - resp.Header.Set("Access-Control-Allow-Credentials", "true") - resp.Header.Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") - resp.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") - - if resp.Request.Method == http.MethodOptions { - resp.Header.Set("Content-Type", "text/plain; charset=utf-8") - resp.Header.Set("X-Content-Type-Options", "nosniff") - resp.StatusCode = http.StatusNoContent - resp.Status = http.StatusText(http.StatusNoContent) + if resp.Request.Method == http.MethodOptions { + resp.Header.Set("Content-Type", "text/plain; charset=utf-8") + resp.Header.Set("X-Content-Type-Options", "nosniff") + resp.StatusCode = http.StatusNoContent + resp.Status = http.StatusText(http.StatusNoContent) + } } return nil }, @@ -66,7 +63,7 @@ func main() { if hostFlag != "" { log.Printf("- Rewriting host to '%s'\n", hostFlag) } - if allowCors { + if allowCors != "" { log.Println("- Allowing all cross-origin requests") } if err := s.ListenAndServe(); err != nil {