diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml new file mode 100644 index 0000000..a55e7a1 --- /dev/null +++ b/.idea/codeStyles/codeStyleConfig.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/cmd/violet/main.go b/cmd/violet/main.go index 264545f..e1bb0a8 100644 --- a/cmd/violet/main.go +++ b/cmd/violet/main.go @@ -50,10 +50,10 @@ func main() { } } - allowedDomains := domains.New() + allowedDomains := domains.New(db) reverseProxy := proxy.CreateHybridReverseProxy() r := router.New(reverseProxy) - servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{}) + servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{allowedDomains}) servers.NewHttpServer(*httpListen, 0, allowedDomains) } diff --git a/domains/domains.go b/domains/domains.go index 5029276..52dbd7f 100644 --- a/domains/domains.go +++ b/domains/domains.go @@ -1,31 +1,39 @@ package domains import ( + "database/sql" "github.com/MrMelon54/violet/utils" + "log" "strings" "sync" ) type Domains struct { - s *sync.RWMutex - m map[string]struct{} + db *sql.DB + s *sync.RWMutex + m map[string]struct{} } -func New() *Domains { +func New(db *sql.DB) *Domains { return &Domains{ - s: &sync.RWMutex{}, - m: make(map[string]struct{}), + db: db, + s: &sync.RWMutex{}, + m: make(map[string]struct{}), } } func (d *Domains) IsValid(host string) bool { + // remove the port domain, ok := utils.GetDomainWithoutPort(host) if !ok { return false } + + // read lock for safety d.s.RLock() defer d.s.RUnlock() + // check root domains `www.example.com`, `example.com`, `com` n := strings.Split(domain, ".") for i := 0; i < len(n); i++ { if _, ok := d.m[strings.Join(n[i:], ".")]; ok { @@ -34,3 +42,43 @@ func (d *Domains) IsValid(host string) bool { } return false } + +func (d *Domains) Compile() { + // async compile magic + go func() { + domainMap := make(map[string]struct{}) + err := d.internalCompile(domainMap) + if err != nil { + log.Printf("[Domains] Compile failed: %s\n", err) + return + } + // lock while replacing the map + d.s.Lock() + d.m = domainMap + d.s.Unlock() + }() +} + +func (d *Domains) internalCompile(m map[string]struct{}) error { + log.Println("[Domains] Updating domains from database") + + // sql or something? + rows, err := d.db.Query("select name from domains where enabled = true") + if err != nil { + return err + } + defer rows.Close() + + // loop through rows and scan the allowed domain names + for rows.Next() { + var name string + err = rows.Scan(&name) + if err != nil { + return err + } + m[name] = struct{}{} + } + + // check for errors + return rows.Err() +} diff --git a/servers/api.go b/servers/api.go index 7ef4e89..7c66935 100644 --- a/servers/api.go +++ b/servers/api.go @@ -10,13 +10,14 @@ import ( "time" ) -// NewApiServer creates and runs a *http.Server containing all the API endpoints for the software +// NewApiServer creates and runs a http server containing all the API +// endpoints for the software // -// `/compile` - reloads all domains, routes and redirects from the configuration files +// `/compile` - reloads all domains, routes and redirects func NewApiServer(listen string, verify mjwt.Provider, compileTarget utils.MultiCompilable) *http.Server { r := httprouter.New() - // Endpoint `/compile` reloads all domains, routes and redirects from the configuration files + // Endpoint for compile action r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { // Get bearer token bearer := utils.GetBearer(req) diff --git a/servers/http.go b/servers/http.go index 23e522d..3aa9c4c 100644 --- a/servers/http.go +++ b/servers/http.go @@ -5,22 +5,69 @@ import ( "github.com/MrMelon54/violet/domains" "github.com/MrMelon54/violet/utils" "github.com/julienschmidt/httprouter" + "log" "net/http" + "net/url" + "time" ) -func NewHttpServer(listen string, httpsPort uint16, domainCheck *domains.Domains) *http.Server { +// NewHttpServer creates and runs a http server containing the public http +// endpoints for the reverse proxy. +// +// `/.well-known/acme-challenge/{token}` is used for outputting answers for +// acme challenges, this is used for Lets Encrypt HTTP verification. +func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains) *http.Server { r := httprouter.New() - r.GET("/.well-known/acme-challenge/{token}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { - if hostname, ok := utils.GetDomainWithoutPort(req.Host); ok { + var secureExtend string + if httpsPort != 443 { + secureExtend = fmt.Sprintf(":%d", httpsPort) + } + + // Endpoint for acme challenge outputs + r.GET("/.well-known/acme-challenge/{key}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + if h, ok := utils.GetDomainWithoutPort(req.Host); ok { + // check if the host is valid if !domainCheck.IsValid(req.Host) { http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420) return } - if tokenValue := params.ByName("token"); tokenValue != "" { + + // check if the key is valid + key := params.ByName("key") + if key == "" { rw.WriteHeader(http.StatusOK) return } + } rw.WriteHeader(http.StatusNotFound) }) + + // All other paths lead here and are forwarded to HTTPS + r.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if h, ok := utils.GetDomainWithoutPort(req.Host); ok { + u := &url.URL{ + Scheme: "https", + Host: h + secureExtend, + Path: req.URL.Path, + RawPath: req.URL.RawPath, + RawQuery: req.URL.RawQuery, + } + utils.FastRedirect(rw, req, u.String(), http.StatusPermanentRedirect) + } + }) + + // Create and run http server + s := &http.Server{ + Addr: listen, + Handler: r, + ReadTimeout: time.Minute, + ReadHeaderTimeout: time.Minute, + WriteTimeout: time.Minute, + IdleTimeout: time.Minute, + MaxHeaderBytes: 2500, + } + log.Printf("[HTTP] Starting HTTP server on: '%s'\n", s.Addr) + go utils.RunBackgroundHttp("HTTP", s) + return s }