mirror of
https://github.com/1f349/orchid.git
synced 2025-02-05 14:06:40 +00:00
379 lines
9.1 KiB
Go
379 lines
9.1 KiB
Go
package agent
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"crypto/x509/pkix"
|
|
"database/sql"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/1f349/orchid/database"
|
|
"github.com/1f349/orchid/logger"
|
|
"github.com/charmbracelet/log"
|
|
"github.com/mrmelon54/certgen"
|
|
"github.com/stretchr/testify/assert"
|
|
"golang.org/x/crypto/ed25519"
|
|
"golang.org/x/crypto/ssh"
|
|
"io"
|
|
"math/big"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestAgentSyncing(t *testing.T) {
|
|
logger.Logger.SetLevel(log.DebugLevel)
|
|
|
|
if testing.Short() {
|
|
t.Skip("Skipping agent syncing tests in short mode")
|
|
}
|
|
|
|
t.Run("agent syncing test", func(t *testing.T) {
|
|
certDir, err := os.MkdirTemp("", "orchid-certs")
|
|
assert.NoError(t, err)
|
|
keyDir, err := os.MkdirTemp("", "orchid-keys")
|
|
assert.NoError(t, err)
|
|
|
|
defer func() {
|
|
assert.NoError(t, os.RemoveAll(certDir))
|
|
assert.NoError(t, os.RemoveAll(keyDir))
|
|
}()
|
|
|
|
_, privKey, err := ed25519.GenerateKey(nil)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
sshPrivKey, err := ssh.NewSignerFromKey(privKey)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
agent := &Agent{
|
|
db: &fakeAgentDb{},
|
|
ticker: nil,
|
|
done: nil,
|
|
syncLock: nil,
|
|
sshKey: sshPrivKey,
|
|
certDir: certDir,
|
|
keyDir: keyDir,
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
|
|
t.Run("missing cert file", func(t *testing.T) {
|
|
err = agent.syncSingleAgentCertPair(now, database.FindAgentToSyncRow{
|
|
AgentID: 1337,
|
|
Address: "",
|
|
User: "test",
|
|
Dir: "~/hello/world",
|
|
Fingerprint: "",
|
|
CertID: 420,
|
|
CertNotAfter: sql.NullTime{Time: now, Valid: true},
|
|
})
|
|
assert.Contains(t, err.Error(), "open cert file:")
|
|
assert.Contains(t, err.Error(), "no such file or directory")
|
|
})
|
|
|
|
// generate example certificate
|
|
tlsCert, err := certgen.MakeServerTls(nil, 2048, pkix.Name{
|
|
Country: []string{"GB"},
|
|
Province: []string{"London"},
|
|
StreetAddress: []string{"221B Baker Street"},
|
|
PostalCode: []string{"NW1 6XE"},
|
|
SerialNumber: "test123456",
|
|
CommonName: "orchid-agent-test.local",
|
|
}, big.NewInt(1234567899), func(now time.Time) time.Time {
|
|
return now.Add(1 * time.Hour)
|
|
}, []string{"orchid-agent-test.local"}, []net.IP{
|
|
net.IPv6loopback,
|
|
net.IPv4(127, 0, 0, 1),
|
|
})
|
|
assert.NoError(t, err)
|
|
|
|
err = os.WriteFile(filepath.Join(certDir, "420.cert.pem"), tlsCert.GetCertPem(), 0600)
|
|
assert.NoError(t, err)
|
|
|
|
t.Run("missing key file", func(t *testing.T) {
|
|
err = agent.syncSingleAgentCertPair(now, database.FindAgentToSyncRow{
|
|
AgentID: 1337,
|
|
Address: "",
|
|
User: "test",
|
|
Dir: "~/hello/world",
|
|
Fingerprint: "",
|
|
CertID: 420,
|
|
CertNotAfter: sql.NullTime{Time: now, Valid: true},
|
|
})
|
|
assert.Contains(t, err.Error(), "open key file:")
|
|
assert.Contains(t, err.Error(), "no such file or directory")
|
|
})
|
|
|
|
err = os.WriteFile(filepath.Join(keyDir, "420.key.pem"), tlsCert.GetKeyPem(), 0600)
|
|
assert.NoError(t, err)
|
|
|
|
t.Run("successful sync", func(t *testing.T) {
|
|
var wg sync.WaitGroup
|
|
server := setupFakeSSH(&wg, func(remoteAddrPort netip.AddrPort, remotePubKey ssh.PublicKey) {
|
|
println("Attempt agent syncing")
|
|
|
|
err = agent.syncSingleAgentCertPair(now, database.FindAgentToSyncRow{
|
|
AgentID: 1337,
|
|
Address: remoteAddrPort.String(),
|
|
User: "test",
|
|
Dir: "~/hello/world",
|
|
Fingerprint: string(ssh.MarshalAuthorizedKey(remotePubKey)),
|
|
CertID: 420,
|
|
CertNotAfter: sql.NullTime{Time: now, Valid: true},
|
|
})
|
|
assert.NoError(t, err)
|
|
})
|
|
server.Close()
|
|
|
|
println("Waiting for ssh server to exit")
|
|
|
|
server.Wait()
|
|
})
|
|
})
|
|
}
|
|
|
|
func setupFakeSSH(wg *sync.WaitGroup, call func(addrPort netip.AddrPort, pubKey ssh.PublicKey)) *ssh.ServerConn {
|
|
pubKey, privKey, err := ed25519.GenerateKey(nil)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
sshPubKey, err := ssh.NewPublicKey(pubKey)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
sshSigner, err := ssh.NewSignerFromKey(privKey)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
tcp, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv6Loopback(), 0)))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
addrPort := tcp.Addr().(*net.TCPAddr).AddrPort()
|
|
|
|
var wg2 sync.WaitGroup
|
|
wg2.Add(1)
|
|
go func() {
|
|
defer wg2.Done()
|
|
call(addrPort, sshPubKey)
|
|
}()
|
|
|
|
tcpConn, err := tcp.AcceptTCP()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
serverConfig := &ssh.ServerConfig{
|
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
if conn.User() != "test" {
|
|
return nil, fmt.Errorf("invalid user")
|
|
}
|
|
if !conn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr().IsLoopback() {
|
|
return nil, fmt.Errorf("invalid remote address")
|
|
}
|
|
return &ssh.Permissions{}, nil
|
|
},
|
|
ServerVersion: "SSH-2.0-OrchidTester",
|
|
}
|
|
serverConfig.AddHostKey(sshSigner)
|
|
|
|
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, serverConfig)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
// The incoming Request channel must be serviced.
|
|
wg.Add(1)
|
|
go func() {
|
|
ssh.DiscardRequests(reqs)
|
|
wg.Done()
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
// Service the incoming channel.
|
|
for newChannel := range chans {
|
|
// Channels have a type, depending on the application level
|
|
// protocol intended. In the case of a shell, the type is
|
|
// "session" and ServerShell may be used to present a simple
|
|
// terminal interface.
|
|
if newChannel.ChannelType() != "session" {
|
|
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
|
|
continue
|
|
}
|
|
channel, requests, err := newChannel.Accept()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
var fullFilePath string
|
|
|
|
// Sessions have out-of-band requests such as "shell",
|
|
// "pty-req" and "env". Here we handle only the
|
|
// "shell" request.
|
|
wg.Add(1)
|
|
go func(in <-chan *ssh.Request) {
|
|
for req := range in {
|
|
req.Reply(req.Type == "exec", nil)
|
|
if req.Type == "exec" {
|
|
length := binary.BigEndian.Uint32(req.Payload[:4])
|
|
if len(req.Payload) != int(length)+4 {
|
|
panic(fmt.Errorf("invalid exec payload (expected %d but got %d)", length, len(req.Payload)))
|
|
}
|
|
cmd := string(req.Payload[4:])
|
|
const scpStartStr = "scp -qt \""
|
|
if !strings.HasPrefix(cmd, scpStartStr) {
|
|
panic("invalid start")
|
|
}
|
|
if !strings.HasSuffix(cmd, "\"") {
|
|
panic("invalid end")
|
|
}
|
|
filePath := cmd[len(scpStartStr) : len(cmd)-1]
|
|
fmt.Println("Writing file:", filePath)
|
|
fullFilePath = filePath
|
|
}
|
|
}
|
|
wg.Done()
|
|
}(requests)
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer func() {
|
|
channel.Close()
|
|
wg.Done()
|
|
}()
|
|
|
|
var b [1024]byte
|
|
read := must(channel.Read(b[:]))
|
|
if read < 1 {
|
|
panic("invalid read")
|
|
}
|
|
|
|
fmt.Println(string(b[:read]))
|
|
|
|
r := bufio.NewReader(bytes.NewReader(b[:read]))
|
|
if readByte(r) != 'C' {
|
|
panic("invalid scp command")
|
|
}
|
|
|
|
fileMode := readN(r, 4)
|
|
if string(fileMode) != "0600" {
|
|
panic("unexpected file mode")
|
|
}
|
|
|
|
if readByte(r) != ' ' {
|
|
panic("missing space")
|
|
}
|
|
|
|
fileSizeStr := must(r.ReadString(' '))
|
|
fileSize := must(strconv.Atoi(fileSizeStr[:len(fileSizeStr)-1]))
|
|
|
|
fileName := strings.TrimSpace(string(must(io.ReadAll(r))))
|
|
if fileName != filepath.Base(fullFilePath) {
|
|
panic(fmt.Errorf("invalid file name (expected \"%s\" from full path \"%s\" but got \"%s\")", filepath.Base(fullFilePath), fullFilePath, fileName))
|
|
}
|
|
|
|
if fileName != "420.cert.pem" && fileName != "420.key.pem" {
|
|
panic("invalid file name")
|
|
}
|
|
|
|
channel.Write([]byte{0})
|
|
|
|
buf := new(bytes.Buffer)
|
|
_, err := io.CopyN(buf, channel, int64(fileSize))
|
|
if err != nil {
|
|
panic("Failed to copy channel")
|
|
}
|
|
fmt.Println("Copied file with size:", buf.Len())
|
|
fmt.Println(buf.String())
|
|
|
|
if readLastByte(r) != 0x00 {
|
|
panic("expected ending null byte")
|
|
}
|
|
|
|
channel.Write([]byte{0})
|
|
|
|
channel.SendRequest("exit-status", false, binary.BigEndian.AppendUint32(nil, 0))
|
|
}()
|
|
}
|
|
}()
|
|
|
|
wg2.Wait()
|
|
|
|
return sshConn
|
|
}
|
|
|
|
type fakeAgentDb struct{}
|
|
|
|
func (f *fakeAgentDb) FindAgentToSync(ctx context.Context) ([]database.FindAgentToSyncRow, error) {
|
|
panic("implement me")
|
|
}
|
|
|
|
func (f *fakeAgentDb) UpdateAgentLastSync(ctx context.Context, arg database.UpdateAgentLastSyncParams) error {
|
|
if arg.ID != 1337 {
|
|
return fmt.Errorf("invalid agent id")
|
|
}
|
|
if !arg.LastSync.Valid {
|
|
return fmt.Errorf("invalid last sync")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeAgentDb) UpdateAgentCertNotAfter(ctx context.Context, arg database.UpdateAgentCertNotAfterParams) error {
|
|
if arg.AgentID != 1337 {
|
|
return fmt.Errorf("invalid agent id")
|
|
}
|
|
if arg.CertID != 420 {
|
|
return fmt.Errorf("invalid cert id")
|
|
}
|
|
if !arg.NotAfter.Valid {
|
|
return fmt.Errorf("invalid not after")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func must[T any](t T, err error) T {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return t
|
|
}
|
|
|
|
func readN(r io.Reader, n int) []byte {
|
|
b := make([]byte, n)
|
|
_, err := io.ReadFull(r, b)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return b
|
|
}
|
|
|
|
func readByte(r io.Reader) byte {
|
|
b := readN(r, 1)
|
|
return b[0]
|
|
}
|
|
|
|
func readLastByte(r io.Reader) byte {
|
|
var b [1]byte
|
|
_, err := io.ReadFull(r, b[:])
|
|
if !errors.Is(err, io.EOF) {
|
|
panic("expected EOF")
|
|
}
|
|
return b[0]
|
|
}
|