go/secretsmgr: add support for ACME certificate issuance
This commit is contained in:
parent
b0d2782369
commit
037c6f0fd8
1 changed files with 338 additions and 4 deletions
|
@ -4,19 +4,30 @@ import (
|
|||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-systemd/v22/dbus"
|
||||
log "github.com/golang/glog"
|
||||
vapi "github.com/hashicorp/vault/api"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -25,10 +36,13 @@ var (
|
|||
sshHostKeyRole = flag.String("ssh_host_key_role", hostname(), "Role to use for signing SSH host keys")
|
||||
sshHostKeyPrincipals = flag.String("ssh_host_key_principals", defaultPrincipals(hostname()), "Principals to use for SSH host key certificate")
|
||||
sshHostKeyExpiryThreshold = flag.Duration("ssh_host_key_expiry_threshold", 3*24*time.Hour, "Expiry threshold for SSH host key")
|
||||
|
||||
sshHostKeyOutputDir = flag.String("ssh_host_key_output_dir", "/var/lib/secretsmgr/ssh", "Path to SSH host key output dir")
|
||||
|
||||
sshd = flag.String("sshd", "sshd", "Path to sshd")
|
||||
|
||||
acmeCertificatesConfig = flag.String("acme_certificates_config", "", "Path to ACME certificates config")
|
||||
acmeCertificatesMount = flag.String("acme_certificates_mount", "acme", "Path that ACME is mounted at")
|
||||
acmeCertificatesExpiryThreshold = flag.Duration("acme_certificates_expiry_threshold", 31*24*time.Hour, "Expiry threshold for ACME certificates")
|
||||
acmeCertificatesForceRenew = flag.Bool("acme_certificates_force_renew", false, "Ignore remaining time and force ACME certificate renewal")
|
||||
)
|
||||
|
||||
func hostname() string {
|
||||
|
@ -64,6 +78,71 @@ func sshHostKeyPaths(ctx context.Context) ([]string, error) {
|
|||
return keys, nil
|
||||
}
|
||||
|
||||
type acmeCertificate struct {
|
||||
Role string `json:"role"`
|
||||
OutputName string `json:"output_name"`
|
||||
Hostnames []string `json:"hostnames"`
|
||||
UnitsToReloadOrRestart []string `json:"units_to_reload_or_restart"`
|
||||
Group string `json:"group"`
|
||||
}
|
||||
|
||||
func (c acmeCertificate) OutputDir() string {
|
||||
return filepath.Join("/var/lib/acme", c.OutputName)
|
||||
}
|
||||
|
||||
func (c acmeCertificate) FullchainPath() string {
|
||||
return filepath.Join(c.OutputDir(), "fullchain.pem")
|
||||
}
|
||||
|
||||
func (c acmeCertificate) ChainPath() string {
|
||||
return filepath.Join(c.OutputDir(), "chain.pem")
|
||||
}
|
||||
|
||||
func (c acmeCertificate) PrivkeyPath() string {
|
||||
return filepath.Join(c.OutputDir(), "privkey.pem")
|
||||
}
|
||||
|
||||
func (c acmeCertificate) SortedHostnames() []string {
|
||||
h := make([]string, len(c.Hostnames))
|
||||
copy(h, c.Hostnames)
|
||||
sort.Strings(h)
|
||||
return h
|
||||
}
|
||||
|
||||
var acmeCertBits = regexp.MustCompile(`^(?P<role>[^/]+)/(?P<shortname>[^=]+)=(?P<hostnames>[^/]+)(?:/(?P<units>[^/]+))?$`)
|
||||
|
||||
func splitDiscardEmpty(s string, sep string) []string {
|
||||
bits := strings.Split(s, sep)
|
||||
out := make([]string, 0, len(bits))
|
||||
for _, bit := range bits {
|
||||
bit = strings.TrimSpace(bit)
|
||||
if bit == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, bit)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseACMECertificates() ([]acmeCertificate, error) {
|
||||
if *acmeCertificatesConfig == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
cfgFile, err := os.Open(*acmeCertificatesConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cfgFile.Close()
|
||||
|
||||
var cs []acmeCertificate
|
||||
if err := json.NewDecoder(cfgFile).Decode(&cs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func readFile(p string) ([]byte, error) {
|
||||
f, err := os.Open(p)
|
||||
if err != nil {
|
||||
|
@ -233,8 +312,262 @@ func checkAndRenewSSHCertificates(ctx context.Context, c *vapi.Client) bool {
|
|||
return hadErr
|
||||
}
|
||||
|
||||
func checkAndRenewACMECertificates(ctx context.Context, c *vapi.Client) bool {
|
||||
func setsSame(a, b []string) bool {
|
||||
s := map[string]int{}
|
||||
for _, v := range a {
|
||||
s[v]++
|
||||
}
|
||||
for _, v := range b {
|
||||
s[v]++
|
||||
}
|
||||
for _, n := range s {
|
||||
if n != 2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func getGroupID(groupName string) (int, error) {
|
||||
g, err := user.LookupGroup(groupName)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("looking up group %q: %w", groupName, err)
|
||||
}
|
||||
|
||||
gid, err := strconv.Atoi(g.Gid)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("converting group %q to int: %w", g.Gid, err)
|
||||
}
|
||||
return gid, nil
|
||||
}
|
||||
|
||||
func setGroup(groupName string) (func(), error) {
|
||||
if groupName == "" {
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
gid, err := getGroupID(groupName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
origGid := unix.Getegid()
|
||||
unix.Setregid(-1, gid)
|
||||
return func() {
|
||||
unix.Setregid(-1, origGid)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func makeOrFixPermissionsOnDir(p string, permissions fs.FileMode, group string) error {
|
||||
outDirFi, err := os.Stat(p)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
restoreGroup, err := setGroup(group)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting group to create outdir: %w", err)
|
||||
}
|
||||
defer restoreGroup()
|
||||
|
||||
return os.Mkdir(p, permissions)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("stating output dir %q: %w", p, err)
|
||||
}
|
||||
|
||||
if (outDirFi.Mode() & fs.ModePerm) != permissions {
|
||||
if err := os.Chmod(p, permissions); err != nil {
|
||||
return fmt.Errorf("chmodding output dir %q: %w", p, err)
|
||||
}
|
||||
}
|
||||
|
||||
outDirStat, ok := outDirFi.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return fmt.Errorf("outDirFi isn't a stat_t, is a %T", outDirFi)
|
||||
}
|
||||
|
||||
gid, err := getGroupID(group)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if outDirStat.Gid != uint32(gid) {
|
||||
if err := os.Chown(p, -1, gid); err != nil {
|
||||
return fmt.Errorf("chgrping output dir %q: %w", p, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldRenewACMECert(c acmeCertificate) (bool, error) {
|
||||
outDirFi, err := os.Stat(c.OutputDir())
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return true, fmt.Errorf("output dir %q doesn't exist", c.OutputDir())
|
||||
} else if err != nil {
|
||||
return false, fmt.Errorf("stating output dir %q: %w", c.OutputDir(), err)
|
||||
}
|
||||
|
||||
if !outDirFi.IsDir() {
|
||||
log.Warningf("output dir %q isn't a dir; deleting it")
|
||||
if err := os.Remove(c.OutputDir()); err != nil {
|
||||
return false, fmt.Errorf("couldn't remove non-dir in place of output dir %q: %w", c.OutputDir(), err)
|
||||
}
|
||||
return true, fmt.Errorf("output dir %q wasn't a dir")
|
||||
}
|
||||
|
||||
// Check that all the output files exist.
|
||||
for _, f := range []string{c.FullchainPath(), c.ChainPath(), c.PrivkeyPath()} {
|
||||
if _, err := os.Stat(f); errors.Is(err, fs.ErrNotExist) {
|
||||
return true, fmt.Errorf("output file %q doesn't exist", f)
|
||||
} else if err != nil {
|
||||
return false, fmt.Errorf("stating output file %q: %w", f, err)
|
||||
}
|
||||
}
|
||||
|
||||
fc, err := os.Open(c.FullchainPath())
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("opening fullchain %q: %w", c.FullchainPath(), err)
|
||||
}
|
||||
defer fc.Close()
|
||||
|
||||
fcBytes, err := io.ReadAll(fc)
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("reading fullchain %q: %w", c.FullchainPath(), err)
|
||||
}
|
||||
|
||||
fcB, _ := pem.Decode(fcBytes)
|
||||
if fcB == nil {
|
||||
return true, fmt.Errorf("fullchain %q contained non-PEM bits", c.FullchainPath())
|
||||
} else if fcB.Type != "CERTIFICATE" {
|
||||
return true, fmt.Errorf("fullchain %q contained non-CERTIFICATE PEM block %q", c.FullchainPath(), fcB.Type)
|
||||
}
|
||||
|
||||
crt, err := x509.ParseCertificate(fcB.Bytes)
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("fullchain %q wasn't parsable as x509", c.FullchainPath())
|
||||
}
|
||||
|
||||
// Check that the certificate hasn't expired.
|
||||
now := time.Now()
|
||||
expiryThreshold := now.Add(*acmeCertificatesExpiryThreshold)
|
||||
log.Infof("certificate in %q expires at %v (%v from now); expiry cutoff is %v", c.FullchainPath(), crt.NotAfter, crt.NotAfter.Sub(now), expiryThreshold)
|
||||
if crt.NotAfter.Before(expiryThreshold) {
|
||||
return true, fmt.Errorf("certificate in %q expires at %v (%v from now); expiry cutoff is %v", c.FullchainPath(), crt.NotAfter, crt.NotAfter.Sub(now), expiryThreshold)
|
||||
}
|
||||
|
||||
// Check that the hostnames match.
|
||||
if !setsSame(c.Hostnames, crt.DNSNames) {
|
||||
return true, fmt.Errorf("certificate DNSNames %q don't match wanted hostnames %q", crt.DNSNames, c.Hostnames)
|
||||
}
|
||||
|
||||
if *acmeCertificatesForceRenew {
|
||||
return true, fmt.Errorf("force renew flag set")
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func writeCertificate(certDef acmeCertificate, cert *vapi.Secret) error {
|
||||
restoreGroup, err := setGroup(certDef.Group)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting group to write output: %w", err)
|
||||
}
|
||||
defer restoreGroup()
|
||||
|
||||
setFiles := []struct {
|
||||
name string
|
||||
content []byte
|
||||
perm fs.FileMode
|
||||
}{{
|
||||
name: certDef.FullchainPath(),
|
||||
content: []byte(cert.Data["cert"].(string)),
|
||||
perm: 0644,
|
||||
}, {
|
||||
name: certDef.ChainPath(),
|
||||
content: []byte(cert.Data["issuer_cert"].(string)),
|
||||
perm: 0644,
|
||||
}, {
|
||||
name: certDef.PrivkeyPath(),
|
||||
content: []byte(cert.Data["private_key"].(string)),
|
||||
perm: 0640,
|
||||
}}
|
||||
|
||||
for _, sf := range setFiles {
|
||||
log.Infof("writing file %v mode %s", sf.name, sf.perm)
|
||||
|
||||
os.Remove(sf.name) // optimistically try to remove the file, we don't care if it succeeds
|
||||
// if it doesn't, we'll error when we try to open it
|
||||
|
||||
f, err := os.OpenFile(sf.name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, sf.perm)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening %v: %w", sf.name, err)
|
||||
}
|
||||
|
||||
if _, err := f.Write(sf.content); err != nil {
|
||||
f.Close()
|
||||
os.Remove(sf.name)
|
||||
return fmt.Errorf("writing %v: %w", sf.name, err)
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
return fmt.Errorf("closing %v: %w", sf.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkAndRenewACMECertificates(ctx context.Context, c *vapi.Client) bool {
|
||||
certDefs, err := parseACMECertificates()
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse certificates: %v", err)
|
||||
return true
|
||||
}
|
||||
log.Infof("acme certificate defs: %#v", certDefs)
|
||||
|
||||
var hadErr bool
|
||||
for _, certDef := range certDefs {
|
||||
log.Infof("maybe issuing %v (role: %v) with hostnames %v into %v", certDef.OutputName, certDef.Role, certDef.Hostnames, certDef.OutputDir())
|
||||
|
||||
if shouldRenew, err := shouldRenewACMECert(certDef); shouldRenew && err == nil {
|
||||
log.Infof("shouldRenewACMECert for %v returned true with no reason", certDef.OutputName)
|
||||
} else if shouldRenew {
|
||||
log.Infof("shouldRenewACMECert for %v returned true: %v", certDef.OutputName, err)
|
||||
} else if err != nil {
|
||||
log.Errorf("shouldRenewACMECert for %v returned false and an error: %v", certDef.OutputName, err)
|
||||
hadErr = true
|
||||
continue
|
||||
} else if !shouldRenew {
|
||||
log.Infof("shouldRenewACMECert for %v returned false", certDef.OutputName)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := makeOrFixPermissionsOnDir(certDef.OutputDir(), 0750, certDef.Group); err != nil {
|
||||
log.Errorf("creating/fixing dir %q: %w", certDef.OutputDir(), err)
|
||||
hadErr = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Issue certificate.
|
||||
names := certDef.SortedHostnames()
|
||||
commonName := names[0]
|
||||
names = names[1:]
|
||||
cert, err := c.Logical().Write(fmt.Sprintf("%s/certs/%s", *acmeCertificatesMount, certDef.Role), map[string]interface{}{
|
||||
"common_name": commonName,
|
||||
"alternative_names": strings.Join(names, ","),
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("issuing %v (role: %v) with hostnames %v into %v: %v", certDef.OutputName, certDef.Role, certDef.Hostnames, certDef.OutputDir(), err)
|
||||
hadErr = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Write out certificate.
|
||||
if err := writeCertificate(certDef, cert); err != nil {
|
||||
log.Errorf("writing out certificate %v: %w", certDef.OutputName, err)
|
||||
hadErr = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return hadErr
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
@ -244,6 +577,7 @@ func main() {
|
|||
cfg.Address = "https://vault.int.lukegb.com"
|
||||
cfg.AgentAddress = "http://localhost:8200"
|
||||
cfg.MaxRetries = 0
|
||||
cfg.Timeout = 15 * time.Minute
|
||||
c, err := vapi.NewClient(cfg)
|
||||
if err != nil {
|
||||
log.Exitf("failed to create vault client: %v", err)
|
||||
|
|
Loading…
Reference in a new issue