// Package vaultgcp allows fetching GCP service account credentials from Vault.
package vaultgcp

import (
	"context"
	"encoding/json"
	"fmt"
	"log"
	"net"
	"net/http"
	"net/url"
	"time"

	"golang.org/x/oauth2"
)

type tokenSource struct {
	ctx context.Context

	VaultAddress string
	VaultPath    string
}

func (s tokenSource) Token() (*oauth2.Token, error) {
	return Token(s.ctx, s.VaultAddress, s.VaultPath)
}

// TokenSource returns an oauth2.TokenSource that fetchs OAuth2 tokens from the specified Vault address and Vault path.
func TokenSource(ctx context.Context, vaultAddr string, vaultPath string) oauth2.TokenSource {
	return oauth2.ReuseTokenSource(nil, &tokenSource{
		ctx:          ctx,
		VaultAddress: vaultAddr,
		VaultPath:    vaultPath,
	})
}

// Token returns an OAuth token held by a GCP roleset in Vault.
// vaultPath should look like e.g. gcp/roleset/binary-cache-deployer/token.
func Token(ctx context.Context, vaultAddr string, vaultPath string) (*oauth2.Token, error) {
	resp, err := vaultGet(ctx, vaultAddr, vaultPath)
	if err != nil {
		return nil, fmt.Errorf("fetching %v from %v: %w", vaultAddr, vaultPath, err)
	}
	defer resp.Body.Close()
	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("bad status code %v %s from Vault", resp.StatusCode, resp.Status)
	}

	type vaultResponse struct {
		Data struct {
			ExpiresAtSeconds int64  `json:"expires_at_seconds"`
			Token            string `json:"token"`
		} `json:"data"`
		Errors []string `json:"errors"`
	}
	var vr vaultResponse
	if err := json.NewDecoder(resp.Body).Decode(&vr); err != nil {
		return nil, fmt.Errorf("parsing Vault response: %w", err)
	}

	if len(vr.Errors) > 0 {
		return nil, fmt.Errorf("Vault returned errors: %q", vr.Errors)
	}
	if vr.Data.Token == "" {
		return nil, fmt.Errorf("Vault returned no errors, but also no tokens")
	}
	var expiryTime time.Time
	if vr.Data.ExpiresAtSeconds != 0 {
		expiryTime = time.Unix(vr.Data.ExpiresAtSeconds, 0)
	}
	return &oauth2.Token{
		AccessToken: vr.Data.Token,
		Expiry:      expiryTime,
	}, nil
}

func vaultGet(ctx context.Context, vaultAddrStr string, vaultPath string) (*http.Response, error) {
	vaultAddr, err := url.Parse(vaultAddrStr)
	if err != nil {
		return nil, fmt.Errorf("parsing Vault address %q: %w", vaultAddrStr, err)
	}

	vaultClient := http.DefaultClient
	vaultURL := vaultAddrStr
	if vaultAddr.Scheme == "unix" {
		vaultURL = "http://vault"
		vaultClient = &http.Client{
			Transport: &http.Transport{
				DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
					d := &net.Dialer{}
					return d.DialContext(ctx, "unix", vaultAddr.Path)
				},
				MaxIdleConns:          1,
				IdleConnTimeout:       90 * time.Second,
				TLSHandshakeTimeout:   10 * time.Second,
				ExpectContinueTimeout: 1 * time.Second,
			},
		}
	} else {
		log.Printf("falling back to http.DefaultClient for talking to Vault; this is unlikely to work")
	}

	req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/v1/%s", vaultURL, vaultPath), nil)
	if err != nil {
		return nil, fmt.Errorf("formulating Vault request: %w", err)
	}
	req.Header.Set("X-Vault-Request", "true")
	return vaultClient.Do(req)
}