mirror of
https://github.com/1f349/orchid.git
synced 2025-02-05 05:56:40 +00:00
Add agent for certificate syncing
This commit is contained in:
parent
7e70331179
commit
939875ca4c
182
agent/agent.go
Normal file
182
agent/agent.go
Normal file
@ -0,0 +1,182 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"github.com/1f349/orchid/database"
|
||||
"github.com/1f349/orchid/utils"
|
||||
"github.com/bramvdbogaerde/go-scp"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
//go:embed agent_readme.md
|
||||
var agentReadme []byte
|
||||
|
||||
type agentQueries interface {
|
||||
FindAgentToSync(ctx context.Context) ([]database.FindAgentToSyncRow, error)
|
||||
UpdateAgentLastSync(ctx context.Context, row database.UpdateAgentLastSyncParams) error
|
||||
UpdateAgentCertNotAfter(ctx context.Context, arg database.UpdateAgentCertNotAfterParams) error
|
||||
}
|
||||
|
||||
func NewAgent(wg *sync.WaitGroup, db agentQueries, sshKey ssh.Signer, certDir string, keyDir string) (*Agent, error) {
|
||||
a := &Agent{
|
||||
db: db,
|
||||
ticker: time.NewTicker(time.Minute * 10),
|
||||
done: make(chan struct{}),
|
||||
syncLock: new(sync.Mutex),
|
||||
sshKey: sshKey,
|
||||
certDir: certDir,
|
||||
keyDir: keyDir,
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go a.syncRoutine(wg)
|
||||
return a, nil
|
||||
}
|
||||
|
||||
type Agent struct {
|
||||
db agentQueries
|
||||
ticker *time.Ticker
|
||||
done chan struct{}
|
||||
syncLock *sync.Mutex
|
||||
sshKey ssh.Signer
|
||||
certDir string
|
||||
keyDir string
|
||||
}
|
||||
|
||||
func (a *Agent) Shutdown() {
|
||||
Logger.Info("Shutting down agent syncing service")
|
||||
close(a.done)
|
||||
}
|
||||
|
||||
func (a *Agent) syncRoutine(wg *sync.WaitGroup) {
|
||||
Logger.Debug("Starting syncRoutine")
|
||||
|
||||
// Upon leaving the function stop the ticker and clear the WaitGroup.
|
||||
defer func() {
|
||||
a.ticker.Stop()
|
||||
Logger.Info("Stopped agent syncing service")
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-a.done:
|
||||
// Exit if done has closed
|
||||
return
|
||||
case <-a.ticker.C:
|
||||
Logger.Debug("Ticking agent syncing")
|
||||
|
||||
go a.syncCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) syncCheck() {
|
||||
// if the lock is unavailable then ignore this cycle
|
||||
if !a.syncLock.TryLock() {
|
||||
return
|
||||
}
|
||||
defer a.syncLock.Unlock()
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
actions, err := a.db.FindAgentToSync(context.Background())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
agentErrs := make(map[int64][]error)
|
||||
|
||||
for _, action := range actions {
|
||||
err = a.syncSingleAgentCertPair(now, action)
|
||||
if err != nil {
|
||||
agentErrs[action.AgentID] = append(agentErrs[action.AgentID], err)
|
||||
}
|
||||
}
|
||||
|
||||
for agentId, errs := range agentErrs {
|
||||
Logger.Warn("Agent sync failed", "agent", agentId, "errs", errs)
|
||||
}
|
||||
|
||||
// TODO: idk what to do now
|
||||
}
|
||||
|
||||
func (a *Agent) syncSingleAgentCertPair(startTime time.Time, row database.FindAgentToSyncRow) error {
|
||||
certName := utils.GetCertFileName(row.CertID)
|
||||
keyName := utils.GetKeyFileName(row.CertID)
|
||||
|
||||
certPath := filepath.Join(a.certDir, certName)
|
||||
keyPath := filepath.Join(a.keyDir, keyName)
|
||||
|
||||
// open cert and key files
|
||||
openCert, err := os.Open(certPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open cert file: %w", err)
|
||||
}
|
||||
defer openCert.Close()
|
||||
openKey, err := os.Open(keyPath)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("open key file: %w", err)
|
||||
}
|
||||
defer openKey.Close()
|
||||
|
||||
hostPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(row.Fingerprint))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse fingerprint: %w", err)
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", row.Address, &ssh.ClientConfig{
|
||||
User: row.User,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(a.sshKey),
|
||||
},
|
||||
HostKeyCallback: ssh.FixedHostKey(hostPubKey),
|
||||
Timeout: time.Second * 30,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ssh dial: %w", err)
|
||||
}
|
||||
|
||||
scpClient, err := scp.NewClientBySSH(client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("scp client: %w", err)
|
||||
}
|
||||
|
||||
// copy cert and key to agent
|
||||
err = scpClient.CopyFromFile(context.Background(), *openCert, filepath.Join(row.Dir, "certificates", certName), "0600")
|
||||
if err != nil {
|
||||
return fmt.Errorf("copy cert file: %w", err)
|
||||
}
|
||||
err = scpClient.CopyFromFile(context.Background(), *openKey, filepath.Join(row.Dir, "keys", keyName), "0600")
|
||||
if err != nil {
|
||||
return fmt.Errorf("copy cert file: %w", err)
|
||||
}
|
||||
|
||||
// update last sync to the time when the database request happened
|
||||
err = a.db.UpdateAgentLastSync(context.Background(), database.UpdateAgentLastSyncParams{
|
||||
LastSync: sql.NullTime{Time: startTime, Valid: true},
|
||||
ID: row.AgentID,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating agent last sync: %v", err)
|
||||
}
|
||||
|
||||
err = a.db.UpdateAgentCertNotAfter(context.Background(), database.UpdateAgentCertNotAfterParams{
|
||||
NotAfter: row.CertNotAfter,
|
||||
AgentID: row.AgentID,
|
||||
CertID: row.CertID,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating agent last sync: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
5
agent/agent_readme.md
Normal file
5
agent/agent_readme.md
Normal file
@ -0,0 +1,5 @@
|
||||
# Orchid Agent
|
||||
|
||||
This directory is controlled by Orchid agent configuration.
|
||||
|
||||
Certificates in this directory will be automatically updated when required.
|
378
agent/agent_test.go
Normal file
378
agent/agent_test.go
Normal file
@ -0,0 +1,378 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509/pkix"
|
||||
"database/sql"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/1f349/orchid/database"
|
||||
"github.com/1f349/orchid/logger"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/mrmelon54/certgen"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAgentSyncing(t *testing.T) {
|
||||
logger.Logger.SetLevel(log.DebugLevel)
|
||||
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping agent syncing tests in short mode")
|
||||
}
|
||||
|
||||
t.Run("agent syncing test", func(t *testing.T) {
|
||||
certDir, err := os.MkdirTemp("", "orchid-certs")
|
||||
assert.NoError(t, err)
|
||||
keyDir, err := os.MkdirTemp("", "orchid-keys")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
assert.NoError(t, os.RemoveAll(certDir))
|
||||
assert.NoError(t, os.RemoveAll(keyDir))
|
||||
}()
|
||||
|
||||
_, privKey, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
sshPrivKey, err := ssh.NewSignerFromKey(privKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
agent := &Agent{
|
||||
db: &fakeAgentDb{},
|
||||
ticker: nil,
|
||||
done: nil,
|
||||
syncLock: nil,
|
||||
sshKey: sshPrivKey,
|
||||
certDir: certDir,
|
||||
keyDir: keyDir,
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
t.Run("missing cert file", func(t *testing.T) {
|
||||
err = agent.syncSingleAgentCertPair(now, database.FindAgentToSyncRow{
|
||||
AgentID: 1337,
|
||||
Address: "",
|
||||
User: "test",
|
||||
Dir: "~/hello/world",
|
||||
Fingerprint: "",
|
||||
CertID: 420,
|
||||
CertNotAfter: sql.NullTime{Time: now, Valid: true},
|
||||
})
|
||||
assert.Contains(t, err.Error(), "open cert file:")
|
||||
assert.Contains(t, err.Error(), "no such file or directory")
|
||||
})
|
||||
|
||||
// generate example certificate
|
||||
tlsCert, err := certgen.MakeServerTls(nil, 2048, pkix.Name{
|
||||
Country: []string{"GB"},
|
||||
Province: []string{"London"},
|
||||
StreetAddress: []string{"221B Baker Street"},
|
||||
PostalCode: []string{"NW1 6XE"},
|
||||
SerialNumber: "test123456",
|
||||
CommonName: "orchid-agent-test.local",
|
||||
}, big.NewInt(1234567899), func(now time.Time) time.Time {
|
||||
return now.Add(1 * time.Hour)
|
||||
}, []string{"orchid-agent-test.local"}, []net.IP{
|
||||
net.IPv6loopback,
|
||||
net.IPv4(127, 0, 0, 1),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(certDir, "420.cert.pem"), tlsCert.GetCertPem(), 0600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("missing key file", func(t *testing.T) {
|
||||
err = agent.syncSingleAgentCertPair(now, database.FindAgentToSyncRow{
|
||||
AgentID: 1337,
|
||||
Address: "",
|
||||
User: "test",
|
||||
Dir: "~/hello/world",
|
||||
Fingerprint: "",
|
||||
CertID: 420,
|
||||
CertNotAfter: sql.NullTime{Time: now, Valid: true},
|
||||
})
|
||||
assert.Contains(t, err.Error(), "open key file:")
|
||||
assert.Contains(t, err.Error(), "no such file or directory")
|
||||
})
|
||||
|
||||
err = os.WriteFile(filepath.Join(keyDir, "420.key.pem"), tlsCert.GetKeyPem(), 0600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("successful sync", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
server := setupFakeSSH(&wg, func(remoteAddrPort netip.AddrPort, remotePubKey ssh.PublicKey) {
|
||||
println("Attempt agent syncing")
|
||||
|
||||
err = agent.syncSingleAgentCertPair(now, database.FindAgentToSyncRow{
|
||||
AgentID: 1337,
|
||||
Address: remoteAddrPort.String(),
|
||||
User: "test",
|
||||
Dir: "~/hello/world",
|
||||
Fingerprint: string(ssh.MarshalAuthorizedKey(remotePubKey)),
|
||||
CertID: 420,
|
||||
CertNotAfter: sql.NullTime{Time: now, Valid: true},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
server.Close()
|
||||
|
||||
println("Waiting for ssh server to exit")
|
||||
|
||||
server.Wait()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func setupFakeSSH(wg *sync.WaitGroup, call func(addrPort netip.AddrPort, pubKey ssh.PublicKey)) *ssh.ServerConn {
|
||||
pubKey, privKey, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
sshPubKey, err := ssh.NewPublicKey(pubKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
sshSigner, err := ssh.NewSignerFromKey(privKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tcp, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv6Loopback(), 0)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
addrPort := tcp.Addr().(*net.TCPAddr).AddrPort()
|
||||
|
||||
var wg2 sync.WaitGroup
|
||||
wg2.Add(1)
|
||||
go func() {
|
||||
defer wg2.Done()
|
||||
call(addrPort, sshPubKey)
|
||||
}()
|
||||
|
||||
tcpConn, err := tcp.AcceptTCP()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
serverConfig := &ssh.ServerConfig{
|
||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
if conn.User() != "test" {
|
||||
return nil, fmt.Errorf("invalid user")
|
||||
}
|
||||
if !conn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr().IsLoopback() {
|
||||
return nil, fmt.Errorf("invalid remote address")
|
||||
}
|
||||
return &ssh.Permissions{}, nil
|
||||
},
|
||||
ServerVersion: "SSH-2.0-OrchidTester",
|
||||
}
|
||||
serverConfig.AddHostKey(sshSigner)
|
||||
|
||||
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, serverConfig)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// The incoming Request channel must be serviced.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
ssh.DiscardRequests(reqs)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Service the incoming channel.
|
||||
for newChannel := range chans {
|
||||
// Channels have a type, depending on the application level
|
||||
// protocol intended. In the case of a shell, the type is
|
||||
// "session" and ServerShell may be used to present a simple
|
||||
// terminal interface.
|
||||
if newChannel.ChannelType() != "session" {
|
||||
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
|
||||
continue
|
||||
}
|
||||
channel, requests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var fullFilePath string
|
||||
|
||||
// Sessions have out-of-band requests such as "shell",
|
||||
// "pty-req" and "env". Here we handle only the
|
||||
// "shell" request.
|
||||
wg.Add(1)
|
||||
go func(in <-chan *ssh.Request) {
|
||||
for req := range in {
|
||||
req.Reply(req.Type == "exec", nil)
|
||||
if req.Type == "exec" {
|
||||
length := binary.BigEndian.Uint32(req.Payload[:4])
|
||||
if len(req.Payload) != int(length)+4 {
|
||||
panic(fmt.Errorf("invalid exec payload (expected %d but got %d)", length, len(req.Payload)))
|
||||
}
|
||||
cmd := string(req.Payload[4:])
|
||||
const scpStartStr = "scp -qt \""
|
||||
if !strings.HasPrefix(cmd, scpStartStr) {
|
||||
panic("invalid start")
|
||||
}
|
||||
if !strings.HasSuffix(cmd, "\"") {
|
||||
panic("invalid end")
|
||||
}
|
||||
filePath := cmd[len(scpStartStr) : len(cmd)-1]
|
||||
fmt.Println("Writing file:", filePath)
|
||||
fullFilePath = filePath
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
}(requests)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
channel.Close()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
var b [1024]byte
|
||||
read := must(channel.Read(b[:]))
|
||||
if read < 1 {
|
||||
panic("invalid read")
|
||||
}
|
||||
|
||||
fmt.Println(string(b[:read]))
|
||||
|
||||
r := bufio.NewReader(bytes.NewReader(b[:read]))
|
||||
if readByte(r) != 'C' {
|
||||
panic("invalid scp command")
|
||||
}
|
||||
|
||||
fileMode := readN(r, 4)
|
||||
if string(fileMode) != "0600" {
|
||||
panic("unexpected file mode")
|
||||
}
|
||||
|
||||
if readByte(r) != ' ' {
|
||||
panic("missing space")
|
||||
}
|
||||
|
||||
fileSizeStr := must(r.ReadString(' '))
|
||||
fileSize := must(strconv.Atoi(fileSizeStr[:len(fileSizeStr)-1]))
|
||||
|
||||
fileName := strings.TrimSpace(string(must(io.ReadAll(r))))
|
||||
if fileName != filepath.Base(fullFilePath) {
|
||||
panic(fmt.Errorf("invalid file name (expected \"%s\" from full path \"%s\" but got \"%s\")", filepath.Base(fullFilePath), fullFilePath, fileName))
|
||||
}
|
||||
|
||||
if fileName != "420.cert.pem" && fileName != "420.key.pem" {
|
||||
panic("invalid file name")
|
||||
}
|
||||
|
||||
channel.Write([]byte{0})
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
_, err := io.CopyN(buf, channel, int64(fileSize))
|
||||
if err != nil {
|
||||
panic("Failed to copy channel")
|
||||
}
|
||||
fmt.Println("Copied file with size:", buf.Len())
|
||||
fmt.Println(buf.String())
|
||||
|
||||
if readLastByte(r) != 0x00 {
|
||||
panic("expected ending null byte")
|
||||
}
|
||||
|
||||
channel.Write([]byte{0})
|
||||
|
||||
channel.SendRequest("exit-status", false, binary.BigEndian.AppendUint32(nil, 0))
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
wg2.Wait()
|
||||
|
||||
return sshConn
|
||||
}
|
||||
|
||||
type fakeAgentDb struct{}
|
||||
|
||||
func (f *fakeAgentDb) FindAgentToSync(ctx context.Context) ([]database.FindAgentToSyncRow, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (f *fakeAgentDb) UpdateAgentLastSync(ctx context.Context, arg database.UpdateAgentLastSyncParams) error {
|
||||
if arg.ID != 1337 {
|
||||
return fmt.Errorf("invalid agent id")
|
||||
}
|
||||
if !arg.LastSync.Valid {
|
||||
return fmt.Errorf("invalid last sync")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeAgentDb) UpdateAgentCertNotAfter(ctx context.Context, arg database.UpdateAgentCertNotAfterParams) error {
|
||||
if arg.AgentID != 1337 {
|
||||
return fmt.Errorf("invalid agent id")
|
||||
}
|
||||
if arg.CertID != 420 {
|
||||
return fmt.Errorf("invalid cert id")
|
||||
}
|
||||
if !arg.NotAfter.Valid {
|
||||
return fmt.Errorf("invalid not after")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func must[T any](t T, err error) T {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func readN(r io.Reader, n int) []byte {
|
||||
b := make([]byte, n)
|
||||
_, err := io.ReadFull(r, b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func readByte(r io.Reader) byte {
|
||||
b := readN(r, 1)
|
||||
return b[0]
|
||||
}
|
||||
|
||||
func readLastByte(r io.Reader) byte {
|
||||
var b [1]byte
|
||||
_, err := io.ReadFull(r, b[:])
|
||||
if !errors.Is(err, io.EOF) {
|
||||
panic("expected EOF")
|
||||
}
|
||||
return b[0]
|
||||
}
|
5
agent/logger.go
Normal file
5
agent/logger.go
Normal file
@ -0,0 +1,5 @@
|
||||
package agent
|
||||
|
||||
import "github.com/1f349/orchid/logger"
|
||||
|
||||
var Logger = logger.Logger.WithPrefix("Orchid Agent")
|
65
cmd/orchid/agent.go
Normal file
65
cmd/orchid/agent.go
Normal file
@ -0,0 +1,65 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
"github.com/1f349/orchid/logger"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// loadAgentPrivateKey simply attempts to load the agent ssh private key and if
|
||||
// it is missing generates a new key
|
||||
func loadAgentPrivateKey(wd string) ssh.Signer {
|
||||
// load or create a key for orchid agent
|
||||
agentPrivKeyPath := filepath.Join(wd, "agent_id_ed25519")
|
||||
agentPubKeyPath := filepath.Join(wd, "agent_id_ed25519.pub")
|
||||
agentPrivKeyBytes, err := os.ReadFile(agentPrivKeyPath)
|
||||
switch {
|
||||
case err == nil:
|
||||
break
|
||||
case os.IsNotExist(err):
|
||||
pubKey, privKey, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
logger.Logger.Fatal("Failed to generate agent private key", "err", err)
|
||||
}
|
||||
marshalPrivKey, err := ssh.MarshalPrivateKey(privKey, "orchid-agent")
|
||||
if err != nil {
|
||||
logger.Logger.Fatal("Failed to encode private key", "err", err)
|
||||
}
|
||||
agentPrivKeyBytes = pem.EncodeToMemory(marshalPrivKey)
|
||||
|
||||
// public key
|
||||
sshPubKey, err := ssh.NewPublicKey(pubKey)
|
||||
if err != nil {
|
||||
logger.Logger.Fatal("Failed to encode public key", "err", err)
|
||||
}
|
||||
marshalPubKey := ssh.MarshalAuthorizedKey(sshPubKey)
|
||||
if err != nil {
|
||||
logger.Logger.Fatal("Failed to encode public key", "err", err)
|
||||
}
|
||||
|
||||
// write to files
|
||||
err = os.WriteFile(agentPrivKeyPath, agentPrivKeyBytes, 0600)
|
||||
if err != nil {
|
||||
logger.Logger.Fatal("Failed to write agent private key", "path", agentPrivKeyPath, "err", err)
|
||||
}
|
||||
err = os.WriteFile(agentPubKeyPath, marshalPubKey, 0644)
|
||||
if err != nil {
|
||||
logger.Logger.Fatal("Failed to write agent public key", "path", agentPubKeyPath, "err", err)
|
||||
}
|
||||
|
||||
// we can continue now
|
||||
break
|
||||
case err != nil:
|
||||
logger.Logger.Fatal("Failed to read agent private key", "path", agentPrivKeyPath, "err", err)
|
||||
}
|
||||
|
||||
privKey, err := ssh.ParsePrivateKey(agentPrivKeyBytes)
|
||||
if err != nil {
|
||||
logger.Logger.Fatal("Failed to parse agent private key file", "path", agentPrivKeyPath, "err", err)
|
||||
}
|
||||
|
||||
return privKey
|
||||
}
|
@ -7,6 +7,7 @@ type startUpConfig struct {
|
||||
Acme acmeConfig `yaml:"acme"`
|
||||
LE renewal.LetsEncryptConfig `yaml:"letsEncrypt"`
|
||||
Domains []string `yaml:"domains"`
|
||||
AgentKey string `yaml:"agentKey"`
|
||||
}
|
||||
|
||||
type acmeConfig struct {
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"flag"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/orchid"
|
||||
"github.com/1f349/orchid/agent"
|
||||
httpAcme "github.com/1f349/orchid/http-acme"
|
||||
"github.com/1f349/orchid/logger"
|
||||
"github.com/1f349/orchid/renewal"
|
||||
@ -78,7 +79,7 @@ func runDaemon(wd string, conf startUpConfig) {
|
||||
certDir := filepath.Join(wd, "renewal-certs")
|
||||
keyDir := filepath.Join(wd, "renewal-keys")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg := new(sync.WaitGroup)
|
||||
acmeProv, err := httpAcme.NewHttpAcmeProvider(filepath.Join(wd, "tokens.yml"), conf.Acme.PresentUrl, conf.Acme.CleanUpUrl, conf.Acme.RefreshUrl)
|
||||
if err != nil {
|
||||
logger.Logger.Fatal("HTTP Acme Error", "err", err)
|
||||
@ -87,6 +88,10 @@ func runDaemon(wd string, conf startUpConfig) {
|
||||
if err != nil {
|
||||
logger.Logger.Fatal("Service Error", "err", err)
|
||||
}
|
||||
certAgent, err := agent.NewAgent(wg, db, loadAgentPrivateKey(wd), certDir, keyDir)
|
||||
if err != nil {
|
||||
logger.Logger.Fatal("Failed to create agent", "err", err)
|
||||
}
|
||||
srv := servers.NewApiServer(conf.Listen, db, mJwtVerify, conf.Domains)
|
||||
logger.Logger.Info("Starting API server", "listen", srv.Addr)
|
||||
go utils.RunBackgroundHttp(logger.Logger, srv)
|
||||
@ -94,6 +99,7 @@ func runDaemon(wd string, conf startUpConfig) {
|
||||
exitReload.ExitReload("Violet", func() {}, func() {
|
||||
// stop renewal service and api server
|
||||
renewalService.Shutdown()
|
||||
certAgent.Shutdown()
|
||||
srv.Close()
|
||||
})
|
||||
}
|
||||
|
98
database/agent.sql.go
Normal file
98
database/agent.sql.go
Normal file
@ -0,0 +1,98 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.28.0
|
||||
// source: agent.sql
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
const findAgentToSync = `-- name: FindAgentToSync :many
|
||||
SELECT agents.id as agent_id, agents.address, agents.user, agents.dir, agents.fingerprint, cert.id as cert_id, cert.not_after as cert_not_after
|
||||
FROM agents
|
||||
INNER JOIN agent_certs
|
||||
ON agent_certs.agent_id = agents.id
|
||||
INNER JOIN certificates AS cert
|
||||
ON cert.id = agent_certs.cert_id
|
||||
WHERE (agents.last_sync IS NULL OR agents.last_sync < cert.updated_at)
|
||||
AND (agent_certs.not_after IS NULL OR agent_certs.not_after IS NOT cert.not_after)
|
||||
ORDER BY agents.last_sync NULLS FIRST
|
||||
`
|
||||
|
||||
type FindAgentToSyncRow struct {
|
||||
AgentID int64 `json:"agent_id"`
|
||||
Address string `json:"address"`
|
||||
User string `json:"user"`
|
||||
Dir string `json:"dir"`
|
||||
Fingerprint string `json:"fingerprint"`
|
||||
CertID int64 `json:"cert_id"`
|
||||
CertNotAfter sql.NullTime `json:"cert_not_after"`
|
||||
}
|
||||
|
||||
func (q *Queries) FindAgentToSync(ctx context.Context) ([]FindAgentToSyncRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, findAgentToSync)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []FindAgentToSyncRow
|
||||
for rows.Next() {
|
||||
var i FindAgentToSyncRow
|
||||
if err := rows.Scan(
|
||||
&i.AgentID,
|
||||
&i.Address,
|
||||
&i.User,
|
||||
&i.Dir,
|
||||
&i.Fingerprint,
|
||||
&i.CertID,
|
||||
&i.CertNotAfter,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateAgentCertNotAfter = `-- name: UpdateAgentCertNotAfter :exec
|
||||
UPDATE agent_certs
|
||||
SET not_after = ?
|
||||
WHERE agent_id = ?
|
||||
AND cert_id = ?
|
||||
`
|
||||
|
||||
type UpdateAgentCertNotAfterParams struct {
|
||||
NotAfter sql.NullTime `json:"not_after"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
CertID int64 `json:"cert_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateAgentCertNotAfter(ctx context.Context, arg UpdateAgentCertNotAfterParams) error {
|
||||
_, err := q.db.ExecContext(ctx, updateAgentCertNotAfter, arg.NotAfter, arg.AgentID, arg.CertID)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateAgentLastSync = `-- name: UpdateAgentLastSync :exec
|
||||
UPDATE agents
|
||||
SET last_sync = ?
|
||||
WHERE agents.id = ?
|
||||
`
|
||||
|
||||
type UpdateAgentLastSyncParams struct {
|
||||
LastSync sql.NullTime `json:"last_sync"`
|
||||
ID int64 `json:"id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateAgentLastSync(ctx context.Context, arg UpdateAgentLastSyncParams) error {
|
||||
_, err := q.db.ExecContext(ctx, updateAgentLastSync, arg.LastSync, arg.ID)
|
||||
return err
|
||||
}
|
0
database/migrations/20250129211955_agent.down.sql
Normal file
0
database/migrations/20250129211955_agent.down.sql
Normal file
21
database/migrations/20250129211955_agent.up.sql
Normal file
21
database/migrations/20250129211955_agent.up.sql
Normal file
@ -0,0 +1,21 @@
|
||||
CREATE TABLE IF NOT EXISTS agents
|
||||
(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
address TEXT NOT NULL,
|
||||
user TEXT NOT NULL,
|
||||
dir TEXT NOT NULL,
|
||||
fingerprint TEXT NOT NULL,
|
||||
last_sync DATETIME NULL DEFAULT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS agent_certs
|
||||
(
|
||||
agent_id INTEGER NOT NULL,
|
||||
cert_id INTEGER NOT NULL,
|
||||
not_after INTEGER NULL DEFAULT NULL,
|
||||
|
||||
PRIMARY KEY (agent_id, cert_id),
|
||||
|
||||
FOREIGN KEY (agent_id) REFERENCES agents (id),
|
||||
FOREIGN KEY (cert_id) REFERENCES certificates (id)
|
||||
);
|
@ -9,6 +9,21 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Agent struct {
|
||||
ID int64 `json:"id"`
|
||||
Address string `json:"address"`
|
||||
User string `json:"user"`
|
||||
Dir string `json:"dir"`
|
||||
Fingerprint string `json:"fingerprint"`
|
||||
LastSync sql.NullTime `json:"last_sync"`
|
||||
}
|
||||
|
||||
type AgentCert struct {
|
||||
AgentID int64 `json:"agent_id"`
|
||||
CertID int64 `json:"cert_id"`
|
||||
NotAfter sql.NullTime `json:"not_after"`
|
||||
}
|
||||
|
||||
type Certificate struct {
|
||||
ID int64 `json:"id"`
|
||||
Owner string `json:"owner"`
|
||||
@ -16,9 +31,9 @@ type Certificate struct {
|
||||
AutoRenew bool `json:"auto_renew"`
|
||||
Active bool `json:"active"`
|
||||
Renewing bool `json:"renewing"`
|
||||
NotAfter sql.NullTime `json:"not_after"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
TempParent sql.NullInt64 `json:"temp_parent"`
|
||||
NotAfter sql.NullTime `json:"not_after"`
|
||||
RenewRetry sql.NullTime `json:"renew_retry"`
|
||||
}
|
||||
|
||||
|
21
database/queries/agent.sql
Normal file
21
database/queries/agent.sql
Normal file
@ -0,0 +1,21 @@
|
||||
-- name: FindAgentToSync :many
|
||||
SELECT agents.id as agent_id, agents.address, agents.user, agents.dir, agents.fingerprint, cert.id as cert_id, cert.not_after as cert_not_after
|
||||
FROM agents
|
||||
INNER JOIN agent_certs
|
||||
ON agent_certs.agent_id = agents.id
|
||||
INNER JOIN certificates AS cert
|
||||
ON cert.id = agent_certs.cert_id
|
||||
WHERE (agents.last_sync IS NULL OR agents.last_sync < cert.updated_at)
|
||||
AND (agent_certs.not_after IS NULL OR agent_certs.not_after IS NOT cert.not_after)
|
||||
ORDER BY agents.last_sync NULLS FIRST;
|
||||
|
||||
-- name: UpdateAgentLastSync :exec
|
||||
UPDATE agents
|
||||
SET last_sync = ?
|
||||
WHERE agents.id = ?;
|
||||
|
||||
-- name: UpdateAgentCertNotAfter :exec
|
||||
UPDATE agent_certs
|
||||
SET not_after = ?
|
||||
WHERE agent_id = ?
|
||||
AND cert_id = ?;
|
3
go.mod
3
go.mod
@ -6,6 +6,7 @@ require (
|
||||
github.com/1f349/mjwt v0.4.1
|
||||
github.com/1f349/violet v0.0.14
|
||||
github.com/AlecAivazis/survey/v2 v2.3.7
|
||||
github.com/bramvdbogaerde/go-scp v1.5.0
|
||||
github.com/charmbracelet/log v0.4.0
|
||||
github.com/go-acme/lego/v4 v4.21.0
|
||||
github.com/golang-jwt/jwt/v4 v4.5.1
|
||||
@ -17,6 +18,7 @@ require (
|
||||
github.com/mrmelon54/certgen v0.0.2
|
||||
github.com/mrmelon54/exit-reload v0.0.2
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/crypto v0.32.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@ -46,7 +48,6 @@ require (
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/afero v1.12.0 // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
golang.org/x/crypto v0.32.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c // indirect
|
||||
golang.org/x/mod v0.22.0 // indirect
|
||||
golang.org/x/net v0.34.0 // indirect
|
||||
|
2
go.sum
2
go.sum
@ -12,6 +12,8 @@ github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiE
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA=
|
||||
github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4=
|
||||
github.com/bramvdbogaerde/go-scp v1.5.0 h1:a9BinAjTfQh273eh7vd3qUgmBC+bx+3TRDtkZWmIpzM=
|
||||
github.com/bramvdbogaerde/go-scp v1.5.0/go.mod h1:on2aH5AxaFb2G0N5Vsdy6B0Ml7k9HuHSwfo1y0QzAbQ=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/charmbracelet/lipgloss v1.0.0 h1:O7VkGDvqEdGi93X+DeqsQ7PKHDgtQfF8j8/O2qFMQNg=
|
||||
|
Loading…
Reference in New Issue
Block a user