Add database reading to domains list

This commit is contained in:
Melon 2023-04-21 15:49:01 +01:00
parent 0e42e54f08
commit 6d83d4c860
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
5 changed files with 115 additions and 14 deletions

View File

@ -0,0 +1,5 @@
<component name="ProjectCodeStyleConfiguration">
<state>
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
</state>
</component>

View File

@ -50,10 +50,10 @@ func main() {
}
}
allowedDomains := domains.New()
allowedDomains := domains.New(db)
reverseProxy := proxy.CreateHybridReverseProxy()
r := router.New(reverseProxy)
servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{})
servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{allowedDomains})
servers.NewHttpServer(*httpListen, 0, allowedDomains)
}

View File

@ -1,31 +1,39 @@
package domains
import (
"database/sql"
"github.com/MrMelon54/violet/utils"
"log"
"strings"
"sync"
)
type Domains struct {
db *sql.DB
s *sync.RWMutex
m map[string]struct{}
}
func New() *Domains {
func New(db *sql.DB) *Domains {
return &Domains{
db: db,
s: &sync.RWMutex{},
m: make(map[string]struct{}),
}
}
func (d *Domains) IsValid(host string) bool {
// remove the port
domain, ok := utils.GetDomainWithoutPort(host)
if !ok {
return false
}
// read lock for safety
d.s.RLock()
defer d.s.RUnlock()
// check root domains `www.example.com`, `example.com`, `com`
n := strings.Split(domain, ".")
for i := 0; i < len(n); i++ {
if _, ok := d.m[strings.Join(n[i:], ".")]; ok {
@ -34,3 +42,43 @@ func (d *Domains) IsValid(host string) bool {
}
return false
}
func (d *Domains) Compile() {
// async compile magic
go func() {
domainMap := make(map[string]struct{})
err := d.internalCompile(domainMap)
if err != nil {
log.Printf("[Domains] Compile failed: %s\n", err)
return
}
// lock while replacing the map
d.s.Lock()
d.m = domainMap
d.s.Unlock()
}()
}
func (d *Domains) internalCompile(m map[string]struct{}) error {
log.Println("[Domains] Updating domains from database")
// sql or something?
rows, err := d.db.Query("select name from domains where enabled = true")
if err != nil {
return err
}
defer rows.Close()
// loop through rows and scan the allowed domain names
for rows.Next() {
var name string
err = rows.Scan(&name)
if err != nil {
return err
}
m[name] = struct{}{}
}
// check for errors
return rows.Err()
}

View File

@ -10,13 +10,14 @@ import (
"time"
)
// NewApiServer creates and runs a *http.Server containing all the API endpoints for the software
// NewApiServer creates and runs a http server containing all the API
// endpoints for the software
//
// `/compile` - reloads all domains, routes and redirects from the configuration files
// `/compile` - reloads all domains, routes and redirects
func NewApiServer(listen string, verify mjwt.Provider, compileTarget utils.MultiCompilable) *http.Server {
r := httprouter.New()
// Endpoint `/compile` reloads all domains, routes and redirects from the configuration files
// Endpoint for compile action
r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
// Get bearer token
bearer := utils.GetBearer(req)

View File

@ -5,22 +5,69 @@ import (
"github.com/MrMelon54/violet/domains"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"log"
"net/http"
"net/url"
"time"
)
func NewHttpServer(listen string, httpsPort uint16, domainCheck *domains.Domains) *http.Server {
// NewHttpServer creates and runs a http server containing the public http
// endpoints for the reverse proxy.
//
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
// acme challenges, this is used for Lets Encrypt HTTP verification.
func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains) *http.Server {
r := httprouter.New()
r.GET("/.well-known/acme-challenge/{token}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if hostname, ok := utils.GetDomainWithoutPort(req.Host); ok {
var secureExtend string
if httpsPort != 443 {
secureExtend = fmt.Sprintf(":%d", httpsPort)
}
// Endpoint for acme challenge outputs
r.GET("/.well-known/acme-challenge/{key}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if h, ok := utils.GetDomainWithoutPort(req.Host); ok {
// check if the host is valid
if !domainCheck.IsValid(req.Host) {
http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420)
return
}
if tokenValue := params.ByName("token"); tokenValue != "" {
// check if the key is valid
key := params.ByName("key")
if key == "" {
rw.WriteHeader(http.StatusOK)
return
}
}
rw.WriteHeader(http.StatusNotFound)
})
// All other paths lead here and are forwarded to HTTPS
r.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if h, ok := utils.GetDomainWithoutPort(req.Host); ok {
u := &url.URL{
Scheme: "https",
Host: h + secureExtend,
Path: req.URL.Path,
RawPath: req.URL.RawPath,
RawQuery: req.URL.RawQuery,
}
utils.FastRedirect(rw, req, u.String(), http.StatusPermanentRedirect)
}
})
// Create and run http server
s := &http.Server{
Addr: listen,
Handler: r,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,
WriteTimeout: time.Minute,
IdleTimeout: time.Minute,
MaxHeaderBytes: 2500,
}
log.Printf("[HTTP] Starting HTTP server on: '%s'\n", s.Addr)
go utils.RunBackgroundHttp("HTTP", s)
return s
}