diff --git a/agent/agent.go b/agent/agent.go new file mode 100644 index 0000000..07a3351 --- /dev/null +++ b/agent/agent.go @@ -0,0 +1,182 @@ +package agent + +import ( + "context" + "database/sql" + _ "embed" + "fmt" + "github.com/1f349/orchid/database" + "github.com/1f349/orchid/utils" + "github.com/bramvdbogaerde/go-scp" + "golang.org/x/crypto/ssh" + "os" + "path/filepath" + "sync" + "time" +) + +//go:embed agent_readme.md +var agentReadme []byte + +type agentQueries interface { + FindAgentToSync(ctx context.Context) ([]database.FindAgentToSyncRow, error) + UpdateAgentLastSync(ctx context.Context, row database.UpdateAgentLastSyncParams) error + UpdateAgentCertNotAfter(ctx context.Context, arg database.UpdateAgentCertNotAfterParams) error +} + +func NewAgent(wg *sync.WaitGroup, db agentQueries, sshKey ssh.Signer, certDir string, keyDir string) (*Agent, error) { + a := &Agent{ + db: db, + ticker: time.NewTicker(time.Minute * 10), + done: make(chan struct{}), + syncLock: new(sync.Mutex), + sshKey: sshKey, + certDir: certDir, + keyDir: keyDir, + } + + wg.Add(1) + go a.syncRoutine(wg) + return a, nil +} + +type Agent struct { + db agentQueries + ticker *time.Ticker + done chan struct{} + syncLock *sync.Mutex + sshKey ssh.Signer + certDir string + keyDir string +} + +func (a *Agent) Shutdown() { + Logger.Info("Shutting down agent syncing service") + close(a.done) +} + +func (a *Agent) syncRoutine(wg *sync.WaitGroup) { + Logger.Debug("Starting syncRoutine") + + // Upon leaving the function stop the ticker and clear the WaitGroup. + defer func() { + a.ticker.Stop() + Logger.Info("Stopped agent syncing service") + wg.Done() + }() + + for { + select { + case <-a.done: + // Exit if done has closed + return + case <-a.ticker.C: + Logger.Debug("Ticking agent syncing") + + go a.syncCheck() + } + } +} + +func (a *Agent) syncCheck() { + // if the lock is unavailable then ignore this cycle + if !a.syncLock.TryLock() { + return + } + defer a.syncLock.Unlock() + + now := time.Now().UTC() + + actions, err := a.db.FindAgentToSync(context.Background()) + if err != nil { + panic(err) + } + + agentErrs := make(map[int64][]error) + + for _, action := range actions { + err = a.syncSingleAgentCertPair(now, action) + if err != nil { + agentErrs[action.AgentID] = append(agentErrs[action.AgentID], err) + } + } + + for agentId, errs := range agentErrs { + Logger.Warn("Agent sync failed", "agent", agentId, "errs", errs) + } + + // TODO: idk what to do now +} + +func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAgentToSyncRow) error { + certName := utils.GetCertFileName(row.CertID) + keyName := utils.GetKeyFileName(row.CertID) + + certPath := filepath.Join(a.certDir, certName) + keyPath := filepath.Join(a.keyDir, keyName) + + // open cert and key files + openCert, err := os.Open(certPath) + if err != nil { + return fmt.Errorf("open cert file: %w", err) + } + defer openCert.Close() + openKey, err := os.Open(keyPath) + + if err != nil { + return fmt.Errorf("open key file: %w", err) + } + defer openKey.Close() + + hostPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(row.Fingerprint)) + if err != nil { + return fmt.Errorf("failed to parse fingerprint: %w", err) + } + + client, err := ssh.Dial("tcp", row.Address, &ssh.ClientConfig{ + User: row.User, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(a.sshKey), + }, + HostKeyCallback: ssh.FixedHostKey(hostPubKey), + Timeout: time.Second * 30, + }) + if err != nil { + return fmt.Errorf("ssh dial: %w", err) + } + + scpClient, err := scp.NewClientBySSH(client) + if err != nil { + return fmt.Errorf("scp client: %w", err) + } + + // copy cert and key to agent + err = scpClient.CopyFromFile(context.Background(), *openCert, filepath.Join(row.Dir, "certificates", certName), "0600") + if err != nil { + return fmt.Errorf("copy cert file: %w", err) + } + err = scpClient.CopyFromFile(context.Background(), *openKey, filepath.Join(row.Dir, "keys", keyName), "0600") + if err != nil { + return fmt.Errorf("copy cert file: %w", err) + } + + // update last sync to the time when the database request happened + err = a.db.UpdateAgentLastSync(context.Background(), database.UpdateAgentLastSyncParams{ + LastSync: sql.NullTime{Time: startTime, Valid: true}, + ID: row.AgentID, + }) + if err != nil { + return fmt.Errorf("error updating agent last sync: %v", err) + } + + err = a.db.UpdateAgentCertNotAfter(context.Background(), database.UpdateAgentCertNotAfterParams{ + NotAfter: row.CertNotAfter, + AgentID: row.AgentID, + CertID: row.CertID, + }) + if err != nil { + return fmt.Errorf("error updating agent last sync: %v", err) + } + + return nil +} diff --git a/agent/agent_readme.md b/agent/agent_readme.md new file mode 100644 index 0000000..58efe1e --- /dev/null +++ b/agent/agent_readme.md @@ -0,0 +1,5 @@ +# Orchid Agent + +This directory is controlled by Orchid agent configuration. + +Certificates in this directory will be automatically updated when required. diff --git a/agent/agent_test.go b/agent/agent_test.go new file mode 100644 index 0000000..e4e5b11 --- /dev/null +++ b/agent/agent_test.go @@ -0,0 +1,378 @@ +package agent + +import ( + "bufio" + "bytes" + "context" + "crypto/x509/pkix" + "database/sql" + "encoding/binary" + "errors" + "fmt" + "github.com/1f349/orchid/database" + "github.com/1f349/orchid/logger" + "github.com/charmbracelet/log" + "github.com/mrmelon54/certgen" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" + "golang.org/x/crypto/ssh" + "io" + "math/big" + "net" + "net/netip" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +func TestAgentSyncing(t *testing.T) { + logger.Logger.SetLevel(log.DebugLevel) + + if testing.Short() { + t.Skip("Skipping agent syncing tests in short mode") + } + + t.Run("agent syncing test", func(t *testing.T) { + certDir, err := os.MkdirTemp("", "orchid-certs") + assert.NoError(t, err) + keyDir, err := os.MkdirTemp("", "orchid-keys") + assert.NoError(t, err) + + defer func() { + assert.NoError(t, os.RemoveAll(certDir)) + assert.NoError(t, os.RemoveAll(keyDir)) + }() + + _, privKey, err := ed25519.GenerateKey(nil) + if err != nil { + panic(err) + } + sshPrivKey, err := ssh.NewSignerFromKey(privKey) + if err != nil { + panic(err) + } + + agent := &Agent{ + db: &fakeAgentDb{}, + ticker: nil, + done: nil, + syncLock: nil, + sshKey: sshPrivKey, + certDir: certDir, + keyDir: keyDir, + } + + now := time.Now().UTC() + + t.Run("missing cert file", func(t *testing.T) { + err = agent.syncSingleAgentCertPair(now, database.FindAgentToSyncRow{ + AgentID: 1337, + Address: "", + User: "test", + Dir: "~/hello/world", + Fingerprint: "", + CertID: 420, + CertNotAfter: sql.NullTime{Time: now, Valid: true}, + }) + assert.Contains(t, err.Error(), "open cert file:") + assert.Contains(t, err.Error(), "no such file or directory") + }) + + // generate example certificate + tlsCert, err := certgen.MakeServerTls(nil, 2048, pkix.Name{ + Country: []string{"GB"}, + Province: []string{"London"}, + StreetAddress: []string{"221B Baker Street"}, + PostalCode: []string{"NW1 6XE"}, + SerialNumber: "test123456", + CommonName: "orchid-agent-test.local", + }, big.NewInt(1234567899), func(now time.Time) time.Time { + return now.Add(1 * time.Hour) + }, []string{"orchid-agent-test.local"}, []net.IP{ + net.IPv6loopback, + net.IPv4(127, 0, 0, 1), + }) + assert.NoError(t, err) + + err = os.WriteFile(filepath.Join(certDir, "420.cert.pem"), tlsCert.GetCertPem(), 0600) + assert.NoError(t, err) + + t.Run("missing key file", func(t *testing.T) { + err = agent.syncSingleAgentCertPair(now, database.FindAgentToSyncRow{ + AgentID: 1337, + Address: "", + User: "test", + Dir: "~/hello/world", + Fingerprint: "", + CertID: 420, + CertNotAfter: sql.NullTime{Time: now, Valid: true}, + }) + assert.Contains(t, err.Error(), "open key file:") + assert.Contains(t, err.Error(), "no such file or directory") + }) + + err = os.WriteFile(filepath.Join(keyDir, "420.key.pem"), tlsCert.GetKeyPem(), 0600) + assert.NoError(t, err) + + t.Run("successful sync", func(t *testing.T) { + var wg sync.WaitGroup + server := setupFakeSSH(&wg, func(remoteAddrPort netip.AddrPort, remotePubKey ssh.PublicKey) { + println("Attempt agent syncing") + + err = agent.syncSingleAgentCertPair(now, database.FindAgentToSyncRow{ + AgentID: 1337, + Address: remoteAddrPort.String(), + User: "test", + Dir: "~/hello/world", + Fingerprint: string(ssh.MarshalAuthorizedKey(remotePubKey)), + CertID: 420, + CertNotAfter: sql.NullTime{Time: now, Valid: true}, + }) + assert.NoError(t, err) + }) + server.Close() + + println("Waiting for ssh server to exit") + + server.Wait() + }) + }) +} + +func setupFakeSSH(wg *sync.WaitGroup, call func(addrPort netip.AddrPort, pubKey ssh.PublicKey)) *ssh.ServerConn { + pubKey, privKey, err := ed25519.GenerateKey(nil) + if err != nil { + panic(err) + } + sshPubKey, err := ssh.NewPublicKey(pubKey) + if err != nil { + panic(err) + } + sshSigner, err := ssh.NewSignerFromKey(privKey) + if err != nil { + panic(err) + } + + tcp, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv6Loopback(), 0))) + if err != nil { + panic(err) + } + + addrPort := tcp.Addr().(*net.TCPAddr).AddrPort() + + var wg2 sync.WaitGroup + wg2.Add(1) + go func() { + defer wg2.Done() + call(addrPort, sshPubKey) + }() + + tcpConn, err := tcp.AcceptTCP() + if err != nil { + panic(err) + } + + serverConfig := &ssh.ServerConfig{ + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if conn.User() != "test" { + return nil, fmt.Errorf("invalid user") + } + if !conn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr().IsLoopback() { + return nil, fmt.Errorf("invalid remote address") + } + return &ssh.Permissions{}, nil + }, + ServerVersion: "SSH-2.0-OrchidTester", + } + serverConfig.AddHostKey(sshSigner) + + sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, serverConfig) + if err != nil { + return nil + } + + // The incoming Request channel must be serviced. + wg.Add(1) + go func() { + ssh.DiscardRequests(reqs) + wg.Done() + }() + + wg.Add(1) + go func() { + defer wg.Done() + + // Service the incoming channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of a shell, the type is + // "session" and ServerShell may be used to present a simple + // terminal interface. + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + panic(err) + } + + var fullFilePath string + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "shell" request. + wg.Add(1) + go func(in <-chan *ssh.Request) { + for req := range in { + req.Reply(req.Type == "exec", nil) + if req.Type == "exec" { + length := binary.BigEndian.Uint32(req.Payload[:4]) + if len(req.Payload) != int(length)+4 { + panic(fmt.Errorf("invalid exec payload (expected %d but got %d)", length, len(req.Payload))) + } + cmd := string(req.Payload[4:]) + const scpStartStr = "scp -qt \"" + if !strings.HasPrefix(cmd, scpStartStr) { + panic("invalid start") + } + if !strings.HasSuffix(cmd, "\"") { + panic("invalid end") + } + filePath := cmd[len(scpStartStr) : len(cmd)-1] + fmt.Println("Writing file:", filePath) + fullFilePath = filePath + } + } + wg.Done() + }(requests) + + wg.Add(1) + go func() { + defer func() { + channel.Close() + wg.Done() + }() + + var b [1024]byte + read := must(channel.Read(b[:])) + if read < 1 { + panic("invalid read") + } + + fmt.Println(string(b[:read])) + + r := bufio.NewReader(bytes.NewReader(b[:read])) + if readByte(r) != 'C' { + panic("invalid scp command") + } + + fileMode := readN(r, 4) + if string(fileMode) != "0600" { + panic("unexpected file mode") + } + + if readByte(r) != ' ' { + panic("missing space") + } + + fileSizeStr := must(r.ReadString(' ')) + fileSize := must(strconv.Atoi(fileSizeStr[:len(fileSizeStr)-1])) + + fileName := strings.TrimSpace(string(must(io.ReadAll(r)))) + if fileName != filepath.Base(fullFilePath) { + panic(fmt.Errorf("invalid file name (expected \"%s\" from full path \"%s\" but got \"%s\")", filepath.Base(fullFilePath), fullFilePath, fileName)) + } + + if fileName != "420.cert.pem" && fileName != "420.key.pem" { + panic("invalid file name") + } + + channel.Write([]byte{0}) + + buf := new(bytes.Buffer) + _, err := io.CopyN(buf, channel, int64(fileSize)) + if err != nil { + panic("Failed to copy channel") + } + fmt.Println("Copied file with size:", buf.Len()) + fmt.Println(buf.String()) + + if readLastByte(r) != 0x00 { + panic("expected ending null byte") + } + + channel.Write([]byte{0}) + + channel.SendRequest("exit-status", false, binary.BigEndian.AppendUint32(nil, 0)) + }() + } + }() + + wg2.Wait() + + return sshConn +} + +type fakeAgentDb struct{} + +func (f *fakeAgentDb) FindAgentToSync(ctx context.Context) ([]database.FindAgentToSyncRow, error) { + panic("implement me") +} + +func (f *fakeAgentDb) UpdateAgentLastSync(ctx context.Context, arg database.UpdateAgentLastSyncParams) error { + if arg.ID != 1337 { + return fmt.Errorf("invalid agent id") + } + if !arg.LastSync.Valid { + return fmt.Errorf("invalid last sync") + } + return nil +} + +func (f *fakeAgentDb) UpdateAgentCertNotAfter(ctx context.Context, arg database.UpdateAgentCertNotAfterParams) error { + if arg.AgentID != 1337 { + return fmt.Errorf("invalid agent id") + } + if arg.CertID != 420 { + return fmt.Errorf("invalid cert id") + } + if !arg.NotAfter.Valid { + return fmt.Errorf("invalid not after") + } + return nil +} + +func must[T any](t T, err error) T { + if err != nil { + panic(err) + } + return t +} + +func readN(r io.Reader, n int) []byte { + b := make([]byte, n) + _, err := io.ReadFull(r, b) + if err != nil { + panic(err) + } + return b +} + +func readByte(r io.Reader) byte { + b := readN(r, 1) + return b[0] +} + +func readLastByte(r io.Reader) byte { + var b [1]byte + _, err := io.ReadFull(r, b[:]) + if !errors.Is(err, io.EOF) { + panic("expected EOF") + } + return b[0] +} diff --git a/agent/logger.go b/agent/logger.go new file mode 100644 index 0000000..89c6377 --- /dev/null +++ b/agent/logger.go @@ -0,0 +1,5 @@ +package agent + +import "github.com/1f349/orchid/logger" + +var Logger = logger.Logger.WithPrefix("Orchid Agent") diff --git a/cmd/orchid/agent.go b/cmd/orchid/agent.go new file mode 100644 index 0000000..94330f8 --- /dev/null +++ b/cmd/orchid/agent.go @@ -0,0 +1,65 @@ +package main + +import ( + "encoding/pem" + "github.com/1f349/orchid/logger" + "golang.org/x/crypto/ed25519" + "golang.org/x/crypto/ssh" + "os" + "path/filepath" +) + +// loadAgentPrivateKey simply attempts to load the agent ssh private key and if +// it is missing generates a new key +func loadAgentPrivateKey(wd string) ssh.Signer { + // load or create a key for orchid agent + agentPrivKeyPath := filepath.Join(wd, "agent_id_ed25519") + agentPubKeyPath := filepath.Join(wd, "agent_id_ed25519.pub") + agentPrivKeyBytes, err := os.ReadFile(agentPrivKeyPath) + switch { + case err == nil: + break + case os.IsNotExist(err): + pubKey, privKey, err := ed25519.GenerateKey(nil) + if err != nil { + logger.Logger.Fatal("Failed to generate agent private key", "err", err) + } + marshalPrivKey, err := ssh.MarshalPrivateKey(privKey, "orchid-agent") + if err != nil { + logger.Logger.Fatal("Failed to encode private key", "err", err) + } + agentPrivKeyBytes = pem.EncodeToMemory(marshalPrivKey) + + // public key + sshPubKey, err := ssh.NewPublicKey(pubKey) + if err != nil { + logger.Logger.Fatal("Failed to encode public key", "err", err) + } + marshalPubKey := ssh.MarshalAuthorizedKey(sshPubKey) + if err != nil { + logger.Logger.Fatal("Failed to encode public key", "err", err) + } + + // write to files + err = os.WriteFile(agentPrivKeyPath, agentPrivKeyBytes, 0600) + if err != nil { + logger.Logger.Fatal("Failed to write agent private key", "path", agentPrivKeyPath, "err", err) + } + err = os.WriteFile(agentPubKeyPath, marshalPubKey, 0644) + if err != nil { + logger.Logger.Fatal("Failed to write agent public key", "path", agentPubKeyPath, "err", err) + } + + // we can continue now + break + case err != nil: + logger.Logger.Fatal("Failed to read agent private key", "path", agentPrivKeyPath, "err", err) + } + + privKey, err := ssh.ParsePrivateKey(agentPrivKeyBytes) + if err != nil { + logger.Logger.Fatal("Failed to parse agent private key file", "path", agentPrivKeyPath, "err", err) + } + + return privKey +} diff --git a/cmd/orchid/conf.go b/cmd/orchid/conf.go index f60d012..2377d1d 100644 --- a/cmd/orchid/conf.go +++ b/cmd/orchid/conf.go @@ -3,10 +3,11 @@ package main import "github.com/1f349/orchid/renewal" type startUpConfig struct { - Listen string `yaml:"listen"` - Acme acmeConfig `yaml:"acme"` - LE renewal.LetsEncryptConfig `yaml:"letsEncrypt"` - Domains []string `yaml:"domains"` + Listen string `yaml:"listen"` + Acme acmeConfig `yaml:"acme"` + LE renewal.LetsEncryptConfig `yaml:"letsEncrypt"` + Domains []string `yaml:"domains"` + AgentKey string `yaml:"agentKey"` } type acmeConfig struct { diff --git a/cmd/orchid/main.go b/cmd/orchid/main.go index 6ab431c..fabd35c 100644 --- a/cmd/orchid/main.go +++ b/cmd/orchid/main.go @@ -4,6 +4,7 @@ import ( "flag" "github.com/1f349/mjwt" "github.com/1f349/orchid" + "github.com/1f349/orchid/agent" httpAcme "github.com/1f349/orchid/http-acme" "github.com/1f349/orchid/logger" "github.com/1f349/orchid/renewal" @@ -78,7 +79,7 @@ func runDaemon(wd string, conf startUpConfig) { certDir := filepath.Join(wd, "renewal-certs") keyDir := filepath.Join(wd, "renewal-keys") - wg := &sync.WaitGroup{} + wg := new(sync.WaitGroup) acmeProv, err := httpAcme.NewHttpAcmeProvider(filepath.Join(wd, "tokens.yml"), conf.Acme.PresentUrl, conf.Acme.CleanUpUrl, conf.Acme.RefreshUrl) if err != nil { logger.Logger.Fatal("HTTP Acme Error", "err", err) @@ -87,6 +88,10 @@ func runDaemon(wd string, conf startUpConfig) { if err != nil { logger.Logger.Fatal("Service Error", "err", err) } + certAgent, err := agent.NewAgent(wg, db, loadAgentPrivateKey(wd), certDir, keyDir) + if err != nil { + logger.Logger.Fatal("Failed to create agent", "err", err) + } srv := servers.NewApiServer(conf.Listen, db, mJwtVerify, conf.Domains) logger.Logger.Info("Starting API server", "listen", srv.Addr) go utils.RunBackgroundHttp(logger.Logger, srv) @@ -94,6 +99,7 @@ func runDaemon(wd string, conf startUpConfig) { exitReload.ExitReload("Violet", func() {}, func() { // stop renewal service and api server renewalService.Shutdown() + certAgent.Shutdown() srv.Close() }) } diff --git a/database/agent.sql.go b/database/agent.sql.go new file mode 100644 index 0000000..78b219f --- /dev/null +++ b/database/agent.sql.go @@ -0,0 +1,98 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.28.0 +// source: agent.sql + +package database + +import ( + "context" + "database/sql" +) + +const findAgentToSync = `-- name: FindAgentToSync :many +SELECT agents.id as agent_id, agents.address, agents.user, agents.dir, agents.fingerprint, cert.id as cert_id, cert.not_after as cert_not_after +FROM agents + INNER JOIN agent_certs + ON agent_certs.agent_id = agents.id + INNER JOIN certificates AS cert + ON cert.id = agent_certs.cert_id +WHERE (agents.last_sync IS NULL OR agents.last_sync < cert.updated_at) + AND (agent_certs.not_after IS NULL OR agent_certs.not_after IS NOT cert.not_after) +ORDER BY agents.last_sync NULLS FIRST +` + +type FindAgentToSyncRow struct { + AgentID int64 `json:"agent_id"` + Address string `json:"address"` + User string `json:"user"` + Dir string `json:"dir"` + Fingerprint string `json:"fingerprint"` + CertID int64 `json:"cert_id"` + CertNotAfter sql.NullTime `json:"cert_not_after"` +} + +func (q *Queries) FindAgentToSync(ctx context.Context) ([]FindAgentToSyncRow, error) { + rows, err := q.db.QueryContext(ctx, findAgentToSync) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FindAgentToSyncRow + for rows.Next() { + var i FindAgentToSyncRow + if err := rows.Scan( + &i.AgentID, + &i.Address, + &i.User, + &i.Dir, + &i.Fingerprint, + &i.CertID, + &i.CertNotAfter, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateAgentCertNotAfter = `-- name: UpdateAgentCertNotAfter :exec +UPDATE agent_certs +SET not_after = ? +WHERE agent_id = ? + AND cert_id = ? +` + +type UpdateAgentCertNotAfterParams struct { + NotAfter sql.NullTime `json:"not_after"` + AgentID int64 `json:"agent_id"` + CertID int64 `json:"cert_id"` +} + +func (q *Queries) UpdateAgentCertNotAfter(ctx context.Context, arg UpdateAgentCertNotAfterParams) error { + _, err := q.db.ExecContext(ctx, updateAgentCertNotAfter, arg.NotAfter, arg.AgentID, arg.CertID) + return err +} + +const updateAgentLastSync = `-- name: UpdateAgentLastSync :exec +UPDATE agents +SET last_sync = ? +WHERE agents.id = ? +` + +type UpdateAgentLastSyncParams struct { + LastSync sql.NullTime `json:"last_sync"` + ID int64 `json:"id"` +} + +func (q *Queries) UpdateAgentLastSync(ctx context.Context, arg UpdateAgentLastSyncParams) error { + _, err := q.db.ExecContext(ctx, updateAgentLastSync, arg.LastSync, arg.ID) + return err +} diff --git a/database/migrations/20250129211955_agent.down.sql b/database/migrations/20250129211955_agent.down.sql new file mode 100644 index 0000000..e69de29 diff --git a/database/migrations/20250129211955_agent.up.sql b/database/migrations/20250129211955_agent.up.sql new file mode 100644 index 0000000..9c6b4fc --- /dev/null +++ b/database/migrations/20250129211955_agent.up.sql @@ -0,0 +1,21 @@ +CREATE TABLE IF NOT EXISTS agents +( + id INTEGER PRIMARY KEY AUTOINCREMENT, + address TEXT NOT NULL, + user TEXT NOT NULL, + dir TEXT NOT NULL, + fingerprint TEXT NOT NULL, + last_sync DATETIME NULL DEFAULT NULL +); + +CREATE TABLE IF NOT EXISTS agent_certs +( + agent_id INTEGER NOT NULL, + cert_id INTEGER NOT NULL, + not_after INTEGER NULL DEFAULT NULL, + + PRIMARY KEY (agent_id, cert_id), + + FOREIGN KEY (agent_id) REFERENCES agents (id), + FOREIGN KEY (cert_id) REFERENCES certificates (id) +); diff --git a/database/models.go b/database/models.go index 954c8dd..4c3309f 100644 --- a/database/models.go +++ b/database/models.go @@ -9,6 +9,21 @@ import ( "time" ) +type Agent struct { + ID int64 `json:"id"` + Address string `json:"address"` + User string `json:"user"` + Dir string `json:"dir"` + Fingerprint string `json:"fingerprint"` + LastSync sql.NullTime `json:"last_sync"` +} + +type AgentCert struct { + AgentID int64 `json:"agent_id"` + CertID int64 `json:"cert_id"` + NotAfter sql.NullTime `json:"not_after"` +} + type Certificate struct { ID int64 `json:"id"` Owner string `json:"owner"` @@ -16,9 +31,9 @@ type Certificate struct { AutoRenew bool `json:"auto_renew"` Active bool `json:"active"` Renewing bool `json:"renewing"` - NotAfter sql.NullTime `json:"not_after"` UpdatedAt time.Time `json:"updated_at"` TempParent sql.NullInt64 `json:"temp_parent"` + NotAfter sql.NullTime `json:"not_after"` RenewRetry sql.NullTime `json:"renew_retry"` } diff --git a/database/queries/agent.sql b/database/queries/agent.sql new file mode 100644 index 0000000..6f98c7f --- /dev/null +++ b/database/queries/agent.sql @@ -0,0 +1,21 @@ +-- name: FindAgentToSync :many +SELECT agents.id as agent_id, agents.address, agents.user, agents.dir, agents.fingerprint, cert.id as cert_id, cert.not_after as cert_not_after +FROM agents + INNER JOIN agent_certs + ON agent_certs.agent_id = agents.id + INNER JOIN certificates AS cert + ON cert.id = agent_certs.cert_id +WHERE (agents.last_sync IS NULL OR agents.last_sync < cert.updated_at) + AND (agent_certs.not_after IS NULL OR agent_certs.not_after IS NOT cert.not_after) +ORDER BY agents.last_sync NULLS FIRST; + +-- name: UpdateAgentLastSync :exec +UPDATE agents +SET last_sync = ? +WHERE agents.id = ?; + +-- name: UpdateAgentCertNotAfter :exec +UPDATE agent_certs +SET not_after = ? +WHERE agent_id = ? + AND cert_id = ?; diff --git a/go.mod b/go.mod index b12e0f9..c9c3da4 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/1f349/mjwt v0.4.1 github.com/1f349/violet v0.0.14 github.com/AlecAivazis/survey/v2 v2.3.7 + github.com/bramvdbogaerde/go-scp v1.5.0 github.com/charmbracelet/log v0.4.0 github.com/go-acme/lego/v4 v4.21.0 github.com/golang-jwt/jwt/v4 v4.5.1 @@ -17,6 +18,7 @@ require ( github.com/mrmelon54/certgen v0.0.2 github.com/mrmelon54/exit-reload v0.0.2 github.com/stretchr/testify v1.10.0 + golang.org/x/crypto v0.32.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -46,7 +48,6 @@ require ( github.com/rivo/uniseg v0.4.7 // indirect github.com/spf13/afero v1.12.0 // indirect go.uber.org/atomic v1.11.0 // indirect - golang.org/x/crypto v0.32.0 // indirect golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c // indirect golang.org/x/mod v0.22.0 // indirect golang.org/x/net v0.34.0 // indirect diff --git a/go.sum b/go.sum index 82b683c..0a98ab5 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiE github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA= github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4= +github.com/bramvdbogaerde/go-scp v1.5.0 h1:a9BinAjTfQh273eh7vd3qUgmBC+bx+3TRDtkZWmIpzM= +github.com/bramvdbogaerde/go-scp v1.5.0/go.mod h1:on2aH5AxaFb2G0N5Vsdy6B0Ml7k9HuHSwfo1y0QzAbQ= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/charmbracelet/lipgloss v1.0.0 h1:O7VkGDvqEdGi93X+DeqsQ7PKHDgtQfF8j8/O2qFMQNg= diff --git a/sqlc.yaml b/sqlc.yaml index 3661da7..390504d 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -13,3 +13,7 @@ sql: go_type: database/sql.NullTime - column: certificates.renew_retry go_type: database/sql.NullTime + - column: agents.last_sync + go_type: database/sql.NullTime + - column: agent_certs.not_after + go_type: database/sql.NullTime