depot/go/access/access.go

271 lines
6.8 KiB
Go

package main
import (
cryptorand "crypto/rand"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"io/fs"
"log"
"net"
"os"
"os/exec"
"os/user"
"path/filepath"
"strings"
"time"
vapi "github.com/hashicorp/vault/api"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
var (
vaultAddress = flag.String("vault_address", "https://vault.int.lukegb.com", "Address of Vault")
sshMountPoint = flag.String("ssh_mount_point", "ssh-client", "SSH mount point in Vault")
sshRole = flag.String("ssh_role", "user", "SSH role")
sshPrincipal = flag.String("principal", currentUsername(), "principal to request in certificate")
requirePresence = flag.Bool("require_presence", shouldRequirePresence(), "whether to require presence when using certificate")
)
func currentUsername() string {
// What's our principal?
u, err := user.Current()
if err != nil {
log.Fatalf("looking up current user: %v", err)
}
return u.Username
}
func shouldRequirePresence() bool {
// WSL2 makes things hard.
return os.Getenv("WSL_DISTRO_NAME") == ""
}
const (
sshAgentComment = "vault certificate"
)
type agentConn struct {
agent.ExtendedAgent
conn io.ReadWriteCloser
}
func (c *agentConn) Close() error { return c.conn.Close() }
func connectToAgent() (*agentConn, error) {
socket := os.Getenv("SSH_AUTH_SOCK")
if socket == "" {
return nil, fmt.Errorf("no SSH_AUTH_SOCK in environment")
}
conn, err := net.Dial("unix", socket)
if err != nil {
return nil, fmt.Errorf("failed to dial SSH_AUTH_SOCK = %q: %w", socket, err)
}
return &agentConn{agent.NewClient(conn), conn}, nil
}
type tokenHelper struct {
Path string
}
func (h tokenHelper) tokenPath() (string, error) {
if h.Path != "" {
return h.Path, nil
}
homedir, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(homedir, ".vault-token"), nil
}
func (h tokenHelper) Get() (string, error) {
p, err := h.tokenPath()
if err != nil {
return "", fmt.Errorf("getting token path: %w", err)
}
f, err := os.Open(p)
if err != nil {
return "", fmt.Errorf("opening token at %v: %w", p, err)
}
defer f.Close()
tok, err := io.ReadAll(f)
if err != nil {
return "", fmt.Errorf("reading token at %v: %w", p, err)
}
return strings.TrimSpace(string(tok)), nil
}
func (h tokenHelper) Put(token string) error {
p, err := h.tokenPath()
if err != nil {
return fmt.Errorf("getting token path: %w", err)
}
os.Remove(p)
f, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0600)
if err != nil {
return fmt.Errorf("creating token at %v: %w", p, err)
}
if _, err := fmt.Fprintf(f, "%s", token); err != nil {
f.Close()
return fmt.Errorf("writing token at %v: %w", p, err)
}
if err := f.Close(); err != nil {
return fmt.Errorf("completing write of token at %v: %w", p, err)
}
return nil
}
func connectToVault() (*vapi.Client, error) {
cfg := vapi.DefaultConfig()
cfg.Address = *vaultAddress
cfg.AgentAddress = ""
vaultClient, err := vapi.NewClient(cfg)
if err != nil {
return nil, fmt.Errorf("vault/api.NewClient: %w", err)
}
return vaultClient, nil
}
func triggerVaultLogin(vaultClient *vapi.Client) (string, error) {
// TODO: implement actual inline login.
c := exec.Command("vault", "login", "-no-store", "-format=json", "-method=oidc", fmt.Sprintf("-address=%s", vaultClient.Address()))
c.Stdin = os.Stdin
c.Stderr = os.Stderr
out, err := c.Output()
if err != nil {
return "", fmt.Errorf("running vault login: %w", err)
}
var v struct {
Auth struct {
ClientToken string `json:"client_token"`
} `json:"auth"`
}
if err := json.Unmarshal(out, &v); err != nil {
return "", fmt.Errorf("parsing vault login JSON: %w", err)
}
if v.Auth.ClientToken == "" {
return "", fmt.Errorf("some error: %v", string(out))
}
return v.Auth.ClientToken, nil
}
func ensureLoggedIn(vaultClient *vapi.Client) error {
needLogin := false
th := tokenHelper{}
token, err := th.Get()
if errors.Is(err, fs.ErrNotExist) {
needLogin = true
} else if err != nil {
return fmt.Errorf("tokenHelper.Get: %w", err)
} else {
vaultClient.SetToken(token)
// Check if we currently have anything in Vault.
_, err := vaultClient.Auth().Token().LookupSelf()
if err, ok := err.(*vapi.ResponseError); ok && err.StatusCode >= 400 && err.StatusCode < 500 {
log.Printf("got %v from server: assuming we need to log in", err.StatusCode)
needLogin = true
} else if err != nil {
return fmt.Errorf("checking self token: %w", err)
}
}
if needLogin {
token, err = triggerVaultLogin(vaultClient)
if err != nil {
return fmt.Errorf("logging in to vault: %w", err)
}
if err := th.Put(token); err != nil {
return fmt.Errorf("writing vault token to storage: %w", err)
}
vaultClient.SetToken(token)
}
return nil
}
func main() {
flag.Parse()
agentClient, err := connectToAgent()
if err != nil {
log.Fatalf("connectToAgent: %v", err)
}
defer agentClient.Close()
vaultClient, err := connectToVault()
if err != nil {
log.Fatalf("connectToVault: %v", err)
}
if err := ensureLoggedIn(vaultClient); err != nil {
log.Fatalf("ensureLoggedIn: %v", err)
}
// OK, we're probably logged in. Let's go.
// Remove any old certificates.
keys, err := agentClient.List()
if err != nil {
log.Fatalf("listing keys in agent: %v", err)
}
for _, key := range keys {
if key.Comment != sshAgentComment {
continue
}
if err := agentClient.Remove(key); err != nil {
log.Fatalf("removing existing key %s: %v", key, err)
}
}
// Generate a new key.
pubKey, privKey, err := ed25519.GenerateKey(cryptorand.Reader)
if err != nil {
log.Fatalf("generating SSH key: %v", err)
}
sshPubKey, err := ssh.NewPublicKey(pubKey)
if err != nil {
log.Fatalf("converting new SSH key: %v", err)
}
// Sign the key.
vssh := vaultClient.SSHWithMountPoint(*sshMountPoint)
sec, err := vssh.SignKey(*sshRole, map[string]interface{}{
"public_key": string(ssh.MarshalAuthorizedKey(sshPubKey)),
"valid_principals": *sshPrincipal,
})
if err != nil {
log.Fatalf("signing SSH key: %v", err)
}
// Add the certificate to the agent.
log.Printf("signed as serial number = %v", sec.Data["serial_number"])
signedCertStr := sec.Data["signed_key"].(string)
signedCertPK, _, _, _, err := ssh.ParseAuthorizedKey([]byte(signedCertStr))
if err != nil {
log.Fatalf("parsing SSH certificate: %w", err)
}
signedCert := signedCertPK.(*ssh.Certificate)
expiresAt := time.Unix(int64(signedCert.ValidBefore), 0)
certLifetime := expiresAt.Sub(time.Now())
log.Printf("certificate expires at %v (in %v)", expiresAt, certLifetime)
if err := agentClient.Add(agent.AddedKey{
PrivateKey: privKey,
Certificate: signedCert,
Comment: sshAgentComment,
LifetimeSecs: uint32(certLifetime.Seconds()),
ConfirmBeforeUse: *requirePresence,
}); err != nil {
log.Fatalf("adding key to agent: %w", err)
}
}