mirror of
https://github.com/1f349/orchid.git
synced 2025-01-20 22:26:33 +00:00
Too many hours wasted on these tests
This commit is contained in:
parent
2805b72094
commit
fdfdc6c716
5
go.mod
5
go.mod
@ -3,8 +3,12 @@ module github.com/MrMelon54/orchid
|
|||||||
go 1.20
|
go 1.20
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/MrMelon54/certgen v0.0.1
|
||||||
github.com/MrMelon54/mjwt v0.1.0
|
github.com/MrMelon54/mjwt v0.1.0
|
||||||
github.com/go-acme/lego/v4 v4.12.3
|
github.com/go-acme/lego/v4 v4.12.3
|
||||||
|
github.com/google/uuid v1.3.0
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.17
|
||||||
|
github.com/miekg/dns v1.1.50
|
||||||
github.com/stretchr/testify v1.8.4
|
github.com/stretchr/testify v1.8.4
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -14,7 +18,6 @@ require (
|
|||||||
github.com/go-jose/go-jose/v3 v3.0.0 // indirect
|
github.com/go-jose/go-jose/v3 v3.0.0 // indirect
|
||||||
github.com/golang-jwt/jwt/v4 v4.4.3 // indirect
|
github.com/golang-jwt/jwt/v4 v4.4.3 // indirect
|
||||||
github.com/google/go-querystring v1.1.0 // indirect
|
github.com/google/go-querystring v1.1.0 // indirect
|
||||||
github.com/miekg/dns v1.1.50 // indirect
|
|
||||||
github.com/nrdcg/namesilo v0.2.1 // indirect
|
github.com/nrdcg/namesilo v0.2.1 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
6
go.sum
6
go.sum
@ -1,3 +1,5 @@
|
|||||||
|
github.com/MrMelon54/certgen v0.0.1 h1:ycWdZ2RlxQ5qSuejeBVv4aXjGo5hdqqL4j4EjrXnFMk=
|
||||||
|
github.com/MrMelon54/certgen v0.0.1/go.mod h1:GHflVlSbtFLJZLpN1oWyUvDBRrR8qCWiwZLXCCnS2Gc=
|
||||||
github.com/MrMelon54/mjwt v0.1.0 h1:x1wBrh9l2CowRekHecxcZaH2zy9Hvqwlp4ppmW1P1OA=
|
github.com/MrMelon54/mjwt v0.1.0 h1:x1wBrh9l2CowRekHecxcZaH2zy9Hvqwlp4ppmW1P1OA=
|
||||||
github.com/MrMelon54/mjwt v0.1.0/go.mod h1:oYrDBWK09Hju98xb+bRQ0wy+RuAzacxYvKYOZchR2Tk=
|
github.com/MrMelon54/mjwt v0.1.0/go.mod h1:oYrDBWK09Hju98xb+bRQ0wy+RuAzacxYvKYOZchR2Tk=
|
||||||
github.com/cenkalti/backoff/v4 v4.2.0 h1:HN5dHm3WBOgndBH6E8V0q2jIYIR3s9yglV8k/+MN3u4=
|
github.com/cenkalti/backoff/v4 v4.2.0 h1:HN5dHm3WBOgndBH6E8V0q2jIYIR3s9yglV8k/+MN3u4=
|
||||||
@ -17,6 +19,10 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||||||
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
||||||
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
||||||
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
|
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
|
||||||
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA=
|
github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA=
|
||||||
github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
|
github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
|
||||||
github.com/nrdcg/namesilo v0.2.1 h1:kLjCjsufdW/IlC+iSfAqj0iQGgKjlbUUeDJio5Y6eMg=
|
github.com/nrdcg/namesilo v0.2.1 h1:kLjCjsufdW/IlC+iSfAqj0iQGgKjlbUUeDJio5Y6eMg=
|
||||||
|
@ -1,10 +0,0 @@
|
|||||||
//go:build !DEBUG
|
|
||||||
|
|
||||||
package pebble_dev
|
|
||||||
|
|
||||||
import "log"
|
|
||||||
|
|
||||||
func GetPebbleCert() []byte {
|
|
||||||
log.Fatalln("[Renewal] Pebble is selected as the certificate source but this binary was not compiled in debug mode")
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,9 +1,3 @@
|
|||||||
//go:build DEBUG
|
|
||||||
|
|
||||||
package pebble_dev
|
|
||||||
|
|
||||||
func GetPebbleCert() []byte {
|
|
||||||
return []byte(`
|
|
||||||
-----BEGIN CERTIFICATE-----
|
-----BEGIN CERTIFICATE-----
|
||||||
MIIDCTCCAfGgAwIBAgIIJOLbes8sTr4wDQYJKoZIhvcNAQELBQAwIDEeMBwGA1UE
|
MIIDCTCCAfGgAwIBAgIIJOLbes8sTr4wDQYJKoZIhvcNAQELBQAwIDEeMBwGA1UE
|
||||||
AxMVbWluaWNhIHJvb3QgY2EgMjRlMmRiMCAXDTE3MTIwNjE5NDIxMFoYDzIxMTcx
|
AxMVbWluaWNhIHJvb3QgY2EgMjRlMmRiMCAXDTE3MTIwNjE5NDIxMFoYDzIxMTcx
|
||||||
@ -23,5 +17,3 @@ Mfn3qEb9BXSk0Q3prNV5sOV3vgjEtB4THfDxSz9z3+DepVnW3vbbqwEbkXdk3j82
|
|||||||
2muVldgOUgTwK8eT+XdofVdntzU/kzygSAtAQwLJfn51fS1GvEcYGBc1bDryIqmF
|
2muVldgOUgTwK8eT+XdofVdntzU/kzygSAtAQwLJfn51fS1GvEcYGBc1bDryIqmF
|
||||||
p9BI7gVKtWSZYegicA==
|
p9BI7gVKtWSZYegicA==
|
||||||
-----END CERTIFICATE-----
|
-----END CERTIFICATE-----
|
||||||
`)
|
|
||||||
}
|
|
19
pebble/asset/pebble-config.json
Normal file
19
pebble/asset/pebble-config.json
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"pebble": {
|
||||||
|
"listenAddress": "0.0.0.0:14000",
|
||||||
|
"managementListenAddress": "0.0.0.0:15000",
|
||||||
|
"certificate": "certs/localhost/cert.pem",
|
||||||
|
"privateKey": "certs/localhost/key.pem",
|
||||||
|
"httpPort": 5002,
|
||||||
|
"tlsPort": 5001,
|
||||||
|
"ocspResponderURL": "",
|
||||||
|
"externalAccountBindingRequired": false,
|
||||||
|
"domainBlocklist": [
|
||||||
|
"blocked-domain.example"
|
||||||
|
],
|
||||||
|
"retryAfter": {
|
||||||
|
"authz": 3,
|
||||||
|
"order": 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
12
pebble/pebble.go
Normal file
12
pebble/pebble.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
//go:build !DEBUG
|
||||||
|
|
||||||
|
package pebble
|
||||||
|
|
||||||
|
import _ "embed"
|
||||||
|
|
||||||
|
var (
|
||||||
|
//go:embed asset/pebble-cert.pem
|
||||||
|
RawCert []byte
|
||||||
|
//go:embed asset/pebble-config.json
|
||||||
|
RawConfig []byte
|
||||||
|
)
|
@ -7,4 +7,5 @@ type LetsEncryptConfig struct {
|
|||||||
} `yaml:"account"`
|
} `yaml:"account"`
|
||||||
Directory string `yaml:"directory"`
|
Directory string `yaml:"directory"`
|
||||||
Certificate string `yaml:"certificate"`
|
Certificate string `yaml:"certificate"`
|
||||||
|
insecure bool
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
select cert.id, certdata.data_id, certdata.not_after, dns.type, dns.token
|
select cert.id, cert.not_after, dns.type, dns.token
|
||||||
from certificates as cert
|
from certificates as cert
|
||||||
left outer join certificate_data as certdata on cert.id = certdata.meta_id
|
|
||||||
left outer join dns on cert.dns = dns.id
|
left outer join dns on cert.dns = dns.id
|
||||||
where cert.active = 1
|
where cert.active = 1
|
||||||
and cert.auto_renew = 1
|
and cert.auto_renew = 1
|
||||||
and cert.renewing = 0
|
and cert.renewing = 0
|
||||||
and cert.renew_failed = 0
|
and cert.renew_failed = 0
|
||||||
and (certdata.ready IS NULL or certdata.ready = 1)
|
and (cert.not_after IS NULL or DATETIME(cert.not_after, 'utc', '-30 days') < DATETIME())
|
||||||
and (certdata.not_after IS NULL or DATETIME(certdata.not_after, 'utc', '-30 days') < DATETIME())
|
order by cert.not_after DESC NULLS FIRST
|
||||||
order by certdata.not_after DESC NULLS FIRST
|
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
package renewal
|
package renewal
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
// Contains local types for the renewal service
|
// Contains local types for the renewal service
|
||||||
type localCertData struct {
|
type localCertData struct {
|
||||||
id uint64
|
id uint64
|
||||||
dns struct {
|
dns struct {
|
||||||
name string
|
name sql.NullString
|
||||||
token string
|
token sql.NullString
|
||||||
}
|
}
|
||||||
cert struct {
|
notAfter sql.NullTime
|
||||||
current uint64
|
domains []string
|
||||||
notAfter time.Time
|
|
||||||
}
|
|
||||||
domains []string
|
|
||||||
}
|
}
|
||||||
|
@ -1,17 +1,19 @@
|
|||||||
package renewal
|
package renewal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/MrMelon54/orchid/http-acme"
|
"github.com/MrMelon54/orchid/pebble"
|
||||||
"github.com/MrMelon54/orchid/pebble-dev"
|
|
||||||
"github.com/go-acme/lego/v4/certificate"
|
"github.com/go-acme/lego/v4/certificate"
|
||||||
"github.com/go-acme/lego/v4/challenge"
|
"github.com/go-acme/lego/v4/challenge"
|
||||||
|
"github.com/go-acme/lego/v4/challenge/dns01"
|
||||||
"github.com/go-acme/lego/v4/lego"
|
"github.com/go-acme/lego/v4/lego"
|
||||||
"github.com/go-acme/lego/v4/providers/dns/namesilo"
|
"github.com/go-acme/lego/v4/providers/dns/namesilo"
|
||||||
"github.com/go-acme/lego/v4/registration"
|
"github.com/go-acme/lego/v4/registration"
|
||||||
@ -32,23 +34,27 @@ var (
|
|||||||
createTableCertificates string
|
createTableCertificates string
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var testDnsOptions interface {
|
||||||
|
challenge.Provider
|
||||||
|
GetDnsAddrs() []string
|
||||||
|
}
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
httpAcme *http_acme.HttpAcmeProvider
|
httpAcme challenge.Provider
|
||||||
certTicker *time.Ticker
|
certTicker *time.Ticker
|
||||||
certDone chan struct{}
|
certDone chan struct{}
|
||||||
caAddr string
|
caAddr string
|
||||||
caCert []byte
|
caCert []byte
|
||||||
transport *http.Transport
|
transport http.RoundTripper
|
||||||
renewLock *sync.Mutex
|
renewLock *sync.Mutex
|
||||||
leAccount *Account
|
leAccount *Account
|
||||||
certDir string
|
certDir string
|
||||||
keyDir string
|
keyDir string
|
||||||
|
insecure bool
|
||||||
//notify
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRenewalService(wg *sync.WaitGroup, db *sql.DB, httpAcme *http_acme.HttpAcmeProvider, leConfig LetsEncryptConfig) (*Service, error) {
|
func NewRenewalService(wg *sync.WaitGroup, db *sql.DB, httpAcme challenge.Provider, leConfig LetsEncryptConfig, certDir, keyDir string) (*Service, error) {
|
||||||
r := &Service{
|
r := &Service{
|
||||||
db: db,
|
db: db,
|
||||||
httpAcme: httpAcme,
|
httpAcme: httpAcme,
|
||||||
@ -57,19 +63,37 @@ func NewRenewalService(wg *sync.WaitGroup, db *sql.DB, httpAcme *http_acme.HttpA
|
|||||||
renewLock: &sync.Mutex{},
|
renewLock: &sync.Mutex{},
|
||||||
leAccount: &Account{
|
leAccount: &Account{
|
||||||
email: leConfig.Account.Email,
|
email: leConfig.Account.Email,
|
||||||
key: leConfig.Account.PrivateKey,
|
|
||||||
},
|
},
|
||||||
|
certDir: certDir,
|
||||||
|
keyDir: keyDir,
|
||||||
|
insecure: leConfig.insecure,
|
||||||
|
}
|
||||||
|
|
||||||
|
// make certDir and keyDir
|
||||||
|
err := os.MkdirAll(certDir, os.ModePerm)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create certDir '%s': %w", certDir, err)
|
||||||
|
}
|
||||||
|
err = os.MkdirAll(keyDir, os.ModePerm)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create certDir '%s': %w", certDir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// load lets encrypt private key
|
||||||
|
err = r.resolveLEPrivKey(leConfig.Account.PrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to resolve LetsEncrypt account private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// init domains table
|
// init domains table
|
||||||
_, err := r.db.Exec(createTableCertificates)
|
_, err = r.db.Exec(createTableCertificates)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create certificates table: %w", err)
|
return nil, fmt.Errorf("failed to create certificates table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolve CA information
|
// resolve CA information
|
||||||
r.resolveCADirectory(leConfig)
|
r.resolveCADirectory(leConfig.Directory)
|
||||||
err = r.resolveCACertificate(leConfig)
|
err = r.resolveCACertificate(leConfig.Certificate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to resolve CA certificate: %w", err)
|
return nil, fmt.Errorf("failed to resolve CA certificate: %w", err)
|
||||||
}
|
}
|
||||||
@ -84,41 +108,55 @@ func (s *Service) Shutdown() {
|
|||||||
close(s.certDone)
|
close(s.certDone)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) resolveCADirectory(conf LetsEncryptConfig) {
|
func (s *Service) resolveLEPrivKey(a string) error {
|
||||||
switch conf.Directory {
|
key, err := x509.ParsePKCS1PrivateKey([]byte(a))
|
||||||
|
if err != nil {
|
||||||
|
bytes, err := os.ReadFile(a)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
key, err = x509.ParsePKCS1PrivateKey(bytes)
|
||||||
|
}
|
||||||
|
s.leAccount.key = key
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) resolveCADirectory(dir string) {
|
||||||
|
switch dir {
|
||||||
case "production", "prod":
|
case "production", "prod":
|
||||||
s.caAddr = lego.LEDirectoryProduction
|
s.caAddr = lego.LEDirectoryProduction
|
||||||
case "staging":
|
case "staging":
|
||||||
s.caAddr = lego.LEDirectoryStaging
|
s.caAddr = lego.LEDirectoryStaging
|
||||||
default:
|
default:
|
||||||
s.caAddr = conf.Directory
|
s.caAddr = dir
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) resolveCACertificate(conf LetsEncryptConfig) error {
|
func (s *Service) resolveCACertificate(cert string) error {
|
||||||
switch conf.Certificate {
|
switch cert {
|
||||||
case "default":
|
case "default":
|
||||||
// no nothing
|
// no nothing
|
||||||
case "pebble":
|
case "pebble":
|
||||||
s.caCert = pebble_dev.GetPebbleCert()
|
s.caCert = pebble.RawCert
|
||||||
|
case "insecure":
|
||||||
|
s.caCert = []byte{0x00}
|
||||||
default:
|
default:
|
||||||
caGet, err := http.Get(conf.Certificate)
|
s.caCert = []byte(cert)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to download CA certificate: %w", err)
|
|
||||||
}
|
|
||||||
s.caCert, err = io.ReadAll(caGet.Body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read CA certificate: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if s.caCert != nil {
|
if s.caCert != nil {
|
||||||
caPool := x509.NewCertPool()
|
if bytes.Compare([]byte{0x00}, s.caCert) == 0 {
|
||||||
if !caPool.AppendCertsFromPEM(s.caCert) {
|
t := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
return fmt.Errorf("failed to add certificate to CA cert pool")
|
t.TLSClientConfig.InsecureSkipVerify = true
|
||||||
|
s.transport = t
|
||||||
|
} else {
|
||||||
|
caPool := x509.NewCertPool()
|
||||||
|
if !caPool.AppendCertsFromPEM(s.caCert) {
|
||||||
|
return fmt.Errorf("failed to add certificate to CA cert pool")
|
||||||
|
}
|
||||||
|
t := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
t.TLSClientConfig = &tls.Config{RootCAs: caPool}
|
||||||
|
s.transport = t
|
||||||
}
|
}
|
||||||
t := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
t.TLSClientConfig = &tls.Config{RootCAs: caPool}
|
|
||||||
s.transport = t
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -171,14 +209,19 @@ func (s *Service) renewalCheck() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
s.renewCert(localData)
|
err = s.renewCert(localData)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Printf("[Renewal] Updated certificate %d successfully\n", localData.id)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) findNextCertificateToRenew() (*localCertData, error) {
|
func (s *Service) findNextCertificateToRenew() (*localCertData, error) {
|
||||||
d := &localCertData{}
|
d := &localCertData{}
|
||||||
|
|
||||||
row := s.db.QueryRow(findNextCertSql)
|
row := s.db.QueryRow(findNextCertSql)
|
||||||
err := row.Scan(&d.id, &d.cert.current, &d.cert.notAfter, &d.dns.name, &d.dns.token)
|
err := row.Scan(&d.id, &d.notAfter, &d.dns.name, &d.dns.token)
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
// no nothing
|
// no nothing
|
||||||
@ -220,19 +263,29 @@ func (s *Service) setupLegoClient(localData *localCertData) (*lego.Client, error
|
|||||||
if s.transport != nil {
|
if s.transport != nil {
|
||||||
config.HTTPClient.Transport = s.transport
|
config.HTTPClient.Transport = s.transport
|
||||||
}
|
}
|
||||||
dnsProv, err := s.getDnsProvider(localData.dns.name, localData.dns.token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to resolve dns provider: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := lego.NewClient(config)
|
client, err := lego.NewClient(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate client: %w", err)
|
return nil, fmt.Errorf("failed to generate client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// set providers - always returns nil so ignore the error
|
// set http challenge provider
|
||||||
_ = client.Challenge.SetHTTP01Provider(s.httpAcme)
|
_ = client.Challenge.SetHTTP01Provider(s.httpAcme)
|
||||||
_ = client.Challenge.SetDNS01Provider(dnsProv)
|
|
||||||
|
// if testDnsOptions is defined then set up the test provider
|
||||||
|
if testDnsOptions != nil {
|
||||||
|
dnsAddrs := testDnsOptions.GetDnsAddrs()
|
||||||
|
log.Printf("Using testDnsOptions with DNS server: %v\n", dnsAddrs)
|
||||||
|
_ = client.Challenge.SetDNS01Provider(testDnsOptions, dns01.AddRecursiveNameservers(dnsAddrs), dns01.DisableCompletePropagationRequirement())
|
||||||
|
} else if localData.dns.name.Valid && localData.dns.token.Valid {
|
||||||
|
// if the dns name and token are "valid" meaning non-null in this case
|
||||||
|
// set up the specific dns provider requested
|
||||||
|
dnsProv, err := s.getDnsProvider(localData.dns.name.String, localData.dns.token.String)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to resolve dns provider: %w", err)
|
||||||
|
}
|
||||||
|
_ = client.Challenge.SetDNS01Provider(dnsProv)
|
||||||
|
}
|
||||||
|
|
||||||
register, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
|
register, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -260,45 +313,26 @@ func (s *Service) getPrivateKey(id uint64) (*rsa.PrivateKey, error) {
|
|||||||
return x509.ParsePKCS1PrivateKey(privKeyBytes)
|
return x509.ParsePKCS1PrivateKey(privKeyBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) renewCert(localData *localCertData) {
|
func (s *Service) renewCert(localData *localCertData) error {
|
||||||
s.setRenewing(localData.id, true, false)
|
s.setRenewing(localData.id, true, false)
|
||||||
|
|
||||||
cert, certBytes, err := s.renewCertInternal(localData)
|
cert, certBytes, err := s.renewCertInternal(localData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[Renewal Failed to renew cert %d: %s\n", localData.id, err)
|
|
||||||
s.setRenewing(localData.id, false, true)
|
s.setRenewing(localData.id, false, true)
|
||||||
return
|
return fmt.Errorf("failed to renew cert %d: %w", localData.id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.db.Exec(`UPDATE certificates SET renewing = 0, renew_failed = 0, not_after = ?, updated_at = ? WHERE id = ?`, cert.NotAfter, cert.NotBefore, localData.id)
|
_, err = s.db.Exec(`UPDATE certificates SET renewing = 0, renew_failed = 0, not_after = ?, updated_at = ? WHERE id = ?`, cert.NotAfter, cert.NotBefore, localData.id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[Renewal] Failed to update certificate %d in database: %s\n", localData.id, err)
|
return fmt.Errorf("failed to update cert %d in database: %w", localData.id, err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
oldPath := filepath.Join(s.certDir, fmt.Sprintf("%d-old.cert.pem", localData.id))
|
err = s.writeCertFile(localData.id, certBytes)
|
||||||
newPath := filepath.Join(s.certDir, fmt.Sprintf("%d.cert.pem", localData.id))
|
|
||||||
|
|
||||||
err = os.Rename(newPath, oldPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[Renewal] Failed to rename certificate file '%s' => '%s': %s\n", newPath, oldPath, err)
|
return fmt.Errorf("failed to write cert file: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
openCertFile, err := os.Create(newPath)
|
return nil
|
||||||
if err != nil {
|
|
||||||
log.Printf("[Renewal] Failed to create certificate file '%s': %s\n", newPath, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer openCertFile.Close()
|
|
||||||
|
|
||||||
_, err = openCertFile.Write(certBytes)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[Renewal] Failed to write certificate file '%s': %s\n", newPath, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("[Renewal] Updated certificate %d successfully\n", localData.id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) renewCertInternal(localData *localCertData) (*x509.Certificate, []byte, error) {
|
func (s *Service) renewCertInternal(localData *localCertData) (*x509.Certificate, []byte, error) {
|
||||||
@ -320,6 +354,7 @@ func (s *Service) renewCertInternal(localData *localCertData) (*x509.Certificate
|
|||||||
return nil, nil, fmt.Errorf("failed to generate a client: %w", err)
|
return nil, nil, fmt.Errorf("failed to generate a client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// obtain new certificate - this call will hang until a certificate is ready
|
||||||
obtain, err := client.Certificate.Obtain(certificate.ObtainRequest{
|
obtain, err := client.Certificate.Obtain(certificate.ObtainRequest{
|
||||||
Domains: domains,
|
Domains: domains,
|
||||||
PrivateKey: privKey,
|
PrivateKey: privKey,
|
||||||
@ -329,11 +364,19 @@ func (s *Service) renewCertInternal(localData *localCertData) (*x509.Certificate
|
|||||||
return nil, nil, fmt.Errorf("failed to obtain replacement certificate: %w", err)
|
return nil, nil, fmt.Errorf("failed to obtain replacement certificate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
parseCert, err := x509.ParseCertificate(obtain.Certificate)
|
// extract the certificate data from pem encoding
|
||||||
|
p, _ := pem.Decode(obtain.Certificate)
|
||||||
|
if p.Type != "CERTIFICATE" {
|
||||||
|
return nil, nil, fmt.Errorf("invalid certificate type '%s'", p.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse the obtained certificate
|
||||||
|
parseCert, err := x509.ParseCertificate(p.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("failed to parse new certificate: %w", err)
|
return nil, nil, fmt.Errorf("failed to parse new certificate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return the parsed and raw bytes
|
||||||
return parseCert, obtain.Certificate, nil
|
return parseCert, obtain.Certificate, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -343,3 +386,29 @@ func (s *Service) setRenewing(id uint64, renewing, failed bool) {
|
|||||||
log.Printf("[Renewal] Failed to set renewing/failed mode in database %d: %s\n", id, err)
|
log.Printf("[Renewal] Failed to set renewing/failed mode in database %d: %s\n", id, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) writeCertFile(id uint64, certBytes []byte) error {
|
||||||
|
oldPath := filepath.Join(s.certDir, fmt.Sprintf("%d-old.cert.pem", id))
|
||||||
|
newPath := filepath.Join(s.certDir, fmt.Sprintf("%d.cert.pem", id))
|
||||||
|
|
||||||
|
// move certificate file to old name
|
||||||
|
err := os.Rename(newPath, oldPath)
|
||||||
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("failed to rename cert file '%s' => '%s': %w", newPath, oldPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create new certificate file
|
||||||
|
openCertFile, err := os.Create(newPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create cert file '%s': %s", newPath, err)
|
||||||
|
}
|
||||||
|
defer openCertFile.Close()
|
||||||
|
|
||||||
|
// write certificate bytes
|
||||||
|
_, err = openCertFile.Write(certBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write cert file '%s': %s", newPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -1 +1,217 @@
|
|||||||
package renewal
|
package renewal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"github.com/MrMelon54/certgen"
|
||||||
|
"github.com/MrMelon54/orchid/pebble"
|
||||||
|
"github.com/MrMelon54/orchid/test"
|
||||||
|
"github.com/go-acme/lego/v4/lego"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"go/build"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const pebbleUrl = "https://localhost:5000"
|
||||||
|
|
||||||
|
func TestService_resolveCADirectory(t *testing.T) {
|
||||||
|
s := &Service{}
|
||||||
|
s.resolveCADirectory("production")
|
||||||
|
assert.Equal(t, lego.LEDirectoryProduction, s.caAddr)
|
||||||
|
s.resolveCADirectory("prod")
|
||||||
|
assert.Equal(t, lego.LEDirectoryProduction, s.caAddr)
|
||||||
|
s.resolveCADirectory("staging")
|
||||||
|
assert.Equal(t, lego.LEDirectoryStaging, s.caAddr)
|
||||||
|
s.resolveCADirectory(pebbleUrl)
|
||||||
|
assert.Equal(t, pebbleUrl, s.caAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestService_resolveCACertificate(t *testing.T) {
|
||||||
|
s := &Service{}
|
||||||
|
assert.NoError(t, s.resolveCACertificate("default"))
|
||||||
|
assert.Nil(t, s.caCert)
|
||||||
|
assert.NoError(t, s.resolveCACertificate("pebble"))
|
||||||
|
assert.Equal(t, 0, bytes.Compare(pebble.RawCert, s.caCert))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupPebbleSuite(tb testing.TB) (*certgen.CertGen, func()) {
|
||||||
|
fmt.Println("Running pebble")
|
||||||
|
pebbleTmp, err := os.MkdirTemp("", "pebble")
|
||||||
|
assert.NoError(tb, err)
|
||||||
|
assert.NoError(tb, os.WriteFile(filepath.Join(pebbleTmp, "pebble-config.json"), pebble.RawConfig, os.ModePerm))
|
||||||
|
|
||||||
|
serverTls, err := certgen.MakeServerTls(nil, 2048, pkix.Name{
|
||||||
|
Country: []string{"GB"},
|
||||||
|
Organization: []string{"Orchid"},
|
||||||
|
OrganizationalUnit: []string{"Test"},
|
||||||
|
SerialNumber: "0",
|
||||||
|
CommonName: "localhost",
|
||||||
|
}, big.NewInt(1), func(now time.Time) time.Time {
|
||||||
|
return now.AddDate(10, 0, 0)
|
||||||
|
}, []string{"localhost", "pebble"}, []net.IP{net.IPv4(127, 0, 0, 1)})
|
||||||
|
assert.NoError(tb, err)
|
||||||
|
assert.NoError(tb, os.MkdirAll(filepath.Join(pebbleTmp, "certs", "localhost"), os.ModePerm))
|
||||||
|
assert.NoError(tb, os.WriteFile(filepath.Join(pebbleTmp, "certs", "localhost", "cert.pem"), serverTls.GetCertPem(), os.ModePerm))
|
||||||
|
assert.NoError(tb, os.WriteFile(filepath.Join(pebbleTmp, "certs", "localhost", "key.pem"), serverTls.GetKeyPem(), os.ModePerm))
|
||||||
|
|
||||||
|
// hack default resolver
|
||||||
|
net.DefaultResolver.PreferGo = true
|
||||||
|
net.DefaultResolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
tb.Logf("Custom Resolver %s - %s\n", network, address)
|
||||||
|
return nil, fmt.Errorf("ha failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServer := test.MakeFakeDnsProv("127.0.0.34:5053") // 127.0.0.34:53
|
||||||
|
dnsServer.AddRecursiveSOA("example.test.")
|
||||||
|
go dnsServer.Start()
|
||||||
|
testDnsOptions = dnsServer
|
||||||
|
|
||||||
|
pebbleFile := filepath.Join(build.Default.GOPATH, "bin", "pebble")
|
||||||
|
command := exec.Command(pebbleFile, "-config", filepath.Join(pebbleTmp, "pebble-config.json"), "-dnsserver", "127.0.0.34:5053")
|
||||||
|
command.Env = append(command.Env, "PEBBLE_VA_ALWAYS_VALID=1")
|
||||||
|
command.Dir = pebbleTmp
|
||||||
|
|
||||||
|
if command.Start() != nil {
|
||||||
|
instCmd := exec.Command("go", "install", "github.com/letsencrypt/pebble/cmd/pebble@latest")
|
||||||
|
assert.NoError(tb, instCmd.Run(), "Failed to start pebble make sure it is installed... go install github.com/letsencrypt/pebble/cmd/pebble@latest")
|
||||||
|
assert.NoError(tb, command.Start(), "failed to start pebble again")
|
||||||
|
}
|
||||||
|
|
||||||
|
return serverTls, func() {
|
||||||
|
// unhack default resolver
|
||||||
|
net.DefaultResolver.PreferGo = false
|
||||||
|
net.DefaultResolver.Dial = nil
|
||||||
|
|
||||||
|
fmt.Println("Killing pebble")
|
||||||
|
if command != nil && command.Process != nil {
|
||||||
|
assert.NoError(tb, command.Process.Kill())
|
||||||
|
}
|
||||||
|
dnsServer.Shutdown()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) *Service {
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
dbFile := fmt.Sprintf("file:%s?mode=memory&cache=shared", uuid.NewString())
|
||||||
|
db, err := sql.Open("sqlite3", dbFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
log.Println("DB File:", dbFile)
|
||||||
|
|
||||||
|
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
tr.TLSClientConfig.InsecureSkipVerify = true
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://localhost:14000/root", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
res, err := tr.RoundTrip(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
certRaw, err := io.ReadAll(res.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
fmt.Println("Cert:", string(certRaw))
|
||||||
|
|
||||||
|
certDir, err := os.MkdirTemp("", "orchid-certs")
|
||||||
|
keyDir, err := os.MkdirTemp("", "orchid-keys")
|
||||||
|
|
||||||
|
lePrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
acmeProv := test.MakeFakeAcmeProv(serverTls.GetCertPem())
|
||||||
|
service, err := NewRenewalService(wg, db, acmeProv, LetsEncryptConfig{
|
||||||
|
Account: struct {
|
||||||
|
Email string `yaml:"email"`
|
||||||
|
PrivateKey string `yaml:"privateKey"`
|
||||||
|
}{
|
||||||
|
Email: "webmaster@example.test",
|
||||||
|
PrivateKey: string(x509.MarshalPKCS1PrivateKey(lePrivKey)),
|
||||||
|
},
|
||||||
|
Directory: "https://localhost:14000/dir",
|
||||||
|
Certificate: string(certRaw),
|
||||||
|
insecure: true,
|
||||||
|
}, certDir, keyDir)
|
||||||
|
service.transport = acmeProv
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NoError(t, os.WriteFile(filepath.Join(keyDir, "1.key.pem"), x509.MarshalPKCS1PrivateKey(privKey), os.ModePerm))
|
||||||
|
|
||||||
|
return service
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPebbleRenewal(t *testing.T) {
|
||||||
|
serverTls, cancel := setupPebbleSuite(t)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
domains []string
|
||||||
|
}{
|
||||||
|
{"Test", []string{"hello.example.test"}},
|
||||||
|
{"Test with multiple certificates", []string{"example.test", "world.example.test"}},
|
||||||
|
{"Test with wildcard certificate", []string{"example.test", "*.example.test"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, i := range tests {
|
||||||
|
t.Run(i.name, func(t *testing.T) {
|
||||||
|
//t.Parallel()
|
||||||
|
service := setupPebbleTest(t, serverTls)
|
||||||
|
//goland:noinspection SqlWithoutWhere
|
||||||
|
_, err := service.db.Exec("DELETE FROM certificate_domains")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = service.db.Exec(`INSERT INTO certificates (owner, dns, auto_renew, active, renewing, renew_failed, not_after, updated_at) VALUES (1, 1, 1, 1, 0, 0, NULL, NULL)`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
for _, j := range i.domains {
|
||||||
|
_, err = service.db.Exec(`INSERT INTO certificate_domains (cert_id, domain) VALUES (1, ?)`, j)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Database rows")
|
||||||
|
fmt.Println("=============")
|
||||||
|
query, err := service.db.Query("SELECT cert_id, domain from certificate_domains")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
for query.Next() {
|
||||||
|
var a uint64
|
||||||
|
var b string
|
||||||
|
assert.NoError(t, query.Scan(&a, &b))
|
||||||
|
|
||||||
|
fmt.Println(a, b)
|
||||||
|
}
|
||||||
|
fmt.Println("=============")
|
||||||
|
|
||||||
|
assert.NoError(t, service.renewalCheck())
|
||||||
|
certFilePath := filepath.Join(service.certDir, "1.cert.pem")
|
||||||
|
certFileRaw, err := os.ReadFile(certFilePath)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
p, _ := pem.Decode(certFileRaw)
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
if p == nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
assert.Equal(t, "CERTIFICATE", p.Type)
|
||||||
|
outCert, err := x509.ParseCertificate(p.Bytes)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, i.domains, outCert.DNSNames)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
35
test/fakeAcmeProv.go
Normal file
35
test/fakeAcmeProv.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"github.com/go-acme/lego/v4/challenge"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeAcmeProv an acme provider to emulate
|
||||||
|
type fakeAcmeProv struct {
|
||||||
|
t http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeFakeAcmeProv(cert []byte) interface {
|
||||||
|
challenge.Provider
|
||||||
|
http.RoundTripper
|
||||||
|
} {
|
||||||
|
cp := x509.NewCertPool()
|
||||||
|
cp.AppendCertsFromPEM(cert)
|
||||||
|
t := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
t.TLSClientConfig = &tls.Config{RootCAs: cp}
|
||||||
|
return &fakeAcmeProv{t: t}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeAcmeProv) Present(string, string, string) error { return nil }
|
||||||
|
func (f *fakeAcmeProv) CleanUp(string, string, string) error { return nil }
|
||||||
|
func (f *fakeAcmeProv) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
// use transport with custom CertPool for pebble requests
|
||||||
|
if req.URL.Host == "localhost:14000" {
|
||||||
|
return f.t.RoundTrip(req)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("invalid fakeAcmeProv.RoundTrip call to '%s'", req.URL.String())
|
||||||
|
}
|
130
test/fakeDnsProv.go
Normal file
130
test/fakeDnsProv.go
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/go-acme/lego/v4/challenge"
|
||||||
|
"github.com/go-acme/lego/v4/challenge/dns01"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeDnsProv struct {
|
||||||
|
Addr string
|
||||||
|
mTxt map[string]string
|
||||||
|
srv *dns.Server
|
||||||
|
mSoa map[string][2]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeFakeDnsProv(addr string) interface {
|
||||||
|
challenge.Provider
|
||||||
|
GetDnsAddrs() []string
|
||||||
|
Start()
|
||||||
|
Shutdown()
|
||||||
|
AddRecursiveSOA(fqdn string)
|
||||||
|
} {
|
||||||
|
return &fakeDnsProv{
|
||||||
|
Addr: addr,
|
||||||
|
mTxt: make(map[string]string),
|
||||||
|
mSoa: make(map[string][2]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDnsProv) AddRecursiveSOA(fqdn string) {
|
||||||
|
n := fqdn
|
||||||
|
for {
|
||||||
|
f.mSoa[n] = [2]string{"ns." + n, "webmaster." + n}
|
||||||
|
|
||||||
|
// find next subdomain separator and trim the fqdn
|
||||||
|
ni := strings.IndexByte(n, '.')
|
||||||
|
if ni <= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
n = n[ni+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDnsProv) Present(domain, _, keyAuth string) error {
|
||||||
|
info := dns01.GetChallengeInfo(domain, keyAuth)
|
||||||
|
f.mTxt[info.EffectiveFQDN] = info.Value
|
||||||
|
log.Printf("fakeDnsProv.Present(%s TXT %s)\n", info.EffectiveFQDN, info.Value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (f *fakeDnsProv) CleanUp(domain, _, keyAuth string) error {
|
||||||
|
info := dns01.GetChallengeInfo(domain, keyAuth)
|
||||||
|
delete(f.mTxt, info.EffectiveFQDN)
|
||||||
|
log.Printf("fakeDnsProv.CleanUp(%s TXT %s)\n", info.EffectiveFQDN, info.Value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (f *fakeDnsProv) GetDnsAddrs() []string {
|
||||||
|
fmt.Printf("Get dns addrs: %v\n", f.Addr)
|
||||||
|
return []string{f.Addr}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDnsProv) parseQuery(m *dns.Msg) {
|
||||||
|
for _, q := range m.Question {
|
||||||
|
switch q.Qtype {
|
||||||
|
case dns.TypeSOA:
|
||||||
|
log.Printf("Looking up %s SOA record\n", q.Name)
|
||||||
|
n := q.Name
|
||||||
|
for strings.Count(n, ".") > 3 {
|
||||||
|
// find next subdomain separator and trim the fqdn
|
||||||
|
ni := strings.IndexByte(n, '.')
|
||||||
|
if ni <= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
n = n[ni+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// find an answer if possible
|
||||||
|
if strings.Count(q.Name, ".") == 3 {
|
||||||
|
rr, err := dns.NewRR(fmt.Sprintf("%s 32600 IN SOA %s %s 1687993787 86400 7200 4000000 11200", n, "ns.example.com.", "hostmaster.example.com."))
|
||||||
|
if err == nil {
|
||||||
|
m.Answer = append(m.Answer, rr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case dns.TypeTXT:
|
||||||
|
log.Printf("Looking up %s TXT record\n", q.Name)
|
||||||
|
txt := f.mTxt[q.Name]
|
||||||
|
if txt != "" {
|
||||||
|
rr, err := dns.NewRR(fmt.Sprintf("%s 32600 IN TXT \"%s\"", q.Name, txt))
|
||||||
|
if err == nil {
|
||||||
|
fmt.Println("response:", rr.String())
|
||||||
|
m.Answer = append(m.Answer, rr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
log.Printf("Looking up %d for %s\n", q.Qtype, q.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDnsProv) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetReply(r)
|
||||||
|
m.Compress = false
|
||||||
|
|
||||||
|
switch r.Opcode {
|
||||||
|
case dns.OpcodeQuery:
|
||||||
|
f.parseQuery(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = w.WriteMsg(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDnsProv) Start() {
|
||||||
|
// attach request handler func
|
||||||
|
dns.HandleFunc(".", f.handleDnsRequest)
|
||||||
|
|
||||||
|
// start server
|
||||||
|
f.srv = &dns.Server{Addr: f.Addr, Net: "udp"}
|
||||||
|
log.Printf("Starting fake dns service at %s\n", f.srv.Addr)
|
||||||
|
err := f.srv.ListenAndServe()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to start server: %s\n ", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDnsProv) Shutdown() {
|
||||||
|
_ = f.srv.Shutdown()
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user