mirror of
https://github.com/1f349/violet.git
synced 2024-11-21 10:51:40 +00:00
Add database reading to domains list
This commit is contained in:
parent
0e42e54f08
commit
6d83d4c860
5
.idea/codeStyles/codeStyleConfig.xml
Normal file
5
.idea/codeStyles/codeStyleConfig.xml
Normal file
@ -0,0 +1,5 @@
|
||||
<component name="ProjectCodeStyleConfiguration">
|
||||
<state>
|
||||
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
|
||||
</state>
|
||||
</component>
|
@ -50,10 +50,10 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
allowedDomains := domains.New()
|
||||
allowedDomains := domains.New(db)
|
||||
reverseProxy := proxy.CreateHybridReverseProxy()
|
||||
r := router.New(reverseProxy)
|
||||
|
||||
servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{})
|
||||
servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{allowedDomains})
|
||||
servers.NewHttpServer(*httpListen, 0, allowedDomains)
|
||||
}
|
||||
|
@ -1,31 +1,39 @@
|
||||
package domains
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Domains struct {
|
||||
s *sync.RWMutex
|
||||
m map[string]struct{}
|
||||
db *sql.DB
|
||||
s *sync.RWMutex
|
||||
m map[string]struct{}
|
||||
}
|
||||
|
||||
func New() *Domains {
|
||||
func New(db *sql.DB) *Domains {
|
||||
return &Domains{
|
||||
s: &sync.RWMutex{},
|
||||
m: make(map[string]struct{}),
|
||||
db: db,
|
||||
s: &sync.RWMutex{},
|
||||
m: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Domains) IsValid(host string) bool {
|
||||
// remove the port
|
||||
domain, ok := utils.GetDomainWithoutPort(host)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// read lock for safety
|
||||
d.s.RLock()
|
||||
defer d.s.RUnlock()
|
||||
|
||||
// check root domains `www.example.com`, `example.com`, `com`
|
||||
n := strings.Split(domain, ".")
|
||||
for i := 0; i < len(n); i++ {
|
||||
if _, ok := d.m[strings.Join(n[i:], ".")]; ok {
|
||||
@ -34,3 +42,43 @@ func (d *Domains) IsValid(host string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *Domains) Compile() {
|
||||
// async compile magic
|
||||
go func() {
|
||||
domainMap := make(map[string]struct{})
|
||||
err := d.internalCompile(domainMap)
|
||||
if err != nil {
|
||||
log.Printf("[Domains] Compile failed: %s\n", err)
|
||||
return
|
||||
}
|
||||
// lock while replacing the map
|
||||
d.s.Lock()
|
||||
d.m = domainMap
|
||||
d.s.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
func (d *Domains) internalCompile(m map[string]struct{}) error {
|
||||
log.Println("[Domains] Updating domains from database")
|
||||
|
||||
// sql or something?
|
||||
rows, err := d.db.Query("select name from domains where enabled = true")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// loop through rows and scan the allowed domain names
|
||||
for rows.Next() {
|
||||
var name string
|
||||
err = rows.Scan(&name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m[name] = struct{}{}
|
||||
}
|
||||
|
||||
// check for errors
|
||||
return rows.Err()
|
||||
}
|
||||
|
@ -10,13 +10,14 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewApiServer creates and runs a *http.Server containing all the API endpoints for the software
|
||||
// NewApiServer creates and runs a http server containing all the API
|
||||
// endpoints for the software
|
||||
//
|
||||
// `/compile` - reloads all domains, routes and redirects from the configuration files
|
||||
// `/compile` - reloads all domains, routes and redirects
|
||||
func NewApiServer(listen string, verify mjwt.Provider, compileTarget utils.MultiCompilable) *http.Server {
|
||||
r := httprouter.New()
|
||||
|
||||
// Endpoint `/compile` reloads all domains, routes and redirects from the configuration files
|
||||
// Endpoint for compile action
|
||||
r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
||||
// Get bearer token
|
||||
bearer := utils.GetBearer(req)
|
||||
|
@ -5,22 +5,69 @@ import (
|
||||
"github.com/MrMelon54/violet/domains"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewHttpServer(listen string, httpsPort uint16, domainCheck *domains.Domains) *http.Server {
|
||||
// NewHttpServer creates and runs a http server containing the public http
|
||||
// endpoints for the reverse proxy.
|
||||
//
|
||||
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
|
||||
// acme challenges, this is used for Lets Encrypt HTTP verification.
|
||||
func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains) *http.Server {
|
||||
r := httprouter.New()
|
||||
r.GET("/.well-known/acme-challenge/{token}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
if hostname, ok := utils.GetDomainWithoutPort(req.Host); ok {
|
||||
var secureExtend string
|
||||
if httpsPort != 443 {
|
||||
secureExtend = fmt.Sprintf(":%d", httpsPort)
|
||||
}
|
||||
|
||||
// Endpoint for acme challenge outputs
|
||||
r.GET("/.well-known/acme-challenge/{key}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
if h, ok := utils.GetDomainWithoutPort(req.Host); ok {
|
||||
// check if the host is valid
|
||||
if !domainCheck.IsValid(req.Host) {
|
||||
http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420)
|
||||
return
|
||||
}
|
||||
if tokenValue := params.ByName("token"); tokenValue != "" {
|
||||
|
||||
// check if the key is valid
|
||||
key := params.ByName("key")
|
||||
if key == "" {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
|
||||
// All other paths lead here and are forwarded to HTTPS
|
||||
r.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if h, ok := utils.GetDomainWithoutPort(req.Host); ok {
|
||||
u := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: h + secureExtend,
|
||||
Path: req.URL.Path,
|
||||
RawPath: req.URL.RawPath,
|
||||
RawQuery: req.URL.RawQuery,
|
||||
}
|
||||
utils.FastRedirect(rw, req, u.String(), http.StatusPermanentRedirect)
|
||||
}
|
||||
})
|
||||
|
||||
// Create and run http server
|
||||
s := &http.Server{
|
||||
Addr: listen,
|
||||
Handler: r,
|
||||
ReadTimeout: time.Minute,
|
||||
ReadHeaderTimeout: time.Minute,
|
||||
WriteTimeout: time.Minute,
|
||||
IdleTimeout: time.Minute,
|
||||
MaxHeaderBytes: 2500,
|
||||
}
|
||||
log.Printf("[HTTP] Starting HTTP server on: '%s'\n", s.Addr)
|
||||
go utils.RunBackgroundHttp("HTTP", s)
|
||||
return s
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user