package main

import (
	"context"
	"errors"
	"flag"
	"fmt"
	"log"
	"os"
	"sync"

	"git.lukegb.com/lukegb/depot/go/nix/nar/narinfo"
	"git.lukegb.com/lukegb/depot/go/nix/nixdrv"
	"git.lukegb.com/lukegb/depot/go/nix/nixpool"
	"git.lukegb.com/lukegb/depot/go/nix/nixstore"
)

var (
	remoteFlag   = flag.String("remote", "unix://", "Remote.")
	remote2Flag  = flag.String("remote2", "unix://", "Remote. But the second one. The destination, for copy.")
	maxConnsFlag = flag.Int("max_conns", 10, "Max connections.")

	verboseFlag = flag.Bool("verbose", false, "Verbose.")
)

func ensure(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker, path string) error {
	return d.EnsurePath(at, path)
}

func narInfo(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker, path string) error {
	ni, err := d.NARInfo(path)
	if err != nil {
		return err
	}
	fmt.Printf("%s\n", ni)
	return nil
}

func tryEnsure(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker, path string) (bool, error) {
	err := d.EnsurePath(at, path)
	if err == nil {
		return true, nil
	}
	var ne nixstore.NixError
	if errors.As(err, &ne) {
		if ne.Code == 1 {
			return false, nil
		}
	}
	return false, err
}

type workState int

const (
	StateCheckingAlreadyExists workState = iota
	StateFetchingPathInfo
	StateEnqueueingReferences
	StateAwaitingReferences
	StateCopying
	StateDone
	StateFailed
)

func (s workState) String() string {
	return map[workState]string{
		StateCheckingAlreadyExists: "checking already exists",
		StateFetchingPathInfo:      "fetching path info",
		StateEnqueueingReferences:  "enqueueing references",
		StateAwaitingReferences:    "awaiting references",
		StateCopying:               "copying",
		StateDone:                  "done!",
		StateFailed:                "failed :(",
	}[s]
}

func (s workState) Terminal() bool { return s == StateDone || s == StateFailed }

type workItem struct {
	doneCh chan struct{}
	state  workState
	path   string

	err       error
	narInfo   *narinfo.NarInfo
	childWork []*workItem

	srcPool *nixpool.Pool
	srcAT   *nixstore.ActivityTracker
	dstPool *nixpool.Pool
	dstAT   *nixstore.ActivityTracker

	wh *workHolder
}

func (i *workItem) checkAlreadyExists(ctx context.Context) (workState, error) {
	// i.state == StateCheckingAlreadyExists
	dst, err := i.dstPool.Get()
	if err != nil {
		return StateFailed, err
	}
	defer i.dstPool.Put(dst)

	ok, err := tryEnsure(ctx, dst, i.dstAT, i.path)
	if err != nil {
		return StateFailed, err
	}

	if ok {
		return StateDone, nil
	}

	return StateFetchingPathInfo, nil
}

func (i *workItem) fetchPathInfo(ctx context.Context) (workState, error) {
	// i.state == StateFetchingPathInfo
	src, err := i.srcPool.Get()
	if err != nil {
		return StateFailed, err
	}
	defer i.srcPool.Put(src)

	ni, err := src.NARInfo(i.path)
	if err != nil {
		return StateFailed, err
	}

	i.narInfo = ni
	return StateEnqueueingReferences, nil
}

func (i *workItem) enqueueReferences(ctx context.Context) (workState, error) {
	// i.state == StateEnqueueingReferences
	const nixStorePrefix = "/nix/store/"
	var childWork []*workItem
	for _, ref := range i.narInfo.References {
		refPath := nixStorePrefix + ref
		if i.path == refPath {
			continue
		}

		childWork = append(childWork, i.wh.addWork(ctx, refPath))
	}
	i.childWork = childWork
	if len(i.childWork) == 0 {
		return StateCopying, nil
	}
	return StateAwaitingReferences, nil
}

func (i *workItem) awaitReferences(ctx context.Context) (workState, error) {
	// i.state == StateAwaitingReferences
	for _, ci := range i.childWork {
		select {
		case <-ctx.Done():
			return StateFailed, ctx.Err()
		case <-ci.doneCh:
			if ci.state == StateFailed {
				return StateFailed, fmt.Errorf("reference %v failed", ci.path)
			}
			if !ci.state.Terminal() {
				panic(fmt.Sprintf("%v: reference %v declared done in non-terminal state %v", i.path, ci.path, ci.state))
			}
		}
	}
	return StateCopying, nil
}

func (i *workItem) performCopy(ctx context.Context) (workState, error) {
	// i.state == StateCopying
	src, err := i.srcPool.Get()
	if err != nil {
		return StateFailed, err
	}
	defer i.srcPool.Put(src)

	dst, err := i.dstPool.Get()
	if err != nil {
		return StateFailed, err
	}
	defer i.dstPool.Put(dst)

	// TODO: the rest of the owl

	return StateDone, nil
}

func (i *workItem) run(ctx context.Context) error {
	defer close(i.doneCh)
	i.state = StateCheckingAlreadyExists
	for !i.state.Terminal() {
		if ctx.Err() != nil {
			return ctx.Err()
		}

		log.Printf("%s: running worker for %s", i.path, i.state)
		var nextState workState
		var err error
		switch i.state {
		case StateCheckingAlreadyExists:
			nextState, err = i.checkAlreadyExists(ctx)
		case StateFetchingPathInfo:
			nextState, err = i.fetchPathInfo(ctx)
		case StateEnqueueingReferences:
			nextState, err = i.enqueueReferences(ctx)
		case StateAwaitingReferences:
			nextState, err = i.awaitReferences(ctx)
		case StateCopying:
			nextState, err = i.performCopy(ctx)
		default:
			log.Printf("%s: ended up in unimplemented state %s", i.path, i.state)
			<-make(chan struct{})
		}
		i.state = nextState
		if err != nil {
			log.Printf("%s: transitioning to %s: %s", i.path, nextState, err)
			i.err = err
			return fmt.Errorf("%s: %w", i.state, err)
		}
		log.Printf("%s: transitioning to %s", i.path, nextState)
	}
	return nil
}

type workHolder struct {
	mu   sync.Mutex
	work map[string]*workItem

	srcPool *nixpool.Pool
	srcAT   *nixstore.ActivityTracker
	dstPool *nixpool.Pool
	dstAT   *nixstore.ActivityTracker
}

func (wh *workHolder) addWork(ctx context.Context, path string) *workItem {
	wh.mu.Lock()
	defer wh.mu.Unlock()

	if i := wh.work[path]; i != nil {
		return i
	}

	i := &workItem{
		doneCh: make(chan struct{}),
		path:   path,

		srcPool: wh.srcPool,
		srcAT:   wh.srcAT,
		dstPool: wh.dstPool,
		dstAT:   wh.dstAT,

		wh: wh,
	}
	wh.work[path] = i
	go i.run(ctx)
	return i
}

func copyPath(ctx context.Context, poolFrom *nixpool.Pool, atFrom *nixstore.ActivityTracker, poolTo *nixpool.Pool, atTo *nixstore.ActivityTracker, path string) error {
	wh := workHolder{
		work:    make(map[string]*workItem),
		srcPool: poolFrom,
		srcAT:   atFrom,
		dstPool: poolTo,
		dstAT:   atTo,
	}
	i := wh.addWork(ctx, path)
	<-i.doneCh
	return i.err
}

var rs nixdrv.Resolver

func loadDerivation(ctx context.Context, path string) (*nixdrv.Derivation, error) {
	return rs.LoadDerivation(ctx, path)
}

func buildDerivation(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker, path string) error {
	drv, err := loadDerivation(ctx, path)
	if err != nil {
		return fmt.Errorf("loading derivation: %w", err)
	}

	basicDrv, err := drv.ToBasicDerivation(ctx, rs)
	if err != nil {
		return fmt.Errorf("resolving: %w", err)
	}

	brs, err := d.BuildDerivation(at, path, basicDrv, nixstore.BMNormal)
	if err != nil {
		return err
	}
	log.Printf("build result: %v", brs)
	return nil
}

func main() {
	flag.Parse()

	nixstore.VerboseActivities = *verboseFlag

	badCall := func(f string, xs ...interface{}) {
		fmt.Fprintf(os.Stderr, f+"\n", xs...)
		flag.Usage()
		os.Exit(1)
	}

	if flag.NArg() < 1 {
		badCall("need a subcommand")
	}

	ctx := context.Background()
	var cmd func(context.Context, *nixstore.Daemon, *nixstore.ActivityTracker) error
	switch flag.Arg(0) {
	case "ensure":
		if flag.NArg() != 2 {
			badCall("`ensure` needs a store path")
		}
		cmd = func(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker) error {
			return ensure(ctx, d, at, flag.Arg(1))
		}
	case "show-derivation":
		if flag.NArg() != 2 {
			badCall("`show-derivation` needs a derivation")
		}
		cmd = func(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker) error {
			drv, err := loadDerivation(ctx, flag.Arg(1))
			if err != nil {
				return err
			}
			fmt.Printf("%#v\n", drv)
			return nil
		}
	case "build-derivation":
		if flag.NArg() != 2 {
			badCall("`build-derivation` needs a derivation")
		}
		cmd = func(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker) error {
			return buildDerivation(ctx, d, at, flag.Arg(1))
		}
	case "nar-info":
		if flag.NArg() != 2 {
			badCall("`nar-info` needs a store path")
		}
		cmd = func(ctx context.Context, d *nixstore.Daemon, at *nixstore.ActivityTracker) error {
			return narInfo(ctx, d, at, flag.Arg(1))
		}
	case "copy":
		if flag.NArg() != 2 {
			badCall("`copy` needs a store path")
		}
		factorySrc, err := nixpool.DaemonDialer(ctx, *remoteFlag)
		if err != nil {
			log.Fatalf("creating dialer for src: %v", err)
		}
		poolSrc := nixpool.New(ctx, factorySrc, *maxConnsFlag)
		atSrc := nixstore.NewActivityTracker()

		factoryDst, err := nixpool.DaemonDialer(ctx, *remote2Flag)
		if err != nil {
			log.Fatalf("creating dialer for dst: %v", err)
		}
		poolDst := nixpool.New(ctx, factoryDst, *maxConnsFlag)
		atDst := nixstore.NewActivityTracker()

		if err := copyPath(ctx, poolSrc, atSrc, poolDst, atDst, flag.Arg(1)); err != nil {
			log.Fatalf("copy: %v", err)
		}
		return
	default:
		badCall("bad subcommand %s", flag.Arg(0))
	}

	at := nixstore.NewActivityTracker()

	factory, err := nixpool.DaemonDialer(ctx, *remoteFlag)
	if err != nil {
		log.Fatalf("DaemonDialer(%q): %v", *remoteFlag, err)
	}
	d, err := factory()
	if err != nil {
		log.Fatalf("connecting to %q: %v", *remoteFlag, err)
	}
	defer d.Close()
	rs = nixdrv.LocalFSResolver{}

	if err := cmd(ctx, d, at); err != nil {
		log.Fatalf("%s: %s", flag.Arg(0), err)
	}
}