// Binary bcachegc garbage collects a Nix binary cache.
package main

import (
	"bufio"
	"context"
	"flag"
	"fmt"
	"io"
	"log"
	"net/http"
	"os"
	"regexp"
	"strings"
	"sync"
	"time"

	"gocloud.dev/blob"
	"golang.org/x/sync/errgroup"
	"golang.org/x/sync/singleflight"
	"hg.lukegb.com/lukegb/depot/go/nix/nar/narinfo"

	_ "gocloud.dev/blob/gcsblob"
	_ "hg.lukegb.com/lukegb/depot/go/vault/vaultgcsblob"
)

var (
	blobURLFlag          = flag.String("cache_url", "", "Cache URL")
	rootsFlag            = flag.String("gc_roots_file", "", "Path to file containing GC roots")
	dryRunFlag           = flag.Bool("dry_run", true, "If true, don't actually delete anything")
	trustPublicCacheFlag = flag.Bool("trust_public_cache", true, "If true, also remove things which are present in public caches")
)

var (
	hashExtractRegexp = regexp.MustCompile(`(^|/)([0-9a-df-np-sv-z]{32})([-.].*)?$`)

	publicCaches = []string{
		"https://cache.nixos.org",
	}
)

func hashExtract(s string) string {
	res := hashExtractRegexp.FindStringSubmatch(s)
	if len(res) == 0 {
		return ""
	}
	return res[2]
}

func loadGCRoots(ctx context.Context, rootsFilePath string) (map[string]bool, error) {
	r := map[string]bool{}
	f, err := os.Open(rootsFilePath)
	if err != nil {
		return nil, fmt.Errorf("opening file: %w", err)
	}
	defer f.Close()

	s := bufio.NewScanner(f)
	for s.Scan() {
		if h := hashExtract(strings.TrimSpace(s.Text())); h != "" {
			r[h] = true
		}
	}
	if err := s.Err(); err != nil {
		return nil, fmt.Errorf("scanning: %w", err)
	}

	return r, nil
}

type gcer struct {
	bucket *blob.Bucket

	publicCaches   []string
	publicCacheSF  singleflight.Group
	publicCacheMap sync.Map
}

func (g *gcer) enumerateBucket(ctx context.Context) (narInfoFilenames, narFilenames map[string]bool, err error) {
	iter := g.bucket.List(nil)
	narInfoFilenames = map[string]bool{}
	narFilenames = map[string]bool{}
iterLoop:
	for {
		obj, err := iter.Next(ctx)
		switch {
		case err == io.EOF:
			break iterLoop
		case err != nil:
			return nil, nil, fmt.Errorf("iterating files: %w", err)
		}
		if strings.HasSuffix(obj.Key, ".narinfo") {
			narInfoFilenames[obj.Key] = true
		}
		if strings.HasPrefix(obj.Key, "nar/") {
			narFilenames[obj.Key] = true
		}
	}
	return narInfoFilenames, narFilenames, nil
}

func (g *gcer) isInPublicCache(ctx context.Context, narinfoHash string) (bool, error) {
	r, ok := g.publicCacheMap.Load(narinfoHash)
	if ok {
		return r.(bool), nil
	}
	r, err, _ := g.publicCacheSF.Do(narinfoHash, func() (interface{}, error) {
		for _, c := range g.publicCaches {
			req, err := http.NewRequestWithContext(ctx, "HEAD", fmt.Sprintf("%v/%v.narinfo", c, narinfoHash), nil)
			if err != nil {
				return false, fmt.Errorf("constructing request for %v/%v.narinfo: %v", c, narinfoHash, err)
			}

			resp, err := http.DefaultClient.Do(req)
			if err != nil {
				return false, fmt.Errorf("making request for %v/%v.narinfo: %v", c, narinfoHash, err)
			}
			resp.Body.Close()

			if resp.StatusCode == http.StatusOK {
				g.publicCacheMap.Store(narinfoHash, true)
				return true, nil
			}
		}
		g.publicCacheMap.Store(narinfoHash, false)
		return false, nil
	})
	return r.(bool), err
}

func (g *gcer) expandRoots(ctx context.Context, roots, narInfoFilenames map[string]bool) (keepNarInfos map[string]bool, keepNars map[string]bool, err error) {
	const workerCount = 20
	keepNarInfos = map[string]bool{}
	keepNars = map[string]bool{}

	eg, egctx := errgroup.WithContext(ctx)
	var running sync.WaitGroup

	// Deduplicating and reenqueueing worker
	foundRootCh := make(chan string)
	keepNarCh := make(chan string)
	workCh := make(chan string, len(narInfoFilenames))
	eg.Go(func() error {
		for {
			select {
			case <-egctx.Done():
				return egctx.Err()
			case f, ok := <-keepNarCh:
				if !ok {
					return nil
				}
				keepNars[f] = true
			}
		}
	})
	eg.Go(func() error {
		defer close(workCh)
		defer close(keepNarCh)
		for {
			select {
			case <-egctx.Done():
				return egctx.Err()
			case r, ok := <-foundRootCh:
				if !ok {
					return nil
				}
				if keepNarInfos[r] {
					running.Done()
					continue
				}
				keepNarInfos[r] = true
				workCh <- r
			}
		}
	})

	// Workers
	processWork := func(w string) error {
		defer running.Done()

		// Check if it exists in the set of known narInfoFilenames.
		filename := fmt.Sprintf("%s.narinfo", w)
		if !narInfoFilenames[filename] {
			return fmt.Errorf("no known narinfo %v", filename)
		}

		// Great, parse it.
		r, err := g.bucket.NewReader(egctx, filename, nil)
		if err != nil {
			return fmt.Errorf("opening narinfo %v: %w", filename, err)
		}
		defer r.Close()

		ni, err := narinfo.ParseNarInfo(r)
		if err != nil {
			return fmt.Errorf("parsing narinfo: %w", err)
		}
		keepNarCh <- ni.URL
		for _, ref := range ni.References {
			refHash := hashExtract(ref)
			if exists, err := g.isInPublicCache(egctx, refHash); err != nil {
				return fmt.Errorf("checking if %v is in public cache: %v", refHash, err)
			} else if !exists {
				running.Add(1)
				foundRootCh <- refHash
			}
		}

		return nil
	}
	for n := 0; n < workerCount; n++ {
		eg.Go(func() error {
			for {
				select {
				case <-egctx.Done():
					return egctx.Err()
				case w, ok := <-workCh:
					if !ok {
						return nil
					}
					if err := processWork(w); err != nil {
						return fmt.Errorf("working on %q: %w", w, err)
					}
				}
			}
		})
	}

	// Enqueue work!
	running.Add(1)
	go func() {
		defer running.Done()
		for k := range roots {
			running.Add(1)
			select {
			case foundRootCh <- k:
			case <-egctx.Done():
				running.Done()
				log.Printf("aborting enqueueing %v", egctx.Err())
				return
			}
		}
		log.Printf("Enqueued %d roots", len(roots))
	}()

	// Once there's no work remaining to be done, close the foundRootCh.
	go func() {
		running.Wait()
		log.Println("Work queues drained, shutting down")
		close(foundRootCh)
	}()

	if err := eg.Wait(); err != nil {
		return nil, nil, err
	}
	return keepNarInfos, keepNars, nil
}

func except(src map[string]bool, other map[string]bool) map[string]bool {
	out := make(map[string]bool)
	for k := range src {
		out[k] = true
	}
	for k := range other {
		delete(out, k)
	}
	return out
}

func (g *gcer) deleteFiles(ctx context.Context, files map[string]bool) error {
	const workerCount = 64
	deleteCh := make(chan string, workerCount)

	eg, egctx := errgroup.WithContext(ctx)
	for n := 0; n < workerCount; n++ {
		eg.Go(func() error {
			ctx := egctx
			for f := range deleteCh {
				if err := g.bucket.Delete(ctx, f); err != nil {
					return err
				}
			}
			return nil
		})
	}

	start := time.Now()
	n := len(files)
queueLoop:
	for f := range files {
		deletedSoFar := len(files) - n
		if n%1000 == 0 && deletedSoFar > 0 {
			deletedInSeconds := time.Since(start).Seconds()
			secondsPerFile := deletedInSeconds / float64(deletedSoFar)
			estimatedTimeRemaining := time.Duration(float64(n) * secondsPerFile * float64(time.Second))
			log.Printf("%d files left (guesstimating %s for this batch of files)...", n, estimatedTimeRemaining)
		}
		select {
		case <-egctx.Done():
			break queueLoop
		case deleteCh <- f:
		}
		n--
	}
	close(deleteCh)
	log.Printf("Done enqueueing %d files for deleting in %s.", len(files), time.Since(start))

	if err := eg.Wait(); err != nil {
		return err
	}
	log.Printf("Done deleting %d files in %s.", len(files), time.Since(start))
	return nil
}

func (g *gcer) Run(ctx context.Context, roots map[string]bool) error {
	log.Println("Iterating through bucket")
	start := time.Now()
	narInfoFilenames, narFilenames, err := g.enumerateBucket(ctx)
	if err != nil {
		return fmt.Errorf("enumerateBucket: %w", err)
	}
	log.Printf("Done loading bucket data; found %d narinfo files and %d nar files in %v", len(narInfoFilenames), len(narFilenames), time.Since(start))

	// Expand gcroots to all covered NAR infos.
	// Error if gcroots don't have a .narinfo in scope.
	log.Println("Expanding gcroots")
	roots, nars, err := g.expandRoots(ctx, roots, narInfoFilenames)
	if err != nil {
		return fmt.Errorf("expandRoots: %w", err)
	}
	log.Printf("Got %d narinfos to keep, with %d nars", len(roots), len(nars))

	// Compute what we're _deleting_.
	deleteNarInfos := map[string]bool{}
	for k := range roots {
		deleteNarInfos[k+".narinfo"] = true
	}
	deleteNarInfos = except(narInfoFilenames, deleteNarInfos)
	deleteNars := except(narFilenames, nars)
	log.Printf("Computed %d narinfos and %d nars to delete", len(deleteNarInfos), len(deleteNars))

	if *dryRunFlag {
		return nil
	}

	// Delete all the narinfos first, then the NARs. This avoids having narinfos with no actual NAR behind them.
	log.Printf("Deleting %d narinfos...", len(deleteNarInfos))
	if err := g.deleteFiles(ctx, deleteNarInfos); err != nil {
		return fmt.Errorf("deleting narinfos: %w", err)
	}

	log.Printf("Deleting %d nars...", len(deleteNars))
	if err := g.deleteFiles(ctx, deleteNars); err != nil {
		return fmt.Errorf("deleting nars: %w", err)
	}

	return nil
}

func main() {
	flag.Parse()

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	log.Printf("Using cache URL %q", *blobURLFlag)
	bucket, err := blob.OpenBucket(ctx, *blobURLFlag)
	if err != nil {
		log.Fatalf("opening bucket %q: %v", *blobURLFlag, err)
	}
	defer bucket.Close()

	roots, err := loadGCRoots(ctx, *rootsFlag)
	if err != nil {
		log.Fatalf("loading GC roots from %q: %v", *rootsFlag, err)
	}
	log.Printf("Using %d roots from %v", len(roots), *rootsFlag)

	g := &gcer{
		bucket: bucket,
	}
	if *trustPublicCacheFlag {
		g.publicCaches = publicCaches
	}
	if err := g.Run(ctx, roots); err != nil {
		log.Fatalf("running GC: %v", err)
	}
}