ssl-certs-checker/utils.go

160 lines
3.2 KiB
Go

package main
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"os"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/charmbracelet/lipgloss"
"github.com/jedib0t/go-pretty/v6/table"
"github.com/jedib0t/go-pretty/v6/text"
"gopkg.in/yaml.v3"
)
var (
notAfterFail = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("196"))
notAfter30day = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("202"))
notAfterValid = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("46"))
)
type Config struct {
Hosts []string `yaml:"hosts"`
}
func readConfig(config string) Config {
c := Config{}
y, err := ioutil.ReadFile(config)
if err != nil {
fmt.Printf("fatal: %s\n", err)
os.Exit(1)
}
err = yaml.Unmarshal(y, &c)
if err != nil {
fmt.Printf("fatal: %s\n", err)
os.Exit(1)
}
return c
}
func getPeerCertificates(h string, port int, timeout int) ([]*x509.Certificate, error) {
conn, err := tls.DialWithDialer(
&net.Dialer{
Timeout: time.Duration(timeout) * time.Second,
},
protocol,
h+":"+strconv.Itoa(port),
&tls.Config{
ServerName: h,
})
if err != nil {
return nil, err
}
defer conn.Close()
if err := conn.Handshake(); err != nil {
return nil, err
}
return conn.ConnectionState().PeerCertificates, nil
}
func getCells(host string, port, timeout int, wg *sync.WaitGroup) []table.Row {
defer wg.Done()
certs, err := getPeerCertificates(host, port, timeout)
if err != nil {
fmt.Printf("err: %s\n", err)
return nil // skip if target host invalid
}
n := time.Now()
a := make([]table.Row, 0, len(certs))
for _, c := range certs {
if c.IsCA {
continue
}
a = append(a, table.Row{
host + ":" + strconv.Itoa(port),
(*c).Subject.CommonName,
strings.Join((*c).DNSNames, "\n"),
(*c).NotBefore,
renderNotAfter((*c).NotAfter, n),
(*c).PublicKeyAlgorithm.String(),
(*c).Issuer.CommonName,
})
}
return a
}
func renderNotAfter(t time.Time, n time.Time) string {
s := t.Sub(n)
r := int(s.Round(24*time.Hour).Hours() / 24)
if s < 0 {
return notAfterFail.Render(t.String())
} else if s < 30*24*time.Hour {
return fmt.Sprintf("%s (%d days)", notAfter30day.Render(t.String()), r)
}
return fmt.Sprintf("%s (%d days)", notAfterValid.Render(t.String()), r)
}
func prettyPrintCertsInfo(config string, timeout int) {
rc := readConfig(config)
if len(rc.Hosts) <= 0 {
fmt.Printf("key not found, or empty input\n")
return
}
t := table.NewWriter()
t.SetOutputMirror(os.Stdout)
t.AppendHeader(table.Row{
"Host",
"Common Name",
"DNS Names",
"Not Before",
"Not After",
"PublicKeyAlgorithm",
"Issuer",
})
var wg sync.WaitGroup
as := new(sync.Mutex)
a := make([]table.Row, 0, len(rc.Hosts))
for _, target := range rc.Hosts {
p := defaultPort
ts := strings.Split(target, ":")
if len(ts) == 2 {
tp, err := strconv.Atoi(ts[1])
if err != nil {
fmt.Errorf("err: invalid port [%s], assume target port is 443\n", target)
} else {
p = tp
}
}
wg.Add(1)
go func() {
b := getCells(ts[0], p, timeout, &wg)
as.Lock()
a = append(a, b...)
as.Unlock()
}()
}
wg.Wait()
sort.Slice(a, func(i, j int) bool {
return a[i][0].(string) < a[j][0].(string)
})
t.AppendRows(a)
t.Style().Format.Header = text.FormatDefault
t.Render()
}