diff --git a/go/access/access.go b/go/access/access.go new file mode 100644 index 0000000000..908b27fca0 --- /dev/null +++ b/go/access/access.go @@ -0,0 +1,252 @@ +package main + +import ( + cryptorand "crypto/rand" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "io/fs" + "log" + "net" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + vapi "github.com/hashicorp/vault/api" + "golang.org/x/crypto/ed25519" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +var ( + vaultAddress = flag.String("vault_address", "https://vault.int.lukegb.com", "Address of Vault") + sshMountPoint = flag.String("ssh_mount_point", "ssh-client", "SSH mount point in Vault") + sshRole = flag.String("ssh_role", "user", "SSH role") +) + +const ( + sshAgentComment = "vault certificate" +) + +type agentConn struct { + agent.ExtendedAgent + + conn io.ReadWriteCloser +} + +func (c *agentConn) Close() error { return c.conn.Close() } + +func connectToAgent() (*agentConn, error) { + socket := os.Getenv("SSH_AUTH_SOCK") + if socket == "" { + return nil, fmt.Errorf("no SSH_AUTH_SOCK in environment") + } + conn, err := net.Dial("unix", socket) + if err != nil { + return nil, fmt.Errorf("failed to dial SSH_AUTH_SOCK = %q: %w", socket, err) + } + + return &agentConn{agent.NewClient(conn), conn}, nil +} + +type tokenHelper struct { + Path string +} + +func (h tokenHelper) tokenPath() (string, error) { + if h.Path != "" { + return h.Path, nil + } + homedir, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(homedir, ".vault-token"), nil +} + +func (h tokenHelper) Get() (string, error) { + p, err := h.tokenPath() + if err != nil { + return "", fmt.Errorf("getting token path: %w", err) + } + f, err := os.Open(p) + if err != nil { + return "", fmt.Errorf("opening token at %v: %w", p, err) + } + defer f.Close() + tok, err := io.ReadAll(f) + if err != nil { + return "", fmt.Errorf("reading token at %v: %w", p, err) + } + return strings.TrimSpace(string(tok)), nil +} + +func (h tokenHelper) Put(token string) error { + p, err := h.tokenPath() + if err != nil { + return fmt.Errorf("getting token path: %w", err) + } + os.Remove(p) + f, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0600) + if err != nil { + return fmt.Errorf("creating token at %v: %w", p, err) + } + if _, err := fmt.Fprintf(f, "%s", token); err != nil { + f.Close() + return fmt.Errorf("writing token at %v: %w", p, err) + } + if err := f.Close(); err != nil { + return fmt.Errorf("completing write of token at %v: %w", p, err) + } + return nil +} + +func connectToVault() (*vapi.Client, error) { + cfg := vapi.DefaultConfig() + cfg.Address = *vaultAddress + cfg.AgentAddress = "" + vaultClient, err := vapi.NewClient(cfg) + if err != nil { + return nil, fmt.Errorf("vault/api.NewClient: %w", err) + } + + return vaultClient, nil +} + +func triggerVaultLogin(vaultClient *vapi.Client) (string, error) { + // TODO: implement actual inline login. + c := exec.Command("vault", "login", "-no-store", "-format=json", "-method=oidc", fmt.Sprintf("-address=%s", vaultClient.Address())) + c.Stdin = os.Stdin + c.Stderr = os.Stderr + out, err := c.Output() + if err != nil { + return "", fmt.Errorf("running vault login: %w", err) + } + var v struct { + Auth struct { + ClientToken string `json:"client_token"` + } `json:"auth"` + } + if err := json.Unmarshal(out, &v); err != nil { + return "", fmt.Errorf("parsing vault login JSON: %w", err) + } + if v.Auth.ClientToken == "" { + return "", fmt.Errorf("some error: %v", string(out)) + } + return v.Auth.ClientToken, nil +} + +func ensureLoggedIn(vaultClient *vapi.Client) error { + needLogin := false + th := tokenHelper{} + token, err := th.Get() + if errors.Is(err, fs.ErrNotExist) { + needLogin = true + } else if err != nil { + return fmt.Errorf("tokenHelper.Get: %w", err) + } else { + vaultClient.SetToken(token) + + // Check if we currently have anything in Vault. + _, err := vaultClient.Auth().Token().LookupSelf() + if err, ok := err.(*vapi.ResponseError); ok && err.StatusCode >= 400 && err.StatusCode < 500 { + log.Printf("got %v from server: assuming we need to log in", err.StatusCode) + needLogin = true + } else if err != nil { + return fmt.Errorf("checking self token: %w", err) + } + } + + if needLogin { + token, err = triggerVaultLogin(vaultClient) + if err != nil { + return fmt.Errorf("logging in to vault: %w", err) + } + if err := th.Put(token); err != nil { + return fmt.Errorf("writing vault token to storage: %w", err) + } + vaultClient.SetToken(token) + } + + return nil +} + +func main() { + flag.Parse() + + agentClient, err := connectToAgent() + if err != nil { + log.Fatalf("connectToAgent: %v", err) + } + defer agentClient.Close() + + vaultClient, err := connectToVault() + if err != nil { + log.Fatalf("connectToVault: %v", err) + } + + if err := ensureLoggedIn(vaultClient); err != nil { + log.Fatalf("ensureLoggedIn: %v", err) + } + + // OK, we're probably logged in. Let's go. + + // Remove any old certificates. + keys, err := agentClient.List() + if err != nil { + log.Fatalf("listing keys in agent: %v", err) + } + for _, key := range keys { + if key.Comment != sshAgentComment { + continue + } + if err := agentClient.Remove(key); err != nil { + log.Fatalf("removing existing key %s: %v", key, err) + } + } + + // Generate a new key. + pubKey, privKey, err := ed25519.GenerateKey(cryptorand.Reader) + if err != nil { + log.Fatalf("generating SSH key: %v", err) + } + sshPubKey, err := ssh.NewPublicKey(pubKey) + if err != nil { + log.Fatalf("converting new SSH key: %v", err) + } + + // Sign the key. + vssh := vaultClient.SSHWithMountPoint(*sshMountPoint) + sec, err := vssh.SignKey(*sshRole, map[string]interface{}{ + "public_key": string(ssh.MarshalAuthorizedKey(sshPubKey)), + }) + if err != nil { + log.Fatalf("signing SSH key: %v", err) + } + + // Add the certificate to the agent. + log.Printf("signed as serial number = %v", sec.Data["serial_number"]) + signedCertStr := sec.Data["signed_key"].(string) + signedCertPK, _, _, _, err := ssh.ParseAuthorizedKey([]byte(signedCertStr)) + if err != nil { + log.Fatalf("parsing SSH certificate: %w", err) + } + signedCert := signedCertPK.(*ssh.Certificate) + expiresAt := time.Unix(int64(signedCert.ValidBefore), 0) + certLifetime := expiresAt.Sub(time.Now()) + log.Printf("certificate expires at %v (in %v)", expiresAt, certLifetime) + + if err := agentClient.Add(agent.AddedKey{ + PrivateKey: privKey, + Certificate: signedCert, + Comment: sshAgentComment, + LifetimeSecs: uint32(certLifetime.Seconds()), + ConfirmBeforeUse: true, + }); err != nil { + log.Fatalf("adding key to agent: %w", err) + } +} diff --git a/go/access/default.nix b/go/access/default.nix new file mode 100644 index 0000000000..8382925688 --- /dev/null +++ b/go/access/default.nix @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: 2020 Luke Granger-Brown +# +# SPDX-License-Identifier: Apache-2.0 + +{ depot, ... }: +depot.third_party.buildGo.program { + name = "access"; + srcs = [ ./access.go ]; + deps = with depot.third_party; [ + gopkgs."golang.org".x.crypto.ssh + gopkgs."golang.org".x.crypto.ssh.agent + gopkgs."github.com".hashicorp.vault.api + ]; +} diff --git a/go/default.nix b/go/default.nix index f8166204d5..599513c7e5 100644 --- a/go/default.nix +++ b/go/default.nix @@ -12,4 +12,5 @@ args: { journal2clickhouse = import ./journal2clickhouse args; secretsmgr = import ./secretsmgr args; tokend = import ./tokend args; + access = import ./access args; } diff --git a/ops/nixos/lib/home-manager/client.nix b/ops/nixos/lib/home-manager/client.nix index ca4e0266be..90b59a1cfb 100644 --- a/ops/nixos/lib/home-manager/client.nix +++ b/ops/nixos/lib/home-manager/client.nix @@ -39,6 +39,7 @@ in ''; home.packages = lib.mkAfter (with pkgs; ([ + depot.go.access depot.nix.pkgs.datez depot.nix.pkgs.common-updater-scripts go