From 037c6f0fd8280bdbf9cfab83ae7e5041f10bb477 Mon Sep 17 00:00:00 2001 From: Luke Granger-Brown Date: Thu, 17 Mar 2022 01:26:18 +0000 Subject: [PATCH] go/secretsmgr: add support for ACME certificate issuance --- go/secretsmgr/secretsmgr.go | 342 +++++++++++++++++++++++++++++++++++- 1 file changed, 338 insertions(+), 4 deletions(-) diff --git a/go/secretsmgr/secretsmgr.go b/go/secretsmgr/secretsmgr.go index 85ce0de0db..4f210826c0 100644 --- a/go/secretsmgr/secretsmgr.go +++ b/go/secretsmgr/secretsmgr.go @@ -4,19 +4,30 @@ import ( "bufio" "bytes" "context" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" "flag" "fmt" "io" + "io/fs" "os" "os/exec" + "os/user" "path/filepath" + "regexp" + "sort" + "strconv" "strings" + "syscall" "time" "github.com/coreos/go-systemd/v22/dbus" log "github.com/golang/glog" vapi "github.com/hashicorp/vault/api" "golang.org/x/crypto/ssh" + "golang.org/x/sys/unix" ) var ( @@ -25,10 +36,13 @@ var ( sshHostKeyRole = flag.String("ssh_host_key_role", hostname(), "Role to use for signing SSH host keys") sshHostKeyPrincipals = flag.String("ssh_host_key_principals", defaultPrincipals(hostname()), "Principals to use for SSH host key certificate") sshHostKeyExpiryThreshold = flag.Duration("ssh_host_key_expiry_threshold", 3*24*time.Hour, "Expiry threshold for SSH host key") + sshHostKeyOutputDir = flag.String("ssh_host_key_output_dir", "/var/lib/secretsmgr/ssh", "Path to SSH host key output dir") + sshd = flag.String("sshd", "sshd", "Path to sshd") - sshHostKeyOutputDir = flag.String("ssh_host_key_output_dir", "/var/lib/secretsmgr/ssh", "Path to SSH host key output dir") - - sshd = flag.String("sshd", "sshd", "Path to sshd") + acmeCertificatesConfig = flag.String("acme_certificates_config", "", "Path to ACME certificates config") + acmeCertificatesMount = flag.String("acme_certificates_mount", "acme", "Path that ACME is mounted at") + acmeCertificatesExpiryThreshold = flag.Duration("acme_certificates_expiry_threshold", 31*24*time.Hour, "Expiry threshold for ACME certificates") + acmeCertificatesForceRenew = flag.Bool("acme_certificates_force_renew", false, "Ignore remaining time and force ACME certificate renewal") ) func hostname() string { @@ -64,6 +78,71 @@ func sshHostKeyPaths(ctx context.Context) ([]string, error) { return keys, nil } +type acmeCertificate struct { + Role string `json:"role"` + OutputName string `json:"output_name"` + Hostnames []string `json:"hostnames"` + UnitsToReloadOrRestart []string `json:"units_to_reload_or_restart"` + Group string `json:"group"` +} + +func (c acmeCertificate) OutputDir() string { + return filepath.Join("/var/lib/acme", c.OutputName) +} + +func (c acmeCertificate) FullchainPath() string { + return filepath.Join(c.OutputDir(), "fullchain.pem") +} + +func (c acmeCertificate) ChainPath() string { + return filepath.Join(c.OutputDir(), "chain.pem") +} + +func (c acmeCertificate) PrivkeyPath() string { + return filepath.Join(c.OutputDir(), "privkey.pem") +} + +func (c acmeCertificate) SortedHostnames() []string { + h := make([]string, len(c.Hostnames)) + copy(h, c.Hostnames) + sort.Strings(h) + return h +} + +var acmeCertBits = regexp.MustCompile(`^(?P[^/]+)/(?P[^=]+)=(?P[^/]+)(?:/(?P[^/]+))?$`) + +func splitDiscardEmpty(s string, sep string) []string { + bits := strings.Split(s, sep) + out := make([]string, 0, len(bits)) + for _, bit := range bits { + bit = strings.TrimSpace(bit) + if bit == "" { + continue + } + out = append(out, bit) + } + return out +} + +func parseACMECertificates() ([]acmeCertificate, error) { + if *acmeCertificatesConfig == "" { + return nil, nil + } + + cfgFile, err := os.Open(*acmeCertificatesConfig) + if err != nil { + return nil, err + } + defer cfgFile.Close() + + var cs []acmeCertificate + if err := json.NewDecoder(cfgFile).Decode(&cs); err != nil { + return nil, err + } + + return cs, nil +} + func readFile(p string) ([]byte, error) { f, err := os.Open(p) if err != nil { @@ -233,8 +312,262 @@ func checkAndRenewSSHCertificates(ctx context.Context, c *vapi.Client) bool { return hadErr } +func setsSame(a, b []string) bool { + s := map[string]int{} + for _, v := range a { + s[v]++ + } + for _, v := range b { + s[v]++ + } + for _, n := range s { + if n != 2 { + return false + } + } + return true +} + +func getGroupID(groupName string) (int, error) { + g, err := user.LookupGroup(groupName) + if err != nil { + return 0, fmt.Errorf("looking up group %q: %w", groupName, err) + } + + gid, err := strconv.Atoi(g.Gid) + if err != nil { + return 0, fmt.Errorf("converting group %q to int: %w", g.Gid, err) + } + return gid, nil +} + +func setGroup(groupName string) (func(), error) { + if groupName == "" { + return func() {}, nil + } + + gid, err := getGroupID(groupName) + if err != nil { + return nil, err + } + + origGid := unix.Getegid() + unix.Setregid(-1, gid) + return func() { + unix.Setregid(-1, origGid) + }, nil +} + +func makeOrFixPermissionsOnDir(p string, permissions fs.FileMode, group string) error { + outDirFi, err := os.Stat(p) + if errors.Is(err, fs.ErrNotExist) { + restoreGroup, err := setGroup(group) + if err != nil { + return fmt.Errorf("setting group to create outdir: %w", err) + } + defer restoreGroup() + + return os.Mkdir(p, permissions) + } else if err != nil { + return fmt.Errorf("stating output dir %q: %w", p, err) + } + + if (outDirFi.Mode() & fs.ModePerm) != permissions { + if err := os.Chmod(p, permissions); err != nil { + return fmt.Errorf("chmodding output dir %q: %w", p, err) + } + } + + outDirStat, ok := outDirFi.Sys().(*syscall.Stat_t) + if !ok { + return fmt.Errorf("outDirFi isn't a stat_t, is a %T", outDirFi) + } + + gid, err := getGroupID(group) + if err != nil { + return err + } + + if outDirStat.Gid != uint32(gid) { + if err := os.Chown(p, -1, gid); err != nil { + return fmt.Errorf("chgrping output dir %q: %w", p, err) + } + } + + return nil +} + +func shouldRenewACMECert(c acmeCertificate) (bool, error) { + outDirFi, err := os.Stat(c.OutputDir()) + if errors.Is(err, fs.ErrNotExist) { + return true, fmt.Errorf("output dir %q doesn't exist", c.OutputDir()) + } else if err != nil { + return false, fmt.Errorf("stating output dir %q: %w", c.OutputDir(), err) + } + + if !outDirFi.IsDir() { + log.Warningf("output dir %q isn't a dir; deleting it") + if err := os.Remove(c.OutputDir()); err != nil { + return false, fmt.Errorf("couldn't remove non-dir in place of output dir %q: %w", c.OutputDir(), err) + } + return true, fmt.Errorf("output dir %q wasn't a dir") + } + + // Check that all the output files exist. + for _, f := range []string{c.FullchainPath(), c.ChainPath(), c.PrivkeyPath()} { + if _, err := os.Stat(f); errors.Is(err, fs.ErrNotExist) { + return true, fmt.Errorf("output file %q doesn't exist", f) + } else if err != nil { + return false, fmt.Errorf("stating output file %q: %w", f, err) + } + } + + fc, err := os.Open(c.FullchainPath()) + if err != nil { + return true, fmt.Errorf("opening fullchain %q: %w", c.FullchainPath(), err) + } + defer fc.Close() + + fcBytes, err := io.ReadAll(fc) + if err != nil { + return true, fmt.Errorf("reading fullchain %q: %w", c.FullchainPath(), err) + } + + fcB, _ := pem.Decode(fcBytes) + if fcB == nil { + return true, fmt.Errorf("fullchain %q contained non-PEM bits", c.FullchainPath()) + } else if fcB.Type != "CERTIFICATE" { + return true, fmt.Errorf("fullchain %q contained non-CERTIFICATE PEM block %q", c.FullchainPath(), fcB.Type) + } + + crt, err := x509.ParseCertificate(fcB.Bytes) + if err != nil { + return true, fmt.Errorf("fullchain %q wasn't parsable as x509", c.FullchainPath()) + } + + // Check that the certificate hasn't expired. + now := time.Now() + expiryThreshold := now.Add(*acmeCertificatesExpiryThreshold) + log.Infof("certificate in %q expires at %v (%v from now); expiry cutoff is %v", c.FullchainPath(), crt.NotAfter, crt.NotAfter.Sub(now), expiryThreshold) + if crt.NotAfter.Before(expiryThreshold) { + return true, fmt.Errorf("certificate in %q expires at %v (%v from now); expiry cutoff is %v", c.FullchainPath(), crt.NotAfter, crt.NotAfter.Sub(now), expiryThreshold) + } + + // Check that the hostnames match. + if !setsSame(c.Hostnames, crt.DNSNames) { + return true, fmt.Errorf("certificate DNSNames %q don't match wanted hostnames %q", crt.DNSNames, c.Hostnames) + } + + if *acmeCertificatesForceRenew { + return true, fmt.Errorf("force renew flag set") + } + return false, nil +} + +func writeCertificate(certDef acmeCertificate, cert *vapi.Secret) error { + restoreGroup, err := setGroup(certDef.Group) + if err != nil { + return fmt.Errorf("setting group to write output: %w", err) + } + defer restoreGroup() + + setFiles := []struct { + name string + content []byte + perm fs.FileMode + }{{ + name: certDef.FullchainPath(), + content: []byte(cert.Data["cert"].(string)), + perm: 0644, + }, { + name: certDef.ChainPath(), + content: []byte(cert.Data["issuer_cert"].(string)), + perm: 0644, + }, { + name: certDef.PrivkeyPath(), + content: []byte(cert.Data["private_key"].(string)), + perm: 0640, + }} + + for _, sf := range setFiles { + log.Infof("writing file %v mode %s", sf.name, sf.perm) + + os.Remove(sf.name) // optimistically try to remove the file, we don't care if it succeeds + // if it doesn't, we'll error when we try to open it + + f, err := os.OpenFile(sf.name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, sf.perm) + if err != nil { + return fmt.Errorf("opening %v: %w", sf.name, err) + } + + if _, err := f.Write(sf.content); err != nil { + f.Close() + os.Remove(sf.name) + return fmt.Errorf("writing %v: %w", sf.name, err) + } + + if err := f.Close(); err != nil { + return fmt.Errorf("closing %v: %w", sf.name, err) + } + } + + return nil +} + func checkAndRenewACMECertificates(ctx context.Context, c *vapi.Client) bool { - return false + certDefs, err := parseACMECertificates() + if err != nil { + log.Errorf("failed to parse certificates: %v", err) + return true + } + log.Infof("acme certificate defs: %#v", certDefs) + + var hadErr bool + for _, certDef := range certDefs { + log.Infof("maybe issuing %v (role: %v) with hostnames %v into %v", certDef.OutputName, certDef.Role, certDef.Hostnames, certDef.OutputDir()) + + if shouldRenew, err := shouldRenewACMECert(certDef); shouldRenew && err == nil { + log.Infof("shouldRenewACMECert for %v returned true with no reason", certDef.OutputName) + } else if shouldRenew { + log.Infof("shouldRenewACMECert for %v returned true: %v", certDef.OutputName, err) + } else if err != nil { + log.Errorf("shouldRenewACMECert for %v returned false and an error: %v", certDef.OutputName, err) + hadErr = true + continue + } else if !shouldRenew { + log.Infof("shouldRenewACMECert for %v returned false", certDef.OutputName) + continue + } + + if err := makeOrFixPermissionsOnDir(certDef.OutputDir(), 0750, certDef.Group); err != nil { + log.Errorf("creating/fixing dir %q: %w", certDef.OutputDir(), err) + hadErr = true + continue + } + + // Issue certificate. + names := certDef.SortedHostnames() + commonName := names[0] + names = names[1:] + cert, err := c.Logical().Write(fmt.Sprintf("%s/certs/%s", *acmeCertificatesMount, certDef.Role), map[string]interface{}{ + "common_name": commonName, + "alternative_names": strings.Join(names, ","), + }) + if err != nil { + log.Errorf("issuing %v (role: %v) with hostnames %v into %v: %v", certDef.OutputName, certDef.Role, certDef.Hostnames, certDef.OutputDir(), err) + hadErr = true + continue + } + + // Write out certificate. + if err := writeCertificate(certDef, cert); err != nil { + log.Errorf("writing out certificate %v: %w", certDef.OutputName, err) + hadErr = true + continue + } + } + + return hadErr } func main() { @@ -244,6 +577,7 @@ func main() { cfg.Address = "https://vault.int.lukegb.com" cfg.AgentAddress = "http://localhost:8200" cfg.MaxRetries = 0 + cfg.Timeout = 15 * time.Minute c, err := vapi.NewClient(cfg) if err != nil { log.Exitf("failed to create vault client: %v", err)