# SPDX-FileCopyrightText: 2020 Luke Granger-Brown <depot@lukegb.com>
#
# SPDX-License-Identifier: Apache-2.0

{ pkgs, config, depot, lib, ... }:
let
  inherit (lib) mkOption types mkBefore optionalAttrs mkDefault;

  acmeCertificates = lib.mapAttrsToList (name: cOrig: cOrig // { inherit name; }) config.my.vault.acmeCertificates;

  # Work out where we're being asked to write things, and which groups, so we can correctly get permissions.
  fullchainPath = c: pathFor c.fullchain c "fullchain.pem";
  chainPath = c: pathFor c.chain c "chain.pem";
  keyPath = c: pathFor c.key c "privkey.pem";
  pathFor = p: c: suffix: if isNull p.path then "/var/lib/acme/${c.name}/${suffix}" else p.path;

  isNginx = c: builtins.length c.nginxVirtualHosts > 0;
  defaultGroup = c: if isNull c.group then if isNginx c then "nginx" else "acme" else c.group;
  groupOrDefault = p: c: if isNull p then defaultGroup c else p;

  reloadOrRestartUnits = c: (lib.optional (isNginx c) "nginx.service") ++ c.reloadOrRestartUnits;

  acmeCertificatesGroups = lib.unique (lib.filter (x: x != "") (builtins.concatMap (c: [
    (groupOrDefault c.fullchain.group c)
    (groupOrDefault c.chain.group c)
    (groupOrDefault c.key.group c)
  ]) acmeCertificates));

  acmeCertificatesTemplate = map (c: {
    contents = ''
      {{with secret "acme/certs/${c.role}" "common_name=${c.name}" "alternative_names=${builtins.concatStringsSep "," (builtins.sort builtins.lessThan c.extraNames)}"}}
      {{ .Data.cert        | writeToFile "${fullchainPath c}" "vault-agent" "${groupOrDefault c.fullchain.group c}" "${c.fullchain.mode}" "newline" }}
      {{ .Data.issuer_cert | writeToFile "${chainPath c}"     "vault-agent" "${groupOrDefault c.chain.group c}"     "${c.chain.mode}"     "newline" }}
      {{ .Data.private_key | writeToFile "${keyPath c}"       "vault-agent" "${groupOrDefault c.key.group c}"       "${c.key.mode}"       "newline" }}
      {{ end }}
    '';
    destination = "/var/lib/acme/${c.name}/token";
    perms = "0600";
    command = let
      grp = groupOrDefault c.fullchain.group c;
    in pkgs.writeShellScript "post-${c.name}-crt" ''
      ${lib.concatMapStringsSep "\n" (x: ''
        /run/current-system/sw/bin/systemctl reload-or-restart ${x}
      '') (reloadOrRestartUnits c)}
      ${lib.concatMapStringsSep "\n" (x: ''
        /run/current-system/sw/bin/systemctl restart ${x}
      '') c.restartUnits}
      ${lib.optionalString (c.command != "") c.command}
    '';
  }) acmeCertificates;

  extraWritableDirs = lib.unique (builtins.concatMap (c: [
    (dirOf (fullchainPath c))
    (dirOf (chainPath c))
    (dirOf (keyPath c))
  ]) acmeCertificates);
  acmeCertificatesTmpdirs = lib.unique (builtins.concatMap (c:
    let
      fullchainDir = dirOf (fullchainPath c);
      chainDir     = dirOf (chainPath c);
      keyDir       = dirOf (keyPath c);

      fullchainGroup = groupOrDefault c.fullchain.group c;
      chainGroup     = groupOrDefault c.chain.group c;
      keyGroup       = groupOrDefault c.key.group c;

      dirGroup = if fullchainDir == keyDir && chainDir == keyDir && c.fullchain.makeDir && c.chain.makeDir && c.key.makeDir then if fullchainGroup == keyGroup && fullchainGroup == chainGroup then fullchainGroup else "-" else null;

      fullchainDirGroup = if isNull dirGroup then fullchainGroup else dirGroup;
      chainDirGroup     = if isNull dirGroup then chainGroup else dirGroup;
      keyDirGroup       = if isNull dirGroup then keyGroup else dirGroup;
    in lib.optional c.fullchain.makeDir "d ${fullchainDir} 0750 vault-agent ${fullchainDirGroup} - -"
    ++ lib.optional c.chain.makeDir     "d ${chainDir}     0750 vault-agent ${chainDirGroup}     - -"
    ++ lib.optional c.key.makeDir       "d ${keyDir}       0750 vault-agent ${keyDirGroup}       - -"
    ++ [ "d /var/lib/acme/${c.name} 0750 vault-agent - -" ]
  ) acmeCertificates);

  allRestartableUnits = lib.unique (builtins.concatMap (c: (reloadOrRestartUnits c) ++ c.restartUnits) acmeCertificates);
in
{
  imports = [
    ./vault-agent.nix
  ];

  options.my.vault.acmeCertificates = mkOption {
    type = with types; attrsOf (submodule {
      options = let
        fileType = what: defaultMode: submodule {
          options = {
            path = mkOption {
              type = nullOr path;
              default = null;
              description = "Path to put the ${what}.";
            };
            mode = mkOption {
              type = str;
              default = defaultMode;
              description = "Mode to set for the ${what}.";
            };

            group = mkOption {
              type = nullOr str;
              default = null;
              description = "Owner group to set for the ${what}. If null, taken from parent.";
            };

            makeDir = mkOption {
              type = bool;
              default = true;
              description = "If true, creates the parent directory.";
            };
          };
        };
      in {
        role = mkOption { 
          type = str;
          default = "letsencrypt-cloudflare";
          description = "Which role to use for certificate issuance.";
        };
        extraNames = mkOption {
          type = listOf str;
          default = [];
          description = "Non-empty list of hostnames to include.";
        };

        command = mkOption {
          type = lines;
          default = "";
          description = "Command to run after generating the certificate.";
        };
        reloadOrRestartUnits = mkOption {
          type = listOf str;
          default = [];
          description = "List of systemd units to reload/restart after obtaining a new certificate.";
        };
        restartUnits = mkOption {
          type = listOf str;
          default = [];
          description = "List of systemd units to restart after obtaining a new certificate.";
        };

        nginxVirtualHosts = mkOption {
          type = listOf str;
          default = [];
          description = "List of nginx virtual hosts to apply SSL to.";
        };

        group = mkOption {
          type = nullOr str;
          default = null;
          description = "Owner group to set for the generated files. Defaults to 'acme' unless nginxVirtualHosts is set, in which case it defaults to 'nginx'.";
        };

        fullchain = mkOption {
          type = fileType "certificate's full chain" "0644";
          default = {};
        };
        chain = mkOption {
          type = fileType "certificate chain only" "0644";
          default = {};
        };
        key = mkOption {
          type = fileType "certificate's key" "0640";
          default = {};
        };
      };
    });
    default = {};
  };

  config = {
    my.vault.settings = {
      # TODO: lukegb: figure out how to not get this to DoS Let's Encrypt.
      #template = mkBefore acmeCertificatesTemplate;
    };

    systemd = optionalAttrs config.my.vault.enable {
      services.vault-agent = {
        serviceConfig = {
          SupplementaryGroups = mkBefore acmeCertificatesGroups;
          ReadWritePaths = mkBefore extraWritableDirs;
        };
      };

      tmpfiles.rules = acmeCertificatesTmpdirs;
    };

    services.nginx = optionalAttrs config.my.vault.enable {
      virtualHosts = builtins.listToAttrs (builtins.concatMap (certData: let
        fullchain = fullchainPath certData;
        chain     = chainPath certData;
        key       = keyPath certData;
      in map (hostName: lib.nameValuePair hostName {
        sslCertificate        = mkDefault (fullchainPath certData);
        sslCertificateKey     = mkDefault (keyPath certData);
        sslTrustedCertificate = mkDefault (chainPath certData);
      }) certData.nginxVirtualHosts) acmeCertificates);
    };

    security.polkit.extraConfig = lib.mkAfter ''
      // NixOS module: depot/lib/vault-agent-acme.nix
      polkit.addRule(function(action, subject) {
        if (action.id !== "org.freedesktop.systemd1.manage-units" ||
            subject.user !== "vault-agent") {
          return polkit.Result.NOT_HANDLED;
        }

        var verb = action.lookup("verb");
        if (verb !== "restart" && verb !== "reload-or-restart") {
          return polkit.Result.NOT_HANDLED;
        }

        var allowedUnits = ${builtins.toJSON allRestartableUnits};
        var unit = action.lookup("unit");
        for (var i = 0; i < allowedUnits.length; i++) {
          if (allowedUnits[i] === unit) {
            return polkit.Result.YES;
          }
        }
        return polkit.Result.NOT_HANDLED;
      });
    '';
  };
}