mirror of
https://github.com/1f349/violet.git
synced 2024-11-24 20:31:37 +00:00
Add database reading to domains list
This commit is contained in:
parent
0e42e54f08
commit
6d83d4c860
5
.idea/codeStyles/codeStyleConfig.xml
Normal file
5
.idea/codeStyles/codeStyleConfig.xml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
<component name="ProjectCodeStyleConfiguration">
|
||||||
|
<state>
|
||||||
|
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
|
||||||
|
</state>
|
||||||
|
</component>
|
@ -50,10 +50,10 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
allowedDomains := domains.New()
|
allowedDomains := domains.New(db)
|
||||||
reverseProxy := proxy.CreateHybridReverseProxy()
|
reverseProxy := proxy.CreateHybridReverseProxy()
|
||||||
r := router.New(reverseProxy)
|
r := router.New(reverseProxy)
|
||||||
|
|
||||||
servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{})
|
servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{allowedDomains})
|
||||||
servers.NewHttpServer(*httpListen, 0, allowedDomains)
|
servers.NewHttpServer(*httpListen, 0, allowedDomains)
|
||||||
}
|
}
|
||||||
|
@ -1,31 +1,39 @@
|
|||||||
package domains
|
package domains
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Domains struct {
|
type Domains struct {
|
||||||
s *sync.RWMutex
|
db *sql.DB
|
||||||
m map[string]struct{}
|
s *sync.RWMutex
|
||||||
|
m map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func New() *Domains {
|
func New(db *sql.DB) *Domains {
|
||||||
return &Domains{
|
return &Domains{
|
||||||
s: &sync.RWMutex{},
|
db: db,
|
||||||
m: make(map[string]struct{}),
|
s: &sync.RWMutex{},
|
||||||
|
m: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Domains) IsValid(host string) bool {
|
func (d *Domains) IsValid(host string) bool {
|
||||||
|
// remove the port
|
||||||
domain, ok := utils.GetDomainWithoutPort(host)
|
domain, ok := utils.GetDomainWithoutPort(host)
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// read lock for safety
|
||||||
d.s.RLock()
|
d.s.RLock()
|
||||||
defer d.s.RUnlock()
|
defer d.s.RUnlock()
|
||||||
|
|
||||||
|
// check root domains `www.example.com`, `example.com`, `com`
|
||||||
n := strings.Split(domain, ".")
|
n := strings.Split(domain, ".")
|
||||||
for i := 0; i < len(n); i++ {
|
for i := 0; i < len(n); i++ {
|
||||||
if _, ok := d.m[strings.Join(n[i:], ".")]; ok {
|
if _, ok := d.m[strings.Join(n[i:], ".")]; ok {
|
||||||
@ -34,3 +42,43 @@ func (d *Domains) IsValid(host string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Domains) Compile() {
|
||||||
|
// async compile magic
|
||||||
|
go func() {
|
||||||
|
domainMap := make(map[string]struct{})
|
||||||
|
err := d.internalCompile(domainMap)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Domains] Compile failed: %s\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// lock while replacing the map
|
||||||
|
d.s.Lock()
|
||||||
|
d.m = domainMap
|
||||||
|
d.s.Unlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Domains) internalCompile(m map[string]struct{}) error {
|
||||||
|
log.Println("[Domains] Updating domains from database")
|
||||||
|
|
||||||
|
// sql or something?
|
||||||
|
rows, err := d.db.Query("select name from domains where enabled = true")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
// loop through rows and scan the allowed domain names
|
||||||
|
for rows.Next() {
|
||||||
|
var name string
|
||||||
|
err = rows.Scan(&name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m[name] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check for errors
|
||||||
|
return rows.Err()
|
||||||
|
}
|
||||||
|
@ -10,13 +10,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewApiServer creates and runs a *http.Server containing all the API endpoints for the software
|
// NewApiServer creates and runs a http server containing all the API
|
||||||
|
// endpoints for the software
|
||||||
//
|
//
|
||||||
// `/compile` - reloads all domains, routes and redirects from the configuration files
|
// `/compile` - reloads all domains, routes and redirects
|
||||||
func NewApiServer(listen string, verify mjwt.Provider, compileTarget utils.MultiCompilable) *http.Server {
|
func NewApiServer(listen string, verify mjwt.Provider, compileTarget utils.MultiCompilable) *http.Server {
|
||||||
r := httprouter.New()
|
r := httprouter.New()
|
||||||
|
|
||||||
// Endpoint `/compile` reloads all domains, routes and redirects from the configuration files
|
// Endpoint for compile action
|
||||||
r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
||||||
// Get bearer token
|
// Get bearer token
|
||||||
bearer := utils.GetBearer(req)
|
bearer := utils.GetBearer(req)
|
||||||
|
@ -5,22 +5,69 @@ import (
|
|||||||
"github.com/MrMelon54/violet/domains"
|
"github.com/MrMelon54/violet/domains"
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewHttpServer(listen string, httpsPort uint16, domainCheck *domains.Domains) *http.Server {
|
// NewHttpServer creates and runs a http server containing the public http
|
||||||
|
// endpoints for the reverse proxy.
|
||||||
|
//
|
||||||
|
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
|
||||||
|
// acme challenges, this is used for Lets Encrypt HTTP verification.
|
||||||
|
func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains) *http.Server {
|
||||||
r := httprouter.New()
|
r := httprouter.New()
|
||||||
r.GET("/.well-known/acme-challenge/{token}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
var secureExtend string
|
||||||
if hostname, ok := utils.GetDomainWithoutPort(req.Host); ok {
|
if httpsPort != 443 {
|
||||||
|
secureExtend = fmt.Sprintf(":%d", httpsPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Endpoint for acme challenge outputs
|
||||||
|
r.GET("/.well-known/acme-challenge/{key}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||||
|
if h, ok := utils.GetDomainWithoutPort(req.Host); ok {
|
||||||
|
// check if the host is valid
|
||||||
if !domainCheck.IsValid(req.Host) {
|
if !domainCheck.IsValid(req.Host) {
|
||||||
http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420)
|
http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if tokenValue := params.ByName("token"); tokenValue != "" {
|
|
||||||
|
// check if the key is valid
|
||||||
|
key := params.ByName("key")
|
||||||
|
if key == "" {
|
||||||
rw.WriteHeader(http.StatusOK)
|
rw.WriteHeader(http.StatusOK)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
rw.WriteHeader(http.StatusNotFound)
|
rw.WriteHeader(http.StatusNotFound)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// All other paths lead here and are forwarded to HTTPS
|
||||||
|
r.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if h, ok := utils.GetDomainWithoutPort(req.Host); ok {
|
||||||
|
u := &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: h + secureExtend,
|
||||||
|
Path: req.URL.Path,
|
||||||
|
RawPath: req.URL.RawPath,
|
||||||
|
RawQuery: req.URL.RawQuery,
|
||||||
|
}
|
||||||
|
utils.FastRedirect(rw, req, u.String(), http.StatusPermanentRedirect)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create and run http server
|
||||||
|
s := &http.Server{
|
||||||
|
Addr: listen,
|
||||||
|
Handler: r,
|
||||||
|
ReadTimeout: time.Minute,
|
||||||
|
ReadHeaderTimeout: time.Minute,
|
||||||
|
WriteTimeout: time.Minute,
|
||||||
|
IdleTimeout: time.Minute,
|
||||||
|
MaxHeaderBytes: 2500,
|
||||||
|
}
|
||||||
|
log.Printf("[HTTP] Starting HTTP server on: '%s'\n", s.Addr)
|
||||||
|
go utils.RunBackgroundHttp("HTTP", s)
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user