mirror of
https://github.com/1f349/orchid.git
synced 2025-03-10 14:03:07 +00:00
Refactor agent to copy all files for an agent dir at the same time
This commit is contained in:
parent
a23155a827
commit
fc2f8e34c6
114
agent/agent.go
114
agent/agent.go
@ -99,52 +99,50 @@ func (a *Agent) syncCheck() {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
agentErrs := make(map[int64][]error)
|
a.syncCertPairs(now, actions)
|
||||||
|
|
||||||
for _, action := range actions {
|
|
||||||
err = a.syncSingleAgentCertPair(now, action)
|
|
||||||
if err != nil {
|
|
||||||
agentErrs[action.AgentID] = append(agentErrs[action.AgentID], err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for agentId, errs := range agentErrs {
|
|
||||||
Logger.Warn("Agent sync failed", "agent", agentId, "errs", errs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: idk what to do now
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAgentToSyncRow) error {
|
type syncAgent struct {
|
||||||
certName := utils.GetCertFileName(row.CertID)
|
agentId int64
|
||||||
keyName := utils.GetKeyFileName(row.CertID)
|
address string
|
||||||
|
user string
|
||||||
|
fingerprint string
|
||||||
|
}
|
||||||
|
|
||||||
certPath := filepath.Join(a.certDir, certName)
|
func (a *Agent) syncCertPairs(startTime time.Time, rows []database.FindAgentToSyncRow) {
|
||||||
keyPath := filepath.Join(a.keyDir, keyName)
|
agentMap := make(map[syncAgent][]database.FindAgentToSyncRow)
|
||||||
|
|
||||||
// open cert and key files
|
for _, row := range rows {
|
||||||
openCert, err := os.Open(certPath)
|
a := syncAgent{
|
||||||
if err != nil {
|
agentId: row.AgentID,
|
||||||
return fmt.Errorf("open cert file: %w", err)
|
address: row.Address,
|
||||||
|
user: row.User,
|
||||||
|
fingerprint: row.Fingerprint,
|
||||||
|
}
|
||||||
|
agentMap[a] = append(agentMap[a], row)
|
||||||
}
|
}
|
||||||
defer openCert.Close()
|
|
||||||
openKey, err := os.Open(keyPath)
|
|
||||||
|
|
||||||
if err != nil {
|
for agent, certPairs := range agentMap {
|
||||||
return fmt.Errorf("open key file: %w", err)
|
err := a.syncSingleAgentCertPairs(startTime, agent, certPairs)
|
||||||
|
if err != nil {
|
||||||
|
// This agent sync is allowed to fail without stopping other agent syncs from
|
||||||
|
// occurring.
|
||||||
|
Logger.Warn("Agent sync failed", "agent", agent.agentId, "err", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
defer openKey.Close()
|
}
|
||||||
|
|
||||||
hostPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(row.Fingerprint))
|
func (a *Agent) syncSingleAgentCertPairs(startTime time.Time, agent syncAgent, rows []database.FindAgentToSyncRow) error {
|
||||||
|
hostPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(agent.fingerprint))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse fingerprint: %w", err)
|
return fmt.Errorf("failed to parse fingerprint: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := ssh.Dial("tcp", row.Address, &ssh.ClientConfig{
|
client, err := ssh.Dial("tcp", agent.address, &ssh.ClientConfig{
|
||||||
Config: ssh.Config{
|
Config: ssh.Config{
|
||||||
KeyExchanges: []string{"curve25519-sha256"},
|
KeyExchanges: []string{"curve25519-sha256"},
|
||||||
},
|
},
|
||||||
User: row.User,
|
User: agent.user,
|
||||||
Auth: []ssh.AuthMethod{
|
Auth: []ssh.AuthMethod{
|
||||||
ssh.PublicKeys(a.sshKey),
|
ssh.PublicKeys(a.sshKey),
|
||||||
},
|
},
|
||||||
@ -161,6 +159,51 @@ func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAg
|
|||||||
return fmt.Errorf("scp client: %w", err)
|
return fmt.Errorf("scp client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, row := range rows {
|
||||||
|
err := a.copySingleCertPair(scpClient, row)
|
||||||
|
if err != nil {
|
||||||
|
// This cert sync is allowed to fail without stopping other certs going to the
|
||||||
|
// same agent from copying.
|
||||||
|
err = fmt.Errorf("copySingleCertPair: %w", err)
|
||||||
|
Logger.Warn("Agent certificate sync failed", "agent", row.AgentID, "cert", row.CertID, "not after", row.CertNotAfter, "err", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update last sync to the time when the database request happened. This ensures
|
||||||
|
// that certificates updated after the database request and before the agent
|
||||||
|
// syncing are updated properly.
|
||||||
|
err = a.db.UpdateAgentLastSync(context.Background(), database.UpdateAgentLastSyncParams{
|
||||||
|
LastSync: sql.NullTime{Time: startTime, Valid: true},
|
||||||
|
ID: agent.agentId,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error updating agent last sync: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) copySingleCertPair(scpClient scp.Client, row database.FindAgentToSyncRow) error {
|
||||||
|
certName := utils.GetCertFileName(row.CertID)
|
||||||
|
keyName := utils.GetKeyFileName(row.CertID)
|
||||||
|
|
||||||
|
certPath := filepath.Join(a.certDir, certName)
|
||||||
|
keyPath := filepath.Join(a.keyDir, keyName)
|
||||||
|
|
||||||
|
// open cert and key files
|
||||||
|
openCert, err := os.Open(certPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open cert file: %w", err)
|
||||||
|
}
|
||||||
|
defer openCert.Close()
|
||||||
|
|
||||||
|
openKey, err := os.Open(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open key file: %w", err)
|
||||||
|
}
|
||||||
|
defer openKey.Close()
|
||||||
|
|
||||||
// copy cert and key to agent
|
// copy cert and key to agent
|
||||||
err = scpClient.CopyFromFile(context.Background(), *openCert, filepath.Join(row.Dir, "certificates", certName), "0600")
|
err = scpClient.CopyFromFile(context.Background(), *openCert, filepath.Join(row.Dir, "certificates", certName), "0600")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -171,15 +214,6 @@ func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAg
|
|||||||
return fmt.Errorf("copy cert file: %w", err)
|
return fmt.Errorf("copy cert file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// update last sync to the time when the database request happened
|
|
||||||
err = a.db.UpdateAgentLastSync(context.Background(), database.UpdateAgentLastSyncParams{
|
|
||||||
LastSync: sql.NullTime{Time: startTime, Valid: true},
|
|
||||||
ID: row.AgentID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error updating agent last sync: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = a.db.UpdateAgentCertNotAfter(context.Background(), database.UpdateAgentCertNotAfterParams{
|
err = a.db.UpdateAgentCertNotAfter(context.Background(), database.UpdateAgentCertNotAfterParams{
|
||||||
NotAfter: row.CertNotAfter,
|
NotAfter: row.CertNotAfter,
|
||||||
AgentID: row.AgentID,
|
AgentID: row.AgentID,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user