2022-03-15 03:07:34 +00:00
package main
import (
"bufio"
"bytes"
"context"
2022-03-17 01:26:18 +00:00
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
2022-03-15 03:07:34 +00:00
"flag"
"fmt"
"io"
2022-03-17 01:26:18 +00:00
"io/fs"
2022-03-15 03:07:34 +00:00
"os"
"os/exec"
2022-03-17 01:26:18 +00:00
"os/user"
2022-03-15 03:07:34 +00:00
"path/filepath"
2022-03-17 01:26:18 +00:00
"regexp"
"sort"
"strconv"
2022-03-15 03:07:34 +00:00
"strings"
2022-03-17 01:26:18 +00:00
"syscall"
2022-03-15 03:07:34 +00:00
"time"
"github.com/coreos/go-systemd/v22/dbus"
log "github.com/golang/glog"
vapi "github.com/hashicorp/vault/api"
"golang.org/x/crypto/ssh"
2022-03-17 01:26:18 +00:00
"golang.org/x/sys/unix"
2022-03-15 03:07:34 +00:00
)
var (
2022-03-20 17:47:52 +00:00
vaultAddress = flag . String ( "vault_address" , "https://vault.int.lukegb.com" , "Address of Vault" )
vaultAgentAddress = flag . String ( "vault_agent_address" , "unix:///run/vault-agent/sock" , "Address of Vault agent" )
2022-03-15 03:07:34 +00:00
signSSHHostKeys = flag . Bool ( "sign_ssh_host_keys" , true , "Sign SSH host keys with CA" )
sshHostKeyCAPath = flag . String ( "ssh_host_key_ca_path" , "ssh-host" , "Path that the SSH CA is mounted at" )
sshHostKeyRole = flag . String ( "ssh_host_key_role" , hostname ( ) , "Role to use for signing SSH host keys" )
sshHostKeyPrincipals = flag . String ( "ssh_host_key_principals" , defaultPrincipals ( hostname ( ) ) , "Principals to use for SSH host key certificate" )
sshHostKeyExpiryThreshold = flag . Duration ( "ssh_host_key_expiry_threshold" , 3 * 24 * time . Hour , "Expiry threshold for SSH host key" )
2022-03-17 01:26:18 +00:00
sshHostKeyOutputDir = flag . String ( "ssh_host_key_output_dir" , "/var/lib/secretsmgr/ssh" , "Path to SSH host key output dir" )
2022-03-17 23:31:55 +00:00
sshDummyHostKey = flag . String ( "ssh_dummy_host_key" , "" , "Path to dummy SSH hostkey" )
2022-03-17 01:26:18 +00:00
sshd = flag . String ( "sshd" , "sshd" , "Path to sshd" )
2022-03-15 03:07:34 +00:00
2022-03-17 01:26:18 +00:00
acmeCertificatesConfig = flag . String ( "acme_certificates_config" , "" , "Path to ACME certificates config" )
acmeCertificatesMount = flag . String ( "acme_certificates_mount" , "acme" , "Path that ACME is mounted at" )
acmeCertificatesExpiryThreshold = flag . Duration ( "acme_certificates_expiry_threshold" , 31 * 24 * time . Hour , "Expiry threshold for ACME certificates" )
acmeCertificatesForceRenew = flag . Bool ( "acme_certificates_force_renew" , false , "Ignore remaining time and force ACME certificate renewal" )
2022-03-15 03:07:34 +00:00
)
func hostname ( ) string {
name , err := os . Hostname ( )
if err != nil {
return ""
}
return name
}
func defaultPrincipals ( hostname string ) string {
return fmt . Sprintf ( "%s.as205479.net,%s.int.as205479.net" , hostname , hostname )
}
func sshHostKeyPaths ( ctx context . Context ) ( [ ] string , error ) {
2022-03-17 23:31:55 +00:00
args := [ ] string { "-T" }
if * sshDummyHostKey != "" {
args = append ( args , "-h" , * sshDummyHostKey )
}
out , err := exec . CommandContext ( ctx , * sshd , args ... ) . Output ( )
2022-03-15 03:07:34 +00:00
if err != nil {
return nil , err
}
scanner := bufio . NewScanner ( bytes . NewReader ( out ) )
var keys [ ] string
const hostkeyPrefix = "hostkey "
for scanner . Scan ( ) {
t := scanner . Text ( )
if strings . HasPrefix ( t , hostkeyPrefix ) {
2022-03-17 23:31:55 +00:00
key := strings . TrimPrefix ( t , hostkeyPrefix )
if key == * sshDummyHostKey {
continue
}
keys = append ( keys , fmt . Sprintf ( "%s.pub" , key ) )
2022-03-15 03:07:34 +00:00
}
}
if scanner . Err ( ) != nil {
return nil , scanner . Err ( )
}
return keys , nil
}
2022-03-17 01:26:18 +00:00
type acmeCertificate struct {
Role string ` json:"role" `
OutputName string ` json:"output_name" `
Hostnames [ ] string ` json:"hostnames" `
UnitsToReloadOrRestart [ ] string ` json:"units_to_reload_or_restart" `
Group string ` json:"group" `
}
func ( c acmeCertificate ) OutputDir ( ) string {
return filepath . Join ( "/var/lib/acme" , c . OutputName )
}
func ( c acmeCertificate ) FullchainPath ( ) string {
return filepath . Join ( c . OutputDir ( ) , "fullchain.pem" )
}
func ( c acmeCertificate ) ChainPath ( ) string {
return filepath . Join ( c . OutputDir ( ) , "chain.pem" )
}
func ( c acmeCertificate ) PrivkeyPath ( ) string {
return filepath . Join ( c . OutputDir ( ) , "privkey.pem" )
}
func ( c acmeCertificate ) SortedHostnames ( ) [ ] string {
h := make ( [ ] string , len ( c . Hostnames ) )
copy ( h , c . Hostnames )
sort . Strings ( h )
return h
}
var acmeCertBits = regexp . MustCompile ( ` ^(?P<role>[^/]+)/(?P<shortname>[^=]+)=(?P<hostnames>[^/]+)(?:/(?P<units>[^/]+))?$ ` )
func splitDiscardEmpty ( s string , sep string ) [ ] string {
bits := strings . Split ( s , sep )
out := make ( [ ] string , 0 , len ( bits ) )
for _ , bit := range bits {
bit = strings . TrimSpace ( bit )
if bit == "" {
continue
}
out = append ( out , bit )
}
return out
}
func parseACMECertificates ( ) ( [ ] acmeCertificate , error ) {
if * acmeCertificatesConfig == "" {
return nil , nil
}
cfgFile , err := os . Open ( * acmeCertificatesConfig )
if err != nil {
return nil , err
}
defer cfgFile . Close ( )
var cs [ ] acmeCertificate
if err := json . NewDecoder ( cfgFile ) . Decode ( & cs ) ; err != nil {
return nil , err
}
return cs , nil
}
2022-03-15 03:07:34 +00:00
func readFile ( p string ) ( [ ] byte , error ) {
f , err := os . Open ( p )
if err != nil {
return nil , err
}
defer f . Close ( )
return io . ReadAll ( f )
}
func writeFile ( p string , content [ ] byte ) error {
f , err := os . Create ( p )
if err != nil {
return err
}
if _ , err := f . Write ( content ) ; err != nil {
f . Close ( )
return err
}
if err := f . Close ( ) ; err != nil {
return err
}
return nil
}
func readSSHPublicKey ( p string ) ( ssh . PublicKey , error ) {
bs , err := readFile ( p )
if err != nil {
return nil , fmt . Errorf ( "reading public key from %v: %w" , p , err )
}
out , _ , _ , _ , err := ssh . ParseAuthorizedKey ( bs )
if err != nil {
return nil , fmt . Errorf ( "parsing public key in %v: %w" , p , err )
}
return out , nil
}
func shouldRenewSSHCert ( certPath string , publicKey [ ] byte , principals string ) ( bool , error ) {
certPK , err := readSSHPublicKey ( certPath )
if err != nil {
return true , err
}
cert , ok := certPK . ( * ssh . Certificate )
if ! ok {
return true , fmt . Errorf ( "file at %v was not an SSH certificate (was a %T)" , certPath , certPK )
}
pubKey , _ , _ , _ , err := ssh . ParseAuthorizedKey ( publicKey )
if err != nil {
return false , err
}
// Verify that this is the same key.
if ! bytes . Equal ( pubKey . Marshal ( ) , cert . Key . Marshal ( ) ) {
return true , fmt . Errorf ( "certificate key and public key don't match" )
}
// Verify the certificate isn't expiring too soon.
unixNow := time . Now ( ) . Unix ( )
expiryThreshold := uint64 ( unixNow + int64 ( sshHostKeyExpiryThreshold . Seconds ( ) ) )
certExpiryDate := time . Unix ( int64 ( cert . ValidBefore ) , 0 )
log . Infof ( "certificate %v expires at %v (in %v)" , certPath , certExpiryDate , certExpiryDate . Sub ( time . Now ( ) ) )
if cert . ValidBefore != uint64 ( ssh . CertTimeInfinity ) && cert . ValidBefore < expiryThreshold {
return true , fmt . Errorf ( "certificate expires at %v, which is earlier than our threshold of %v" , certExpiryDate , time . Unix ( int64 ( expiryThreshold ) , 0 ) )
}
hasPrincipals := map [ string ] bool { }
for _ , p := range cert . ValidPrincipals {
hasPrincipals [ p ] = true
}
for _ , wantP := range strings . Split ( principals , "," ) {
if ! hasPrincipals [ wantP ] {
return true , fmt . Errorf ( "certificate missing principal %v (has principals %v; want principals %v)" , wantP , cert . ValidPrincipals , principals )
}
}
return false , nil
}
func reloadOrTryRestartServices ( ctx context . Context , services [ ] string ) error {
2022-03-17 23:31:55 +00:00
sd , err := dbus . NewWithContext ( ctx )
2022-03-15 03:07:34 +00:00
if err != nil {
return fmt . Errorf ( "failed to connect to systemd: %w" , err )
}
defer sd . Close ( )
var failedUnits [ ] string
for _ , service := range services {
_ , err := sd . ReloadOrTryRestartUnit ( service , "replace" , nil )
if err != nil {
log . Errorf ( "unit %v failed to reload-or-try-restart: %v" , service , err )
failedUnits = append ( failedUnits , service )
}
}
if len ( failedUnits ) != 0 {
return fmt . Errorf ( "some units failed to restart: %v" , failedUnits )
}
return nil
}
func checkAndRenewSSHCertificates ( ctx context . Context , c * vapi . Client ) bool {
sshHostKeys , err := sshHostKeyPaths ( ctx )
if err != nil {
log . Errorf ( "sshHostKeyPaths: %v" , err )
return true
}
cs := c . SSHWithMountPoint ( * sshHostKeyCAPath )
var hadErr bool
var didRenew bool
for _ , sshHostKeyPath := range sshHostKeys {
outPath := filepath . Join ( * sshHostKeyOutputDir , fmt . Sprintf ( "%s-cert.pub" , strings . TrimSuffix ( filepath . Base ( sshHostKeyPath ) , ".pub" ) ) )
log . Infof ( "might sign %v into %v" , sshHostKeyPath , outPath )
sshHostKey , err := readFile ( sshHostKeyPath )
if err != nil {
log . Errorf ( "reading host key %v: %v" , sshHostKeyPath , err )
hadErr = true
continue
}
if shouldRenew , err := shouldRenewSSHCert ( outPath , sshHostKey , * sshHostKeyPrincipals ) ; shouldRenew && err == nil {
log . Infof ( "shouldRenewSSHCert for %v returned true with no reason" , outPath )
} else if shouldRenew {
log . Infof ( "shouldRenewSSHCert for %v returned true: %v" , outPath , err )
} else if err != nil {
log . Errorf ( "shouldRenewSSHCert for %v returned false and an error: %v" , outPath , err )
hadErr = true
continue
} else if ! shouldRenew {
log . Infof ( "shouldRenewSSHCert for %v returned false" , outPath )
continue
}
secret , err := cs . SignKey ( * sshHostKeyRole , map [ string ] interface { } {
"cert_type" : "host" ,
"public_key" : string ( sshHostKey ) ,
"valid_principals" : * sshHostKeyPrincipals ,
} )
if err != nil {
log . Errorf ( "signing %v: %v" , sshHostKeyPath , err )
hadErr = true
continue
}
if err := writeFile ( outPath , [ ] byte ( secret . Data [ "signed_key" ] . ( string ) ) ) ; err != nil {
log . Errorf ( "writing signed certificate to %v: %v" , outPath , err )
hadErr = true
continue
}
log . Infof ( "renewed %v" , outPath )
didRenew = true
}
if didRenew {
log . Infof ( "completed renewals; restarting SSHd" )
if err := reloadOrTryRestartServices ( ctx , [ ] string { "sshd.service" } ) ; err != nil {
log . Errorf ( "failed to restart sshd: %v" , err )
hadErr = true
}
}
return hadErr
}
2022-03-17 01:26:18 +00:00
func setsSame ( a , b [ ] string ) bool {
s := map [ string ] int { }
for _ , v := range a {
s [ v ] ++
}
for _ , v := range b {
s [ v ] ++
}
for _ , n := range s {
if n != 2 {
return false
}
}
return true
}
func getGroupID ( groupName string ) ( int , error ) {
g , err := user . LookupGroup ( groupName )
if err != nil {
return 0 , fmt . Errorf ( "looking up group %q: %w" , groupName , err )
}
gid , err := strconv . Atoi ( g . Gid )
if err != nil {
return 0 , fmt . Errorf ( "converting group %q to int: %w" , g . Gid , err )
}
return gid , nil
}
func setGroup ( groupName string ) ( func ( ) , error ) {
if groupName == "" {
return func ( ) { } , nil
}
gid , err := getGroupID ( groupName )
if err != nil {
return nil , err
}
origGid := unix . Getegid ( )
2022-03-17 23:31:55 +00:00
if err := unix . Setregid ( - 1 , gid ) ; err != nil {
return nil , err
}
2022-03-17 01:26:18 +00:00
return func ( ) {
unix . Setregid ( - 1 , origGid )
} , nil
}
func makeOrFixPermissionsOnDir ( p string , permissions fs . FileMode , group string ) error {
outDirFi , err := os . Stat ( p )
if errors . Is ( err , fs . ErrNotExist ) {
restoreGroup , err := setGroup ( group )
if err != nil {
return fmt . Errorf ( "setting group to create outdir: %w" , err )
}
defer restoreGroup ( )
return os . Mkdir ( p , permissions )
} else if err != nil {
return fmt . Errorf ( "stating output dir %q: %w" , p , err )
}
if ( outDirFi . Mode ( ) & fs . ModePerm ) != permissions {
if err := os . Chmod ( p , permissions ) ; err != nil {
return fmt . Errorf ( "chmodding output dir %q: %w" , p , err )
}
}
outDirStat , ok := outDirFi . Sys ( ) . ( * syscall . Stat_t )
if ! ok {
return fmt . Errorf ( "outDirFi isn't a stat_t, is a %T" , outDirFi )
}
gid , err := getGroupID ( group )
if err != nil {
return err
}
if outDirStat . Gid != uint32 ( gid ) {
if err := os . Chown ( p , - 1 , gid ) ; err != nil {
return fmt . Errorf ( "chgrping output dir %q: %w" , p , err )
}
}
return nil
}
func shouldRenewACMECert ( c acmeCertificate ) ( bool , error ) {
outDirFi , err := os . Stat ( c . OutputDir ( ) )
if errors . Is ( err , fs . ErrNotExist ) {
return true , fmt . Errorf ( "output dir %q doesn't exist" , c . OutputDir ( ) )
} else if err != nil {
return false , fmt . Errorf ( "stating output dir %q: %w" , c . OutputDir ( ) , err )
}
if ! outDirFi . IsDir ( ) {
log . Warningf ( "output dir %q isn't a dir; deleting it" )
if err := os . Remove ( c . OutputDir ( ) ) ; err != nil {
return false , fmt . Errorf ( "couldn't remove non-dir in place of output dir %q: %w" , c . OutputDir ( ) , err )
}
return true , fmt . Errorf ( "output dir %q wasn't a dir" )
}
// Check that all the output files exist.
for _ , f := range [ ] string { c . FullchainPath ( ) , c . ChainPath ( ) , c . PrivkeyPath ( ) } {
if _ , err := os . Stat ( f ) ; errors . Is ( err , fs . ErrNotExist ) {
return true , fmt . Errorf ( "output file %q doesn't exist" , f )
} else if err != nil {
return false , fmt . Errorf ( "stating output file %q: %w" , f , err )
}
}
fc , err := os . Open ( c . FullchainPath ( ) )
if err != nil {
return true , fmt . Errorf ( "opening fullchain %q: %w" , c . FullchainPath ( ) , err )
}
defer fc . Close ( )
fcBytes , err := io . ReadAll ( fc )
if err != nil {
return true , fmt . Errorf ( "reading fullchain %q: %w" , c . FullchainPath ( ) , err )
}
fcB , _ := pem . Decode ( fcBytes )
if fcB == nil {
return true , fmt . Errorf ( "fullchain %q contained non-PEM bits" , c . FullchainPath ( ) )
} else if fcB . Type != "CERTIFICATE" {
return true , fmt . Errorf ( "fullchain %q contained non-CERTIFICATE PEM block %q" , c . FullchainPath ( ) , fcB . Type )
}
crt , err := x509 . ParseCertificate ( fcB . Bytes )
if err != nil {
return true , fmt . Errorf ( "fullchain %q wasn't parsable as x509" , c . FullchainPath ( ) )
}
// Check that the certificate hasn't expired.
now := time . Now ( )
expiryThreshold := now . Add ( * acmeCertificatesExpiryThreshold )
log . Infof ( "certificate in %q expires at %v (%v from now); expiry cutoff is %v" , c . FullchainPath ( ) , crt . NotAfter , crt . NotAfter . Sub ( now ) , expiryThreshold )
if crt . NotAfter . Before ( expiryThreshold ) {
return true , fmt . Errorf ( "certificate in %q expires at %v (%v from now); expiry cutoff is %v" , c . FullchainPath ( ) , crt . NotAfter , crt . NotAfter . Sub ( now ) , expiryThreshold )
}
// Check that the hostnames match.
if ! setsSame ( c . Hostnames , crt . DNSNames ) {
return true , fmt . Errorf ( "certificate DNSNames %q don't match wanted hostnames %q" , crt . DNSNames , c . Hostnames )
}
if * acmeCertificatesForceRenew {
return true , fmt . Errorf ( "force renew flag set" )
}
return false , nil
}
func writeCertificate ( certDef acmeCertificate , cert * vapi . Secret ) error {
restoreGroup , err := setGroup ( certDef . Group )
if err != nil {
return fmt . Errorf ( "setting group to write output: %w" , err )
}
defer restoreGroup ( )
setFiles := [ ] struct {
name string
content [ ] byte
perm fs . FileMode
} { {
name : certDef . FullchainPath ( ) ,
content : [ ] byte ( cert . Data [ "cert" ] . ( string ) ) ,
perm : 0644 ,
} , {
name : certDef . ChainPath ( ) ,
content : [ ] byte ( cert . Data [ "issuer_cert" ] . ( string ) ) ,
perm : 0644 ,
} , {
name : certDef . PrivkeyPath ( ) ,
content : [ ] byte ( cert . Data [ "private_key" ] . ( string ) ) ,
perm : 0640 ,
} }
for _ , sf := range setFiles {
log . Infof ( "writing file %v mode %s" , sf . name , sf . perm )
os . Remove ( sf . name ) // optimistically try to remove the file, we don't care if it succeeds
// if it doesn't, we'll error when we try to open it
f , err := os . OpenFile ( sf . name , os . O_WRONLY | os . O_CREATE | os . O_EXCL , sf . perm )
if err != nil {
return fmt . Errorf ( "opening %v: %w" , sf . name , err )
}
if _ , err := f . Write ( sf . content ) ; err != nil {
f . Close ( )
os . Remove ( sf . name )
return fmt . Errorf ( "writing %v: %w" , sf . name , err )
}
if err := f . Close ( ) ; err != nil {
return fmt . Errorf ( "closing %v: %w" , sf . name , err )
}
}
return nil
}
2022-03-17 23:31:55 +00:00
func uniqueStrings ( vs [ ] string ) [ ] string {
m := map [ string ] bool { }
for _ , v := range vs {
m [ v ] = true
}
x := make ( [ ] string , 0 , len ( m ) )
for v := range m {
x = append ( x , v )
}
return x
}
2022-03-15 03:07:34 +00:00
func checkAndRenewACMECertificates ( ctx context . Context , c * vapi . Client ) bool {
2022-03-17 01:26:18 +00:00
certDefs , err := parseACMECertificates ( )
if err != nil {
log . Errorf ( "failed to parse certificates: %v" , err )
return true
}
log . Infof ( "acme certificate defs: %#v" , certDefs )
var hadErr bool
2022-03-17 23:31:55 +00:00
var unitsToReloadOrRestart [ ] string
2022-03-17 01:26:18 +00:00
for _ , certDef := range certDefs {
log . Infof ( "maybe issuing %v (role: %v) with hostnames %v into %v" , certDef . OutputName , certDef . Role , certDef . Hostnames , certDef . OutputDir ( ) )
if shouldRenew , err := shouldRenewACMECert ( certDef ) ; shouldRenew && err == nil {
log . Infof ( "shouldRenewACMECert for %v returned true with no reason" , certDef . OutputName )
} else if shouldRenew {
log . Infof ( "shouldRenewACMECert for %v returned true: %v" , certDef . OutputName , err )
} else if err != nil {
log . Errorf ( "shouldRenewACMECert for %v returned false and an error: %v" , certDef . OutputName , err )
hadErr = true
continue
} else if ! shouldRenew {
log . Infof ( "shouldRenewACMECert for %v returned false" , certDef . OutputName )
continue
}
if err := makeOrFixPermissionsOnDir ( certDef . OutputDir ( ) , 0750 , certDef . Group ) ; err != nil {
2022-03-17 23:31:55 +00:00
log . Errorf ( "creating/fixing dir %q: %v" , certDef . OutputDir ( ) , err )
2022-03-17 01:26:18 +00:00
hadErr = true
continue
}
// Issue certificate.
names := certDef . SortedHostnames ( )
commonName := names [ 0 ]
names = names [ 1 : ]
cert , err := c . Logical ( ) . Write ( fmt . Sprintf ( "%s/certs/%s" , * acmeCertificatesMount , certDef . Role ) , map [ string ] interface { } {
"common_name" : commonName ,
"alternative_names" : strings . Join ( names , "," ) ,
} )
if err != nil {
log . Errorf ( "issuing %v (role: %v) with hostnames %v into %v: %v" , certDef . OutputName , certDef . Role , certDef . Hostnames , certDef . OutputDir ( ) , err )
hadErr = true
continue
}
// Write out certificate.
if err := writeCertificate ( certDef , cert ) ; err != nil {
log . Errorf ( "writing out certificate %v: %w" , certDef . OutputName , err )
hadErr = true
continue
}
2022-03-17 23:31:55 +00:00
unitsToReloadOrRestart = append ( unitsToReloadOrRestart , certDef . UnitsToReloadOrRestart ... )
}
unitsToReloadOrRestart = uniqueStrings ( unitsToReloadOrRestart )
if len ( unitsToReloadOrRestart ) > 0 {
log . Infof ( "reloading/restarting %d units: %v" , len ( unitsToReloadOrRestart ) , unitsToReloadOrRestart )
if err := reloadOrTryRestartServices ( ctx , unitsToReloadOrRestart ) ; err != nil {
log . Errorf ( "failed to restart units: %v" , err )
hadErr = true
}
2022-03-17 01:26:18 +00:00
}
return hadErr
2022-03-15 03:07:34 +00:00
}
func main ( ) {
flag . Parse ( )
cfg := vapi . DefaultConfig ( )
cfg . Address = "https://vault.int.lukegb.com"
2022-03-20 17:47:52 +00:00
cfg . AgentAddress = "unix:///run/vault-agent/sock"
2022-03-15 03:07:34 +00:00
cfg . MaxRetries = 0
2022-03-17 01:26:18 +00:00
cfg . Timeout = 15 * time . Minute
2022-03-15 03:07:34 +00:00
c , err := vapi . NewClient ( cfg )
if err != nil {
log . Exitf ( "failed to create vault client: %v" , err )
}
ctx := context . Background ( )
var hadErr bool
hadErr = checkAndRenewSSHCertificates ( ctx , c ) || hadErr
hadErr = checkAndRenewACMECertificates ( ctx , c ) || hadErr
if hadErr {
os . Exit ( 1 )
}
}