// Binary bcachegc garbage collects a Nix binary cache.
package main

import (
	"context"
	"flag"
	"fmt"
	"io"
	"log"
	"net/http"
	"os"
	"path/filepath"
	"regexp"
	"sort"
	"strings"
	"sync"
	"time"

	"github.com/jamespfennell/xz"
	"github.com/numtide/go-nix/nixbase32"
	"gocloud.dev/blob"
	"golang.org/x/sync/errgroup"
	"golang.org/x/sync/singleflight"
	"hg.lukegb.com/lukegb/depot/go/nix/nar"
	"hg.lukegb.com/lukegb/depot/go/nix/nar/narinfo"
	"hg.lukegb.com/lukegb/depot/go/nix/nixstore"

	_ "gocloud.dev/blob/fileblob"
	_ "gocloud.dev/blob/gcsblob"
	_ "hg.lukegb.com/lukegb/depot/go/vault/vaultgcsblob"
)

var (
	blobURLFlag              = flag.String("cache_url", "", "Cache URL")
	stateSummaryIntervalFlag = flag.Duration("state_summary_interval", 10*time.Second, "Time between state summary outputs.")
	deepCheckGalacticFlag    = flag.Bool("deep_check_galactic", false, "Ensure that all references are available in the cache before skipping, rather than just checking that the path itself is available.")
	maximumConcurrentUploads = flag.Int("maximum_concurrent_uploads", 64, "Limit for the number of concurrent uploads.")
)

var (
	hashExtractRegexp = regexp.MustCompile(`(^|/)([0-9a-df-np-sv-z]{32})([-.].*)?$`)

	trustedCaches = []string{
		"https://cache.nixos.org",
	}
)

func hashExtract(s string) string {
	res := hashExtractRegexp.FindStringSubmatch(s)
	if len(res) == 0 {
		return ""
	}
	return res[2]
}

type state int

const (
	stateUnknown state = iota
	stateCheckingShouldUpload
	stateUploadingReferences
	stateWaitingForUploadSlot
	stateUploadingContent
	stateCopyingContent
	stateUploadingNarinfo
	stateSkipped
	stateFailed
	stateUploaded
	stateMax
)

func (s state) Terminal() bool {
	return s == stateSkipped || s == stateUploaded || s == stateFailed
}
func (s state) String() string {
	return map[state]string{
		stateUnknown:              "unknown",
		stateCheckingShouldUpload: "determining if upload required",
		stateUploadingReferences:  "uploading references",
		stateWaitingForUploadSlot: "waiting for upload slot",
		stateUploadingContent:     "uploading content",
		stateCopyingContent:       "copying content",
		stateUploadingNarinfo:     "uploading narinfo",
		stateSkipped:              "skipped",
		stateFailed:               "failed",
		stateUploaded:             "uploaded",
	}[s]
}

type stateInfo struct {
	Current state
	Since   time.Time
	Path    string
}

type stateTracker struct {
	mu        sync.Mutex
	pathState map[string]stateInfo
}

func (t *stateTracker) SetState(p string, s state) {
	si := stateInfo{
		Current: s,
		Since:   time.Now(),
		Path:    p,
	}
	t.mu.Lock()
	if t.pathState == nil {
		t.pathState = make(map[string]stateInfo)
	}
	t.pathState[p] = si
	t.mu.Unlock()
}

func (t *stateTracker) CurrentState() map[string]stateInfo {
	out := make(map[string]stateInfo, len(t.pathState))
	t.mu.Lock()
	for k, v := range t.pathState {
		out[k] = v
	}
	t.mu.Unlock()
	return out
}

func (t *stateTracker) StateSummary() string {
	states := t.CurrentState()

	countByState := map[state]int{}
	var oldestActive []stateInfo
	for _, s := range states {
		countByState[s.Current]++
		if !s.Current.Terminal() && s.Current != stateUploadingReferences {
			oldestActive = append(oldestActive, s)
		}
	}
	sort.Slice(oldestActive, func(i, j int) bool {
		a, b := oldestActive[i], oldestActive[j]
		return a.Since.Before(b.Since)
	})

	var firstLineBits []string
	for n := stateUnknown; n < stateMax; n++ {
		c := countByState[n]
		if c != 0 {
			firstLineBits = append(firstLineBits, fmt.Sprintf("%d %s", c, n))
		}
	}

	lines := []string{
		strings.Join(firstLineBits, ", "),
	}

	for n := 0; n < len(oldestActive) && n < 20; n++ {
		si := oldestActive[n]
		lines = append(lines, fmt.Sprintf("\t%s: %s (for %s)", si.Path, si.Current, time.Since(si.Since).Truncate(time.Second)))
	}

	return strings.Join(lines, "\n")
}

type uploader struct {
	bucket    *blob.Bucket
	store     nixstore.Store
	storePath string
	st        stateTracker

	uploadLimiter chan int

	uploadSF          singleflight.Group
	deepCheckGalactic bool // if true, don't skip if this item is already present; always check the references to make sure they exist too.
}

type byteCounterWriter struct{ n uint64 }

func (w *byteCounterWriter) Write(b []byte) (int, error) {
	w.n += uint64(len(b))
	return len(b), nil
}

func (u *uploader) inStore(ctx context.Context, path string) (bool, error) {
	// Check if the narinfo exists.
	key, err := keyForPath(path)
	if err != nil {
		return false, fmt.Errorf("computing narinfo key for %v: %w", path, err)
	}

	return u.bucket.Exists(ctx, key)
}

func (u *uploader) inTrustedCaches(ctx context.Context, path string) (bool, error) {
	key, err := keyForPath(path)
	if err != nil {
		return false, fmt.Errorf("computing narinfo key for %v: %w", path, err)
	}

	for _, c := range trustedCaches {
		req, err := http.NewRequestWithContext(ctx, "HEAD", fmt.Sprintf("%v/%v", c, key), nil)
		if err != nil {
			return false, fmt.Errorf("constructing request for %v/%v: %v", c, key, err)
		}

		resp, err := http.DefaultClient.Do(req)
		if err != nil {
			return false, fmt.Errorf("making request for %v/%v: %v", c, key, err)
		}
		resp.Body.Close()

		if resp.StatusCode == http.StatusOK {
			return true, nil
		}
	}

	return false, nil
}

func (u *uploader) shouldUpload(ctx context.Context, path string) (bool, error) {
	inStore, err := u.inStore(ctx, path)
	if err != nil {
		return false, err
	}
	if inStore {
		return false, nil
	}

	inTrustedCaches, err := u.inTrustedCaches(ctx, path)
	if err != nil {
		return false, err
	}
	if inTrustedCaches {
		return false, nil
	}

	return true, nil
}

func (u *uploader) uploadContent(ctx context.Context, ni *narinfo.NarInfo, path string, dst io.Writer) error {
	if !ni.NarHash.Valid() {
		return fmt.Errorf("nar hash for %v is not valid", path)
	}
	narHasher := ni.NarHash.Algorithm.Hash()
	fileHasher := ni.NarHash.Algorithm.Hash()

	fileByteCounter := &byteCounterWriter{}

	xzWriter := xz.NewWriter(io.MultiWriter(fileHasher, fileByteCounter, dst))

	w := io.MultiWriter(narHasher, xzWriter)
	narSize, err := nar.Pack(w, nar.DirFS(u.storePath), filepath.Base(path))
	if err != nil {
		return fmt.Errorf("packing %v as NAR: %w", path, err)
	}

	if err := xzWriter.Close(); err != nil {
		return fmt.Errorf("compressing with xz: %w", err)
	}

	// Check the NAR hash is correct.
	if uint64(narSize) != ni.NarSize {
		return fmt.Errorf("uploaded nar was %d bytes; expected %d bytes", narSize, ni.NarSize)
	}
	narHash := narinfo.Hash{
		Hash:      narHasher.Sum(nil),
		Algorithm: ni.NarHash.Algorithm,
	}
	if len(narHash.Hash) != len(ni.NarHash.Hash) {
		return fmt.Errorf("uploaded nar hash length was %d bytes; expected %d bytes", len(narHash.Hash), len(ni.NarHash.Hash))
	}
	if got, want := narHash.String(), ni.NarHash.String(); got != want {
		return fmt.Errorf("uploaded nar hash was %v; wanted %v", got, want)
	}

	ni.Compression = narinfo.CompressionXz
	ni.FileHash = narinfo.Hash{
		Hash:      fileHasher.Sum(nil),
		Algorithm: ni.NarHash.Algorithm,
	}
	ni.FileSize = fileByteCounter.n
	return nil
}

func keyForPath(storePath string) (string, error) {
	fileHash := hashExtract(storePath)
	if fileHash == "" {
		return "", fmt.Errorf("store path %v seems to be invalid: couldn't extract hash", storePath)
	}

	return fmt.Sprintf("%s.narinfo", fileHash), nil
}

func (u *uploader) uploadNARInfo(ctx context.Context, ni *narinfo.NarInfo) error {
	key, err := keyForPath(ni.StorePath)
	if err != nil {
		return err
	}
	return u.bucket.WriteAll(ctx, key, []byte(ni.String()), nil)
}

func (u *uploader) uploadRefs(ctx context.Context, current string, refs []string) error {
	if len(refs) == 0 {
		return nil
	}

	eg, egctx := errgroup.WithContext(ctx)

	for _, ref := range refs {
		refPath := filepath.Join(u.storePath, ref)
		if current == refPath {
			// We depend on ourselves, which is fine.
			continue
		}
		eg.Go(func() error {
			return u.Upload(egctx, refPath)
		})
	}

	return eg.Wait()
}

func (u *uploader) getUploadSlot() func() {
	slot := <-u.uploadLimiter
	return func() {
		u.uploadLimiter <- slot
	}
}

func (u *uploader) upload(ctx context.Context, path string) error {
	u.st.SetState(path, stateCheckingShouldUpload)

	shouldUploadThis, err := u.shouldUpload(ctx, path)
	if err != nil {
		u.st.SetState(path, stateFailed)
		return fmt.Errorf("determining if we should upload %v: %w", path, err)
	}
	if !shouldUploadThis && !u.deepCheckGalactic {
		u.st.SetState(path, stateSkipped)
		return nil
	}

	if shouldUploadThis {
		log.Printf("Uploading %v", path)
	}

	ni, err := u.store.NARInfo(path)
	if err != nil {
		u.st.SetState(path, stateFailed)
		return fmt.Errorf("getting narinfo for %v: %w", path, err)
	}

	u.st.SetState(path, stateUploadingReferences)
	if err := u.uploadRefs(ctx, ni.StorePath, ni.References); err != nil {
		u.st.SetState(path, stateFailed)
		return fmt.Errorf("uploading references for %v: %w", path, err)
	}

	if !shouldUploadThis {
		u.st.SetState(path, stateSkipped)
		return nil
	}

	u.st.SetState(path, stateWaitingForUploadSlot)
	defer u.getUploadSlot()()

	u.st.SetState(path, stateUploadingContent)
	if !ni.NarHash.Valid() {
		u.st.SetState(path, stateFailed)
		return fmt.Errorf("nar hash is invalid")
	}

	tmpPath := fmt.Sprintf("tmp-uploading/%s", filepath.Base(path))
	dst, err := u.bucket.NewWriter(ctx, tmpPath, nil)
	if err != nil {
		u.st.SetState(path, stateFailed)
		return fmt.Errorf("creating new writer for upload of %v: %w", path, err)
	}
	defer dst.Close()

	if err := u.uploadContent(ctx, ni, path, dst); err != nil {
		u.st.SetState(path, stateFailed)
		if err := dst.Close(); err == nil {
			u.bucket.Delete(ctx, tmpPath)
		}
		return err
	}

	if err := dst.Close(); err != nil {
		u.bucket.Delete(ctx, tmpPath)
		u.st.SetState(path, stateFailed)
		return fmt.Errorf("completing tmp write of %v: %w", path, err)
	}

	// Copy to the "correct" place.
	u.st.SetState(path, stateCopyingContent)
	finalDstKey := fmt.Sprintf("nar/%s.nar.xz", nixbase32.EncodeToString(ni.FileHash.Hash))
	if err := u.bucket.Copy(ctx, finalDstKey, tmpPath, nil); err != nil {
		u.bucket.Delete(ctx, tmpPath)
		u.st.SetState(path, stateFailed)
		return fmt.Errorf("copying tmp write of %v from %v to %v: %w", path, tmpPath, finalDstKey, err)
	}
	if err := u.bucket.Delete(ctx, tmpPath); err != nil {
		u.bucket.Delete(ctx, finalDstKey)
		u.st.SetState(path, stateFailed)
		return fmt.Errorf("cleaning up tmp write of %v at %v: %w", path, tmpPath, err)
	}

	ni.URL = finalDstKey

	u.st.SetState(path, stateUploadingNarinfo)
	if err := u.uploadNARInfo(ctx, ni); err != nil {
		u.bucket.Delete(ctx, finalDstKey)
		u.st.SetState(path, stateFailed)
		return fmt.Errorf("uploading narinfo for %v: %w", path, err)
	}

	u.st.SetState(path, stateUploaded)
	return nil
}

func (u *uploader) Upload(ctx context.Context, path string) error {
	resCh := u.uploadSF.DoChan(path, func() (any, error) {
		err := u.upload(ctx, path)
		if err != nil {
			log.Printf("Uploading %v: %v", path, err)
		}
		return nil, err
	})
	select {
	case <-ctx.Done():
		return ctx.Err()
	case res := <-resCh:
		return res.Err
	}
}

func main() {
	flag.Parse()

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	log.Printf("Using cache URL %q", *blobURLFlag)
	bucket, err := blob.OpenBucket(ctx, *blobURLFlag)
	if err != nil {
		log.Fatalf("opening bucket %q: %v", *blobURLFlag, err)
	}
	defer bucket.Close()

	store, err := nixstore.Open()
	if err != nil {
		log.Fatalf("opening Nix store: %v", err)
	}
	defer store.Close()

	u := &uploader{
		bucket:            bucket,
		store:             store,
		storePath:         "/nix/store",
		deepCheckGalactic: *deepCheckGalacticFlag,
		uploadLimiter:     make(chan int, *maximumConcurrentUploads),
	}
	for n := 0; n < *maximumConcurrentUploads; n++ {
		u.uploadLimiter <- n
	}

	go func() {
		t := time.NewTicker(*stateSummaryIntervalFlag)
		defer t.Stop()
		for {
			select {
			case <-t.C:
				log.Print(u.st.StateSummary())
			}
		}
	}()

	for _, p := range flag.Args() {
		realPath, err := os.Readlink(p)
		if err != nil {
			log.Fatalf("Readlink(%q): %v", p, err)
		}

		if err := u.Upload(ctx, realPath); err != nil {
			log.Fatalf("upload(%q): %v", p, err)
		}
	}
}