Refactor agent to copy all files for an agent dir at the same time

This commit is contained in:
Melon 2025-02-18 21:17:09 +00:00
parent a23155a827
commit fc2f8e34c6
Signed by: melon
GPG Key ID: 6C9D970C50D26A25

View File

@ -99,52 +99,50 @@ func (a *Agent) syncCheck() {
panic(err) panic(err)
} }
agentErrs := make(map[int64][]error) a.syncCertPairs(now, actions)
for _, action := range actions {
err = a.syncSingleAgentCertPair(now, action)
if err != nil {
agentErrs[action.AgentID] = append(agentErrs[action.AgentID], err)
}
}
for agentId, errs := range agentErrs {
Logger.Warn("Agent sync failed", "agent", agentId, "errs", errs)
}
// TODO: idk what to do now
} }
func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAgentToSyncRow) error { type syncAgent struct {
certName := utils.GetCertFileName(row.CertID) agentId int64
keyName := utils.GetKeyFileName(row.CertID) address string
user string
fingerprint string
}
certPath := filepath.Join(a.certDir, certName) func (a *Agent) syncCertPairs(startTime time.Time, rows []database.FindAgentToSyncRow) {
keyPath := filepath.Join(a.keyDir, keyName) agentMap := make(map[syncAgent][]database.FindAgentToSyncRow)
// open cert and key files for _, row := range rows {
openCert, err := os.Open(certPath) a := syncAgent{
if err != nil { agentId: row.AgentID,
return fmt.Errorf("open cert file: %w", err) address: row.Address,
user: row.User,
fingerprint: row.Fingerprint,
}
agentMap[a] = append(agentMap[a], row)
} }
defer openCert.Close()
openKey, err := os.Open(keyPath)
if err != nil { for agent, certPairs := range agentMap {
return fmt.Errorf("open key file: %w", err) err := a.syncSingleAgentCertPairs(startTime, agent, certPairs)
if err != nil {
// This agent sync is allowed to fail without stopping other agent syncs from
// occurring.
Logger.Warn("Agent sync failed", "agent", agent.agentId, "err", err)
}
} }
defer openKey.Close() }
hostPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(row.Fingerprint)) func (a *Agent) syncSingleAgentCertPairs(startTime time.Time, agent syncAgent, rows []database.FindAgentToSyncRow) error {
hostPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(agent.fingerprint))
if err != nil { if err != nil {
return fmt.Errorf("failed to parse fingerprint: %w", err) return fmt.Errorf("failed to parse fingerprint: %w", err)
} }
client, err := ssh.Dial("tcp", row.Address, &ssh.ClientConfig{ client, err := ssh.Dial("tcp", agent.address, &ssh.ClientConfig{
Config: ssh.Config{ Config: ssh.Config{
KeyExchanges: []string{"curve25519-sha256"}, KeyExchanges: []string{"curve25519-sha256"},
}, },
User: row.User, User: agent.user,
Auth: []ssh.AuthMethod{ Auth: []ssh.AuthMethod{
ssh.PublicKeys(a.sshKey), ssh.PublicKeys(a.sshKey),
}, },
@ -161,6 +159,51 @@ func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAg
return fmt.Errorf("scp client: %w", err) return fmt.Errorf("scp client: %w", err)
} }
for _, row := range rows {
err := a.copySingleCertPair(scpClient, row)
if err != nil {
// This cert sync is allowed to fail without stopping other certs going to the
// same agent from copying.
err = fmt.Errorf("copySingleCertPair: %w", err)
Logger.Warn("Agent certificate sync failed", "agent", row.AgentID, "cert", row.CertID, "not after", row.CertNotAfter, "err", err)
continue
}
}
// Update last sync to the time when the database request happened. This ensures
// that certificates updated after the database request and before the agent
// syncing are updated properly.
err = a.db.UpdateAgentLastSync(context.Background(), database.UpdateAgentLastSyncParams{
LastSync: sql.NullTime{Time: startTime, Valid: true},
ID: agent.agentId,
})
if err != nil {
return fmt.Errorf("error updating agent last sync: %v", err)
}
return nil
}
func (a *Agent) copySingleCertPair(scpClient scp.Client, row database.FindAgentToSyncRow) error {
certName := utils.GetCertFileName(row.CertID)
keyName := utils.GetKeyFileName(row.CertID)
certPath := filepath.Join(a.certDir, certName)
keyPath := filepath.Join(a.keyDir, keyName)
// open cert and key files
openCert, err := os.Open(certPath)
if err != nil {
return fmt.Errorf("open cert file: %w", err)
}
defer openCert.Close()
openKey, err := os.Open(keyPath)
if err != nil {
return fmt.Errorf("open key file: %w", err)
}
defer openKey.Close()
// copy cert and key to agent // copy cert and key to agent
err = scpClient.CopyFromFile(context.Background(), *openCert, filepath.Join(row.Dir, "certificates", certName), "0600") err = scpClient.CopyFromFile(context.Background(), *openCert, filepath.Join(row.Dir, "certificates", certName), "0600")
if err != nil { if err != nil {
@ -171,15 +214,6 @@ func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAg
return fmt.Errorf("copy cert file: %w", err) return fmt.Errorf("copy cert file: %w", err)
} }
// update last sync to the time when the database request happened
err = a.db.UpdateAgentLastSync(context.Background(), database.UpdateAgentLastSyncParams{
LastSync: sql.NullTime{Time: startTime, Valid: true},
ID: row.AgentID,
})
if err != nil {
return fmt.Errorf("error updating agent last sync: %v", err)
}
err = a.db.UpdateAgentCertNotAfter(context.Background(), database.UpdateAgentCertNotAfterParams{ err = a.db.UpdateAgentCertNotAfter(context.Background(), database.UpdateAgentCertNotAfterParams{
NotAfter: row.CertNotAfter, NotAfter: row.CertNotAfter,
AgentID: row.AgentID, AgentID: row.AgentID,