// Package vaultgcp allows fetching GCP service account credentials from Vault. package vaultgcp import ( "context" "encoding/json" "fmt" "log" "net" "net/http" "net/url" "time" "golang.org/x/oauth2" ) type tokenSource struct { ctx context.Context VaultAddress string VaultPath string } func (s tokenSource) Token() (*oauth2.Token, error) { return Token(s.ctx, s.VaultAddress, s.VaultPath) } // TokenSource returns an oauth2.TokenSource that fetchs OAuth2 tokens from the specified Vault address and Vault path. func TokenSource(ctx context.Context, vaultAddr string, vaultPath string) oauth2.TokenSource { return oauth2.ReuseTokenSource(nil, &tokenSource{ ctx: ctx, VaultAddress: vaultAddr, VaultPath: vaultPath, }) } // Token returns an OAuth token held by a GCP roleset in Vault. // vaultPath should look like e.g. gcp/roleset/binary-cache-deployer/token. func Token(ctx context.Context, vaultAddr string, vaultPath string) (*oauth2.Token, error) { resp, err := vaultGet(ctx, vaultAddr, vaultPath) if err != nil { return nil, fmt.Errorf("fetching %v from %v: %w", vaultAddr, vaultPath, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("bad status code %v %s from Vault", resp.StatusCode, resp.Status) } type vaultResponse struct { Data struct { ExpiresAtSeconds int64 `json:"expires_at_seconds"` Token string `json:"token"` } `json:"data"` Errors []string `json:"errors"` } var vr vaultResponse if err := json.NewDecoder(resp.Body).Decode(&vr); err != nil { return nil, fmt.Errorf("parsing Vault response: %w", err) } if len(vr.Errors) > 0 { return nil, fmt.Errorf("Vault returned errors: %q", vr.Errors) } if vr.Data.Token == "" { return nil, fmt.Errorf("Vault returned no errors, but also no tokens") } var expiryTime time.Time if vr.Data.ExpiresAtSeconds != 0 { expiryTime = time.Unix(vr.Data.ExpiresAtSeconds, 0) } return &oauth2.Token{ AccessToken: vr.Data.Token, Expiry: expiryTime, }, nil } func vaultGet(ctx context.Context, vaultAddrStr string, vaultPath string) (*http.Response, error) { vaultAddr, err := url.Parse(vaultAddrStr) if err != nil { return nil, fmt.Errorf("parsing Vault address %q: %w", vaultAddrStr, err) } vaultClient := http.DefaultClient vaultURL := vaultAddrStr if vaultAddr.Scheme == "unix" { vaultURL = "http://vault" vaultClient = &http.Client{ Transport: &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { d := &net.Dialer{} return d.DialContext(ctx, "unix", vaultAddr.Path) }, MaxIdleConns: 1, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, }, } } else { log.Printf("falling back to http.DefaultClient for talking to Vault; this is unlikely to work") } req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/v1/%s", vaultURL, vaultPath), nil) if err != nil { return nil, fmt.Errorf("formulating Vault request: %w", err) } req.Header.Set("X-Vault-Request", "true") return vaultClient.Do(req) }