290 lines
7.2 KiB
Go
290 lines
7.2 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"flag"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"os"
|
||
|
"os/user"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"golang.org/x/sys/unix"
|
||
|
|
||
|
log "github.com/golang/glog"
|
||
|
vapi "github.com/hashicorp/vault/api"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
listenPath = flag.String("listen_path", "/run/tokend/sock", "Path to listen for connections to tokend.")
|
||
|
agentAddr = flag.String("agent_address", "unix:///run/vault-agent/sock", "Address of vault agent.")
|
||
|
|
||
|
cacheTickInterval = flag.Duration("cache_tick_interval", 1*time.Minute, "Time between checking for expirations.")
|
||
|
userGroup = flag.String("user_group", "users", "Name of a group that indicates that the requesting user is a 'user' and not a service.")
|
||
|
)
|
||
|
|
||
|
type userContextKeyType struct{}
|
||
|
|
||
|
var userContextKey = userContextKeyType{}
|
||
|
|
||
|
type userData struct {
|
||
|
Username string
|
||
|
IsPlainUser bool
|
||
|
}
|
||
|
|
||
|
type vaultProxier struct {
|
||
|
v *vapi.Client
|
||
|
c *tokenUserCache
|
||
|
hc *http.Client
|
||
|
}
|
||
|
|
||
|
func shouldAttachToken(path string) bool {
|
||
|
if path == "/v1/auth/token/revoke-self" {
|
||
|
return false
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func shouldObfuscateTokenResponse(r *http.Request, resp *http.Response, attachedToken bool) bool {
|
||
|
path := r.URL.Path
|
||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||
|
return false
|
||
|
}
|
||
|
return attachedToken && (path == "/v1/auth/token/lookup-self" || path == "/v1/auth/token/renew-self")
|
||
|
}
|
||
|
|
||
|
func obfuscateAndCopyResponse(w io.Writer, r io.Reader) error {
|
||
|
sec, err := vapi.ParseSecret(r)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
delete(sec.Data, "id")
|
||
|
delete(sec.Data, "accessor")
|
||
|
|
||
|
if sec.Auth != nil {
|
||
|
sec.Auth.ClientToken = ""
|
||
|
sec.Auth.Accessor = ""
|
||
|
}
|
||
|
|
||
|
return json.NewEncoder(w).Encode(sec)
|
||
|
}
|
||
|
|
||
|
func (vp *vaultProxier) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||
|
ctx := r.Context()
|
||
|
|
||
|
u, ok := ctx.Value(userContextKey).(*userData)
|
||
|
if !ok {
|
||
|
http.Error(rw, "no username could be determined from the request", http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
t, err := vp.c.Get(ctx, u.Username, u.IsPlainUser)
|
||
|
if err != nil {
|
||
|
http.Error(rw, fmt.Sprintf("fetching token for %s: %v", u.Username, err), http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
tokenStr, err := t.TokenID()
|
||
|
if err != nil {
|
||
|
http.Error(rw, fmt.Sprintf("extracting token string for %s: %v", u.Username, err), http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
outReq := r.Clone(ctx)
|
||
|
outReq.RequestURI = ""
|
||
|
outReq.URL.Scheme = "http"
|
||
|
outReq.URL.Host = "vault-agent"
|
||
|
outReq.Trailer = nil
|
||
|
|
||
|
attachedToken := false
|
||
|
if shouldAttachToken(r.URL.Path) {
|
||
|
outReq.Header.Set("X-Vault-Token", tokenStr)
|
||
|
attachedToken = true
|
||
|
}
|
||
|
log.Infof("incoming request [%v / isPlainUser=%v] %v %v", u.Username, u.IsPlainUser, r.Method, r.URL.Path)
|
||
|
start := time.Now()
|
||
|
|
||
|
resp, err := vp.hc.Do(outReq)
|
||
|
if err != nil {
|
||
|
http.Error(rw, fmt.Sprintf("making backend request to vault agent: %v", err), http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
defer resp.Body.Close()
|
||
|
|
||
|
log.Infof("outgoing response [%v / isPlainUser=%v] %v %v: %v %v", u.Username, u.IsPlainUser, r.Method, r.URL.Path, resp.StatusCode, time.Now().Sub(start))
|
||
|
|
||
|
for k, vs := range resp.Header {
|
||
|
rw.Header()[k] = vs
|
||
|
}
|
||
|
if shouldObfuscateTokenResponse(r, resp, attachedToken) {
|
||
|
rw.Header().Del("Content-Length")
|
||
|
rw.WriteHeader(resp.StatusCode)
|
||
|
if err := obfuscateAndCopyResponse(rw, resp.Body); err != nil {
|
||
|
log.Errorf("copying obfuscated lookup-self response: %v", err)
|
||
|
}
|
||
|
} else {
|
||
|
rw.WriteHeader(resp.StatusCode)
|
||
|
if _, err := io.Copy(rw, resp.Body); err != nil {
|
||
|
log.Errorf("copying response from agent: %v", err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
userGroupGidSaved string
|
||
|
userGroupGidOnce sync.Once
|
||
|
)
|
||
|
|
||
|
func userGroupGid() string {
|
||
|
userGroupGidOnce.Do(func() {
|
||
|
userGroupGidSaved = ""
|
||
|
if *userGroup == "" {
|
||
|
// Disabled.
|
||
|
return
|
||
|
}
|
||
|
g, err := user.LookupGroup(*userGroup)
|
||
|
if err != nil {
|
||
|
log.Errorf("looking up user group %q: %v", *userGroup, err)
|
||
|
return
|
||
|
}
|
||
|
userGroupGidSaved = g.Gid
|
||
|
})
|
||
|
return userGroupGidSaved
|
||
|
}
|
||
|
|
||
|
func attachUserData(ctx context.Context, c net.Conn) context.Context {
|
||
|
uc, ok := c.(*net.UnixConn)
|
||
|
if !ok {
|
||
|
log.Warningf("asked to attachUserData to a non UnixConn (%T)", c)
|
||
|
return ctx
|
||
|
}
|
||
|
|
||
|
raw, err := uc.SyscallConn()
|
||
|
if err != nil {
|
||
|
log.Warningf("unable to get the underlying raw connection for UnixConn")
|
||
|
return ctx
|
||
|
}
|
||
|
|
||
|
var cred *unix.Ucred
|
||
|
if ctrlErr := raw.Control(func(fd uintptr) {
|
||
|
cred, err = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
|
||
|
}); ctrlErr != nil {
|
||
|
log.Warningf("control operation to get username failed: %v", err)
|
||
|
return ctx
|
||
|
} else if err != nil {
|
||
|
log.Warningf("getsockoptucred failed: %v", err)
|
||
|
return ctx
|
||
|
}
|
||
|
|
||
|
u, err := user.LookupId(strconv.FormatUint(uint64(cred.Uid), 10))
|
||
|
if err != nil {
|
||
|
log.Warningf("looking up UID %d from unix socket: %v", cred.Uid, err)
|
||
|
return ctx
|
||
|
}
|
||
|
|
||
|
isPlainUser := false
|
||
|
ugs, err := u.GroupIds()
|
||
|
if err != nil {
|
||
|
log.Warningf("looking up groups for user %s: %v", u.Username)
|
||
|
} else if u.Username == "root" {
|
||
|
// We treat root as a plain user for convenience's sake.
|
||
|
isPlainUser = true
|
||
|
} else {
|
||
|
plainUserGid := userGroupGid()
|
||
|
for _, ug := range ugs {
|
||
|
if ug == plainUserGid {
|
||
|
isPlainUser = true
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return context.WithValue(ctx, userContextKey, &userData{
|
||
|
Username: u.Username,
|
||
|
IsPlainUser: isPlainUser,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func main() {
|
||
|
flag.Parse()
|
||
|
|
||
|
vcfg := vapi.DefaultConfig()
|
||
|
vcfg.AgentAddress = *agentAddr
|
||
|
v, err := vapi.NewClient(vcfg)
|
||
|
if err != nil {
|
||
|
log.Exitf("creating vault client against %v: %v", *agentAddr, err)
|
||
|
}
|
||
|
|
||
|
ctx := context.Background()
|
||
|
c := newCache(&vaultTokenInteractor{v})
|
||
|
go func() {
|
||
|
t := time.NewTicker(*cacheTickInterval)
|
||
|
defer t.Stop()
|
||
|
for {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return
|
||
|
case <-t.C:
|
||
|
if err := c.tick(ctx); err != nil {
|
||
|
log.Errorf("ticking the cache: %v", err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
d := &net.Dialer{
|
||
|
Timeout: 30 * time.Second,
|
||
|
KeepAlive: 30 * time.Second,
|
||
|
}
|
||
|
agentPath := strings.TrimPrefix(*agentAddr, "unix://")
|
||
|
vp := &vaultProxier{v: v, c: c, hc: &http.Client{
|
||
|
Transport: &http.Transport{
|
||
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
|
// Ignore what they want.
|
||
|
return d.DialContext(ctx, "unix", agentPath)
|
||
|
},
|
||
|
ForceAttemptHTTP2: true,
|
||
|
MaxIdleConns: 100,
|
||
|
IdleConnTimeout: 90 * time.Second,
|
||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||
|
ExpectContinueTimeout: 1 * time.Second,
|
||
|
},
|
||
|
}}
|
||
|
|
||
|
// Just try to delete the listen path before we start listening.
|
||
|
os.Remove(*listenPath)
|
||
|
|
||
|
listener, err := net.Listen("unix", *listenPath)
|
||
|
if err != nil {
|
||
|
log.Exitf("listening on %v: %v", *listenPath, err)
|
||
|
}
|
||
|
|
||
|
if err := os.Chmod(*listenPath, 0777); err != nil {
|
||
|
log.Exitf("chmodding our unix socket at %v: %v", *listenPath, err)
|
||
|
}
|
||
|
|
||
|
m := http.NewServeMux()
|
||
|
m.Handle("/", vp)
|
||
|
m.HandleFunc("/tokend/cachez", func(rw http.ResponseWriter, r *http.Request) {
|
||
|
c.l.RLock()
|
||
|
for username, token := range c.m {
|
||
|
accessor, _ := token.TokenSecret.TokenAccessor()
|
||
|
log.Infof("cachez: %v: accessor= %v ; expiration=%v, renewal at=%v", username, accessor, token.expiration, token.renewThreshold)
|
||
|
}
|
||
|
c.l.RUnlock()
|
||
|
|
||
|
rw.Header().Set("Content-type", "text/plain")
|
||
|
fmt.Fprintf(rw, "written to log\n")
|
||
|
})
|
||
|
|
||
|
log.Infof("listening on %v", *listenPath)
|
||
|
server := &http.Server{Handler: m, ConnContext: attachUserData}
|
||
|
log.Exit(server.Serve(listener))
|
||
|
}
|