252 lines
6.2 KiB
Go
252 lines
6.2 KiB
Go
package main
|
|
|
|
import (
|
|
cryptorand "crypto/rand"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
vapi "github.com/hashicorp/vault/api"
|
|
"golang.org/x/crypto/ed25519"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/agent"
|
|
)
|
|
|
|
var (
|
|
vaultAddress = flag.String("vault_address", "https://vault.int.lukegb.com", "Address of Vault")
|
|
sshMountPoint = flag.String("ssh_mount_point", "ssh-client", "SSH mount point in Vault")
|
|
sshRole = flag.String("ssh_role", "user", "SSH role")
|
|
)
|
|
|
|
const (
|
|
sshAgentComment = "vault certificate"
|
|
)
|
|
|
|
type agentConn struct {
|
|
agent.ExtendedAgent
|
|
|
|
conn io.ReadWriteCloser
|
|
}
|
|
|
|
func (c *agentConn) Close() error { return c.conn.Close() }
|
|
|
|
func connectToAgent() (*agentConn, error) {
|
|
socket := os.Getenv("SSH_AUTH_SOCK")
|
|
if socket == "" {
|
|
return nil, fmt.Errorf("no SSH_AUTH_SOCK in environment")
|
|
}
|
|
conn, err := net.Dial("unix", socket)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to dial SSH_AUTH_SOCK = %q: %w", socket, err)
|
|
}
|
|
|
|
return &agentConn{agent.NewClient(conn), conn}, nil
|
|
}
|
|
|
|
type tokenHelper struct {
|
|
Path string
|
|
}
|
|
|
|
func (h tokenHelper) tokenPath() (string, error) {
|
|
if h.Path != "" {
|
|
return h.Path, nil
|
|
}
|
|
homedir, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return filepath.Join(homedir, ".vault-token"), nil
|
|
}
|
|
|
|
func (h tokenHelper) Get() (string, error) {
|
|
p, err := h.tokenPath()
|
|
if err != nil {
|
|
return "", fmt.Errorf("getting token path: %w", err)
|
|
}
|
|
f, err := os.Open(p)
|
|
if err != nil {
|
|
return "", fmt.Errorf("opening token at %v: %w", p, err)
|
|
}
|
|
defer f.Close()
|
|
tok, err := io.ReadAll(f)
|
|
if err != nil {
|
|
return "", fmt.Errorf("reading token at %v: %w", p, err)
|
|
}
|
|
return strings.TrimSpace(string(tok)), nil
|
|
}
|
|
|
|
func (h tokenHelper) Put(token string) error {
|
|
p, err := h.tokenPath()
|
|
if err != nil {
|
|
return fmt.Errorf("getting token path: %w", err)
|
|
}
|
|
os.Remove(p)
|
|
f, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0600)
|
|
if err != nil {
|
|
return fmt.Errorf("creating token at %v: %w", p, err)
|
|
}
|
|
if _, err := fmt.Fprintf(f, "%s", token); err != nil {
|
|
f.Close()
|
|
return fmt.Errorf("writing token at %v: %w", p, err)
|
|
}
|
|
if err := f.Close(); err != nil {
|
|
return fmt.Errorf("completing write of token at %v: %w", p, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func connectToVault() (*vapi.Client, error) {
|
|
cfg := vapi.DefaultConfig()
|
|
cfg.Address = *vaultAddress
|
|
cfg.AgentAddress = ""
|
|
vaultClient, err := vapi.NewClient(cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("vault/api.NewClient: %w", err)
|
|
}
|
|
|
|
return vaultClient, nil
|
|
}
|
|
|
|
func triggerVaultLogin(vaultClient *vapi.Client) (string, error) {
|
|
// TODO: implement actual inline login.
|
|
c := exec.Command("vault", "login", "-no-store", "-format=json", "-method=oidc", fmt.Sprintf("-address=%s", vaultClient.Address()))
|
|
c.Stdin = os.Stdin
|
|
c.Stderr = os.Stderr
|
|
out, err := c.Output()
|
|
if err != nil {
|
|
return "", fmt.Errorf("running vault login: %w", err)
|
|
}
|
|
var v struct {
|
|
Auth struct {
|
|
ClientToken string `json:"client_token"`
|
|
} `json:"auth"`
|
|
}
|
|
if err := json.Unmarshal(out, &v); err != nil {
|
|
return "", fmt.Errorf("parsing vault login JSON: %w", err)
|
|
}
|
|
if v.Auth.ClientToken == "" {
|
|
return "", fmt.Errorf("some error: %v", string(out))
|
|
}
|
|
return v.Auth.ClientToken, nil
|
|
}
|
|
|
|
func ensureLoggedIn(vaultClient *vapi.Client) error {
|
|
needLogin := false
|
|
th := tokenHelper{}
|
|
token, err := th.Get()
|
|
if errors.Is(err, fs.ErrNotExist) {
|
|
needLogin = true
|
|
} else if err != nil {
|
|
return fmt.Errorf("tokenHelper.Get: %w", err)
|
|
} else {
|
|
vaultClient.SetToken(token)
|
|
|
|
// Check if we currently have anything in Vault.
|
|
_, err := vaultClient.Auth().Token().LookupSelf()
|
|
if err, ok := err.(*vapi.ResponseError); ok && err.StatusCode >= 400 && err.StatusCode < 500 {
|
|
log.Printf("got %v from server: assuming we need to log in", err.StatusCode)
|
|
needLogin = true
|
|
} else if err != nil {
|
|
return fmt.Errorf("checking self token: %w", err)
|
|
}
|
|
}
|
|
|
|
if needLogin {
|
|
token, err = triggerVaultLogin(vaultClient)
|
|
if err != nil {
|
|
return fmt.Errorf("logging in to vault: %w", err)
|
|
}
|
|
if err := th.Put(token); err != nil {
|
|
return fmt.Errorf("writing vault token to storage: %w", err)
|
|
}
|
|
vaultClient.SetToken(token)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
agentClient, err := connectToAgent()
|
|
if err != nil {
|
|
log.Fatalf("connectToAgent: %v", err)
|
|
}
|
|
defer agentClient.Close()
|
|
|
|
vaultClient, err := connectToVault()
|
|
if err != nil {
|
|
log.Fatalf("connectToVault: %v", err)
|
|
}
|
|
|
|
if err := ensureLoggedIn(vaultClient); err != nil {
|
|
log.Fatalf("ensureLoggedIn: %v", err)
|
|
}
|
|
|
|
// OK, we're probably logged in. Let's go.
|
|
|
|
// Remove any old certificates.
|
|
keys, err := agentClient.List()
|
|
if err != nil {
|
|
log.Fatalf("listing keys in agent: %v", err)
|
|
}
|
|
for _, key := range keys {
|
|
if key.Comment != sshAgentComment {
|
|
continue
|
|
}
|
|
if err := agentClient.Remove(key); err != nil {
|
|
log.Fatalf("removing existing key %s: %v", key, err)
|
|
}
|
|
}
|
|
|
|
// Generate a new key.
|
|
pubKey, privKey, err := ed25519.GenerateKey(cryptorand.Reader)
|
|
if err != nil {
|
|
log.Fatalf("generating SSH key: %v", err)
|
|
}
|
|
sshPubKey, err := ssh.NewPublicKey(pubKey)
|
|
if err != nil {
|
|
log.Fatalf("converting new SSH key: %v", err)
|
|
}
|
|
|
|
// Sign the key.
|
|
vssh := vaultClient.SSHWithMountPoint(*sshMountPoint)
|
|
sec, err := vssh.SignKey(*sshRole, map[string]interface{}{
|
|
"public_key": string(ssh.MarshalAuthorizedKey(sshPubKey)),
|
|
})
|
|
if err != nil {
|
|
log.Fatalf("signing SSH key: %v", err)
|
|
}
|
|
|
|
// Add the certificate to the agent.
|
|
log.Printf("signed as serial number = %v", sec.Data["serial_number"])
|
|
signedCertStr := sec.Data["signed_key"].(string)
|
|
signedCertPK, _, _, _, err := ssh.ParseAuthorizedKey([]byte(signedCertStr))
|
|
if err != nil {
|
|
log.Fatalf("parsing SSH certificate: %w", err)
|
|
}
|
|
signedCert := signedCertPK.(*ssh.Certificate)
|
|
expiresAt := time.Unix(int64(signedCert.ValidBefore), 0)
|
|
certLifetime := expiresAt.Sub(time.Now())
|
|
log.Printf("certificate expires at %v (in %v)", expiresAt, certLifetime)
|
|
|
|
if err := agentClient.Add(agent.AddedKey{
|
|
PrivateKey: privKey,
|
|
Certificate: signedCert,
|
|
Comment: sshAgentComment,
|
|
LifetimeSecs: uint32(certLifetime.Seconds()),
|
|
ConfirmBeforeUse: true,
|
|
}); err != nil {
|
|
log.Fatalf("adding key to agent: %w", err)
|
|
}
|
|
}
|