package main import ( cryptorand "crypto/rand" "encoding/json" "errors" "flag" "fmt" "io" "io/fs" "log" "net" "os" "os/exec" "os/user" "path/filepath" "strings" "time" vapi "github.com/hashicorp/vault/api" "golang.org/x/crypto/ed25519" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" ) var ( vaultAddress = flag.String("vault_address", "https://vault.int.lukegb.com", "Address of Vault") sshMountPoint = flag.String("ssh_mount_point", "ssh-client", "SSH mount point in Vault") sshRole = flag.String("ssh_role", "user", "SSH role") sshPrincipal = flag.String("principal", currentUsername(), "principal to request in certificate") requirePresence = flag.Bool("require_presence", shouldRequirePresence(), "whether to require presence when using certificate") ) func currentUsername() string { // What's our principal? u, err := user.Current() if err != nil { log.Fatalf("looking up current user: %v", err) } return u.Username } func shouldRequirePresence() bool { // WSL2 makes things hard. return os.Getenv("WSL_DISTRO_NAME") == "" } const ( sshAgentComment = "vault certificate" ) type agentConn struct { agent.ExtendedAgent conn io.ReadWriteCloser } func (c *agentConn) Close() error { return c.conn.Close() } func connectToAgent() (*agentConn, error) { socket := os.Getenv("SSH_AUTH_SOCK") if socket == "" { return nil, fmt.Errorf("no SSH_AUTH_SOCK in environment") } conn, err := net.Dial("unix", socket) if err != nil { return nil, fmt.Errorf("failed to dial SSH_AUTH_SOCK = %q: %w", socket, err) } return &agentConn{agent.NewClient(conn), conn}, nil } type tokenHelper struct { Path string } func (h tokenHelper) tokenPath() (string, error) { if h.Path != "" { return h.Path, nil } homedir, err := os.UserHomeDir() if err != nil { return "", err } return filepath.Join(homedir, ".vault-token"), nil } func (h tokenHelper) Get() (string, error) { p, err := h.tokenPath() if err != nil { return "", fmt.Errorf("getting token path: %w", err) } f, err := os.Open(p) if err != nil { return "", fmt.Errorf("opening token at %v: %w", p, err) } defer f.Close() tok, err := io.ReadAll(f) if err != nil { return "", fmt.Errorf("reading token at %v: %w", p, err) } return strings.TrimSpace(string(tok)), nil } func (h tokenHelper) Put(token string) error { p, err := h.tokenPath() if err != nil { return fmt.Errorf("getting token path: %w", err) } os.Remove(p) f, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0600) if err != nil { return fmt.Errorf("creating token at %v: %w", p, err) } if _, err := fmt.Fprintf(f, "%s", token); err != nil { f.Close() return fmt.Errorf("writing token at %v: %w", p, err) } if err := f.Close(); err != nil { return fmt.Errorf("completing write of token at %v: %w", p, err) } return nil } func connectToVault() (*vapi.Client, error) { cfg := vapi.DefaultConfig() cfg.Address = *vaultAddress cfg.AgentAddress = "" vaultClient, err := vapi.NewClient(cfg) if err != nil { return nil, fmt.Errorf("vault/api.NewClient: %w", err) } return vaultClient, nil } func triggerVaultLogin(vaultClient *vapi.Client) (string, error) { // TODO: implement actual inline login. c := exec.Command("vault", "login", "-no-store", "-format=json", "-method=oidc", fmt.Sprintf("-address=%s", vaultClient.Address())) c.Stdin = os.Stdin c.Stderr = os.Stderr out, err := c.Output() if err != nil { return "", fmt.Errorf("running vault login: %w", err) } var v struct { Auth struct { ClientToken string `json:"client_token"` } `json:"auth"` } if err := json.Unmarshal(out, &v); err != nil { return "", fmt.Errorf("parsing vault login JSON: %w", err) } if v.Auth.ClientToken == "" { return "", fmt.Errorf("some error: %v", string(out)) } return v.Auth.ClientToken, nil } func ensureLoggedIn(vaultClient *vapi.Client) error { needLogin := false th := tokenHelper{} token, err := th.Get() if errors.Is(err, fs.ErrNotExist) { needLogin = true } else if err != nil { return fmt.Errorf("tokenHelper.Get: %w", err) } else { vaultClient.SetToken(token) // Check if we currently have anything in Vault. _, err := vaultClient.Auth().Token().LookupSelf() if err, ok := err.(*vapi.ResponseError); ok && err.StatusCode >= 400 && err.StatusCode < 500 { log.Printf("got %v from server: assuming we need to log in", err.StatusCode) needLogin = true } else if err != nil { return fmt.Errorf("checking self token: %w", err) } } if needLogin { token, err = triggerVaultLogin(vaultClient) if err != nil { return fmt.Errorf("logging in to vault: %w", err) } if err := th.Put(token); err != nil { return fmt.Errorf("writing vault token to storage: %w", err) } vaultClient.SetToken(token) } return nil } func main() { flag.Parse() agentClient, err := connectToAgent() if err != nil { log.Fatalf("connectToAgent: %v", err) } defer agentClient.Close() vaultClient, err := connectToVault() if err != nil { log.Fatalf("connectToVault: %v", err) } if err := ensureLoggedIn(vaultClient); err != nil { log.Fatalf("ensureLoggedIn: %v", err) } // OK, we're probably logged in. Let's go. // Remove any old certificates. keys, err := agentClient.List() if err != nil { log.Fatalf("listing keys in agent: %v", err) } for _, key := range keys { if key.Comment != sshAgentComment { continue } if err := agentClient.Remove(key); err != nil { log.Fatalf("removing existing key %s: %v", key, err) } } // Generate a new key. pubKey, privKey, err := ed25519.GenerateKey(cryptorand.Reader) if err != nil { log.Fatalf("generating SSH key: %v", err) } sshPubKey, err := ssh.NewPublicKey(pubKey) if err != nil { log.Fatalf("converting new SSH key: %v", err) } // Sign the key. vssh := vaultClient.SSHWithMountPoint(*sshMountPoint) sec, err := vssh.SignKey(*sshRole, map[string]interface{}{ "public_key": string(ssh.MarshalAuthorizedKey(sshPubKey)), "valid_principals": *sshPrincipal, }) if err != nil { log.Fatalf("signing SSH key: %v", err) } // Add the certificate to the agent. log.Printf("signed as serial number = %v", sec.Data["serial_number"]) signedCertStr := sec.Data["signed_key"].(string) signedCertPK, _, _, _, err := ssh.ParseAuthorizedKey([]byte(signedCertStr)) if err != nil { log.Fatalf("parsing SSH certificate: %w", err) } signedCert := signedCertPK.(*ssh.Certificate) expiresAt := time.Unix(int64(signedCert.ValidBefore), 0) certLifetime := expiresAt.Sub(time.Now()) log.Printf("certificate expires at %v (in %v)", expiresAt, certLifetime) if err := agentClient.Add(agent.AddedKey{ PrivateKey: privKey, Certificate: signedCert, Comment: sshAgentComment, LifetimeSecs: uint32(certLifetime.Seconds()), ConfirmBeforeUse: *requirePresence, }); err != nil { log.Fatalf("adding key to agent: %w", err) } }