depot/go/nix/bnix/bnix.go

570 lines
14 KiB
Go
Raw Normal View History

package main
import (
"context"
"errors"
"flag"
"fmt"
"io"
"log"
"math/rand"
"net/url"
"os"
"sync"
"time"
"golang.org/x/crypto/ssh"
"hg.lukegb.com/lukegb/depot/go/nix/nar/narinfo"
"hg.lukegb.com/lukegb/depot/go/nix/nixdrv"
"hg.lukegb.com/lukegb/depot/go/nix/nixstore"
)
var (
remoteFlag = flag.String("remote", "unix://", "Remote.")
remote2Flag = flag.String("remote2", "unix://", "Remote. But the second one. The destination, for copy.")
maxConnsFlag = flag.Int("max_conns", 10, "Max connections.")
verboseFlag = flag.Bool("verbose", false, "Verbose.")
)
func ensure(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker, path string) error {
return d.EnsurePath(at, path)
}
func narInfo(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker, path string) error {
ni, err := d.NARInfo(path)
if err != nil {
return err
}
fmt.Printf("%s\n", ni)
return nil
}
func tryEnsure(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker, path string) (bool, error) {
err := d.EnsurePath(at, path)
if err == nil {
return true, nil
}
var ne nixstore.NixError
if errors.As(err, &ne) {
if ne.Code == 1 {
return false, nil
}
}
return false, err
}
type connectionPool struct {
factory func() (*nixstore.Daemon, error)
cond *sync.Cond
limit int
generated int
available []*nixstore.Daemon
}
func newConnectionPool(ctx context.Context, dst string, limit int) *connectionPool {
return &connectionPool{
factory: func() (*nixstore.Daemon, error) {
var d *nixstore.Daemon
var err error
retryInterval := 10 * time.Millisecond
maxRetry := 10 * time.Second
for n := 0; n < 20; n++ {
d, _, err = connectTo(ctx, dst)
if err == nil {
break
}
time.Sleep(retryInterval + time.Duration((rand.Float64()-0.5)*float64(retryInterval)))
retryInterval = retryInterval * 2
if retryInterval > maxRetry {
retryInterval = maxRetry
}
}
return d, err
},
cond: sync.NewCond(new(sync.Mutex)),
limit: limit,
}
}
func (p *connectionPool) getFromAvailableLocked() *nixstore.Daemon {
end := len(p.available) - 1
d := p.available[end]
p.available = p.available[:end]
return d
}
func (p *connectionPool) Get() (*nixstore.Daemon, error) {
p.cond.L.Lock()
if len(p.available) > 0 {
d := p.getFromAvailableLocked()
p.cond.L.Unlock()
return d, nil
}
if p.generated < p.limit {
p.generated++
p.cond.L.Unlock()
d, err := p.factory()
if err != nil {
p.cond.L.Lock()
p.generated--
p.cond.L.Unlock()
return nil, err
}
return d, nil
}
for len(p.available) == 0 {
p.cond.Wait()
}
d := p.getFromAvailableLocked()
p.cond.L.Unlock()
return d, nil
}
func (p *connectionPool) Put(d *nixstore.Daemon) {
p.cond.L.Lock()
p.available = append(p.available, d)
p.cond.Signal()
p.cond.L.Unlock()
}
type workState int
const (
StateCheckingAlreadyExists workState = iota
StateFetchingPathInfo
StateEnqueueingReferences
StateAwaitingReferences
StateCopying
StateDone
StateFailed
)
func (s workState) String() string {
return map[workState]string{
StateCheckingAlreadyExists: "checking already exists",
StateFetchingPathInfo: "fetching path info",
StateEnqueueingReferences: "enqueueing references",
StateAwaitingReferences: "awaiting references",
StateCopying: "copying",
StateDone: "done!",
StateFailed: "failed :(",
}[s]
}
func (s workState) Terminal() bool { return s == StateDone || s == StateFailed }
type workItem struct {
doneCh chan struct{}
state workState
path string
err error
narInfo *narinfo.NarInfo
childWork []*workItem
srcPool *connectionPool
srcAT *nixstore.ActivityTracker
dstPool *connectionPool
dstAT *nixstore.ActivityTracker
wh *workHolder
}
func (i *workItem) checkAlreadyExists(ctx context.Context) (workState, error) {
// i.state == StateCheckingAlreadyExists
dst, err := i.dstPool.Get()
if err != nil {
return StateFailed, err
}
defer i.dstPool.Put(dst)
ok, err := tryEnsure(ctx, dst, i.dstAT, i.path)
if err != nil {
return StateFailed, err
}
if ok {
return StateDone, nil
}
return StateFetchingPathInfo, nil
}
func (i *workItem) fetchPathInfo(ctx context.Context) (workState, error) {
// i.state == StateFetchingPathInfo
src, err := i.srcPool.Get()
if err != nil {
return StateFailed, err
}
defer i.srcPool.Put(src)
ni, err := src.NARInfo(i.path)
if err != nil {
return StateFailed, err
}
i.narInfo = ni
return StateEnqueueingReferences, nil
}
func (i *workItem) enqueueReferences(ctx context.Context) (workState, error) {
// i.state == StateEnqueueingReferences
const nixStorePrefix = "/nix/store/"
var childWork []*workItem
for _, ref := range i.narInfo.References {
refPath := nixStorePrefix + ref
if i.path == refPath {
continue
}
childWork = append(childWork, i.wh.addWork(ctx, refPath))
}
i.childWork = childWork
if len(i.childWork) == 0 {
return StateCopying, nil
}
return StateAwaitingReferences, nil
}
func (i *workItem) awaitReferences(ctx context.Context) (workState, error) {
// i.state == StateAwaitingReferences
for _, ci := range i.childWork {
select {
case <-ctx.Done():
return StateFailed, ctx.Err()
case <-ci.doneCh:
if ci.state == StateFailed {
return StateFailed, fmt.Errorf("reference %v failed", ci.path)
}
if !ci.state.Terminal() {
panic(fmt.Sprintf("%v: reference %v declared done in non-terminal state %v", i.path, ci.path, ci.state))
}
}
}
return StateCopying, nil
}
func (i *workItem) performCopy(ctx context.Context) (workState, error) {
// i.state == StateCopying
src, err := i.srcPool.Get()
if err != nil {
return StateFailed, err
}
defer i.srcPool.Put(src)
dst, err := i.dstPool.Get()
if err != nil {
return StateFailed, err
}
defer i.dstPool.Put(dst)
// TODO: the rest of the owl
return StateDone, nil
}
func (i *workItem) run(ctx context.Context) error {
defer close(i.doneCh)
i.state = StateCheckingAlreadyExists
for !i.state.Terminal() {
if ctx.Err() != nil {
return ctx.Err()
}
log.Printf("%s: running worker for %s", i.path, i.state)
var nextState workState
var err error
switch i.state {
case StateCheckingAlreadyExists:
nextState, err = i.checkAlreadyExists(ctx)
case StateFetchingPathInfo:
nextState, err = i.fetchPathInfo(ctx)
case StateEnqueueingReferences:
nextState, err = i.enqueueReferences(ctx)
case StateAwaitingReferences:
nextState, err = i.awaitReferences(ctx)
case StateCopying:
nextState, err = i.performCopy(ctx)
default:
log.Printf("%s: ended up in unimplemented state %s", i.path, i.state)
<-make(chan struct{})
}
i.state = nextState
if err != nil {
log.Printf("%s: transitioning to %s: %s", i.path, nextState, err)
i.err = err
return fmt.Errorf("%s: %w", i.state, err)
}
log.Printf("%s: transitioning to %s", i.path, nextState)
}
return nil
}
type workHolder struct {
mu sync.Mutex
work map[string]*workItem
srcPool *connectionPool
srcAT *nixstore.ActivityTracker
dstPool *connectionPool
dstAT *nixstore.ActivityTracker
}
func (wh *workHolder) addWork(ctx context.Context, path string) *workItem {
wh.mu.Lock()
defer wh.mu.Unlock()
if i := wh.work[path]; i != nil {
return i
}
i := &workItem{
doneCh: make(chan struct{}),
path: path,
srcPool: wh.srcPool,
srcAT: wh.srcAT,
dstPool: wh.dstPool,
dstAT: wh.dstAT,
wh: wh,
}
wh.work[path] = i
go i.run(ctx)
return i
}
func copyPath(ctx context.Context, poolFrom *connectionPool, atFrom *nixstore.ActivityTracker, poolTo *connectionPool, atTo *nixstore.ActivityTracker, path string) error {
wh := workHolder{
work: make(map[string]*workItem),
srcPool: poolFrom,
srcAT: atFrom,
dstPool: poolTo,
dstAT: atTo,
}
i := wh.addWork(ctx, path)
<-i.doneCh
return i.err
}
var rs nixdrv.Resolver
func loadDerivation(ctx context.Context, path string) (*nixdrv.Derivation, error) {
return rs.LoadDerivation(path)
}
func buildDerivation(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker, path string) error {
drv, err := loadDerivation(ctx, path)
if err != nil {
return fmt.Errorf("loading derivation: %w", err)
}
basicDrv, err := drv.ToBasicDerivation(rs)
if err != nil {
return fmt.Errorf("resolving: %w", err)
}
brs, err := d.BuildDerivation(at, path, basicDrv, nixstore.BMNormal)
if err != nil {
return err
}
log.Printf("build result: %v", brs)
return nil
}
func connectTo(ctx context.Context, remote string) (*nixstore.Daemon, nixdrv.Resolver, error) {
u, err := url.Parse(remote)
if err != nil {
return nil, nil, fmt.Errorf("parsing remote %q as URL: %w", remote, err)
}
switch u.Scheme {
case "unix":
if u.Path == "" {
u.Path = nixstore.DaemonSock
}
d, err := nixstore.OpenDaemon(u.Path)
if err != nil {
return nil, nil, err
}
return d, nixdrv.LocalFSResolver{}, nil
case "ssh-ng":
// Construct a ClientConfig from the URL.
cfg := &ssh.ClientConfig{}
if u.Query().Has("privkey") {
var keys []ssh.Signer
for _, privkeyPath := range u.Query()["privkey"] {
privkeyF, err := os.Open(privkeyPath)
if err != nil {
return nil, nil, fmt.Errorf("opening privkey %q: %w", privkeyPath, err)
}
defer privkeyF.Close()
privkeyB, err := io.ReadAll(privkeyF)
if err != nil {
return nil, nil, fmt.Errorf("reading privkey %q: %w", privkeyPath, err)
}
privkey, err := ssh.ParsePrivateKey(privkeyB)
if err != nil {
return nil, nil, fmt.Errorf("parsing privkey %q: %w", privkeyPath, err)
}
keys = append(keys, privkey)
}
cfg.Auth = append(cfg.Auth, ssh.PublicKeys(keys...))
}
if u.User != nil {
cfg.User = u.User.Username()
if pw, ok := u.User.Password(); ok {
cfg.Auth = append(cfg.Auth, ssh.Password(pw))
}
}
switch {
case u.Query().Has("host-key"):
hkStr := u.Query().Get("host-key")
_, _, hk, _, _, err := ssh.ParseKnownHosts(append([]byte("x "), []byte(hkStr)...))
if err != nil {
return nil, nil, fmt.Errorf("parsing host-key %q: %w", hkStr, err)
}
cfg.HostKeyCallback = ssh.FixedHostKey(hk)
case u.Query().Has("insecure-allow-any-ssh-host-key"):
cfg.HostKeyCallback = ssh.InsecureIgnoreHostKey()
default:
return nil, nil, fmt.Errorf("some SSH host key configuration is required (?host-key=; ?insecure-allow-any-ssh-host-key)")
}
// Work out other misc parameters.
// ...remote command.
remoteCmd := "nix-daemon --stdio"
if u.Query().Has("remote-cmd") {
remoteCmd = u.Query().Get("remote-cmd")
}
// Work out the host:port to connect to.
remote := u.Hostname()
if portStr := u.Port(); portStr != "" {
remote = remote + ":" + portStr
} else {
remote = remote + ":22"
}
conn, err := ssh.Dial("tcp", remote, cfg)
if err != nil {
return nil, nil, fmt.Errorf("dialing %v via SSH: %w", remote, err)
}
sess, err := conn.NewSession()
if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("opening SSH session to %v: %w", remote, err)
}
stdin, err := sess.StdinPipe()
if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("opening stdin pipe: %w", err)
}
stdout, err := sess.StdoutPipe()
if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("opening stdout pipe: %w", err)
}
if err := sess.Start(remoteCmd); err != nil {
conn.Close()
return nil, nil, fmt.Errorf("starting %q: %w", remoteCmd, err)
}
d, err := nixstore.OpenDaemonWithIOs(stdout, stdin, conn)
if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("establishing connection to daemon: %w", err)
}
return d, nixdrv.LocalFSResolver{}, nil // TODO: change from LocalFSResolver?
default:
return nil, nil, fmt.Errorf("unknown remote %q", remote)
}
}
func main() {
flag.Parse()
nixstore.VerboseActivities = *verboseFlag
badCall := func(f string, xs ...interface{}) {
fmt.Fprintf(os.Stderr, f+"\n", xs...)
flag.Usage()
os.Exit(1)
}
if flag.NArg() < 1 {
badCall("need a subcommand")
}
ctx := context.Background()
var cmd func(context.Context, *nixstore.Daemon, *nixstore.ActivityTracker) error
switch flag.Arg(0) {
case "ensure":
if flag.NArg() != 2 {
badCall("`ensure` needs a store path")
}
cmd = func(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker) error {
return ensure(ctx, d, at, flag.Arg(1))
}
case "show-derivation":
if flag.NArg() != 2 {
badCall("`show-derivation` needs a derivation")
}
cmd = func(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker) error {
drv, err := loadDerivation(ctx, flag.Arg(1))
if err != nil {
return err
}
fmt.Printf("%#v\n", drv)
return nil
}
case "build-derivation":
if flag.NArg() != 2 {
badCall("`build-derivation` needs a derivation")
}
cmd = func(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker) error {
return buildDerivation(ctx, d, at, flag.Arg(1))
}
case "nar-info":
if flag.NArg() != 2 {
badCall("`nar-info` needs a store path")
}
cmd = func(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker) error {
return narInfo(ctx, d, at, flag.Arg(1))
}
case "copy":
if flag.NArg() != 2 {
badCall("`copy` needs a store path")
}
poolSrc := newConnectionPool(ctx, *remoteFlag, *maxConnsFlag)
atSrc := nixstore.NewActivityTracker()
poolDst := newConnectionPool(ctx, *remote2Flag, *maxConnsFlag)
atDst := nixstore.NewActivityTracker()
if err := copyPath(ctx, poolSrc, atSrc, poolDst, atDst, flag.Arg(1)); err != nil {
log.Fatalf("copy: %v", err)
}
return
default:
badCall("bad subcommand %s", flag.Arg(0))
}
at := nixstore.NewActivityTracker()
d, rss, err := connectTo(ctx, *remoteFlag)
if err != nil {
log.Fatalf("connectTo(%q): %v", *remoteFlag, err)
}
defer d.Close()
rs = rss
if err := cmd(ctx, d, at); err != nil {
log.Fatalf("%s: %s", flag.Arg(0), err)
}
}