mirror of
https://github.com/1f349/violet.git
synced 2024-11-24 20:31:37 +00:00
133 lines
2.8 KiB
Go
133 lines
2.8 KiB
Go
package domains
|
|
|
|
import (
|
|
"database/sql"
|
|
_ "embed"
|
|
"github.com/MrMelon54/rescheduler"
|
|
"github.com/MrMelon54/violet/utils"
|
|
"log"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
//go:embed create-table-domains.sql
|
|
var createTableDomains string
|
|
|
|
// Domains is the domain list and management system.
|
|
type Domains struct {
|
|
db *sql.DB
|
|
s *sync.RWMutex
|
|
m map[string]struct{}
|
|
r *rescheduler.Rescheduler
|
|
}
|
|
|
|
// New creates a new domain list
|
|
func New(db *sql.DB) *Domains {
|
|
a := &Domains{
|
|
db: db,
|
|
s: &sync.RWMutex{},
|
|
m: make(map[string]struct{}),
|
|
}
|
|
a.r = rescheduler.NewRescheduler(a.threadCompile)
|
|
|
|
// init domains table
|
|
_, err := a.db.Exec(createTableDomains)
|
|
if err != nil {
|
|
log.Printf("[WARN] Failed to generate 'domains' table\n")
|
|
return nil
|
|
}
|
|
return a
|
|
}
|
|
|
|
// IsValid returns true if a domain is valid.
|
|
func (d *Domains) IsValid(host string) bool {
|
|
domain, _, _ := utils.SplitDomainPort(host, 0)
|
|
|
|
// read lock for safety
|
|
d.s.RLock()
|
|
defer d.s.RUnlock()
|
|
|
|
// check root domains `www.example.com`, `example.com`, `com`
|
|
for len(domain) > 0 {
|
|
if _, ok := d.m[domain]; ok {
|
|
return true
|
|
}
|
|
n := strings.IndexByte(domain, '.')
|
|
if n == -1 {
|
|
break
|
|
}
|
|
domain = domain[n+1:]
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Compile downloads the list of domains from the database and loads them into
|
|
// memory for faster lookups.
|
|
//
|
|
// This method makes use of the rescheduler instead of just ignoring multiple
|
|
// calls.
|
|
func (d *Domains) Compile() {
|
|
d.r.Run()
|
|
}
|
|
|
|
func (d *Domains) threadCompile() {
|
|
// new map
|
|
domainMap := make(map[string]struct{})
|
|
|
|
// compile map and check errors
|
|
err := d.internalCompile(domainMap)
|
|
if err != nil {
|
|
log.Printf("[Domains] Compile failed: %s\n", err)
|
|
return
|
|
}
|
|
|
|
// lock while replacing the map
|
|
d.s.Lock()
|
|
d.m = domainMap
|
|
d.s.Unlock()
|
|
}
|
|
|
|
// internalCompile is a hidden internal method for querying the database during
|
|
// the Compile() method.
|
|
func (d *Domains) internalCompile(m map[string]struct{}) error {
|
|
log.Println("[Domains] Updating domains from database")
|
|
|
|
// sql or something?
|
|
rows, err := d.db.Query(`select domain from domains where active = 1`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
// loop through rows and scan the allowed domain names
|
|
for rows.Next() {
|
|
var name string
|
|
err = rows.Scan(&name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m[name] = struct{}{}
|
|
}
|
|
|
|
// check for errors
|
|
return rows.Err()
|
|
}
|
|
|
|
func (d *Domains) Put(domain string, active bool) {
|
|
d.s.Lock()
|
|
defer d.s.Unlock()
|
|
_, err := d.db.Exec("INSERT OR REPLACE INTO domains (domain, active) VALUES (?, ?)", domain, active)
|
|
if err != nil {
|
|
log.Printf("[Violet] Database error: %s\n", err)
|
|
}
|
|
}
|
|
|
|
func (d *Domains) Delete(domain string) {
|
|
d.s.Lock()
|
|
defer d.s.Unlock()
|
|
_, err := d.db.Exec("INSERT OR REPLACE INTO domains (domain, active) VALUES (?, ?)", domain, false)
|
|
if err != nil {
|
|
log.Printf("[Violet] Database error: %s\n", err)
|
|
}
|
|
}
|