package main

import (
	cryptorand "crypto/rand"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"io"
	"io/fs"
	"log"
	"net"
	"os"
	"os/exec"
	"os/user"
	"path/filepath"
	"strings"
	"time"

	vapi "github.com/hashicorp/vault/api"
	"golang.org/x/crypto/ed25519"
	"golang.org/x/crypto/ssh"
	"golang.org/x/crypto/ssh/agent"
)

var (
	vaultAddress  = flag.String("vault_address", "https://vault.int.lukegb.com", "Address of Vault")
	sshMountPoint = flag.String("ssh_mount_point", "ssh-client", "SSH mount point in Vault")
	sshRole       = flag.String("ssh_role", "user", "SSH role")
	sshPrincipal  = flag.String("principal", currentUsername(), "principal to request in certificate")

	requirePresence = flag.Bool("require_presence", shouldRequirePresence(), "whether to require presence when using certificate")
)

func currentUsername() string {
	// What's our principal?
	u, err := user.Current()
	if err != nil {
		log.Fatalf("looking up current user: %v", err)
	}
	return u.Username
}

func shouldRequirePresence() bool {
	// WSL2 makes things hard.
	return os.Getenv("WSL_DISTRO_NAME") == ""
}

const (
	sshAgentComment = "vault certificate"
)

type agentConn struct {
	agent.ExtendedAgent

	conn io.ReadWriteCloser
}

func (c *agentConn) Close() error { return c.conn.Close() }

func connectToAgent() (*agentConn, error) {
	socket := os.Getenv("SSH_AUTH_SOCK")
	if socket == "" {
		return nil, fmt.Errorf("no SSH_AUTH_SOCK in environment")
	}
	conn, err := net.Dial("unix", socket)
	if err != nil {
		return nil, fmt.Errorf("failed to dial SSH_AUTH_SOCK = %q: %w", socket, err)
	}

	return &agentConn{agent.NewClient(conn), conn}, nil
}

type tokenHelper struct {
	Path string
}

func (h tokenHelper) tokenPath() (string, error) {
	if h.Path != "" {
		return h.Path, nil
	}
	homedir, err := os.UserHomeDir()
	if err != nil {
		return "", err
	}
	return filepath.Join(homedir, ".vault-token"), nil
}

func (h tokenHelper) Get() (string, error) {
	p, err := h.tokenPath()
	if err != nil {
		return "", fmt.Errorf("getting token path: %w", err)
	}
	f, err := os.Open(p)
	if err != nil {
		return "", fmt.Errorf("opening token at %v: %w", p, err)
	}
	defer f.Close()
	tok, err := io.ReadAll(f)
	if err != nil {
		return "", fmt.Errorf("reading token at %v: %w", p, err)
	}
	return strings.TrimSpace(string(tok)), nil
}

func (h tokenHelper) Put(token string) error {
	p, err := h.tokenPath()
	if err != nil {
		return fmt.Errorf("getting token path: %w", err)
	}
	os.Remove(p)
	f, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0600)
	if err != nil {
		return fmt.Errorf("creating token at %v: %w", p, err)
	}
	if _, err := fmt.Fprintf(f, "%s", token); err != nil {
		f.Close()
		return fmt.Errorf("writing token at %v: %w", p, err)
	}
	if err := f.Close(); err != nil {
		return fmt.Errorf("completing write of token at %v: %w", p, err)
	}
	return nil
}

func connectToVault() (*vapi.Client, error) {
	cfg := vapi.DefaultConfig()
	cfg.Address = *vaultAddress
	cfg.AgentAddress = ""
	vaultClient, err := vapi.NewClient(cfg)
	if err != nil {
		return nil, fmt.Errorf("vault/api.NewClient: %w", err)
	}

	return vaultClient, nil
}

func triggerVaultLogin(vaultClient *vapi.Client) (string, error) {
	// TODO: implement actual inline login.
	c := exec.Command("vault", "login", "-no-store", "-format=json", "-method=oidc", fmt.Sprintf("-address=%s", vaultClient.Address()))
	c.Stdin = os.Stdin
	c.Stderr = os.Stderr
	out, err := c.Output()
	if err != nil {
		return "", fmt.Errorf("running vault login: %w", err)
	}
	var v struct {
		Auth struct {
			ClientToken string `json:"client_token"`
		} `json:"auth"`
	}
	if err := json.Unmarshal(out, &v); err != nil {
		return "", fmt.Errorf("parsing vault login JSON: %w", err)
	}
	if v.Auth.ClientToken == "" {
		return "", fmt.Errorf("some error: %v", string(out))
	}
	return v.Auth.ClientToken, nil
}

func ensureLoggedIn(vaultClient *vapi.Client) error {
	needLogin := false
	th := tokenHelper{}
	token, err := th.Get()
	if errors.Is(err, fs.ErrNotExist) {
		needLogin = true
	} else if err != nil {
		return fmt.Errorf("tokenHelper.Get: %w", err)
	} else {
		vaultClient.SetToken(token)

		// Check if we currently have anything in Vault.
		_, err := vaultClient.Auth().Token().LookupSelf()
		if err, ok := err.(*vapi.ResponseError); ok && err.StatusCode >= 400 && err.StatusCode < 500 {
			log.Printf("got %v from server: assuming we need to log in", err.StatusCode)
			needLogin = true
		} else if err != nil {
			return fmt.Errorf("checking self token: %w", err)
		}
	}

	if needLogin {
		token, err = triggerVaultLogin(vaultClient)
		if err != nil {
			return fmt.Errorf("logging in to vault: %w", err)
		}
		if err := th.Put(token); err != nil {
			return fmt.Errorf("writing vault token to storage: %w", err)
		}
		vaultClient.SetToken(token)
	}

	return nil
}

func main() {
	flag.Parse()

	agentClient, err := connectToAgent()
	if err != nil {
		log.Fatalf("connectToAgent: %v", err)
	}
	defer agentClient.Close()

	vaultClient, err := connectToVault()
	if err != nil {
		log.Fatalf("connectToVault: %v", err)
	}

	if err := ensureLoggedIn(vaultClient); err != nil {
		log.Fatalf("ensureLoggedIn: %v", err)
	}

	// OK, we're probably logged in. Let's go.

	// Remove any old certificates.
	keys, err := agentClient.List()
	if err != nil {
		log.Fatalf("listing keys in agent: %v", err)
	}
	for _, key := range keys {
		if key.Comment != sshAgentComment {
			continue
		}
		if err := agentClient.Remove(key); err != nil {
			log.Fatalf("removing existing key %s: %v", key, err)
		}
	}

	// Generate a new key.
	pubKey, privKey, err := ed25519.GenerateKey(cryptorand.Reader)
	if err != nil {
		log.Fatalf("generating SSH key: %v", err)
	}
	sshPubKey, err := ssh.NewPublicKey(pubKey)
	if err != nil {
		log.Fatalf("converting new SSH key: %v", err)
	}

	// Sign the key.
	vssh := vaultClient.SSHWithMountPoint(*sshMountPoint)
	sec, err := vssh.SignKey(*sshRole, map[string]interface{}{
		"public_key":       string(ssh.MarshalAuthorizedKey(sshPubKey)),
		"valid_principals": *sshPrincipal,
	})
	if err != nil {
		log.Fatalf("signing SSH key: %v", err)
	}

	// Add the certificate to the agent.
	log.Printf("signed as serial number = %v", sec.Data["serial_number"])
	signedCertStr := sec.Data["signed_key"].(string)
	signedCertPK, _, _, _, err := ssh.ParseAuthorizedKey([]byte(signedCertStr))
	if err != nil {
		log.Fatalf("parsing SSH certificate: %w", err)
	}
	signedCert := signedCertPK.(*ssh.Certificate)
	expiresAt := time.Unix(int64(signedCert.ValidBefore), 0)
	certLifetime := expiresAt.Sub(time.Now())
	log.Printf("certificate expires at %v (in %v)", expiresAt, certLifetime)

	if err := agentClient.Add(agent.AddedKey{
		PrivateKey:       privKey,
		Certificate:      signedCert,
		Comment:          sshAgentComment,
		LifetimeSecs:     uint32(certLifetime.Seconds()),
		ConfirmBeforeUse: *requirePresence,
	}); err != nil {
		log.Fatalf("adding key to agent: %w", err)
	}
}