orchid/agent/agent_test.go

383 lines
9.1 KiB
Go

package agent
import (
"bufio"
"bytes"
"context"
"crypto/x509/pkix"
"database/sql"
"encoding/binary"
"errors"
"fmt"
"github.com/1f349/orchid/database"
"github.com/1f349/orchid/logger"
"github.com/charmbracelet/log"
"github.com/mrmelon54/certgen"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh"
"io"
"math/big"
"net"
"net/netip"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
"time"
)
func TestAgentSyncing(t *testing.T) {
logger.Logger.SetLevel(log.DebugLevel)
if testing.Short() {
t.Skip("Skipping agent syncing tests in short mode")
}
t.Run("agent syncing test", func(t *testing.T) {
certDir, err := os.MkdirTemp("", "orchid-certs")
assert.NoError(t, err)
keyDir, err := os.MkdirTemp("", "orchid-keys")
assert.NoError(t, err)
defer func() {
assert.NoError(t, os.RemoveAll(certDir))
assert.NoError(t, os.RemoveAll(keyDir))
}()
_, privKey, err := ed25519.GenerateKey(nil)
if err != nil {
panic(err)
}
sshPrivKey, err := ssh.NewSignerFromKey(privKey)
if err != nil {
panic(err)
}
agent := &Agent{
db: &fakeAgentDb{},
ticker: nil,
done: nil,
syncLock: nil,
sshKey: sshPrivKey,
certDir: certDir,
keyDir: keyDir,
}
now := time.Now().UTC()
t.Run("missing cert file", func(t *testing.T) {
err = agent.syncSingleAgentCertPair(now, database.FindAgentToSyncRow{
AgentID: 1337,
Address: "",
User: "test",
Dir: "~/hello/world",
Fingerprint: "",
CertID: 420,
CertNotAfter: sql.NullTime{Time: now, Valid: true},
})
assert.Contains(t, err.Error(), "open cert file:")
assert.Contains(t, err.Error(), "no such file or directory")
})
// generate example certificate
tlsCert, err := certgen.MakeServerTls(nil, 2048, pkix.Name{
Country: []string{"GB"},
Province: []string{"London"},
StreetAddress: []string{"221B Baker Street"},
PostalCode: []string{"NW1 6XE"},
SerialNumber: "test123456",
CommonName: "orchid-agent-test.local",
}, big.NewInt(1234567899), func(now time.Time) time.Time {
return now.Add(1 * time.Hour)
}, []string{"orchid-agent-test.local"}, []net.IP{
net.IPv6loopback,
net.IPv4(127, 0, 0, 1),
})
assert.NoError(t, err)
err = os.WriteFile(filepath.Join(certDir, "420.cert.pem"), tlsCert.GetCertPem(), 0600)
assert.NoError(t, err)
t.Run("missing key file", func(t *testing.T) {
err = agent.syncSingleAgentCertPair(now, database.FindAgentToSyncRow{
AgentID: 1337,
Address: "",
User: "test",
Dir: "~/hello/world",
Fingerprint: "",
CertID: 420,
CertNotAfter: sql.NullTime{Time: now, Valid: true},
})
assert.Contains(t, err.Error(), "open key file:")
assert.Contains(t, err.Error(), "no such file or directory")
})
err = os.WriteFile(filepath.Join(keyDir, "420.key.pem"), tlsCert.GetKeyPem(), 0600)
assert.NoError(t, err)
t.Run("successful sync", func(t *testing.T) {
var wg sync.WaitGroup
server := setupFakeSSH(&wg, func(remoteAddrPort netip.AddrPort, remotePubKey ssh.PublicKey) {
println("Attempt agent syncing")
err = agent.syncSingleAgentCertPair(now, database.FindAgentToSyncRow{
AgentID: 1337,
Address: remoteAddrPort.String(),
User: "test",
Dir: "~/hello/world",
Fingerprint: string(ssh.MarshalAuthorizedKey(remotePubKey)),
CertID: 420,
CertNotAfter: sql.NullTime{Time: now, Valid: true},
})
assert.NoError(t, err)
})
server.Close()
println("Waiting for ssh server to exit")
server.Wait()
})
})
}
func setupFakeSSH(wg *sync.WaitGroup, call func(addrPort netip.AddrPort, pubKey ssh.PublicKey)) *ssh.ServerConn {
pubKey, privKey, err := ed25519.GenerateKey(nil)
if err != nil {
panic(err)
}
sshPubKey, err := ssh.NewPublicKey(pubKey)
if err != nil {
panic(err)
}
sshSigner, err := ssh.NewSignerFromKey(privKey)
if err != nil {
panic(err)
}
tcp, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv6Loopback(), 0)))
if err != nil {
panic(err)
}
addrPort := tcp.Addr().(*net.TCPAddr).AddrPort()
var wg2 sync.WaitGroup
wg2.Add(1)
go func() {
defer wg2.Done()
call(addrPort, sshPubKey)
}()
tcpConn, err := tcp.AcceptTCP()
if err != nil {
panic(err)
}
serverConfig := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if conn.User() != "test" {
return nil, fmt.Errorf("invalid user")
}
if !conn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr().IsLoopback() {
return nil, fmt.Errorf("invalid remote address")
}
return &ssh.Permissions{}, nil
},
ServerVersion: "SSH-2.0-OrchidTester",
}
serverConfig.AddHostKey(sshSigner)
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, serverConfig)
if err != nil {
panic(err)
}
// The incoming Request channel must be serviced.
wg.Add(1)
go func() {
ssh.DiscardRequests(reqs)
wg.Done()
}()
wg.Add(1)
go func() {
defer wg.Done()
// Service the incoming channel.
for newChannel := range chans {
// Channels have a type, depending on the application level
// protocol intended. In the case of a shell, the type is
// "session" and ServerShell may be used to present a simple
// terminal interface.
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
panic(err)
}
var fullFilePath *string
{
a := ""
fullFilePath = &a
}
// Sessions have out-of-band requests such as "shell",
// "pty-req" and "env". Here we handle only the
// "shell" request.
wg.Add(1)
go func(in <-chan *ssh.Request) {
for req := range in {
req.Reply(req.Type == "exec", nil)
if req.Type == "exec" {
length := binary.BigEndian.Uint32(req.Payload[:4])
if len(req.Payload) != int(length)+4 {
panic(fmt.Errorf("invalid exec payload (expected %d but got %d)", length, len(req.Payload)))
}
cmd := string(req.Payload[4:])
const scpStartStr = "scp -qt \""
if !strings.HasPrefix(cmd, scpStartStr) {
panic("invalid start")
}
if !strings.HasSuffix(cmd, "\"") {
panic("invalid end")
}
filePath := cmd[len(scpStartStr) : len(cmd)-1]
fmt.Println("Writing file:", filePath)
*fullFilePath = filePath
}
}
wg.Done()
}(requests)
wg.Add(1)
go func() {
defer func() {
channel.Close()
wg.Done()
}()
var b [1024]byte
read := must(channel.Read(b[:]))
if read < 1 {
panic("invalid read")
}
fmt.Println(string(b[:read]))
r := bufio.NewReader(bytes.NewReader(b[:read]))
if readByte(r) != 'C' {
panic("invalid scp command")
}
fileMode := readN(r, 4)
if string(fileMode) != "0600" {
panic("unexpected file mode")
}
if readByte(r) != ' ' {
panic("missing space")
}
fileSizeStr := must(r.ReadString(' '))
fileSize := must(strconv.Atoi(fileSizeStr[:len(fileSizeStr)-1]))
fileName := strings.TrimSpace(string(must(io.ReadAll(r))))
if fileName != filepath.Base(*fullFilePath) {
panic(fmt.Errorf("invalid file name (expected \"%s\" from full path \"%s\" but got \"%s\")", filepath.Base(*fullFilePath), *fullFilePath, fileName))
}
if fileName != "420.cert.pem" && fileName != "420.key.pem" {
panic("invalid file name")
}
channel.Write([]byte{0})
buf := new(bytes.Buffer)
_, err := io.CopyN(buf, channel, int64(fileSize))
if err != nil {
panic("Failed to copy channel")
}
fmt.Println("Copied file with size:", buf.Len())
fmt.Println(buf.String())
if readLastByte(r) != 0x00 {
panic("expected ending null byte")
}
channel.Write([]byte{0})
channel.SendRequest("exit-status", false, binary.BigEndian.AppendUint32(nil, 0))
}()
}
}()
wg2.Wait()
return sshConn
}
type fakeAgentDb struct{}
func (f *fakeAgentDb) FindAgentToSync(ctx context.Context) ([]database.FindAgentToSyncRow, error) {
panic("implement me")
}
func (f *fakeAgentDb) UpdateAgentLastSync(ctx context.Context, arg database.UpdateAgentLastSyncParams) error {
if arg.ID != 1337 {
return fmt.Errorf("invalid agent id")
}
if !arg.LastSync.Valid {
return fmt.Errorf("invalid last sync")
}
return nil
}
func (f *fakeAgentDb) UpdateAgentCertNotAfter(ctx context.Context, arg database.UpdateAgentCertNotAfterParams) error {
if arg.AgentID != 1337 {
return fmt.Errorf("invalid agent id")
}
if arg.CertID != 420 {
return fmt.Errorf("invalid cert id")
}
if !arg.NotAfter.Valid {
return fmt.Errorf("invalid not after")
}
return nil
}
func must[T any](t T, err error) T {
if err != nil {
panic(err)
}
return t
}
func readN(r io.Reader, n int) []byte {
b := make([]byte, n)
_, err := io.ReadFull(r, b)
if err != nil {
panic(err)
}
return b
}
func readByte(r io.Reader) byte {
b := readN(r, 1)
return b[0]
}
func readLastByte(r io.Reader) byte {
var b [1]byte
_, err := io.ReadFull(r, b[:])
if !errors.Is(err, io.EOF) {
panic("expected EOF")
}
return b[0]
}