depot/go/secretsmgr/secretsmgr.go

262 lines
7.1 KiB
Go
Raw Normal View History

package main
import (
"bufio"
"bytes"
"context"
"flag"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/coreos/go-systemd/v22/dbus"
log "github.com/golang/glog"
vapi "github.com/hashicorp/vault/api"
"golang.org/x/crypto/ssh"
)
var (
signSSHHostKeys = flag.Bool("sign_ssh_host_keys", true, "Sign SSH host keys with CA")
sshHostKeyCAPath = flag.String("ssh_host_key_ca_path", "ssh-host", "Path that the SSH CA is mounted at")
sshHostKeyRole = flag.String("ssh_host_key_role", hostname(), "Role to use for signing SSH host keys")
sshHostKeyPrincipals = flag.String("ssh_host_key_principals", defaultPrincipals(hostname()), "Principals to use for SSH host key certificate")
sshHostKeyExpiryThreshold = flag.Duration("ssh_host_key_expiry_threshold", 3*24*time.Hour, "Expiry threshold for SSH host key")
sshHostKeyOutputDir = flag.String("ssh_host_key_output_dir", "/var/lib/secretsmgr/ssh", "Path to SSH host key output dir")
sshd = flag.String("sshd", "sshd", "Path to sshd")
)
func hostname() string {
name, err := os.Hostname()
if err != nil {
return ""
}
return name
}
func defaultPrincipals(hostname string) string {
return fmt.Sprintf("%s.as205479.net,%s.int.as205479.net", hostname, hostname)
}
func sshHostKeyPaths(ctx context.Context) ([]string, error) {
out, err := exec.CommandContext(ctx, *sshd, "-T").Output()
if err != nil {
return nil, err
}
scanner := bufio.NewScanner(bytes.NewReader(out))
var keys []string
const hostkeyPrefix = "hostkey "
for scanner.Scan() {
t := scanner.Text()
if strings.HasPrefix(t, hostkeyPrefix) {
keys = append(keys, fmt.Sprintf("%s.pub", strings.TrimPrefix(t, hostkeyPrefix)))
}
}
if scanner.Err() != nil {
return nil, scanner.Err()
}
return keys, nil
}
func readFile(p string) ([]byte, error) {
f, err := os.Open(p)
if err != nil {
return nil, err
}
defer f.Close()
return io.ReadAll(f)
}
func writeFile(p string, content []byte) error {
f, err := os.Create(p)
if err != nil {
return err
}
if _, err := f.Write(content); err != nil {
f.Close()
return err
}
if err := f.Close(); err != nil {
return err
}
return nil
}
func readSSHPublicKey(p string) (ssh.PublicKey, error) {
bs, err := readFile(p)
if err != nil {
return nil, fmt.Errorf("reading public key from %v: %w", p, err)
}
out, _, _, _, err := ssh.ParseAuthorizedKey(bs)
if err != nil {
return nil, fmt.Errorf("parsing public key in %v: %w", p, err)
}
return out, nil
}
func shouldRenewSSHCert(certPath string, publicKey []byte, principals string) (bool, error) {
certPK, err := readSSHPublicKey(certPath)
if err != nil {
return true, err
}
cert, ok := certPK.(*ssh.Certificate)
if !ok {
return true, fmt.Errorf("file at %v was not an SSH certificate (was a %T)", certPath, certPK)
}
pubKey, _, _, _, err := ssh.ParseAuthorizedKey(publicKey)
if err != nil {
return false, err
}
// Verify that this is the same key.
if !bytes.Equal(pubKey.Marshal(), cert.Key.Marshal()) {
return true, fmt.Errorf("certificate key and public key don't match")
}
// Verify the certificate isn't expiring too soon.
unixNow := time.Now().Unix()
expiryThreshold := uint64(unixNow + int64(sshHostKeyExpiryThreshold.Seconds()))
certExpiryDate := time.Unix(int64(cert.ValidBefore), 0)
log.Infof("certificate %v expires at %v (in %v)", certPath, certExpiryDate, certExpiryDate.Sub(time.Now()))
if cert.ValidBefore != uint64(ssh.CertTimeInfinity) && cert.ValidBefore < expiryThreshold {
return true, fmt.Errorf("certificate expires at %v, which is earlier than our threshold of %v", certExpiryDate, time.Unix(int64(expiryThreshold), 0))
}
hasPrincipals := map[string]bool{}
for _, p := range cert.ValidPrincipals {
hasPrincipals[p] = true
}
for _, wantP := range strings.Split(principals, ",") {
if !hasPrincipals[wantP] {
return true, fmt.Errorf("certificate missing principal %v (has principals %v; want principals %v)", wantP, cert.ValidPrincipals, principals)
}
}
return false, nil
}
func reloadOrTryRestartServices(ctx context.Context, services []string) error {
sd, err := dbus.NewSystemdConnection()
if err != nil {
return fmt.Errorf("failed to connect to systemd: %w", err)
}
defer sd.Close()
var failedUnits []string
for _, service := range services {
_, err := sd.ReloadOrTryRestartUnit(service, "replace", nil)
if err != nil {
log.Errorf("unit %v failed to reload-or-try-restart: %v", service, err)
failedUnits = append(failedUnits, service)
}
}
if len(failedUnits) != 0 {
return fmt.Errorf("some units failed to restart: %v", failedUnits)
}
return nil
}
func checkAndRenewSSHCertificates(ctx context.Context, c *vapi.Client) bool {
sshHostKeys, err := sshHostKeyPaths(ctx)
if err != nil {
log.Errorf("sshHostKeyPaths: %v", err)
return true
}
cs := c.SSHWithMountPoint(*sshHostKeyCAPath)
var hadErr bool
var didRenew bool
for _, sshHostKeyPath := range sshHostKeys {
outPath := filepath.Join(*sshHostKeyOutputDir, fmt.Sprintf("%s-cert.pub", strings.TrimSuffix(filepath.Base(sshHostKeyPath), ".pub")))
log.Infof("might sign %v into %v", sshHostKeyPath, outPath)
sshHostKey, err := readFile(sshHostKeyPath)
if err != nil {
log.Errorf("reading host key %v: %v", sshHostKeyPath, err)
hadErr = true
continue
}
if shouldRenew, err := shouldRenewSSHCert(outPath, sshHostKey, *sshHostKeyPrincipals); shouldRenew && err == nil {
log.Infof("shouldRenewSSHCert for %v returned true with no reason", outPath)
} else if shouldRenew {
log.Infof("shouldRenewSSHCert for %v returned true: %v", outPath, err)
} else if err != nil {
log.Errorf("shouldRenewSSHCert for %v returned false and an error: %v", outPath, err)
hadErr = true
continue
} else if !shouldRenew {
log.Infof("shouldRenewSSHCert for %v returned false", outPath)
continue
}
secret, err := cs.SignKey(*sshHostKeyRole, map[string]interface{}{
"cert_type": "host",
"public_key": string(sshHostKey),
"valid_principals": *sshHostKeyPrincipals,
})
if err != nil {
log.Errorf("signing %v: %v", sshHostKeyPath, err)
hadErr = true
continue
}
if err := writeFile(outPath, []byte(secret.Data["signed_key"].(string))); err != nil {
log.Errorf("writing signed certificate to %v: %v", outPath, err)
hadErr = true
continue
}
log.Infof("renewed %v", outPath)
didRenew = true
}
if didRenew {
log.Infof("completed renewals; restarting SSHd")
if err := reloadOrTryRestartServices(ctx, []string{"sshd.service"}); err != nil {
log.Errorf("failed to restart sshd: %v", err)
hadErr = true
}
}
return hadErr
}
func checkAndRenewACMECertificates(ctx context.Context, c *vapi.Client) bool {
return false
}
func main() {
flag.Parse()
cfg := vapi.DefaultConfig()
cfg.Address = "https://vault.int.lukegb.com"
cfg.AgentAddress = "http://localhost:8200"
cfg.MaxRetries = 0
c, err := vapi.NewClient(cfg)
if err != nil {
log.Exitf("failed to create vault client: %v", err)
}
ctx := context.Background()
var hadErr bool
hadErr = checkAndRenewSSHCertificates(ctx, c) || hadErr
hadErr = checkAndRenewACMECertificates(ctx, c) || hadErr
if hadErr {
os.Exit(1)
}
}