package main

import (
	"bufio"
	"bytes"
	"context"
	"crypto/x509"
	"encoding/json"
	"encoding/pem"
	"errors"
	"flag"
	"fmt"
	"io"
	"io/fs"
	"os"
	"os/exec"
	"os/user"
	"path/filepath"
	"regexp"
	"sort"
	"strconv"
	"strings"
	"syscall"
	"time"

	"github.com/coreos/go-systemd/v22/dbus"
	log "github.com/golang/glog"
	vapi "github.com/hashicorp/vault/api"
	"golang.org/x/crypto/ssh"
	"golang.org/x/sys/unix"
)

var (
	vaultAddress      = flag.String("vault_address", "https://vault.int.lukegb.com", "Address of Vault")
	vaultAgentAddress = flag.String("vault_agent_address", "unix:///run/vault-agent/sock", "Address of Vault agent")

	signSSHHostKeys           = flag.Bool("sign_ssh_host_keys", true, "Sign SSH host keys with CA")
	sshHostKeyCAPath          = flag.String("ssh_host_key_ca_path", "ssh-host", "Path that the SSH CA is mounted at")
	sshHostKeyRole            = flag.String("ssh_host_key_role", hostname(), "Role to use for signing SSH host keys")
	sshHostKeyPrincipals      = flag.String("ssh_host_key_principals", defaultPrincipals(hostname()), "Principals to use for SSH host key certificate")
	sshHostKeyExpiryThreshold = flag.Duration("ssh_host_key_expiry_threshold", 3*24*time.Hour, "Expiry threshold for SSH host key")
	sshHostKeyOutputDir       = flag.String("ssh_host_key_output_dir", "/var/lib/secretsmgr/ssh", "Path to SSH host key output dir")
	sshDummyHostKey           = flag.String("ssh_dummy_host_key", "", "Path to dummy SSH hostkey")
	sshd                      = flag.String("sshd", "sshd", "Path to sshd")

	acmeCertificatesConfig          = flag.String("acme_certificates_config", "", "Path to ACME certificates config")
	acmeCertificatesMount           = flag.String("acme_certificates_mount", "acme", "Path that ACME is mounted at")
	acmeCertificatesExpiryThreshold = flag.Duration("acme_certificates_expiry_threshold", 31*24*time.Hour, "Expiry threshold for ACME certificates")
	acmeCertificatesForceRenew      = flag.Bool("acme_certificates_force_renew", false, "Ignore remaining time and force ACME certificate renewal")
)

func hostname() string {
	name, err := os.Hostname()
	if err != nil {
		return ""
	}
	return name
}

func defaultPrincipals(hostname string) string {
	return fmt.Sprintf("%s.as205479.net,%s.int.as205479.net", hostname, hostname)
}

func sshHostKeyPaths(ctx context.Context) ([]string, error) {
	args := []string{"-T"}
	if *sshDummyHostKey != "" {
		args = append(args, "-h", *sshDummyHostKey)
	}
	out, err := exec.CommandContext(ctx, *sshd, args...).Output()
	if err != nil {
		return nil, err
	}

	scanner := bufio.NewScanner(bytes.NewReader(out))
	var keys []string
	const hostkeyPrefix = "hostkey "
	for scanner.Scan() {
		t := scanner.Text()
		if strings.HasPrefix(t, hostkeyPrefix) {
			key := strings.TrimPrefix(t, hostkeyPrefix)
			if key == *sshDummyHostKey {
				continue
			}
			keys = append(keys, fmt.Sprintf("%s.pub", key))
		}
	}
	if scanner.Err() != nil {
		return nil, scanner.Err()
	}
	return keys, nil
}

type acmeCertificate struct {
	Role                   string   `json:"role"`
	OutputName             string   `json:"output_name"`
	Hostnames              []string `json:"hostnames"`
	UnitsToReloadOrRestart []string `json:"units_to_reload_or_restart"`
	Group                  string   `json:"group"`
}

func (c acmeCertificate) OutputDir() string {
	return filepath.Join("/var/lib/acme", c.OutputName)
}

func (c acmeCertificate) FullchainPath() string {
	return filepath.Join(c.OutputDir(), "fullchain.pem")
}

func (c acmeCertificate) ChainPath() string {
	return filepath.Join(c.OutputDir(), "chain.pem")
}

func (c acmeCertificate) PrivkeyPath() string {
	return filepath.Join(c.OutputDir(), "privkey.pem")
}

func (c acmeCertificate) SortedHostnames() []string {
	h := make([]string, len(c.Hostnames))
	copy(h, c.Hostnames)
	sort.Strings(h)
	return h
}

var acmeCertBits = regexp.MustCompile(`^(?P<role>[^/]+)/(?P<shortname>[^=]+)=(?P<hostnames>[^/]+)(?:/(?P<units>[^/]+))?$`)

func splitDiscardEmpty(s string, sep string) []string {
	bits := strings.Split(s, sep)
	out := make([]string, 0, len(bits))
	for _, bit := range bits {
		bit = strings.TrimSpace(bit)
		if bit == "" {
			continue
		}
		out = append(out, bit)
	}
	return out
}

func parseACMECertificates() ([]acmeCertificate, error) {
	if *acmeCertificatesConfig == "" {
		return nil, nil
	}

	cfgFile, err := os.Open(*acmeCertificatesConfig)
	if err != nil {
		return nil, err
	}
	defer cfgFile.Close()

	var cs []acmeCertificate
	if err := json.NewDecoder(cfgFile).Decode(&cs); err != nil {
		return nil, err
	}

	return cs, nil
}

func readFile(p string) ([]byte, error) {
	f, err := os.Open(p)
	if err != nil {
		return nil, err
	}
	defer f.Close()

	return io.ReadAll(f)
}

func writeFile(p string, content []byte) error {
	f, err := os.Create(p)
	if err != nil {
		return err
	}
	if _, err := f.Write(content); err != nil {
		f.Close()
		return err
	}
	if err := f.Close(); err != nil {
		return err
	}
	return nil
}

func readSSHPublicKey(p string) (ssh.PublicKey, error) {
	bs, err := readFile(p)
	if err != nil {
		return nil, fmt.Errorf("reading public key from %v: %w", p, err)
	}

	out, _, _, _, err := ssh.ParseAuthorizedKey(bs)
	if err != nil {
		return nil, fmt.Errorf("parsing public key in %v: %w", p, err)
	}

	return out, nil
}

func shouldRenewSSHCert(certPath string, publicKey []byte, principals string) (bool, error) {
	certPK, err := readSSHPublicKey(certPath)
	if err != nil {
		return true, err
	}

	cert, ok := certPK.(*ssh.Certificate)
	if !ok {
		return true, fmt.Errorf("file at %v was not an SSH certificate (was a %T)", certPath, certPK)
	}

	pubKey, _, _, _, err := ssh.ParseAuthorizedKey(publicKey)
	if err != nil {
		return false, err
	}

	// Verify that this is the same key.
	if !bytes.Equal(pubKey.Marshal(), cert.Key.Marshal()) {
		return true, fmt.Errorf("certificate key and public key don't match")
	}

	// Verify the certificate isn't expiring too soon.
	unixNow := time.Now().Unix()
	expiryThreshold := uint64(unixNow + int64(sshHostKeyExpiryThreshold.Seconds()))
	certExpiryDate := time.Unix(int64(cert.ValidBefore), 0)
	log.Infof("certificate %v expires at %v (in %v)", certPath, certExpiryDate, certExpiryDate.Sub(time.Now()))
	if cert.ValidBefore != uint64(ssh.CertTimeInfinity) && cert.ValidBefore < expiryThreshold {
		return true, fmt.Errorf("certificate expires at %v, which is earlier than our threshold of %v", certExpiryDate, time.Unix(int64(expiryThreshold), 0))
	}

	hasPrincipals := map[string]bool{}
	for _, p := range cert.ValidPrincipals {
		hasPrincipals[p] = true
	}
	for _, wantP := range strings.Split(principals, ",") {
		if !hasPrincipals[wantP] {
			return true, fmt.Errorf("certificate missing principal %v (has principals %v; want principals %v)", wantP, cert.ValidPrincipals, principals)
		}
	}

	return false, nil
}

func reloadOrTryRestartServices(ctx context.Context, services []string) error {
	sd, err := dbus.NewWithContext(ctx)
	if err != nil {
		return fmt.Errorf("failed to connect to systemd: %w", err)
	}
	defer sd.Close()

	var failedUnits []string
	for _, service := range services {
		_, err := sd.ReloadOrTryRestartUnit(service, "replace", nil)
		if err != nil {
			log.Errorf("unit %v failed to reload-or-try-restart: %v", service, err)
			failedUnits = append(failedUnits, service)
		}
	}

	if len(failedUnits) != 0 {
		return fmt.Errorf("some units failed to restart: %v", failedUnits)
	}
	return nil
}

func checkAndRenewSSHCertificates(ctx context.Context, c *vapi.Client) bool {
	sshHostKeys, err := sshHostKeyPaths(ctx)
	if err != nil {
		log.Errorf("sshHostKeyPaths: %v", err)
		return true
	}

	cs := c.SSHWithMountPoint(*sshHostKeyCAPath)
	var hadErr bool
	var didRenew bool
	for _, sshHostKeyPath := range sshHostKeys {
		outPath := filepath.Join(*sshHostKeyOutputDir, fmt.Sprintf("%s-cert.pub", strings.TrimSuffix(filepath.Base(sshHostKeyPath), ".pub")))
		log.Infof("might sign %v into %v", sshHostKeyPath, outPath)

		sshHostKey, err := readFile(sshHostKeyPath)
		if err != nil {
			log.Errorf("reading host key %v: %v", sshHostKeyPath, err)
			hadErr = true
			continue
		}

		if shouldRenew, err := shouldRenewSSHCert(outPath, sshHostKey, *sshHostKeyPrincipals); shouldRenew && err == nil {
			log.Infof("shouldRenewSSHCert for %v returned true with no reason", outPath)
		} else if shouldRenew {
			log.Infof("shouldRenewSSHCert for %v returned true: %v", outPath, err)
		} else if err != nil {
			log.Errorf("shouldRenewSSHCert for %v returned false and an error: %v", outPath, err)
			hadErr = true
			continue
		} else if !shouldRenew {
			log.Infof("shouldRenewSSHCert for %v returned false", outPath)
			continue
		}

		secret, err := cs.SignKey(*sshHostKeyRole, map[string]interface{}{
			"cert_type":        "host",
			"public_key":       string(sshHostKey),
			"valid_principals": *sshHostKeyPrincipals,
		})
		if err != nil {
			log.Errorf("signing %v: %v", sshHostKeyPath, err)
			hadErr = true
			continue
		}

		if err := writeFile(outPath, []byte(secret.Data["signed_key"].(string))); err != nil {
			log.Errorf("writing signed certificate to %v: %v", outPath, err)
			hadErr = true
			continue
		}
		log.Infof("renewed %v", outPath)
		didRenew = true
	}

	if didRenew {
		log.Infof("completed renewals; restarting SSHd")
		if err := reloadOrTryRestartServices(ctx, []string{"sshd.service"}); err != nil {
			log.Errorf("failed to restart sshd: %v", err)
			hadErr = true
		}
	}

	return hadErr
}

func setsSame(a, b []string) bool {
	s := map[string]int{}
	for _, v := range a {
		s[v]++
	}
	for _, v := range b {
		s[v]++
	}
	for _, n := range s {
		if n != 2 {
			return false
		}
	}
	return true
}

func getGroupID(groupName string) (int, error) {
	g, err := user.LookupGroup(groupName)
	if err != nil {
		return 0, fmt.Errorf("looking up group %q: %w", groupName, err)
	}

	gid, err := strconv.Atoi(g.Gid)
	if err != nil {
		return 0, fmt.Errorf("converting group %q to int: %w", g.Gid, err)
	}
	return gid, nil
}

func setGroup(groupName string) (func(), error) {
	if groupName == "" {
		return func() {}, nil
	}

	gid, err := getGroupID(groupName)
	if err != nil {
		return nil, err
	}

	origGid := unix.Getegid()
	if err := unix.Setregid(-1, gid); err != nil {
		return nil, err
	}
	return func() {
		unix.Setregid(-1, origGid)
	}, nil
}

func makeOrFixPermissionsOnDir(p string, permissions fs.FileMode, group string) error {
	outDirFi, err := os.Stat(p)
	if errors.Is(err, fs.ErrNotExist) {
		restoreGroup, err := setGroup(group)
		if err != nil {
			return fmt.Errorf("setting group to create outdir: %w", err)
		}
		defer restoreGroup()

		return os.Mkdir(p, permissions)
	} else if err != nil {
		return fmt.Errorf("stating output dir %q: %w", p, err)
	}

	if (outDirFi.Mode() & fs.ModePerm) != permissions {
		if err := os.Chmod(p, permissions); err != nil {
			return fmt.Errorf("chmodding output dir %q: %w", p, err)
		}
	}

	outDirStat, ok := outDirFi.Sys().(*syscall.Stat_t)
	if !ok {
		return fmt.Errorf("outDirFi isn't a stat_t, is a %T", outDirFi)
	}

	gid, err := getGroupID(group)
	if err != nil {
		return err
	}

	if outDirStat.Gid != uint32(gid) {
		if err := os.Chown(p, -1, gid); err != nil {
			return fmt.Errorf("chgrping output dir %q: %w", p, err)
		}
	}

	return nil
}

func shouldRenewACMECert(c acmeCertificate) (bool, error) {
	outDirFi, err := os.Stat(c.OutputDir())
	if errors.Is(err, fs.ErrNotExist) {
		return true, fmt.Errorf("output dir %q doesn't exist", c.OutputDir())
	} else if err != nil {
		return false, fmt.Errorf("stating output dir %q: %w", c.OutputDir(), err)
	}

	if !outDirFi.IsDir() {
		log.Warningf("output dir %q isn't a dir; deleting it")
		if err := os.Remove(c.OutputDir()); err != nil {
			return false, fmt.Errorf("couldn't remove non-dir in place of output dir %q: %w", c.OutputDir(), err)
		}
		return true, fmt.Errorf("output dir %q wasn't a dir")
	}

	// Check that all the output files exist.
	for _, f := range []string{c.FullchainPath(), c.ChainPath(), c.PrivkeyPath()} {
		if _, err := os.Stat(f); errors.Is(err, fs.ErrNotExist) {
			return true, fmt.Errorf("output file %q doesn't exist", f)
		} else if err != nil {
			return false, fmt.Errorf("stating output file %q: %w", f, err)
		}
	}

	fc, err := os.Open(c.FullchainPath())
	if err != nil {
		return true, fmt.Errorf("opening fullchain %q: %w", c.FullchainPath(), err)
	}
	defer fc.Close()

	fcBytes, err := io.ReadAll(fc)
	if err != nil {
		return true, fmt.Errorf("reading fullchain %q: %w", c.FullchainPath(), err)
	}

	fcB, _ := pem.Decode(fcBytes)
	if fcB == nil {
		return true, fmt.Errorf("fullchain %q contained non-PEM bits", c.FullchainPath())
	} else if fcB.Type != "CERTIFICATE" {
		return true, fmt.Errorf("fullchain %q contained non-CERTIFICATE PEM block %q", c.FullchainPath(), fcB.Type)
	}

	crt, err := x509.ParseCertificate(fcB.Bytes)
	if err != nil {
		return true, fmt.Errorf("fullchain %q wasn't parsable as x509", c.FullchainPath())
	}

	// Check that the certificate hasn't expired.
	now := time.Now()
	expiryThreshold := now.Add(*acmeCertificatesExpiryThreshold)
	log.Infof("certificate in %q expires at %v (%v from now); expiry cutoff is %v", c.FullchainPath(), crt.NotAfter, crt.NotAfter.Sub(now), expiryThreshold)
	if crt.NotAfter.Before(expiryThreshold) {
		return true, fmt.Errorf("certificate in %q expires at %v (%v from now); expiry cutoff is %v", c.FullchainPath(), crt.NotAfter, crt.NotAfter.Sub(now), expiryThreshold)
	}

	// Check that the hostnames match.
	if !setsSame(c.Hostnames, crt.DNSNames) {
		return true, fmt.Errorf("certificate DNSNames %q don't match wanted hostnames %q", crt.DNSNames, c.Hostnames)
	}

	if *acmeCertificatesForceRenew {
		return true, fmt.Errorf("force renew flag set")
	}
	return false, nil
}

func writeCertificate(certDef acmeCertificate, cert *vapi.Secret) error {
	setFiles := []struct {
		name    string
		content []byte
		perm    fs.FileMode
	}{{
		name:    certDef.FullchainPath(),
		content: []byte(cert.Data["cert"].(string)),
		perm:    0644,
	}, {
		name:    certDef.ChainPath(),
		content: []byte(cert.Data["issuer_cert"].(string)),
		perm:    0644,
	}, {
		name:    certDef.PrivkeyPath(),
		content: []byte(cert.Data["private_key"].(string)),
		perm:    0640,
	}}

	for _, sf := range setFiles {
		os.Remove(sf.name) // optimistically try to remove the file, we don't care if it succeeds
		// if it doesn't, we'll error when we try to open it
	}

	restoreGroup, err := setGroup(certDef.Group)
	if err != nil {
		return fmt.Errorf("setting group to write output: %w", err)
	}
	defer restoreGroup()

	for _, sf := range setFiles {
		log.Infof("writing file %v mode %s group %s", sf.name, sf.perm, certDef.Group)

		f, err := os.OpenFile(sf.name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, sf.perm)
		if err != nil {
			return fmt.Errorf("opening %v: %w", sf.name, err)
		}

		if _, err := f.Write(sf.content); err != nil {
			f.Close()
			os.Remove(sf.name)
			return fmt.Errorf("writing %v: %w", sf.name, err)
		}

		if err := f.Close(); err != nil {
			return fmt.Errorf("closing %v: %w", sf.name, err)
		}
	}

	return nil
}

func uniqueStrings(vs []string) []string {
	m := map[string]bool{}
	for _, v := range vs {
		m[v] = true
	}
	x := make([]string, 0, len(m))
	for v := range m {
		x = append(x, v)
	}
	return x
}

func checkAndRenewACMECertificates(ctx context.Context, c *vapi.Client) bool {
	certDefs, err := parseACMECertificates()
	if err != nil {
		log.Errorf("failed to parse certificates: %v", err)
		return true
	}
	log.Infof("acme certificate defs: %#v", certDefs)

	var hadErr bool
	var unitsToReloadOrRestart []string
	for _, certDef := range certDefs {
		log.Infof("maybe issuing %v (role: %v) with hostnames %v into %v", certDef.OutputName, certDef.Role, certDef.Hostnames, certDef.OutputDir())

		if shouldRenew, err := shouldRenewACMECert(certDef); shouldRenew && err == nil {
			log.Infof("shouldRenewACMECert for %v returned true with no reason", certDef.OutputName)
		} else if shouldRenew {
			log.Infof("shouldRenewACMECert for %v returned true: %v", certDef.OutputName, err)
		} else if err != nil {
			log.Errorf("shouldRenewACMECert for %v returned false and an error: %v", certDef.OutputName, err)
			hadErr = true
			continue
		} else if !shouldRenew {
			log.Infof("shouldRenewACMECert for %v returned false", certDef.OutputName)
			continue
		}

		if err := makeOrFixPermissionsOnDir(certDef.OutputDir(), 0750, certDef.Group); err != nil {
			log.Errorf("creating/fixing dir %q: %v", certDef.OutputDir(), err)
			hadErr = true
			continue
		}

		// Issue certificate.
		names := certDef.SortedHostnames()
		commonName := names[0]
		names = names[1:]
		cert, err := c.Logical().Write(fmt.Sprintf("%s/certs/%s", *acmeCertificatesMount, certDef.Role), map[string]interface{}{
			"common_name":       commonName,
			"alternative_names": strings.Join(names, ","),
		})
		if err != nil {
			log.Errorf("issuing %v (role: %v) with hostnames %v into %v: %v", certDef.OutputName, certDef.Role, certDef.Hostnames, certDef.OutputDir(), err)
			hadErr = true
			continue
		}

		// Write out certificate.
		if err := writeCertificate(certDef, cert); err != nil {
			log.Errorf("writing out certificate %v: %w", certDef.OutputName, err)
			hadErr = true
			continue
		}

		unitsToReloadOrRestart = append(unitsToReloadOrRestart, certDef.UnitsToReloadOrRestart...)
	}
	unitsToReloadOrRestart = uniqueStrings(unitsToReloadOrRestart)
	if len(unitsToReloadOrRestart) > 0 {
		log.Infof("reloading/restarting %d units: %v", len(unitsToReloadOrRestart), unitsToReloadOrRestart)
		if err := reloadOrTryRestartServices(ctx, unitsToReloadOrRestart); err != nil {
			log.Errorf("failed to restart units: %v", err)
			hadErr = true
		}
	}

	return hadErr
}

func main() {
	flag.Parse()

	cfg := vapi.DefaultConfig()
	cfg.Address = "https://vault.int.lukegb.com"
	cfg.AgentAddress = "unix:///run/vault-agent/sock"
	cfg.MaxRetries = 0
	cfg.Timeout = 15 * time.Minute
	c, err := vapi.NewClient(cfg)
	if err != nil {
		log.Exitf("failed to create vault client: %v", err)
	}

	ctx := context.Background()

	var hadErr bool
	hadErr = checkAndRenewSSHCertificates(ctx, c) || hadErr
	hadErr = checkAndRenewACMECertificates(ctx, c) || hadErr

	if hadErr {
		os.Exit(1)
	}
}