package main import ( "bufio" "bytes" "context" "flag" "fmt" "io" "os" "os/exec" "path/filepath" "strings" "time" "github.com/coreos/go-systemd/v22/dbus" log "github.com/golang/glog" vapi "github.com/hashicorp/vault/api" "golang.org/x/crypto/ssh" ) var ( signSSHHostKeys = flag.Bool("sign_ssh_host_keys", true, "Sign SSH host keys with CA") sshHostKeyCAPath = flag.String("ssh_host_key_ca_path", "ssh-host", "Path that the SSH CA is mounted at") sshHostKeyRole = flag.String("ssh_host_key_role", hostname(), "Role to use for signing SSH host keys") sshHostKeyPrincipals = flag.String("ssh_host_key_principals", defaultPrincipals(hostname()), "Principals to use for SSH host key certificate") sshHostKeyExpiryThreshold = flag.Duration("ssh_host_key_expiry_threshold", 3*24*time.Hour, "Expiry threshold for SSH host key") sshHostKeyOutputDir = flag.String("ssh_host_key_output_dir", "/var/lib/secretsmgr/ssh", "Path to SSH host key output dir") sshd = flag.String("sshd", "sshd", "Path to sshd") ) func hostname() string { name, err := os.Hostname() if err != nil { return "" } return name } func defaultPrincipals(hostname string) string { return fmt.Sprintf("%s.as205479.net,%s.int.as205479.net", hostname, hostname) } func sshHostKeyPaths(ctx context.Context) ([]string, error) { out, err := exec.CommandContext(ctx, *sshd, "-T").Output() if err != nil { return nil, err } scanner := bufio.NewScanner(bytes.NewReader(out)) var keys []string const hostkeyPrefix = "hostkey " for scanner.Scan() { t := scanner.Text() if strings.HasPrefix(t, hostkeyPrefix) { keys = append(keys, fmt.Sprintf("%s.pub", strings.TrimPrefix(t, hostkeyPrefix))) } } if scanner.Err() != nil { return nil, scanner.Err() } return keys, nil } func readFile(p string) ([]byte, error) { f, err := os.Open(p) if err != nil { return nil, err } defer f.Close() return io.ReadAll(f) } func writeFile(p string, content []byte) error { f, err := os.Create(p) if err != nil { return err } if _, err := f.Write(content); err != nil { f.Close() return err } if err := f.Close(); err != nil { return err } return nil } func readSSHPublicKey(p string) (ssh.PublicKey, error) { bs, err := readFile(p) if err != nil { return nil, fmt.Errorf("reading public key from %v: %w", p, err) } out, _, _, _, err := ssh.ParseAuthorizedKey(bs) if err != nil { return nil, fmt.Errorf("parsing public key in %v: %w", p, err) } return out, nil } func shouldRenewSSHCert(certPath string, publicKey []byte, principals string) (bool, error) { certPK, err := readSSHPublicKey(certPath) if err != nil { return true, err } cert, ok := certPK.(*ssh.Certificate) if !ok { return true, fmt.Errorf("file at %v was not an SSH certificate (was a %T)", certPath, certPK) } pubKey, _, _, _, err := ssh.ParseAuthorizedKey(publicKey) if err != nil { return false, err } // Verify that this is the same key. if !bytes.Equal(pubKey.Marshal(), cert.Key.Marshal()) { return true, fmt.Errorf("certificate key and public key don't match") } // Verify the certificate isn't expiring too soon. unixNow := time.Now().Unix() expiryThreshold := uint64(unixNow + int64(sshHostKeyExpiryThreshold.Seconds())) certExpiryDate := time.Unix(int64(cert.ValidBefore), 0) log.Infof("certificate %v expires at %v (in %v)", certPath, certExpiryDate, certExpiryDate.Sub(time.Now())) if cert.ValidBefore != uint64(ssh.CertTimeInfinity) && cert.ValidBefore < expiryThreshold { return true, fmt.Errorf("certificate expires at %v, which is earlier than our threshold of %v", certExpiryDate, time.Unix(int64(expiryThreshold), 0)) } hasPrincipals := map[string]bool{} for _, p := range cert.ValidPrincipals { hasPrincipals[p] = true } for _, wantP := range strings.Split(principals, ",") { if !hasPrincipals[wantP] { return true, fmt.Errorf("certificate missing principal %v (has principals %v; want principals %v)", wantP, cert.ValidPrincipals, principals) } } return false, nil } func reloadOrTryRestartServices(ctx context.Context, services []string) error { sd, err := dbus.NewSystemdConnection() if err != nil { return fmt.Errorf("failed to connect to systemd: %w", err) } defer sd.Close() var failedUnits []string for _, service := range services { _, err := sd.ReloadOrTryRestartUnit(service, "replace", nil) if err != nil { log.Errorf("unit %v failed to reload-or-try-restart: %v", service, err) failedUnits = append(failedUnits, service) } } if len(failedUnits) != 0 { return fmt.Errorf("some units failed to restart: %v", failedUnits) } return nil } func checkAndRenewSSHCertificates(ctx context.Context, c *vapi.Client) bool { sshHostKeys, err := sshHostKeyPaths(ctx) if err != nil { log.Errorf("sshHostKeyPaths: %v", err) return true } cs := c.SSHWithMountPoint(*sshHostKeyCAPath) var hadErr bool var didRenew bool for _, sshHostKeyPath := range sshHostKeys { outPath := filepath.Join(*sshHostKeyOutputDir, fmt.Sprintf("%s-cert.pub", strings.TrimSuffix(filepath.Base(sshHostKeyPath), ".pub"))) log.Infof("might sign %v into %v", sshHostKeyPath, outPath) sshHostKey, err := readFile(sshHostKeyPath) if err != nil { log.Errorf("reading host key %v: %v", sshHostKeyPath, err) hadErr = true continue } if shouldRenew, err := shouldRenewSSHCert(outPath, sshHostKey, *sshHostKeyPrincipals); shouldRenew && err == nil { log.Infof("shouldRenewSSHCert for %v returned true with no reason", outPath) } else if shouldRenew { log.Infof("shouldRenewSSHCert for %v returned true: %v", outPath, err) } else if err != nil { log.Errorf("shouldRenewSSHCert for %v returned false and an error: %v", outPath, err) hadErr = true continue } else if !shouldRenew { log.Infof("shouldRenewSSHCert for %v returned false", outPath) continue } secret, err := cs.SignKey(*sshHostKeyRole, map[string]interface{}{ "cert_type": "host", "public_key": string(sshHostKey), "valid_principals": *sshHostKeyPrincipals, }) if err != nil { log.Errorf("signing %v: %v", sshHostKeyPath, err) hadErr = true continue } if err := writeFile(outPath, []byte(secret.Data["signed_key"].(string))); err != nil { log.Errorf("writing signed certificate to %v: %v", outPath, err) hadErr = true continue } log.Infof("renewed %v", outPath) didRenew = true } if didRenew { log.Infof("completed renewals; restarting SSHd") if err := reloadOrTryRestartServices(ctx, []string{"sshd.service"}); err != nil { log.Errorf("failed to restart sshd: %v", err) hadErr = true } } return hadErr } func checkAndRenewACMECertificates(ctx context.Context, c *vapi.Client) bool { return false } func main() { flag.Parse() cfg := vapi.DefaultConfig() cfg.Address = "https://vault.int.lukegb.com" cfg.AgentAddress = "http://localhost:8200" cfg.MaxRetries = 0 c, err := vapi.NewClient(cfg) if err != nil { log.Exitf("failed to create vault client: %v", err) } ctx := context.Background() var hadErr bool hadErr = checkAndRenewSSHCertificates(ctx, c) || hadErr hadErr = checkAndRenewACMECertificates(ctx, c) || hadErr if hadErr { os.Exit(1) } }