package main

import (
	"context"
	"encoding/json"
	"flag"
	"fmt"
	"io"
	"net"
	"net/http"
	"os"
	"os/user"
	"strconv"
	"strings"
	"sync"
	"time"

	"golang.org/x/sys/unix"

	log "github.com/golang/glog"
	vapi "github.com/hashicorp/vault/api"
)

var (
	listenPath = flag.String("listen_path", "/run/tokend/sock", "Path to listen for connections to tokend.")
	agentAddr  = flag.String("agent_address", "unix:///run/vault-agent/sock", "Address of vault agent.")

	cacheTickInterval = flag.Duration("cache_tick_interval", 1*time.Minute, "Time between checking for expirations.")
	userGroup         = flag.String("user_group", "users", "Name of a group that indicates that the requesting user is a 'user' and not a service.")
)

type userContextKeyType struct{}

var userContextKey = userContextKeyType{}

type userData struct {
	Username    string
	IsPlainUser bool
}

type vaultProxier struct {
	v  *vapi.Client
	c  *tokenUserCache
	hc *http.Client
}

func shouldAttachToken(path string) bool {
	if path == "/v1/auth/token/revoke-self" {
		return false
	}
	return true
}

func shouldObfuscateTokenResponse(r *http.Request, resp *http.Response, attachedToken bool) bool {
	path := r.URL.Path
	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
		return false
	}
	return attachedToken && (path == "/v1/auth/token/lookup-self" || path == "/v1/auth/token/renew-self")
}

func obfuscateAndCopyResponse(w io.Writer, r io.Reader) error {
	sec, err := vapi.ParseSecret(r)
	if err != nil {
		return err
	}

	delete(sec.Data, "id")
	delete(sec.Data, "accessor")

	if sec.Auth != nil {
		sec.Auth.ClientToken = ""
		sec.Auth.Accessor = ""
	}

	return json.NewEncoder(w).Encode(sec)
}

func (vp *vaultProxier) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
	ctx := r.Context()

	u, ok := ctx.Value(userContextKey).(*userData)
	if !ok {
		http.Error(rw, "no username could be determined from the request", http.StatusInternalServerError)
		return
	}

	t, err := vp.c.Get(ctx, u.Username, u.IsPlainUser)
	if err != nil {
		http.Error(rw, fmt.Sprintf("fetching token for %s: %v", u.Username, err), http.StatusInternalServerError)
		return
	}

	tokenStr, err := t.TokenID()
	if err != nil {
		http.Error(rw, fmt.Sprintf("extracting token string for %s: %v", u.Username, err), http.StatusInternalServerError)
		return
	}

	outReq := r.Clone(ctx)
	outReq.RequestURI = ""
	outReq.URL.Scheme = "http"
	outReq.URL.Host = "vault-agent"
	outReq.Trailer = nil

	attachedToken := false
	if shouldAttachToken(r.URL.Path) {
		outReq.Header.Set("X-Vault-Token", tokenStr)
		attachedToken = true
	}
	log.Infof("incoming request [%v / isPlainUser=%v] %v %v", u.Username, u.IsPlainUser, r.Method, r.URL.Path)
	start := time.Now()

	resp, err := vp.hc.Do(outReq)
	if err != nil {
		http.Error(rw, fmt.Sprintf("making backend request to vault agent: %v", err), http.StatusInternalServerError)
		return
	}
	defer resp.Body.Close()

	log.Infof("outgoing response [%v / isPlainUser=%v] %v %v: %v %v", u.Username, u.IsPlainUser, r.Method, r.URL.Path, resp.StatusCode, time.Now().Sub(start))

	for k, vs := range resp.Header {
		rw.Header()[k] = vs
	}
	if shouldObfuscateTokenResponse(r, resp, attachedToken) {
		rw.Header().Del("Content-Length")
		rw.WriteHeader(resp.StatusCode)
		if err := obfuscateAndCopyResponse(rw, resp.Body); err != nil {
			log.Errorf("copying obfuscated lookup-self response: %v", err)
		}
	} else {
		rw.WriteHeader(resp.StatusCode)
		if _, err := io.Copy(rw, resp.Body); err != nil {
			log.Errorf("copying response from agent: %v", err)
		}
	}
}

var (
	userGroupGidSaved string
	userGroupGidOnce  sync.Once
)

func userGroupGid() string {
	userGroupGidOnce.Do(func() {
		userGroupGidSaved = ""
		if *userGroup == "" {
			// Disabled.
			return
		}
		g, err := user.LookupGroup(*userGroup)
		if err != nil {
			log.Errorf("looking up user group %q: %v", *userGroup, err)
			return
		}
		userGroupGidSaved = g.Gid
	})
	return userGroupGidSaved
}

func attachUserData(ctx context.Context, c net.Conn) context.Context {
	uc, ok := c.(*net.UnixConn)
	if !ok {
		log.Warningf("asked to attachUserData to a non UnixConn (%T)", c)
		return ctx
	}

	raw, err := uc.SyscallConn()
	if err != nil {
		log.Warningf("unable to get the underlying raw connection for UnixConn")
		return ctx
	}

	var cred *unix.Ucred
	if ctrlErr := raw.Control(func(fd uintptr) {
		cred, err = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
	}); ctrlErr != nil {
		log.Warningf("control operation to get username failed: %v", err)
		return ctx
	} else if err != nil {
		log.Warningf("getsockoptucred failed: %v", err)
		return ctx
	}

	u, err := user.LookupId(strconv.FormatUint(uint64(cred.Uid), 10))
	if err != nil {
		log.Warningf("looking up UID %d from unix socket: %v", cred.Uid, err)
		return ctx
	}

	isPlainUser := false
	ugs, err := u.GroupIds()
	if err != nil {
		log.Warningf("looking up groups for user %s: %v", u.Username)
	} else if u.Username == "root" {
		// We treat root as a plain user for convenience's sake.
		isPlainUser = true
	} else {
		plainUserGid := userGroupGid()
		for _, ug := range ugs {
			if ug == plainUserGid {
				isPlainUser = true
				break
			}
		}
	}

	return context.WithValue(ctx, userContextKey, &userData{
		Username:    u.Username,
		IsPlainUser: isPlainUser,
	})
}

func main() {
	flag.Parse()

	d := &net.Dialer{
		Timeout:   30 * time.Second,
		KeepAlive: 30 * time.Second,
	}
	agentPath := strings.TrimPrefix(*agentAddr, "unix://")
	agentDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
		// Ignore what they want.
		return d.DialContext(ctx, "unix", agentPath)
	}

	vcfg := vapi.DefaultConfig()
	vcfg.AgentAddress = "http://vault-agent"
	vcfg.HttpClient.Transport.(*http.Transport).DialContext = agentDialer
	v, err := vapi.NewClient(vcfg)
	if err != nil {
		log.Exitf("creating vault client against %v: %v", *agentAddr, err)
	}

	ctx := context.Background()
	c := newCache(&vaultTokenInteractor{v})
	go func() {
		t := time.NewTicker(*cacheTickInterval)
		defer t.Stop()
		for {
			select {
			case <-ctx.Done():
				return
			case <-t.C:
				if err := c.tick(ctx); err != nil {
					log.Errorf("ticking the cache: %v", err)
				}
			}
		}
	}()
	vp := &vaultProxier{v: v, c: c, hc: &http.Client{
		Transport: &http.Transport{
			DialContext:           agentDialer,
			ForceAttemptHTTP2:     true,
			MaxIdleConns:          100,
			IdleConnTimeout:       90 * time.Second,
			TLSHandshakeTimeout:   10 * time.Second,
			ExpectContinueTimeout: 1 * time.Second,
		},
	}}

	// Just try to delete the listen path before we start listening.
	os.Remove(*listenPath)

	listener, err := net.Listen("unix", *listenPath)
	if err != nil {
		log.Exitf("listening on %v: %v", *listenPath, err)
	}

	if err := os.Chmod(*listenPath, 0777); err != nil {
		log.Exitf("chmodding our unix socket at %v: %v", *listenPath, err)
	}

	m := http.NewServeMux()
	m.Handle("/", vp)
	m.HandleFunc("/tokend/cachez", func(rw http.ResponseWriter, r *http.Request) {
		c.l.RLock()
		for username, token := range c.m {
			accessor, _ := token.TokenSecret.TokenAccessor()
			log.Infof("cachez: %v: accessor= %v ; expiration=%v, renewal at=%v", username, accessor, token.expiration, token.renewThreshold)
		}
		c.l.RUnlock()

		rw.Header().Set("Content-type", "text/plain")
		fmt.Fprintf(rw, "written to log\n")
	})

	log.Infof("listening on %v", *listenPath)
	server := &http.Server{Handler: m, ConnContext: attachUserData}
	log.Exit(server.Serve(listener))
}