package main import ( "bufio" "bytes" "context" "crypto/x509" "encoding/json" "encoding/pem" "errors" "flag" "fmt" "io" "io/fs" "net" "net/http" "os" "os/exec" "os/user" "path/filepath" "regexp" "sort" "strconv" "strings" "syscall" "time" "github.com/coreos/go-systemd/v22/dbus" log "github.com/golang/glog" vapi "github.com/hashicorp/vault/api" "golang.org/x/crypto/ssh" "golang.org/x/sys/unix" ) var ( vaultAddress = flag.String("vault_address", "https://vault.int.lukegb.com", "Address of Vault") vaultAgentAddress = flag.String("vault_agent_address", "unix:///run/vault-agent/sock", "Address of Vault agent") signSSHHostKeys = flag.Bool("sign_ssh_host_keys", true, "Sign SSH host keys with CA") sshHostKeyCAPath = flag.String("ssh_host_key_ca_path", "ssh-host", "Path that the SSH CA is mounted at") sshHostKeyRole = flag.String("ssh_host_key_role", hostname(), "Role to use for signing SSH host keys") sshHostKeyPrincipals = flag.String("ssh_host_key_principals", defaultPrincipals(hostname()), "Principals to use for SSH host key certificate") sshHostKeyExpiryThreshold = flag.Duration("ssh_host_key_expiry_threshold", 3*24*time.Hour, "Expiry threshold for SSH host key") sshHostKeyOutputDir = flag.String("ssh_host_key_output_dir", "/var/lib/secretsmgr/ssh", "Path to SSH host key output dir") sshDummyHostKey = flag.String("ssh_dummy_host_key", "", "Path to dummy SSH hostkey") sshd = flag.String("sshd", "sshd", "Path to sshd") acmeCertificatesConfig = flag.String("acme_certificates_config", "", "Path to ACME certificates config") acmeCertificatesMount = flag.String("acme_certificates_mount", "acme", "Path that ACME is mounted at") acmeCertificatesExpiryThreshold = flag.Duration("acme_certificates_expiry_threshold", 31*24*time.Hour, "Expiry threshold for ACME certificates") acmeCertificatesForceRenew = flag.Bool("acme_certificates_force_renew", false, "Ignore remaining time and force ACME certificate renewal") ) func hostname() string { name, err := os.Hostname() if err != nil { return "" } return name } func defaultPrincipals(hostname string) string { return fmt.Sprintf("%s,%s.as205479.net,%s.int.as205479.net,%s.otter-acoustic.ts.net", hostname, hostname, hostname, hostname) } func sshHostKeyPaths(ctx context.Context) ([]string, error) { args := []string{"-T"} if *sshDummyHostKey != "" { args = append(args, "-h", *sshDummyHostKey) } out, err := exec.CommandContext(ctx, *sshd, args...).Output() if err != nil { return nil, err } scanner := bufio.NewScanner(bytes.NewReader(out)) var keys []string const hostkeyPrefix = "hostkey " for scanner.Scan() { t := scanner.Text() if strings.HasPrefix(t, hostkeyPrefix) { key := strings.TrimPrefix(t, hostkeyPrefix) if key == *sshDummyHostKey { continue } keys = append(keys, fmt.Sprintf("%s.pub", key)) } } if scanner.Err() != nil { return nil, scanner.Err() } return keys, nil } type acmeCertificate struct { Role string `json:"role"` OutputName string `json:"output_name"` Hostnames []string `json:"hostnames"` UnitsToReloadOrRestart []string `json:"units_to_reload_or_restart"` Group string `json:"group"` } func (c acmeCertificate) OutputDir() string { return filepath.Join("/var/lib/acme", c.OutputName) } func (c acmeCertificate) FullchainPath() string { return filepath.Join(c.OutputDir(), "fullchain.pem") } func (c acmeCertificate) ChainPath() string { return filepath.Join(c.OutputDir(), "chain.pem") } func (c acmeCertificate) PrivkeyPath() string { return filepath.Join(c.OutputDir(), "privkey.pem") } func (c acmeCertificate) SortedHostnames() []string { h := make([]string, len(c.Hostnames)) copy(h, c.Hostnames) sort.Strings(h) return h } var acmeCertBits = regexp.MustCompile(`^(?P[^/]+)/(?P[^=]+)=(?P[^/]+)(?:/(?P[^/]+))?$`) func splitDiscardEmpty(s string, sep string) []string { bits := strings.Split(s, sep) out := make([]string, 0, len(bits)) for _, bit := range bits { bit = strings.TrimSpace(bit) if bit == "" { continue } out = append(out, bit) } return out } func parseACMECertificates() ([]acmeCertificate, error) { if *acmeCertificatesConfig == "" { return nil, nil } cfgFile, err := os.Open(*acmeCertificatesConfig) if err != nil { return nil, err } defer cfgFile.Close() var cs []acmeCertificate if err := json.NewDecoder(cfgFile).Decode(&cs); err != nil { return nil, err } return cs, nil } func readFile(p string) ([]byte, error) { f, err := os.Open(p) if err != nil { return nil, err } defer f.Close() return io.ReadAll(f) } func writeFile(p string, content []byte) error { f, err := os.Create(p) if err != nil { return err } if _, err := f.Write(content); err != nil { f.Close() return err } if err := f.Close(); err != nil { return err } return nil } func readSSHPublicKey(p string) (ssh.PublicKey, error) { bs, err := readFile(p) if err != nil { return nil, fmt.Errorf("reading public key from %v: %w", p, err) } out, _, _, _, err := ssh.ParseAuthorizedKey(bs) if err != nil { return nil, fmt.Errorf("parsing public key in %v: %w", p, err) } return out, nil } func shouldRenewSSHCert(certPath string, publicKey []byte, principals string) (bool, error) { certPK, err := readSSHPublicKey(certPath) if err != nil { return true, err } cert, ok := certPK.(*ssh.Certificate) if !ok { return true, fmt.Errorf("file at %v was not an SSH certificate (was a %T)", certPath, certPK) } pubKey, _, _, _, err := ssh.ParseAuthorizedKey(publicKey) if err != nil { return false, err } // Verify that this is the same key. if !bytes.Equal(pubKey.Marshal(), cert.Key.Marshal()) { return true, fmt.Errorf("certificate key and public key don't match") } // Verify the certificate isn't expiring too soon. unixNow := time.Now().Unix() expiryThreshold := uint64(unixNow + int64(sshHostKeyExpiryThreshold.Seconds())) certExpiryDate := time.Unix(int64(cert.ValidBefore), 0) log.Infof("certificate %v expires at %v (in %v)", certPath, certExpiryDate, certExpiryDate.Sub(time.Now())) if cert.ValidBefore != uint64(ssh.CertTimeInfinity) && cert.ValidBefore < expiryThreshold { return true, fmt.Errorf("certificate expires at %v, which is earlier than our threshold of %v", certExpiryDate, time.Unix(int64(expiryThreshold), 0)) } hasPrincipals := map[string]bool{} for _, p := range cert.ValidPrincipals { hasPrincipals[p] = true } for _, wantP := range strings.Split(principals, ",") { if !hasPrincipals[wantP] { return true, fmt.Errorf("certificate missing principal %v (has principals %v; want principals %v)", wantP, cert.ValidPrincipals, principals) } } return false, nil } func reloadOrTryRestartServices(ctx context.Context, services []string) error { sd, err := dbus.NewWithContext(ctx) if err != nil { return fmt.Errorf("failed to connect to systemd: %w", err) } defer sd.Close() var failedUnits []string for _, service := range services { _, err := sd.ReloadOrTryRestartUnit(service, "replace", nil) if err != nil { log.Errorf("unit %v failed to reload-or-try-restart: %v", service, err) failedUnits = append(failedUnits, service) } } if len(failedUnits) != 0 { return fmt.Errorf("some units failed to restart: %v", failedUnits) } return nil } func checkAndRenewSSHCertificates(ctx context.Context, c *vapi.Client) bool { sshHostKeys, err := sshHostKeyPaths(ctx) if err != nil { log.Errorf("sshHostKeyPaths: %v", err) return true } cs := c.SSHWithMountPoint(*sshHostKeyCAPath) var hadErr bool var didRenew bool for _, sshHostKeyPath := range sshHostKeys { outPath := filepath.Join(*sshHostKeyOutputDir, fmt.Sprintf("%s-cert.pub", strings.TrimSuffix(filepath.Base(sshHostKeyPath), ".pub"))) log.Infof("might sign %v into %v", sshHostKeyPath, outPath) sshHostKey, err := readFile(sshHostKeyPath) if err != nil { log.Errorf("reading host key %v: %v", sshHostKeyPath, err) hadErr = true continue } if shouldRenew, err := shouldRenewSSHCert(outPath, sshHostKey, *sshHostKeyPrincipals); shouldRenew && err == nil { log.Infof("shouldRenewSSHCert for %v returned true with no reason", outPath) } else if shouldRenew { log.Infof("shouldRenewSSHCert for %v returned true: %v", outPath, err) } else if err != nil { log.Errorf("shouldRenewSSHCert for %v returned false and an error: %v", outPath, err) hadErr = true continue } else if !shouldRenew { log.Infof("shouldRenewSSHCert for %v returned false", outPath) continue } secret, err := cs.SignKey(*sshHostKeyRole, map[string]interface{}{ "cert_type": "host", "public_key": string(sshHostKey), "valid_principals": *sshHostKeyPrincipals, }) if err != nil { log.Errorf("signing %v: %v", sshHostKeyPath, err) hadErr = true continue } if err := writeFile(outPath, []byte(secret.Data["signed_key"].(string))); err != nil { log.Errorf("writing signed certificate to %v: %v", outPath, err) hadErr = true continue } log.Infof("renewed %v", outPath) didRenew = true } if didRenew { log.Infof("completed renewals; restarting SSHd") if err := reloadOrTryRestartServices(ctx, []string{"sshd.service"}); err != nil { log.Errorf("failed to restart sshd: %v", err) hadErr = true } } return hadErr } func setsSame(a, b []string) bool { s := map[string]int{} for _, v := range a { s[v]++ } for _, v := range b { s[v]++ } for _, n := range s { if n != 2 { return false } } return true } func getGroupID(groupName string) (int, error) { g, err := user.LookupGroup(groupName) if err != nil { return 0, fmt.Errorf("looking up group %q: %w", groupName, err) } gid, err := strconv.Atoi(g.Gid) if err != nil { return 0, fmt.Errorf("converting group %q to int: %w", g.Gid, err) } return gid, nil } func setGroup(groupName string) (func(), error) { if groupName == "" { return func() {}, nil } gid, err := getGroupID(groupName) if err != nil { return nil, err } origGid := unix.Getegid() if err := unix.Setregid(-1, gid); err != nil { return nil, err } return func() { unix.Setregid(-1, origGid) }, nil } func makeOrFixPermissionsOnDir(p string, permissions fs.FileMode, group string) error { outDirFi, err := os.Stat(p) if errors.Is(err, fs.ErrNotExist) { restoreGroup, err := setGroup(group) if err != nil { return fmt.Errorf("setting group to create outdir: %w", err) } defer restoreGroup() return os.Mkdir(p, permissions) } else if err != nil { return fmt.Errorf("stating output dir %q: %w", p, err) } if (outDirFi.Mode() & fs.ModePerm) != permissions { if err := os.Chmod(p, permissions); err != nil { return fmt.Errorf("chmodding output dir %q: %w", p, err) } } outDirStat, ok := outDirFi.Sys().(*syscall.Stat_t) if !ok { return fmt.Errorf("outDirFi isn't a stat_t, is a %T", outDirFi) } gid, err := getGroupID(group) if err != nil { return err } if outDirStat.Gid != uint32(gid) { if err := os.Chown(p, -1, gid); err != nil { return fmt.Errorf("chgrping output dir %q: %w", p, err) } } return nil } func shouldRenewACMECert(c acmeCertificate) (bool, error) { outDirFi, err := os.Stat(c.OutputDir()) if errors.Is(err, fs.ErrNotExist) { return true, fmt.Errorf("output dir %q doesn't exist", c.OutputDir()) } else if err != nil { return false, fmt.Errorf("stating output dir %q: %w", c.OutputDir(), err) } if !outDirFi.IsDir() { log.Warningf("output dir %q isn't a dir; deleting it") if err := os.Remove(c.OutputDir()); err != nil { return false, fmt.Errorf("couldn't remove non-dir in place of output dir %q: %w", c.OutputDir(), err) } return true, fmt.Errorf("output dir %q wasn't a dir") } // Check that all the output files exist. for _, f := range []string{c.FullchainPath(), c.ChainPath(), c.PrivkeyPath()} { if _, err := os.Stat(f); errors.Is(err, fs.ErrNotExist) { return true, fmt.Errorf("output file %q doesn't exist", f) } else if err != nil { return false, fmt.Errorf("stating output file %q: %w", f, err) } } fc, err := os.Open(c.FullchainPath()) if err != nil { return true, fmt.Errorf("opening fullchain %q: %w", c.FullchainPath(), err) } defer fc.Close() fcBytes, err := io.ReadAll(fc) if err != nil { return true, fmt.Errorf("reading fullchain %q: %w", c.FullchainPath(), err) } fcB, _ := pem.Decode(fcBytes) if fcB == nil { return true, fmt.Errorf("fullchain %q contained non-PEM bits", c.FullchainPath()) } else if fcB.Type != "CERTIFICATE" { return true, fmt.Errorf("fullchain %q contained non-CERTIFICATE PEM block %q", c.FullchainPath(), fcB.Type) } crt, err := x509.ParseCertificate(fcB.Bytes) if err != nil { return true, fmt.Errorf("fullchain %q wasn't parsable as x509", c.FullchainPath()) } // Check that the certificate hasn't expired. now := time.Now() expiryThreshold := now.Add(*acmeCertificatesExpiryThreshold) log.Infof("certificate in %q expires at %v (%v from now); expiry cutoff is %v", c.FullchainPath(), crt.NotAfter, crt.NotAfter.Sub(now), expiryThreshold) if crt.NotAfter.Before(expiryThreshold) { return true, fmt.Errorf("certificate in %q expires at %v (%v from now); expiry cutoff is %v", c.FullchainPath(), crt.NotAfter, crt.NotAfter.Sub(now), expiryThreshold) } // Check that the hostnames match. if !setsSame(c.Hostnames, crt.DNSNames) { return true, fmt.Errorf("certificate DNSNames %q don't match wanted hostnames %q", crt.DNSNames, c.Hostnames) } if *acmeCertificatesForceRenew { return true, fmt.Errorf("force renew flag set") } return false, nil } func writeCertificate(certDef acmeCertificate, cert *vapi.Secret) error { setFiles := []struct { name string content []byte perm fs.FileMode }{{ name: certDef.FullchainPath(), content: []byte(cert.Data["cert"].(string)), perm: 0644, }, { name: certDef.ChainPath(), content: []byte(cert.Data["issuer_cert"].(string)), perm: 0644, }, { name: certDef.PrivkeyPath(), content: []byte(cert.Data["private_key"].(string)), perm: 0640, }} for _, sf := range setFiles { os.Remove(sf.name) // optimistically try to remove the file, we don't care if it succeeds // if it doesn't, we'll error when we try to open it } restoreGroup, err := setGroup(certDef.Group) if err != nil { return fmt.Errorf("setting group to write output: %w", err) } defer restoreGroup() for _, sf := range setFiles { log.Infof("writing file %v mode %s group %s", sf.name, sf.perm, certDef.Group) f, err := os.OpenFile(sf.name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, sf.perm) if err != nil { return fmt.Errorf("opening %v: %w", sf.name, err) } if _, err := f.Write(sf.content); err != nil { f.Close() os.Remove(sf.name) return fmt.Errorf("writing %v: %w", sf.name, err) } if err := f.Close(); err != nil { return fmt.Errorf("closing %v: %w", sf.name, err) } } return nil } func uniqueStrings(vs []string) []string { m := map[string]bool{} for _, v := range vs { m[v] = true } x := make([]string, 0, len(m)) for v := range m { x = append(x, v) } return x } func checkAndRenewACMECertificates(ctx context.Context, c *vapi.Client) bool { certDefs, err := parseACMECertificates() if err != nil { log.Errorf("failed to parse certificates: %v", err) return true } log.Infof("acme certificate defs: %#v", certDefs) var hadErr bool var unitsToReloadOrRestart []string for _, certDef := range certDefs { log.Infof("maybe issuing %v (role: %v) with hostnames %v into %v", certDef.OutputName, certDef.Role, certDef.Hostnames, certDef.OutputDir()) if shouldRenew, err := shouldRenewACMECert(certDef); shouldRenew && err == nil { log.Infof("shouldRenewACMECert for %v returned true with no reason", certDef.OutputName) } else if shouldRenew { log.Infof("shouldRenewACMECert for %v returned true: %v", certDef.OutputName, err) } else if err != nil { log.Errorf("shouldRenewACMECert for %v returned false and an error: %v", certDef.OutputName, err) hadErr = true continue } else if !shouldRenew { log.Infof("shouldRenewACMECert for %v returned false", certDef.OutputName) continue } if err := makeOrFixPermissionsOnDir(certDef.OutputDir(), 0750, certDef.Group); err != nil { log.Errorf("creating/fixing dir %q: %v", certDef.OutputDir(), err) hadErr = true continue } // Issue certificate. names := certDef.SortedHostnames() commonName := names[0] names = names[1:] cert, err := c.Logical().Write(fmt.Sprintf("%s/certs/%s", *acmeCertificatesMount, certDef.Role), map[string]interface{}{ "common_name": commonName, "alternative_names": strings.Join(names, ","), }) if err != nil { log.Errorf("issuing %v (role: %v) with hostnames %v into %v: %v", certDef.OutputName, certDef.Role, certDef.Hostnames, certDef.OutputDir(), err) hadErr = true continue } // Write out certificate. if err := writeCertificate(certDef, cert); err != nil { log.Errorf("writing out certificate %v: %w", certDef.OutputName, err) hadErr = true continue } unitsToReloadOrRestart = append(unitsToReloadOrRestart, certDef.UnitsToReloadOrRestart...) } unitsToReloadOrRestart = uniqueStrings(unitsToReloadOrRestart) if len(unitsToReloadOrRestart) > 0 { log.Infof("reloading/restarting %d units: %v", len(unitsToReloadOrRestart), unitsToReloadOrRestart) if err := reloadOrTryRestartServices(ctx, unitsToReloadOrRestart); err != nil { log.Errorf("failed to restart units: %v", err) hadErr = true } } return hadErr } func main() { flag.Parse() d := &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, } agentPath := strings.TrimPrefix(*vaultAgentAddress, "unix://") agentDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { if !strings.HasPrefix(*vaultAgentAddress, "unix://") { return http.DefaultClient.Transport.(*http.Transport).DialContext(ctx, network, addr) } // Ignore what they want. return d.DialContext(ctx, "unix", agentPath) } vcfg := vapi.DefaultConfig() vcfg.AgentAddress = "http://vault-agent" vcfg.MaxRetries = 0 vcfg.Timeout = 15 * time.Minute vcfg.HttpClient.Transport.(*http.Transport).DialContext = agentDialer c, err := vapi.NewClient(vcfg) if err != nil { log.Exitf("creating vault client against %v: %v", *vaultAgentAddress, err) } ctx := context.Background() var hadErr bool hadErr = checkAndRenewSSHCertificates(ctx, c) || hadErr hadErr = checkAndRenewACMECertificates(ctx, c) || hadErr if hadErr { os.Exit(1) } }