violet/certs/certs.go

131 lines
2.6 KiB
Go
Raw Normal View History

2023-04-22 18:11:21 +01:00
package certs
import (
"code.mrmelon54.com/melon/certgen"
"crypto/tls"
"fmt"
"github.com/MrMelon54/violet/utils"
"io/fs"
"log"
"path/filepath"
"sync"
)
2023-04-24 01:35:23 +01:00
// Certs is the certificate loader and management system.
2023-04-22 18:11:21 +01:00
type Certs struct {
cDir fs.FS
kDir fs.FS
s *sync.RWMutex
m map[string]*tls.Certificate
}
2023-04-24 01:35:23 +01:00
// New creates a new cert list
2023-04-22 18:11:21 +01:00
func New(certDir fs.FS, keyDir fs.FS) *Certs {
a := &Certs{
cDir: certDir,
kDir: keyDir,
s: &sync.RWMutex{},
m: make(map[string]*tls.Certificate),
}
2023-04-24 01:35:23 +01:00
// run compile to get the initial data
2023-04-22 18:11:21 +01:00
a.Compile()
return a
}
func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
// safety read lock
c.s.RLock()
defer c.s.RUnlock()
// lookup and return cert
if cert, ok := c.m[domain]; ok {
return cert
}
// lookup and return wildcard cert
if wildcardDomain, ok := utils.ReplaceSubdomainWithWildcard(domain); ok {
if cert, ok := c.m[wildcardDomain]; ok {
return cert
}
}
// no cert found
return nil
}
func (c *Certs) Compile() {
// async compile magic
go func() {
2023-04-22 22:18:39 +01:00
// new map
2023-04-22 18:11:21 +01:00
certMap := make(map[string]*tls.Certificate)
2023-04-22 22:18:39 +01:00
// compile map and check errors
2023-04-22 18:11:21 +01:00
err := c.internalCompile(certMap)
if err != nil {
log.Printf("[Certs] Compile failed: %s\n", err)
return
}
2023-04-24 01:35:23 +01:00
2023-04-22 18:11:21 +01:00
// lock while replacing the map
c.s.Lock()
c.m = certMap
c.s.Unlock()
}()
}
2023-04-24 01:35:23 +01:00
// internalCompile is a hidden internal method for loading the certificate and
// key files
2023-04-22 18:11:21 +01:00
func (c *Certs) internalCompile(m map[string]*tls.Certificate) error {
// try to read dir
files, err := fs.ReadDir(c.cDir, "")
if err != nil {
return fmt.Errorf("failed to read cert dir: %w", err)
}
log.Printf("[Certs] Compiling lookup table for %d certificates\n", len(files))
// find and parse certs
for _, i := range files {
// skip dirs
if i.IsDir() {
continue
}
// get file name and extension
name := i.Name()
ext := filepath.Ext(name)
keyName := name[:len(name)-len(ext)] + "key"
// try to read cert file
certData, err := fs.ReadFile(c.cDir, name)
if err != nil {
return fmt.Errorf("failed to read cert file '%s': %w", name, err)
}
// try to read key file
keyData, err := fs.ReadFile(c.kDir, keyName)
if err != nil {
return fmt.Errorf("failed to read key file '%s': %w", keyName, err)
}
// load key pair
pair, err := tls.X509KeyPair(certData, keyData)
if err != nil {
return fmt.Errorf("failed to load x509 key pair '%s + %s': %w", name, keyName, err)
}
// load tls leaf
cert := &pair
leaf := certgen.TlsLeaf(cert)
// save in map under each dns name
for _, j := range leaf.DNSNames {
m[j] = cert
}
}
// well no errors happened
return nil
}