// SPDX-FileCopyrightText: 2023 Luke Granger-Brown <depot@lukegb.com>
//
// SPDX-License-Identifier: Apache-2.0

package nixbuild

import (
	"context"
	"errors"
	"fmt"
	"log"
	"math/rand"
	"path"
	"sort"
	"sync"

	"hg.lukegb.com/lukegb/depot/go/nix/nar/narinfo"
	"hg.lukegb.com/lukegb/depot/go/nix/nixdrv"
	"hg.lukegb.com/lukegb/depot/go/nix/nixstore"
)

type WorkItem struct {
	doneCh chan struct{}
	state  State

	// Either (path) or (drv, drvPath) must be set.
	path    string
	drv     *nixdrv.BasicDerivation
	drvPath string

	err     error
	cleanup []func()

	// StateFetchingReferencesFromResolvers, StateFetchingFromResolver, StateDoneFromResolver
	resolver Resolver
	narInfos []*narinfo.NarInfo

	// StateFetchingReferencesFromResolvers, StateCopyingReferencesToBuilder
	childWork []*WorkItem

	// StateChoosingBuilder, StateResolvingReferencesOnBuilder, StateCopyingReferencesToBuilder, StateBuildingDerivation, StateCopying, StateDoneBuilt
	builder                 *Builder
	builderDaemon           *nixstore.Daemon
	builderNeededReferences []string

	coord *Coordinator
	at    *nixstore.ActivityTracker
}

func (i *WorkItem) paths() []string {
	if i.path != "" {
		return []string{i.path}
	}
	os := make([]string, 0, len(i.drv.Outputs))
	for _, o := range i.drv.Outputs {
		os = append(os, o.Path)
	}
	return os
}

func (i *WorkItem) checkOnTarget(ctx context.Context) (State, error) {
	// i.state == StateCheckingOnTarget
	if i.coord.cfg.Target == nil {
		return StateResolvingOnTarget, nil
	}

	wantPaths := i.paths()
	gotPaths, err := i.coord.cfg.Target.ValidPaths(ctx, wantPaths)
	if err != nil {
		return StateFailed, err
	}
	if len(wantPaths) == len(gotPaths) {
		return StateDoneAlreadyExisted, nil
	}
	return StateResolvingOnTarget, nil
}

func (i *WorkItem) resolveOnTarget(ctx context.Context) (State, error) {
	// i.state == StateResolvingOnTarget
	if i.coord.cfg.Target == nil || !i.coord.cfg.ResolveOnTarget {
		return StateCheckingResolvers, nil
	}

	wantPaths := i.paths()
	resolvedPaths, err := i.coord.cfg.Target.EnsurePaths(ctx, wantPaths)
	if err != nil {
		var ne nixstore.NixError
		if !errors.As(err, &ne) {
			return StateFailed, err
		}
	}
	if len(wantPaths) == len(resolvedPaths) {
		return StateDoneResolvedOnTarget, nil
	}
	return StateCheckingResolvers, nil
}

func (i *WorkItem) checkResolvers(ctx context.Context) (State, error) {
	// i.state == StateCheckingResolvers
	if len(i.coord.cfg.Resolvers) == 0 {
		return StateLookingUpDerivation, nil
	}

	type result struct {
		r  Resolver
		ni []*narinfo.NarInfo
	}
	validResolversCh := make(chan result, len(i.coord.cfg.Resolvers))
	var wg sync.WaitGroup
	for _, r := range i.coord.cfg.Resolvers {
		r := r
		wg.Add(1)
		go func() {
			defer wg.Done()
			res := result{r: r}
			for _, p := range i.paths() {
				ni, err := r.FetchNarInfo(ctx, p)
				if err != nil {
					// Nope.
					return
				}
				res.ni = append(res.ni, ni)
			}
			validResolversCh <- res
		}()
	}
	wg.Wait()
	close(validResolversCh)

	validResolvers := make([]result, 0, len(i.coord.cfg.Resolvers))
	for r := range validResolversCh {
		validResolvers = append(validResolvers, r)
	}
	if len(validResolvers) == 0 {
		return StateLookingUpDerivation, nil
	}
	sort.Slice(validResolvers, func(i, j int) bool {
		return validResolvers[i].r.RelativePriority() < validResolvers[j].r.RelativePriority()
	})
	i.resolver = validResolvers[0].r
	i.narInfos = validResolvers[0].ni
	return StateFetchingReferencesFromResolvers, nil
}

func pathSet(ps []string) map[string]bool {
	s := make(map[string]bool)
	for _, p := range ps {
		s[p] = true
	}
	return s
}

func (i *WorkItem) pathSet() map[string]bool {
	return pathSet(i.paths())
}

func (i *WorkItem) fetchReferencesFromResolvers(ctx context.Context) (State, error) {
	// i.state == StateFetchingReferencesFromResolvers
	// i.narInfos populated
	if i.coord.cfg.Target == nil {
		// Nowhere to fetch to.
		return StateFetchingFromResolver, nil
	}

	ownPaths := i.pathSet()
	refs := make(map[string]bool)
	var wis []*WorkItem
	sp := i.coord.cfg.storePath()
	for _, ni := range i.narInfos {
		for _, ref := range ni.References {
			refPath := path.Join(sp, ref)
			if refs[refPath] || ownPaths[refPath] {
				continue
			}
			refs[refPath] = true
			wis = append(wis, i.coord.AddWork(ctx, refPath))
		}
	}
	if len(refs) == 0 {
		// No references to fetch.
		return StateFetchingFromResolver, nil
	}
	i.childWork = wis

	if err := i.awaitChildWorkItems(ctx); err != nil {
		return StateFailed, err
	}
	return StateFetchingFromResolver, nil
}

func (i *WorkItem) awaitChildWorkItems(ctx context.Context) error {
	// i.childWork populated
	for _, wi := range i.childWork {
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-wi.doneCh:
			if wi.err != nil {
				return fmt.Errorf("dependency %v failed: %w", wi.paths(), wi.err)
			}
		}
	}
	return nil
}

func (i *WorkItem) fetchFromResolver(ctx context.Context) (State, error) {
	// i.state == StateFetchingFromResolver
	// i.resolver populated
	// i.narInfos populated
	if i.coord.cfg.Target == nil {
		// Present on a resolver and nowhere to copy to; we're done here.
		return StateDoneFromResolver, nil
	}

	type result struct {
		err error
	}
	resCh := make(chan result, len(i.narInfos))

	var wg sync.WaitGroup
	for _, ni := range i.narInfos {
		ni := ni
		wg.Add(1)
		go func() {
			defer wg.Done()
			rc, err := i.resolver.FetchNar(ctx, ni)
			if err != nil {
				resCh <- result{err}
				return
			}
			if err := i.coord.cfg.Target.PutNar(ctx, ni, rc); err != nil {
				resCh <- result{err}
				return
			}
			rc.Close()
			resCh <- result{}
		}()
	}
	wg.Wait()
	close(resCh)
	for res := range resCh {
		if res.err != nil {
			return StateFailed, res.err
		}
	}
	return StateDoneFromResolver, nil
}

func (i *WorkItem) lookUpDerivation(ctx context.Context) (State, error) {
	// i.state == StateLookingUpDerivation
	if i.drv != nil {
		return StateChoosingBuilder, nil
	}
	type result struct {
		drv     *nixdrv.BasicDerivation
		drvPath string
	}
	var wg sync.WaitGroup
	resCh := make(chan result, len(i.coord.cfg.Resolvers))
	for _, res := range i.coord.cfg.Resolvers {
		res, ok := res.(DerivationFetcher)
		if !ok {
			continue
		}
		wg.Add(1)
		go func() {
			defer wg.Done()
			drv, drvPath, err := res.FetchDerivationForPath(ctx, i.path)
			if err != nil {
				return
			}
			resCh <- result{drv, drvPath}
		}()
	}
	wg.Wait()
	close(resCh)

	for res := range resCh {
		i.drv = res.drv
		i.drvPath = res.drvPath
		return StateChoosingBuilder, nil
	}
	return StateFailed, fmt.Errorf("couldn't find someone who knows how to build %v", i.path)
}

func (i *WorkItem) chooseBuilder(ctx context.Context) (State, error) {
	// i.state == StateChoosingBuilder
	// i.drv non-empty
	if len(i.coord.cfg.Builders) == 0 {
		return StateFailed, fmt.Errorf("no builders configured")
	}
	var err error
	// TODO: a better algorithm for picking a builder than "the first entry in the list" :)
	type builderCandidate struct {
		builder  *Builder
		busyness float64
	}
	var candidates []builderCandidate
	rsf := i.drv.RequiredSystemFeatures()
	for _, b := range i.coord.cfg.Builders {
		if !b.SupportedPlatforms[i.drv.Platform] || !b.SupportsAll(rsf) || !b.CanBuild(rsf) {
			continue
		}
		candidates = append(candidates, builderCandidate{
			builder:  b,
			busyness: b.Pool.Busyness(),
		})
	}
	if len(candidates) == 0 {
		return StateFailed, fmt.Errorf("no builders support platform %q and features %v", i.drv.Platform, rsf)
	}
	sort.Slice(candidates, func(i, j int) bool {
		return candidates[i].busyness < candidates[j].busyness
	})
	minBusyness := candidates[0].busyness
	sameBusynessCandidates := candidates
	for n, c := range candidates {
		if c.busyness != minBusyness {
			sameBusynessCandidates = candidates[:n]
			break
		}
	}
	rand.Shuffle(len(sameBusynessCandidates), func(i, j int) {
		sameBusynessCandidates[i], sameBusynessCandidates[j] = sameBusynessCandidates[j], sameBusynessCandidates[i]
	})
	i.builder = sameBusynessCandidates[0].builder
	i.builderDaemon, err = sameBusynessCandidates[0].builder.Pool.Get()
	if err != nil {
		return StateFailed, err
	}
	i.cleanup = append(i.cleanup, func() {
		if i.builderDaemon != nil {
			i.builder.Pool.Put(i.builderDaemon)
		}
		i.builderDaemon = nil
	})
	return StateResolvingReferencesOnBuilder, nil
}

func (i *WorkItem) ensureReferencesOnBuilder(ctx context.Context) (State, error) {
	// i.state == StateEnsuringReferencesOnBuilder
	// i.builder non-empty
	// i.drv non-empty
	for path := range i.drv.InputSrcs {
		if !i.coord.cfg.ResolveDependenciesOnBuilders {
			i.builderNeededReferences = append(i.builderNeededReferences, path)
			continue
		}

		if err := i.builderDaemon.EnsurePath(i.at, path); err != nil {
			var pnv nixstore.PathNotValidError
			var ne nixstore.NixError
			if errors.As(err, &pnv) || errors.As(err, &ne) {
				i.builderNeededReferences = append(i.builderNeededReferences, path)
			} else {
				return StateFailed, fmt.Errorf("path %v: %w", path, err)
			}
		}
	}
	return StateCopyingReferencesToBuilder, nil
}

func (i *WorkItem) copyReferencesToBuilder(ctx context.Context) (State, error) {
	// i.state == StateCopyingReferencesToBuilder
	// i.builder, i.builderPool non-empty
	if len(i.builderNeededReferences) == 0 {
		return StateBuildingDerivation, nil
	}

	i.builder.Pool.Put(i.builderDaemon)
	i.builderDaemon = nil
	var wis []*WorkItem
	for _, p := range i.builderNeededReferences {
		// Must be in same order in wis/i.childWork as i.builderNeededReferences!
		wis = append(wis, i.coord.AddWork(ctx, p))
	}
	i.childWork = wis
	if err := i.awaitChildWorkItems(ctx); err != nil {
		return StateFailed, err
	}
	var err error
	i.builderDaemon, err = i.builder.Pool.Get()
	if err != nil {
		return StateFailed, fmt.Errorf("reacquiring builder connection: %w", err)
	}

	// Recheck which things we need to fetch; maybe we built or resolved some of them on the box.
	validPaths, err := i.builderDaemon.ValidPaths(i.builderNeededReferences)
	if err != nil {
		return StateFailed, fmt.Errorf("checking valid paths on builder: %w", err)
	}
	validOnBuilder := pathSet(validPaths)

	for n, wi := range i.childWork {
		path := i.builderNeededReferences[n]
		if validOnBuilder[path] {
			continue
		}

		var f Fetcher
		var ni *narinfo.NarInfo
		switch wi.state {
		case StateDoneAlreadyExisted, StateDoneResolvedOnTarget:
			// copy from Target
			f = i.coord.cfg.Target
		case StateDoneFromResolver:
			// copy from Resolver (wi.resolver)
			f = wi.resolver
			// We should already have the NAR info, so skip a step.
			for _, wini := range wi.narInfos {
				if wini.StorePath == path {
					ni = wini
				}
			}
		case StateDoneBuilt:
			// copy from Builder (wi.builderPool) or Target
			if wi.builder == i.builder {
				// Nothing to do, we built the dep on the same builder.
				continue
			}

			if i.coord.cfg.Target != nil {
				f = i.coord.cfg.Target
			} else {
				f = PeerFetcher{wi.builder.Pool}
			}
		default:
			return StateFailed, fmt.Errorf("don't know where the target of %s is (for path %s)", wi.state, wi.paths())
		}

		if ni == nil {
			var err error
			ni, err = f.FetchNarInfo(ctx, path)
			if err != nil {
				return StateFailed, fmt.Errorf("looking up dependency NAR info for %v: %w", path, err)
			}
		}
		rc, err := f.FetchNar(ctx, ni)
		if err != nil {
			return StateFailed, fmt.Errorf("fetching NAR for dependency %v: %w", wi.paths(), err)
		}
		if err := i.builderDaemon.AddToStoreNar(ni, rc); err != nil {
			return StateFailed, fmt.Errorf("copying NAR to builder for dependency %v: %w", wi.paths(), err)
		}
		rc.Close()
	}

	return StateBuildingDerivation, nil
}

func (i *WorkItem) build(ctx context.Context) (State, error) {
	// i.state == StateBuildingDerivation
	// i.builder non-empty

	brs, err := i.builderDaemon.BuildDerivation(i.at, i.drvPath, i.drv, nixstore.BMNormal)
	if err != nil {
		return StateFailed, err
	}
	if !brs.IsSuccess() {
		return StateFailed, fmt.Errorf("build failed: %v", brs)
	}
	return StateCopying, nil
}

func topoPaths(paths []string, pathNARs map[string]*narinfo.NarInfo) []string {
	pset := pathSet(paths)

	var outPaths []string
	donePaths := map[string]bool{}
	for len(outPaths) != len(paths) {
	pathLoop:
		for _, p := range paths {
			if donePaths[p] {
				continue
			}
			for _, ref := range pathNARs[p].References {
				refX := path.Join(path.Dir(p), ref)
				if pset[refX] && p != refX && !donePaths[refX] {
					continue pathLoop
				}
			}

			donePaths[p] = true
			outPaths = append(outPaths, p)
		}
	}
	return outPaths
}

func (i *WorkItem) copyFromBuilder(ctx context.Context) (State, error) {
	// i.state == StateCopying
	// i.builder non-empty
	if i.coord.cfg.Target == nil {
		return StateDoneBuilt, nil
	}

	// Check if these paths are already valid on target.
	validPaths, err := i.coord.cfg.Target.ValidPaths(ctx, i.paths())
	if err != nil {
		return StateFailed, fmt.Errorf("checking valid paths on target: %w", err)
	}
	validPathsSet := make(map[string]bool)
	for _, p := range validPaths {
		validPathsSet[p] = true
	}

	var drvPaths []string
	for _, out := range i.drv.Outputs {
		if validPathsSet[out.Path] {
			continue
		}
		drvPaths = append(drvPaths, out.Path)
	}

	pathNARs := map[string]*narinfo.NarInfo{}
	for _, path := range drvPaths {
		ni, err := i.builderDaemon.NARInfo(path)
		if err != nil {
			return StateFailed, fmt.Errorf("fetching narinfo for %v: %w", path, err)
		}
		pathNARs[path] = ni
	}

	for _, path := range topoPaths(drvPaths, pathNARs) {
		ni := pathNARs[path]
		rc, err := i.builderDaemon.NARFromPath(path)
		if err != nil {
			return StateFailed, fmt.Errorf("fetching nar for %v: %w", path, err)
		}
		if err := i.coord.cfg.Target.PutNar(ctx, ni, rc); err != nil {
			return StateFailed, fmt.Errorf("putting nar for %v to target: %w", path, err)
		}
	}
	return StateDoneBuilt, nil
}

func (i *WorkItem) run(ctx context.Context) error {
	defer close(i.doneCh)
	defer func() {
		for _, cleanup := range i.cleanup {
			cleanup()
		}
	}()
	i.state = StateCheckingOnTarget
	for !i.state.Terminal() {
		if ctx.Err() != nil {
			return ctx.Err()
		}

		log.Printf("%s: running worker for %s", i.path, i.state)
		var nextstate State
		var err error
		switch i.state {
		case StateCheckingOnTarget:
			nextstate, err = i.checkOnTarget(ctx)
		case StateResolvingOnTarget:
			nextstate, err = i.resolveOnTarget(ctx)
		case StateCheckingResolvers:
			nextstate, err = i.checkResolvers(ctx)
		case StateFetchingReferencesFromResolvers:
			nextstate, err = i.fetchReferencesFromResolvers(ctx)
		case StateFetchingFromResolver:
			nextstate, err = i.fetchFromResolver(ctx)
		case StateLookingUpDerivation:
			nextstate, err = i.lookUpDerivation(ctx)
		case StateChoosingBuilder:
			nextstate, err = i.chooseBuilder(ctx)
		case StateResolvingReferencesOnBuilder:
			nextstate, err = i.ensureReferencesOnBuilder(ctx)
		case StateCopyingReferencesToBuilder:
			nextstate, err = i.copyReferencesToBuilder(ctx)
		case StateBuildingDerivation:
			nextstate, err = i.build(ctx)
		case StateCopying:
			nextstate, err = i.copyFromBuilder(ctx)
		default:
			log.Printf("%s: ended up in unimplemented state %s", i.path, i.state)
			<-make(chan struct{})
		}
		prevstate := i.state
		i.state = nextstate
		if err != nil {
			log.Printf("%s: transitioning to %s: %s", i.path, nextstate, err)
			i.err = fmt.Errorf("%s: %w", prevstate, err)
			return i.err
		}
		log.Printf("%s: transitioning to %s", i.path, nextstate)
	}
	return nil
}

func (wi *WorkItem) Done() <-chan struct{} { return wi.doneCh }