mirror of
https://github.com/1f349/orchid.git
synced 2025-03-10 05:53:11 +00:00
Refactor agent to copy all files for an agent dir at the same time
This commit is contained in:
parent
a23155a827
commit
fc2f8e34c6
116
agent/agent.go
116
agent/agent.go
@ -99,52 +99,50 @@ func (a *Agent) syncCheck() {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
agentErrs := make(map[int64][]error)
|
||||
|
||||
for _, action := range actions {
|
||||
err = a.syncSingleAgentCertPair(now, action)
|
||||
if err != nil {
|
||||
agentErrs[action.AgentID] = append(agentErrs[action.AgentID], err)
|
||||
}
|
||||
}
|
||||
|
||||
for agentId, errs := range agentErrs {
|
||||
Logger.Warn("Agent sync failed", "agent", agentId, "errs", errs)
|
||||
}
|
||||
|
||||
// TODO: idk what to do now
|
||||
a.syncCertPairs(now, actions)
|
||||
}
|
||||
|
||||
func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAgentToSyncRow) error {
|
||||
certName := utils.GetCertFileName(row.CertID)
|
||||
keyName := utils.GetKeyFileName(row.CertID)
|
||||
type syncAgent struct {
|
||||
agentId int64
|
||||
address string
|
||||
user string
|
||||
fingerprint string
|
||||
}
|
||||
|
||||
certPath := filepath.Join(a.certDir, certName)
|
||||
keyPath := filepath.Join(a.keyDir, keyName)
|
||||
func (a *Agent) syncCertPairs(startTime time.Time, rows []database.FindAgentToSyncRow) {
|
||||
agentMap := make(map[syncAgent][]database.FindAgentToSyncRow)
|
||||
|
||||
// open cert and key files
|
||||
openCert, err := os.Open(certPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open cert file: %w", err)
|
||||
for _, row := range rows {
|
||||
a := syncAgent{
|
||||
agentId: row.AgentID,
|
||||
address: row.Address,
|
||||
user: row.User,
|
||||
fingerprint: row.Fingerprint,
|
||||
}
|
||||
defer openCert.Close()
|
||||
openKey, err := os.Open(keyPath)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("open key file: %w", err)
|
||||
agentMap[a] = append(agentMap[a], row)
|
||||
}
|
||||
defer openKey.Close()
|
||||
|
||||
hostPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(row.Fingerprint))
|
||||
for agent, certPairs := range agentMap {
|
||||
err := a.syncSingleAgentCertPairs(startTime, agent, certPairs)
|
||||
if err != nil {
|
||||
// This agent sync is allowed to fail without stopping other agent syncs from
|
||||
// occurring.
|
||||
Logger.Warn("Agent sync failed", "agent", agent.agentId, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) syncSingleAgentCertPairs(startTime time.Time, agent syncAgent, rows []database.FindAgentToSyncRow) error {
|
||||
hostPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(agent.fingerprint))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse fingerprint: %w", err)
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", row.Address, &ssh.ClientConfig{
|
||||
client, err := ssh.Dial("tcp", agent.address, &ssh.ClientConfig{
|
||||
Config: ssh.Config{
|
||||
KeyExchanges: []string{"curve25519-sha256"},
|
||||
},
|
||||
User: row.User,
|
||||
User: agent.user,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(a.sshKey),
|
||||
},
|
||||
@ -161,6 +159,51 @@ func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAg
|
||||
return fmt.Errorf("scp client: %w", err)
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
err := a.copySingleCertPair(scpClient, row)
|
||||
if err != nil {
|
||||
// This cert sync is allowed to fail without stopping other certs going to the
|
||||
// same agent from copying.
|
||||
err = fmt.Errorf("copySingleCertPair: %w", err)
|
||||
Logger.Warn("Agent certificate sync failed", "agent", row.AgentID, "cert", row.CertID, "not after", row.CertNotAfter, "err", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Update last sync to the time when the database request happened. This ensures
|
||||
// that certificates updated after the database request and before the agent
|
||||
// syncing are updated properly.
|
||||
err = a.db.UpdateAgentLastSync(context.Background(), database.UpdateAgentLastSyncParams{
|
||||
LastSync: sql.NullTime{Time: startTime, Valid: true},
|
||||
ID: agent.agentId,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating agent last sync: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Agent) copySingleCertPair(scpClient scp.Client, row database.FindAgentToSyncRow) error {
|
||||
certName := utils.GetCertFileName(row.CertID)
|
||||
keyName := utils.GetKeyFileName(row.CertID)
|
||||
|
||||
certPath := filepath.Join(a.certDir, certName)
|
||||
keyPath := filepath.Join(a.keyDir, keyName)
|
||||
|
||||
// open cert and key files
|
||||
openCert, err := os.Open(certPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open cert file: %w", err)
|
||||
}
|
||||
defer openCert.Close()
|
||||
|
||||
openKey, err := os.Open(keyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open key file: %w", err)
|
||||
}
|
||||
defer openKey.Close()
|
||||
|
||||
// copy cert and key to agent
|
||||
err = scpClient.CopyFromFile(context.Background(), *openCert, filepath.Join(row.Dir, "certificates", certName), "0600")
|
||||
if err != nil {
|
||||
@ -171,15 +214,6 @@ func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAg
|
||||
return fmt.Errorf("copy cert file: %w", err)
|
||||
}
|
||||
|
||||
// update last sync to the time when the database request happened
|
||||
err = a.db.UpdateAgentLastSync(context.Background(), database.UpdateAgentLastSyncParams{
|
||||
LastSync: sql.NullTime{Time: startTime, Valid: true},
|
||||
ID: row.AgentID,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating agent last sync: %v", err)
|
||||
}
|
||||
|
||||
err = a.db.UpdateAgentCertNotAfter(context.Background(), database.UpdateAgentCertNotAfterParams{
|
||||
NotAfter: row.CertNotAfter,
|
||||
AgentID: row.AgentID,
|
||||
|
Loading…
x
Reference in New Issue
Block a user