// SPDX-FileCopyrightText: 2023 Luke Granger-Brown <depot@lukegb.com>
//
// SPDX-License-Identifier: Apache-2.0

package nar

import (
	"bytes"
	"fmt"
	"io"
	"strings"

	"hg.lukegb.com/lukegb/depot/go/nix/nixwire"
)

type WriteFile interface {
	io.WriteCloser
	MakeExecutable() error
}

type WriteFS interface {
	Create(name string) (WriteFile, error)
	Mkdir(name string) (WriteFS, error)
	Symlink(name, target string) error
}

type NullFS struct{}

func (NullFS) Close() error                             { return nil }
func (NullFS) MakeExecutable() error                    { return nil }
func (NullFS) Write(p []byte) (n int, err error)        { return len(p), nil }
func (f NullFS) Create(name string) (WriteFile, error)  { return f, nil }
func (f NullFS) Mkdir(name string) (WriteFS, error)     { return f, nil }
func (NullFS) Symlink(name string, target string) error { return nil }

func unpackFile(dw nixwire.Deserializer, fs WriteFS, fn string) error {
	s, err := dw.ReadString()
	if err != nil {
		return fmt.Errorf("reading file tuple open: %w", err)
	}
	if s != "(" {
		return fmt.Errorf("expected file tuple open '('; got %q", s)
	}

	var fileType string
	var regularFile WriteFile
	var directoryFS WriteFS
	var directoryLastName string
parseLoop:
	for {
		s, err := dw.ReadString()
		if err != nil {
			return fmt.Errorf("reading next file tuple part: %w", err)
		}
		switch {
		case s == ")":
			break parseLoop
		case s == "type":
			if fileType != "" {
				return fmt.Errorf("multiple file types for %s", fileType)
			}

			t, err := dw.ReadString()
			if err != nil {
				return fmt.Errorf("reading file type: %w", err)
			}
			switch t {
			case "regular":
				regularFile, err = fs.Create(fn)
				if err != nil {
					return fmt.Errorf("creating file %v: %w", fn, err)
				}
			case "directory":
				directoryFS, err = fs.Mkdir(fn)
				if err != nil {
					return fmt.Errorf("creating directory %v: %w", fn, err)
				}
			case "symlink":
				// Do nothing until we have a target.
			default:
				return fmt.Errorf("bad file type %q; want regular/directory/symlink", t)
			}
			fileType = t
		case s == "contents" && fileType == "regular":
			size, err := dw.ReadUint64()
			if err != nil {
				return fmt.Errorf("reading file size: %w", err)
			}
			if _, err := io.CopyN(regularFile, dw.Reader, int64(size)); err != nil {
				return fmt.Errorf("reading file content: %w", err)
			}
			if err := dw.ReadPadding(size); err != nil {
				return fmt.Errorf("reading file padding: %w", err)
			}
		case s == "executable" && fileType == "regular":
			v, err := dw.ReadString()
			if err != nil {
				return fmt.Errorf("reading executable emptystring: %w", err)
			}
			if v != "" {
				return fmt.Errorf("executable emptystring is not empty: %v", v)
			}
			if err := regularFile.MakeExecutable(); err != nil {
				return fmt.Errorf("making target file %v executable: %w", fn, err)
			}
		case s == "entry" && fileType == "directory":
			// TODO
			dirOpen, err := dw.ReadString()
			if err != nil {
				return fmt.Errorf("reading dirent open: %w", err)
			}
			if dirOpen != "(" {
				return fmt.Errorf("dirent open was %q; wanted (", dirOpen)
			}
			var name string
		direntLoop:
			for {
				s, err := dw.ReadString()
				if err != nil {
					return fmt.Errorf("reading next dirent tuple part: %w", err)
				}
				switch s {
				case ")":
					break direntLoop
				case "name":
					name, err = dw.ReadString()
					if err != nil {
						return fmt.Errorf("reading dirent name: %w", err)
					}
					switch {
					case name == "", name == ".", name == "..", strings.Contains(name, "/"), bytes.Contains([]byte(name), []byte{0}):
						return fmt.Errorf("invalid file name %q", name)
					case name <= directoryLastName:
						return fmt.Errorf("NAR not sorted (%q came after %q)", name, directoryLastName)
					}
					directoryLastName = name
				case "node":
					if name == "" {
						return fmt.Errorf("entry name must come before node")
					}
					if err := unpackFile(dw, directoryFS, name); err != nil {
						return fmt.Errorf("%s: %w", fn, err)
					}
				default:
					return fmt.Errorf("%s: unknown field %s", fn, s)
				}
			}
		case s == "target" && fileType == "symlink":
			target, err := dw.ReadString()
			if err != nil {
				return fmt.Errorf("reading symlink target: %w", err)
			}
			if err := fs.Symlink(fn, target); err != nil {
				return fmt.Errorf("creating symlink %q to %q: %w", fn, target, err)
			}
		default:
			return fmt.Errorf("unknown field %v (filetype %q, filename %q)", s, fileType, fn)
		}
	}

	if regularFile != nil {
		if err := regularFile.Close(); err != nil {
			return fmt.Errorf("closing file: %w", err)
		}
	}
	return nil
}

func Unpack(r io.Reader, fs WriteFS, fn string) error {
	dw := nixwire.Deserializer{Reader: r}

	s, err := dw.ReadString()
	if err != nil {
		return fmt.Errorf("reading magic: %w", err)
	}
	if want := "nix-archive-1"; s != want {
		return fmt.Errorf("invalid NAR magic %q (wanted %q)", s, want)
	}

	return unpackFile(dw, fs, fn)
}