109 lines
3.1 KiB
Go
109 lines
3.1 KiB
Go
// Package vaultgcp allows fetching GCP service account credentials from Vault.
|
|
package vaultgcp
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"time"
|
|
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type tokenSource struct {
|
|
ctx context.Context
|
|
|
|
VaultAddress string
|
|
VaultPath string
|
|
}
|
|
|
|
func (s tokenSource) Token() (*oauth2.Token, error) {
|
|
return Token(s.ctx, s.VaultAddress, s.VaultPath)
|
|
}
|
|
|
|
// TokenSource returns an oauth2.TokenSource that fetchs OAuth2 tokens from the specified Vault address and Vault path.
|
|
func TokenSource(ctx context.Context, vaultAddr string, vaultPath string) oauth2.TokenSource {
|
|
return oauth2.ReuseTokenSource(nil, &tokenSource{
|
|
ctx: ctx,
|
|
VaultAddress: vaultAddr,
|
|
VaultPath: vaultPath,
|
|
})
|
|
}
|
|
|
|
// Token returns an OAuth token held by a GCP roleset in Vault.
|
|
// vaultPath should look like e.g. gcp/roleset/binary-cache-deployer/token.
|
|
func Token(ctx context.Context, vaultAddr string, vaultPath string) (*oauth2.Token, error) {
|
|
resp, err := vaultGet(ctx, vaultAddr, vaultPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("fetching %v from %v: %w", vaultAddr, vaultPath, err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("bad status code %v %s from Vault", resp.StatusCode, resp.Status)
|
|
}
|
|
|
|
type vaultResponse struct {
|
|
Data struct {
|
|
ExpiresAtSeconds int64 `json:"expires_at_seconds"`
|
|
Token string `json:"token"`
|
|
} `json:"data"`
|
|
Errors []string `json:"errors"`
|
|
}
|
|
var vr vaultResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&vr); err != nil {
|
|
return nil, fmt.Errorf("parsing Vault response: %w", err)
|
|
}
|
|
|
|
if len(vr.Errors) > 0 {
|
|
return nil, fmt.Errorf("Vault returned errors: %q", vr.Errors)
|
|
}
|
|
if vr.Data.Token == "" {
|
|
return nil, fmt.Errorf("Vault returned no errors, but also no tokens")
|
|
}
|
|
var expiryTime time.Time
|
|
if vr.Data.ExpiresAtSeconds != 0 {
|
|
expiryTime = time.Unix(vr.Data.ExpiresAtSeconds, 0)
|
|
}
|
|
return &oauth2.Token{
|
|
AccessToken: vr.Data.Token,
|
|
Expiry: expiryTime,
|
|
}, nil
|
|
}
|
|
|
|
func vaultGet(ctx context.Context, vaultAddrStr string, vaultPath string) (*http.Response, error) {
|
|
vaultAddr, err := url.Parse(vaultAddrStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing Vault address %q: %w", vaultAddrStr, err)
|
|
}
|
|
|
|
vaultClient := http.DefaultClient
|
|
vaultURL := vaultAddrStr
|
|
if vaultAddr.Scheme == "unix" {
|
|
vaultURL = "http://vault"
|
|
vaultClient = &http.Client{
|
|
Transport: &http.Transport{
|
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
d := &net.Dialer{}
|
|
return d.DialContext(ctx, "unix", vaultAddr.Path)
|
|
},
|
|
MaxIdleConns: 1,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
},
|
|
}
|
|
} else {
|
|
log.Printf("falling back to http.DefaultClient for talking to Vault; this is unlikely to work")
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/v1/%s", vaultURL, vaultPath), nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("formulating Vault request: %w", err)
|
|
}
|
|
req.Header.Set("X-Vault-Request", "true")
|
|
return vaultClient.Do(req)
|
|
}
|