diff --git a/agent/agent.go b/agent/agent.go index 9186289..f7ece1d 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -99,52 +99,50 @@ func (a *Agent) syncCheck() { panic(err) } - agentErrs := make(map[int64][]error) - - for _, action := range actions { - err = a.syncSingleAgentCertPair(now, action) - if err != nil { - agentErrs[action.AgentID] = append(agentErrs[action.AgentID], err) - } - } - - for agentId, errs := range agentErrs { - Logger.Warn("Agent sync failed", "agent", agentId, "errs", errs) - } - - // TODO: idk what to do now + a.syncCertPairs(now, actions) } -func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAgentToSyncRow) error { - certName := utils.GetCertFileName(row.CertID) - keyName := utils.GetKeyFileName(row.CertID) +type syncAgent struct { + agentId int64 + address string + user string + fingerprint string +} - certPath := filepath.Join(a.certDir, certName) - keyPath := filepath.Join(a.keyDir, keyName) +func (a *Agent) syncCertPairs(startTime time.Time, rows []database.FindAgentToSyncRow) { + agentMap := make(map[syncAgent][]database.FindAgentToSyncRow) - // open cert and key files - openCert, err := os.Open(certPath) - if err != nil { - return fmt.Errorf("open cert file: %w", err) + for _, row := range rows { + a := syncAgent{ + agentId: row.AgentID, + address: row.Address, + user: row.User, + fingerprint: row.Fingerprint, + } + agentMap[a] = append(agentMap[a], row) } - defer openCert.Close() - openKey, err := os.Open(keyPath) - if err != nil { - return fmt.Errorf("open key file: %w", err) + for agent, certPairs := range agentMap { + err := a.syncSingleAgentCertPairs(startTime, agent, certPairs) + if err != nil { + // This agent sync is allowed to fail without stopping other agent syncs from + // occurring. + Logger.Warn("Agent sync failed", "agent", agent.agentId, "err", err) + } } - defer openKey.Close() +} - hostPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(row.Fingerprint)) +func (a *Agent) syncSingleAgentCertPairs(startTime time.Time, agent syncAgent, rows []database.FindAgentToSyncRow) error { + hostPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(agent.fingerprint)) if err != nil { return fmt.Errorf("failed to parse fingerprint: %w", err) } - client, err := ssh.Dial("tcp", row.Address, &ssh.ClientConfig{ + client, err := ssh.Dial("tcp", agent.address, &ssh.ClientConfig{ Config: ssh.Config{ KeyExchanges: []string{"curve25519-sha256"}, }, - User: row.User, + User: agent.user, Auth: []ssh.AuthMethod{ ssh.PublicKeys(a.sshKey), }, @@ -161,6 +159,51 @@ func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAg return fmt.Errorf("scp client: %w", err) } + for _, row := range rows { + err := a.copySingleCertPair(scpClient, row) + if err != nil { + // This cert sync is allowed to fail without stopping other certs going to the + // same agent from copying. + err = fmt.Errorf("copySingleCertPair: %w", err) + Logger.Warn("Agent certificate sync failed", "agent", row.AgentID, "cert", row.CertID, "not after", row.CertNotAfter, "err", err) + continue + } + } + + // Update last sync to the time when the database request happened. This ensures + // that certificates updated after the database request and before the agent + // syncing are updated properly. + err = a.db.UpdateAgentLastSync(context.Background(), database.UpdateAgentLastSyncParams{ + LastSync: sql.NullTime{Time: startTime, Valid: true}, + ID: agent.agentId, + }) + if err != nil { + return fmt.Errorf("error updating agent last sync: %v", err) + } + + return nil +} + +func (a *Agent) copySingleCertPair(scpClient scp.Client, row database.FindAgentToSyncRow) error { + certName := utils.GetCertFileName(row.CertID) + keyName := utils.GetKeyFileName(row.CertID) + + certPath := filepath.Join(a.certDir, certName) + keyPath := filepath.Join(a.keyDir, keyName) + + // open cert and key files + openCert, err := os.Open(certPath) + if err != nil { + return fmt.Errorf("open cert file: %w", err) + } + defer openCert.Close() + + openKey, err := os.Open(keyPath) + if err != nil { + return fmt.Errorf("open key file: %w", err) + } + defer openKey.Close() + // copy cert and key to agent err = scpClient.CopyFromFile(context.Background(), *openCert, filepath.Join(row.Dir, "certificates", certName), "0600") if err != nil { @@ -171,15 +214,6 @@ func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAg return fmt.Errorf("copy cert file: %w", err) } - // update last sync to the time when the database request happened - err = a.db.UpdateAgentLastSync(context.Background(), database.UpdateAgentLastSyncParams{ - LastSync: sql.NullTime{Time: startTime, Valid: true}, - ID: row.AgentID, - }) - if err != nil { - return fmt.Errorf("error updating agent last sync: %v", err) - } - err = a.db.UpdateAgentCertNotAfter(context.Background(), database.UpdateAgentCertNotAfterParams{ NotAfter: row.CertNotAfter, AgentID: row.AgentID,