depot/go/vault/vaultgcp/token.go

110 lines
3.1 KiB
Go
Raw Normal View History

// Package vaultgcp allows fetching GCP service account credentials from Vault.
package vaultgcp
import (
"context"
"encoding/json"
"fmt"
"log"
"net"
"net/http"
"net/url"
"time"
"golang.org/x/oauth2"
)
type tokenSource struct {
ctx context.Context
VaultAddress string
VaultPath string
}
func (s tokenSource) Token() (*oauth2.Token, error) {
return Token(s.ctx, s.VaultAddress, s.VaultPath)
}
// TokenSource returns an oauth2.TokenSource that fetchs OAuth2 tokens from the specified Vault address and Vault path.
func TokenSource(ctx context.Context, vaultAddr string, vaultPath string) oauth2.TokenSource {
return oauth2.ReuseTokenSource(nil, &tokenSource{
ctx: ctx,
VaultAddress: vaultAddr,
VaultPath: vaultPath,
})
}
// Token returns an OAuth token held by a GCP roleset in Vault.
// vaultPath should look like e.g. gcp/roleset/binary-cache-deployer/token.
func Token(ctx context.Context, vaultAddr string, vaultPath string) (*oauth2.Token, error) {
resp, err := vaultGet(ctx, vaultAddr, vaultPath)
if err != nil {
return nil, fmt.Errorf("fetching %v from %v: %w", vaultAddr, vaultPath, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("bad status code %v %s from Vault", resp.StatusCode, resp.Status)
}
type vaultResponse struct {
Data struct {
ExpiresAtSeconds int64 `json:"expires_at_seconds"`
Token string `json:"token"`
} `json:"data"`
Errors []string `json:"errors"`
}
var vr vaultResponse
if err := json.NewDecoder(resp.Body).Decode(&vr); err != nil {
return nil, fmt.Errorf("parsing Vault response: %w", err)
}
if len(vr.Errors) > 0 {
return nil, fmt.Errorf("Vault returned errors: %q", vr.Errors)
}
if vr.Data.Token == "" {
return nil, fmt.Errorf("Vault returned no errors, but also no tokens")
}
var expiryTime time.Time
if vr.Data.ExpiresAtSeconds != 0 {
expiryTime = time.Unix(vr.Data.ExpiresAtSeconds, 0)
}
return &oauth2.Token{
AccessToken: vr.Data.Token,
Expiry: expiryTime,
}, nil
}
func vaultGet(ctx context.Context, vaultAddrStr string, vaultPath string) (*http.Response, error) {
vaultAddr, err := url.Parse(vaultAddrStr)
if err != nil {
return nil, fmt.Errorf("parsing Vault address %q: %w", vaultAddrStr, err)
}
vaultClient := http.DefaultClient
vaultURL := vaultAddrStr
if vaultAddr.Scheme == "unix" {
vaultURL = "http://vault"
vaultClient = &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
d := &net.Dialer{}
return d.DialContext(ctx, "unix", vaultAddr.Path)
},
MaxIdleConns: 1,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}
} else {
log.Printf("falling back to http.DefaultClient for talking to Vault; this is unlikely to work")
}
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/v1/%s", vaultURL, vaultPath), nil)
if err != nil {
return nil, fmt.Errorf("formulating Vault request: %w", err)
}
req.Header.Set("X-Vault-Request", "true")
return vaultClient.Do(req)
}