orchid/agent/agent.go
2025-02-18 21:33:15 +00:00

228 lines
5.7 KiB
Go

package agent
import (
"context"
"database/sql"
_ "embed"
"fmt"
"github.com/1f349/orchid/database"
"github.com/1f349/orchid/utils"
"github.com/bramvdbogaerde/go-scp"
"golang.org/x/crypto/ssh"
"os"
"path/filepath"
"sync"
"time"
)
//go:embed agent_readme.md
var agentReadme []byte
type agentQueries interface {
FindAgentToSync(ctx context.Context) ([]database.FindAgentToSyncRow, error)
UpdateAgentLastSync(ctx context.Context, row database.UpdateAgentLastSyncParams) error
UpdateAgentCertNotAfter(ctx context.Context, arg database.UpdateAgentCertNotAfterParams) error
}
func NewAgent(wg *sync.WaitGroup, db agentQueries, sshKey ssh.Signer, certDir string, keyDir string) (*Agent, error) {
a := &Agent{
db: db,
ticker: time.NewTicker(time.Minute * 10),
done: make(chan struct{}),
syncLock: new(sync.Mutex),
sshKey: sshKey,
certDir: certDir,
keyDir: keyDir,
}
wg.Add(1)
go a.syncRoutine(wg)
return a, nil
}
type Agent struct {
db agentQueries
ticker *time.Ticker
done chan struct{}
syncLock *sync.Mutex
sshKey ssh.Signer
certDir string
keyDir string
}
func (a *Agent) Shutdown() {
Logger.Info("Shutting down agent syncing service")
close(a.done)
}
func (a *Agent) syncRoutine(wg *sync.WaitGroup) {
Logger.Debug("Starting syncRoutine")
// Upon leaving the function stop the ticker and clear the WaitGroup.
defer func() {
a.ticker.Stop()
Logger.Info("Stopped agent syncing service")
wg.Done()
}()
Logger.Info("Doing quick agent check before starting...")
a.syncCheck()
// Logging or something
Logger.Info("Initial check complete, continually checking every 10 minutes...")
// Main loop
for {
select {
case <-a.done:
// Exit if done has closed
return
case <-a.ticker.C:
Logger.Debug("Ticking agent syncing")
go a.syncCheck()
}
}
}
func (a *Agent) syncCheck() {
// if the lock is unavailable then ignore this cycle
if !a.syncLock.TryLock() {
return
}
defer a.syncLock.Unlock()
now := time.Now().UTC()
actions, err := a.db.FindAgentToSync(context.Background())
if err != nil {
panic(err)
}
a.syncCertPairs(now, actions)
}
type syncAgent struct {
agentId int64
address string
user string
fingerprint string
}
func (a *Agent) syncCertPairs(startTime time.Time, rows []database.FindAgentToSyncRow) {
agentMap := make(map[syncAgent][]database.FindAgentToSyncRow)
for _, row := range rows {
a := syncAgent{
agentId: row.AgentID,
address: row.Address,
user: row.User,
fingerprint: row.Fingerprint,
}
agentMap[a] = append(agentMap[a], row)
}
for agent, certPairs := range agentMap {
err := a.syncSingleAgentCertPairs(startTime, agent, certPairs)
if err != nil {
// This agent sync is allowed to fail without stopping other agent syncs from
// occurring.
Logger.Warn("Agent sync failed", "agent", agent.agentId, "err", err)
}
}
}
func (a *Agent) syncSingleAgentCertPairs(startTime time.Time, agent syncAgent, rows []database.FindAgentToSyncRow) error {
hostPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(agent.fingerprint))
if err != nil {
return fmt.Errorf("failed to parse fingerprint: %w", err)
}
client, err := ssh.Dial("tcp", agent.address, &ssh.ClientConfig{
Config: ssh.Config{
KeyExchanges: []string{"curve25519-sha256"},
},
User: agent.user,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(a.sshKey),
},
HostKeyAlgorithms: []string{"ssh-ed25519"},
HostKeyCallback: ssh.FixedHostKey(hostPubKey),
Timeout: time.Second * 30,
})
if err != nil {
return fmt.Errorf("ssh dial: %w", err)
}
scpClient, err := scp.NewClientBySSH(client)
if err != nil {
return fmt.Errorf("scp client: %w", err)
}
for _, row := range rows {
err := a.copySingleCertPair(&scpClient, row)
if err != nil {
// This cert sync is allowed to fail without stopping other certs going to the
// same agent from copying.
err = fmt.Errorf("copySingleCertPair: %w", err)
Logger.Warn("Agent certificate sync failed", "agent", row.AgentID, "cert", row.CertID, "not after", row.CertNotAfter, "err", err)
continue
}
}
// Update last sync to the time when the database request happened. This ensures
// that certificates updated after the database request and before the agent
// syncing are updated properly.
err = a.db.UpdateAgentLastSync(context.Background(), database.UpdateAgentLastSyncParams{
LastSync: sql.NullTime{Time: startTime, Valid: true},
ID: agent.agentId,
})
if err != nil {
return fmt.Errorf("error updating agent last sync: %v", err)
}
return nil
}
func (a *Agent) copySingleCertPair(scpClient *scp.Client, row database.FindAgentToSyncRow) error {
certName := utils.GetCertFileName(row.CertID)
keyName := utils.GetKeyFileName(row.CertID)
certPath := filepath.Join(a.certDir, certName)
keyPath := filepath.Join(a.keyDir, keyName)
// open cert and key files
openCert, err := os.Open(certPath)
if err != nil {
return fmt.Errorf("open cert file: %w", err)
}
defer openCert.Close()
openKey, err := os.Open(keyPath)
if err != nil {
return fmt.Errorf("open key file: %w", err)
}
defer openKey.Close()
// copy cert and key to agent
err = scpClient.CopyFromFile(context.Background(), *openCert, filepath.Join(row.Dir, "certificates", certName), "0600")
if err != nil {
return fmt.Errorf("copy cert file: %w", err)
}
err = scpClient.CopyFromFile(context.Background(), *openKey, filepath.Join(row.Dir, "keys", keyName), "0600")
if err != nil {
return fmt.Errorf("copy cert file: %w", err)
}
err = a.db.UpdateAgentCertNotAfter(context.Background(), database.UpdateAgentCertNotAfterParams{
NotAfter: row.CertNotAfter,
AgentID: row.AgentID,
CertID: row.CertID,
})
if err != nil {
return fmt.Errorf("error updating agent last sync: %v", err)
}
return nil
}