package main import ( "context" "encoding/json" "flag" "fmt" "io" "net" "net/http" "os" "os/user" "strconv" "strings" "sync" "time" "golang.org/x/sys/unix" log "github.com/golang/glog" vapi "github.com/hashicorp/vault/api" ) var ( listenPath = flag.String("listen_path", "/run/tokend/sock", "Path to listen for connections to tokend.") agentAddr = flag.String("agent_address", "unix:///run/vault-agent/sock", "Address of vault agent.") cacheTickInterval = flag.Duration("cache_tick_interval", 1*time.Minute, "Time between checking for expirations.") userGroup = flag.String("user_group", "users", "Name of a group that indicates that the requesting user is a 'user' and not a service.") ) type userContextKeyType struct{} var userContextKey = userContextKeyType{} type userData struct { Username string IsPlainUser bool } type vaultProxier struct { v *vapi.Client c *tokenUserCache hc *http.Client } func shouldAttachToken(path string) bool { if path == "/v1/auth/token/revoke-self" { return false } return true } func shouldObfuscateTokenResponse(r *http.Request, resp *http.Response, attachedToken bool) bool { path := r.URL.Path if resp.StatusCode < 200 || resp.StatusCode >= 300 { return false } return attachedToken && (path == "/v1/auth/token/lookup-self" || path == "/v1/auth/token/renew-self") } func obfuscateAndCopyResponse(w io.Writer, r io.Reader) error { sec, err := vapi.ParseSecret(r) if err != nil { return err } delete(sec.Data, "id") delete(sec.Data, "accessor") if sec.Auth != nil { sec.Auth.ClientToken = "" sec.Auth.Accessor = "" } return json.NewEncoder(w).Encode(sec) } func (vp *vaultProxier) ServeHTTP(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() u, ok := ctx.Value(userContextKey).(*userData) if !ok { http.Error(rw, "no username could be determined from the request", http.StatusInternalServerError) return } t, err := vp.c.Get(ctx, u.Username, u.IsPlainUser) if err != nil { http.Error(rw, fmt.Sprintf("fetching token for %s: %v", u.Username, err), http.StatusInternalServerError) return } tokenStr, err := t.TokenID() if err != nil { http.Error(rw, fmt.Sprintf("extracting token string for %s: %v", u.Username, err), http.StatusInternalServerError) return } outReq := r.Clone(ctx) outReq.RequestURI = "" outReq.URL.Scheme = "http" outReq.URL.Host = "vault-agent" outReq.Trailer = nil attachedToken := false if shouldAttachToken(r.URL.Path) { outReq.Header.Set("X-Vault-Token", tokenStr) attachedToken = true } log.Infof("incoming request [%v / isPlainUser=%v] %v %v", u.Username, u.IsPlainUser, r.Method, r.URL.Path) start := time.Now() resp, err := vp.hc.Do(outReq) if err != nil { http.Error(rw, fmt.Sprintf("making backend request to vault agent: %v", err), http.StatusInternalServerError) return } defer resp.Body.Close() log.Infof("outgoing response [%v / isPlainUser=%v] %v %v: %v %v", u.Username, u.IsPlainUser, r.Method, r.URL.Path, resp.StatusCode, time.Now().Sub(start)) for k, vs := range resp.Header { rw.Header()[k] = vs } if shouldObfuscateTokenResponse(r, resp, attachedToken) { rw.Header().Del("Content-Length") rw.WriteHeader(resp.StatusCode) if err := obfuscateAndCopyResponse(rw, resp.Body); err != nil { log.Errorf("copying obfuscated lookup-self response: %v", err) } } else { rw.WriteHeader(resp.StatusCode) if _, err := io.Copy(rw, resp.Body); err != nil { log.Errorf("copying response from agent: %v", err) } } } var ( userGroupGidSaved string userGroupGidOnce sync.Once ) func userGroupGid() string { userGroupGidOnce.Do(func() { userGroupGidSaved = "" if *userGroup == "" { // Disabled. return } g, err := user.LookupGroup(*userGroup) if err != nil { log.Errorf("looking up user group %q: %v", *userGroup, err) return } userGroupGidSaved = g.Gid }) return userGroupGidSaved } func attachUserData(ctx context.Context, c net.Conn) context.Context { uc, ok := c.(*net.UnixConn) if !ok { log.Warningf("asked to attachUserData to a non UnixConn (%T)", c) return ctx } raw, err := uc.SyscallConn() if err != nil { log.Warningf("unable to get the underlying raw connection for UnixConn") return ctx } var cred *unix.Ucred if ctrlErr := raw.Control(func(fd uintptr) { cred, err = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED) }); ctrlErr != nil { log.Warningf("control operation to get username failed: %v", err) return ctx } else if err != nil { log.Warningf("getsockoptucred failed: %v", err) return ctx } u, err := user.LookupId(strconv.FormatUint(uint64(cred.Uid), 10)) if err != nil { log.Warningf("looking up UID %d from unix socket: %v", cred.Uid, err) return ctx } isPlainUser := false ugs, err := u.GroupIds() if err != nil { log.Warningf("looking up groups for user %s: %v", u.Username) } else if u.Username == "root" { // We treat root as a plain user for convenience's sake. isPlainUser = true } else { plainUserGid := userGroupGid() for _, ug := range ugs { if ug == plainUserGid { isPlainUser = true break } } } return context.WithValue(ctx, userContextKey, &userData{ Username: u.Username, IsPlainUser: isPlainUser, }) } func main() { flag.Parse() vcfg := vapi.DefaultConfig() vcfg.AgentAddress = *agentAddr v, err := vapi.NewClient(vcfg) if err != nil { log.Exitf("creating vault client against %v: %v", *agentAddr, err) } ctx := context.Background() c := newCache(&vaultTokenInteractor{v}) go func() { t := time.NewTicker(*cacheTickInterval) defer t.Stop() for { select { case <-ctx.Done(): return case <-t.C: if err := c.tick(ctx); err != nil { log.Errorf("ticking the cache: %v", err) } } } }() d := &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, } agentPath := strings.TrimPrefix(*agentAddr, "unix://") vp := &vaultProxier{v: v, c: c, hc: &http.Client{ Transport: &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { // Ignore what they want. return d.DialContext(ctx, "unix", agentPath) }, ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, }, }} // Just try to delete the listen path before we start listening. os.Remove(*listenPath) listener, err := net.Listen("unix", *listenPath) if err != nil { log.Exitf("listening on %v: %v", *listenPath, err) } if err := os.Chmod(*listenPath, 0777); err != nil { log.Exitf("chmodding our unix socket at %v: %v", *listenPath, err) } m := http.NewServeMux() m.Handle("/", vp) m.HandleFunc("/tokend/cachez", func(rw http.ResponseWriter, r *http.Request) { c.l.RLock() for username, token := range c.m { accessor, _ := token.TokenSecret.TokenAccessor() log.Infof("cachez: %v: accessor= %v ; expiration=%v, renewal at=%v", username, accessor, token.expiration, token.renewThreshold) } c.l.RUnlock() rw.Header().Set("Content-type", "text/plain") fmt.Fprintf(rw, "written to log\n") }) log.Infof("listening on %v", *listenPath) server := &http.Server{Handler: m, ConnContext: attachUserData} log.Exit(server.Serve(listener)) }