package nar

import (
	"fmt"
	"io"
	"io/fs"
	"path"
	"sort"

	"hg.lukegb.com/lukegb/depot/go/nix/nixwire"
)

type FS interface {
	Open(string) (fs.File, error)
	Stat(string) (fs.FileInfo, error)
	Lstat(string) (fs.FileInfo, error)
	Readlink(string) (string, error)
}

func packFile(sw nixwire.Serializer, root FS, fn string, stat fs.FileInfo) (int64, error) {

	var nSoFar int64
	write := func(data ...any) (int64, error) {
		for _, datum := range data {
			var n int64
			var err error
			switch datum := datum.(type) {
			case string:
				n, err = sw.WriteString(datum)
			case uint64:
				n, err = sw.WriteUint64(datum)
			default:
				return nSoFar, fmt.Errorf("unknown data type %T (%s)", datum, err)
			}
			if err != nil {
				return nSoFar + n, err
			}
			nSoFar += n
		}
		return 0, nil
	}

	if n, err := write("("); err != nil {
		return n, err
	}

	switch {
	case stat.Mode()&fs.ModeDir != 0:
		// Directory.
		if n, err := write("type", "directory"); err != nil {
			return n, err
		}

		f, err := root.Open(fn)
		if err != nil {
			return 0, err
		}
		defer f.Close()

		dirF, ok := f.(fs.ReadDirFile)
		if !ok {
			return nSoFar, fmt.Errorf("%v didn't get me a ReadDirFile", fn)
		}

		dents, err := dirF.ReadDir(-1)
		if err != nil {
			return nSoFar, fmt.Errorf("reading dents from %v: %w", fn, err)
		}

		sort.Slice(dents, func(i, j int) bool {
			return dents[i].Name() < dents[j].Name()
		})

		for _, dent := range dents {
			if n, err := write("entry", "(", "name", dent.Name(), "node"); err != nil {
				return n, err
			}

			dentStat, err := dent.Info()
			if err != nil {
				return nSoFar, fmt.Errorf("stat for %v: %w", path.Join(fn, dent.Name()), err)
			}

			n, err := packFile(sw, root, path.Join(fn, dent.Name()), dentStat)
			if err != nil {
				return nSoFar + n, err
			}
			nSoFar += n

			if n, err := write(")"); err != nil {
				return n, err
			}
		}
	case stat.Mode()&fs.ModeSymlink != 0:
		// Symlink.
		target, err := root.Readlink(fn)
		if err != nil {
			return nSoFar, err
		}

		if n, err := write("type", "symlink", "target", target); err != nil {
			return n, err
		}
	case stat.Mode().Type() != 0:
		return 0, fmt.Errorf("not implemented (other: %s)", stat.Mode())
	default:
		// Regular file.
		if n, err := write("type", "regular"); err != nil {
			return n, err
		}
		if stat.Mode()&0o100 != 0 {
			// Executable.
			if n, err := write("executable", ""); err != nil {
				return n, err
			}
		}

		if n, err := write("contents", uint64(stat.Size())); err != nil {
			return n, err
		}

		f, err := root.Open(fn)
		if err != nil {
			return 0, err
		}
		defer f.Close()

		wrote, err := io.Copy(sw, f)
		if err != nil {
			return nSoFar + wrote, err
		}
		nSoFar += wrote

		n, err := sw.WritePadding(wrote)
		if err != nil {
			return nSoFar + n, err
		}
		nSoFar += n
	}

	if n, err := write(")"); err != nil {
		return n, err
	}

	return nSoFar, nil
}

func Pack(w io.Writer, fs FS, fn string) (int64, error) {
	sw := nixwire.Serializer{w}

	n, err := sw.WriteString("nix-archive-1")
	if err != nil {
		return n, err
	}

	stat, err := fs.Lstat(fn)
	if err != nil {
		return n, fmt.Errorf("lstat(%q): %w", fn, err)
	}

	npf, err := packFile(sw, fs, fn, stat)
	return npf + n, err
}