Add database reading to domains list

This commit is contained in:
Melon 2023-04-21 15:49:01 +01:00
parent 0e42e54f08
commit 6d83d4c860
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
5 changed files with 115 additions and 14 deletions

View File

@ -0,0 +1,5 @@
<component name="ProjectCodeStyleConfiguration">
<state>
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
</state>
</component>

View File

@ -50,10 +50,10 @@ func main() {
} }
} }
allowedDomains := domains.New() allowedDomains := domains.New(db)
reverseProxy := proxy.CreateHybridReverseProxy() reverseProxy := proxy.CreateHybridReverseProxy()
r := router.New(reverseProxy) r := router.New(reverseProxy)
servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{}) servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{allowedDomains})
servers.NewHttpServer(*httpListen, 0, allowedDomains) servers.NewHttpServer(*httpListen, 0, allowedDomains)
} }

View File

@ -1,31 +1,39 @@
package domains package domains
import ( import (
"database/sql"
"github.com/MrMelon54/violet/utils" "github.com/MrMelon54/violet/utils"
"log"
"strings" "strings"
"sync" "sync"
) )
type Domains struct { type Domains struct {
s *sync.RWMutex db *sql.DB
m map[string]struct{} s *sync.RWMutex
m map[string]struct{}
} }
func New() *Domains { func New(db *sql.DB) *Domains {
return &Domains{ return &Domains{
s: &sync.RWMutex{}, db: db,
m: make(map[string]struct{}), s: &sync.RWMutex{},
m: make(map[string]struct{}),
} }
} }
func (d *Domains) IsValid(host string) bool { func (d *Domains) IsValid(host string) bool {
// remove the port
domain, ok := utils.GetDomainWithoutPort(host) domain, ok := utils.GetDomainWithoutPort(host)
if !ok { if !ok {
return false return false
} }
// read lock for safety
d.s.RLock() d.s.RLock()
defer d.s.RUnlock() defer d.s.RUnlock()
// check root domains `www.example.com`, `example.com`, `com`
n := strings.Split(domain, ".") n := strings.Split(domain, ".")
for i := 0; i < len(n); i++ { for i := 0; i < len(n); i++ {
if _, ok := d.m[strings.Join(n[i:], ".")]; ok { if _, ok := d.m[strings.Join(n[i:], ".")]; ok {
@ -34,3 +42,43 @@ func (d *Domains) IsValid(host string) bool {
} }
return false return false
} }
func (d *Domains) Compile() {
// async compile magic
go func() {
domainMap := make(map[string]struct{})
err := d.internalCompile(domainMap)
if err != nil {
log.Printf("[Domains] Compile failed: %s\n", err)
return
}
// lock while replacing the map
d.s.Lock()
d.m = domainMap
d.s.Unlock()
}()
}
func (d *Domains) internalCompile(m map[string]struct{}) error {
log.Println("[Domains] Updating domains from database")
// sql or something?
rows, err := d.db.Query("select name from domains where enabled = true")
if err != nil {
return err
}
defer rows.Close()
// loop through rows and scan the allowed domain names
for rows.Next() {
var name string
err = rows.Scan(&name)
if err != nil {
return err
}
m[name] = struct{}{}
}
// check for errors
return rows.Err()
}

View File

@ -10,13 +10,14 @@ import (
"time" "time"
) )
// NewApiServer creates and runs a *http.Server containing all the API endpoints for the software // NewApiServer creates and runs a http server containing all the API
// endpoints for the software
// //
// `/compile` - reloads all domains, routes and redirects from the configuration files // `/compile` - reloads all domains, routes and redirects
func NewApiServer(listen string, verify mjwt.Provider, compileTarget utils.MultiCompilable) *http.Server { func NewApiServer(listen string, verify mjwt.Provider, compileTarget utils.MultiCompilable) *http.Server {
r := httprouter.New() r := httprouter.New()
// Endpoint `/compile` reloads all domains, routes and redirects from the configuration files // Endpoint for compile action
r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
// Get bearer token // Get bearer token
bearer := utils.GetBearer(req) bearer := utils.GetBearer(req)

View File

@ -5,22 +5,69 @@ import (
"github.com/MrMelon54/violet/domains" "github.com/MrMelon54/violet/domains"
"github.com/MrMelon54/violet/utils" "github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"log"
"net/http" "net/http"
"net/url"
"time"
) )
func NewHttpServer(listen string, httpsPort uint16, domainCheck *domains.Domains) *http.Server { // NewHttpServer creates and runs a http server containing the public http
// endpoints for the reverse proxy.
//
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
// acme challenges, this is used for Lets Encrypt HTTP verification.
func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains) *http.Server {
r := httprouter.New() r := httprouter.New()
r.GET("/.well-known/acme-challenge/{token}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { var secureExtend string
if hostname, ok := utils.GetDomainWithoutPort(req.Host); ok { if httpsPort != 443 {
secureExtend = fmt.Sprintf(":%d", httpsPort)
}
// Endpoint for acme challenge outputs
r.GET("/.well-known/acme-challenge/{key}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if h, ok := utils.GetDomainWithoutPort(req.Host); ok {
// check if the host is valid
if !domainCheck.IsValid(req.Host) { if !domainCheck.IsValid(req.Host) {
http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420) http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420)
return return
} }
if tokenValue := params.ByName("token"); tokenValue != "" {
// check if the key is valid
key := params.ByName("key")
if key == "" {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
return return
} }
} }
rw.WriteHeader(http.StatusNotFound) rw.WriteHeader(http.StatusNotFound)
}) })
// All other paths lead here and are forwarded to HTTPS
r.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if h, ok := utils.GetDomainWithoutPort(req.Host); ok {
u := &url.URL{
Scheme: "https",
Host: h + secureExtend,
Path: req.URL.Path,
RawPath: req.URL.RawPath,
RawQuery: req.URL.RawQuery,
}
utils.FastRedirect(rw, req, u.String(), http.StatusPermanentRedirect)
}
})
// Create and run http server
s := &http.Server{
Addr: listen,
Handler: r,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,
WriteTimeout: time.Minute,
IdleTimeout: time.Minute,
MaxHeaderBytes: 2500,
}
log.Printf("[HTTP] Starting HTTP server on: '%s'\n", s.Addr)
go utils.RunBackgroundHttp("HTTP", s)
return s
} }