// SPDX-FileCopyrightText: 2023 Luke Granger-Brown <depot@lukegb.com>
//
// SPDX-License-Identifier: Apache-2.0

// Package nixdrv holds basic data types for describing Nix derivations.
package nixdrv

import (
	"bufio"
	"bytes"
	"context"
	"fmt"
	"path"
	"sort"
	"strings"

	"hg.lukegb.com/lukegb/depot/go/nix/nixhash"
)

type Output struct {
	Path          string
	HashAlgorithm string
	Hash          string
}

type InputDerivation struct {
	Path    string
	Outputs []string
}

type BasicDerivation struct {
	Outputs   map[string]Output
	InputSrcs map[string]bool // PathSet
	Platform  string
	Builder   string
	Args      []string
	Env       map[string]string
}

type Derivation struct {
	BasicDerivation

	InputDerivations []InputDerivation
}

func (drv *BasicDerivation) RequiredSystemFeatures() map[string]bool {
	rsfs := drv.Env["requiredSystemFeatures"]
	if rsfs == "" {
		return nil
	}
	out := map[string]bool{}
	for _, rsf := range strings.Fields(rsfs) {
		out[rsf] = true
	}
	return out
}

func (drv *BasicDerivation) Clone() *BasicDerivation {
	o := &BasicDerivation{
		Outputs:   map[string]Output{},
		InputSrcs: map[string]bool{},
		Platform:  drv.Platform,
		Builder:   drv.Builder,
		Args:      drv.Args,
		Env:       map[string]string{},
	}
	for k, v := range drv.Outputs {
		o.Outputs[k] = v
	}
	for k, v := range drv.InputSrcs {
		o.InputSrcs[k] = v
	}
	for k, v := range drv.Env {
		o.Env[k] = v
	}
	return o
}

func (drv *Derivation) Clone() *Derivation {
	return &Derivation{
		BasicDerivation:  *drv.BasicDerivation.Clone(),
		InputDerivations: drv.InputDerivations,
	}
}

type Resolver interface {
	LoadDerivation(ctx context.Context, path string) (*Derivation, error)
}

func (drv *Derivation) ToBasicDerivation(ctx context.Context, r Resolver) (*BasicDerivation, error) {
	o := drv.BasicDerivation.Clone()
	for _, inp := range drv.InputDerivations {
		inpDrv, err := r.LoadDerivation(ctx, inp.Path)
		if err != nil {
			return nil, fmt.Errorf("resolving %v: %w", inp.Path, err)
		}

		for _, inpOut := range inp.Outputs {
			inpDrvOut, ok := inpDrv.Outputs[inpOut]
			if !ok {
				return nil, fmt.Errorf("input derivation %v has no output %v", inp.Path, inpOut)
			}

			o.InputSrcs[inpDrvOut.Path] = true
		}
	}
	return o, nil
}

func (drv *Derivation) StoreHash(name, storeRoot string) string {
	var references []string
	for ref := range drv.InputSrcs {
		references = append(references, ref)
	}
	for _, ref := range drv.InputDerivations {
		references = append(references, ref.Path)
	}
	return nixhash.StorePathForText(name+".drv", drv.String(), storeRoot, references)
}

func (drv *Derivation) StorePath(name, storeRoot string) string {
	return path.Join(storeRoot, fmt.Sprintf("%s-%s.drv", drv.StoreHash(name, storeRoot), name))
}

func expect(f *bufio.Reader, s string) error {
	buf := make([]byte, len(s))
	_, err := f.Read(buf)
	if err != nil {
		return err
	}
	if !bytes.Equal(buf, []byte(s)) {
		return fmt.Errorf("expected %q; got %q", s, string(buf))
	}
	return nil
}

func expectChar(f *bufio.Reader, c byte) error {
	b, err := f.ReadByte()
	switch {
	case err != nil:
		return fmt.Errorf("reading char: %w", err)
	case b != c:
		return fmt.Errorf("expected char '%c', got '%c'", c, b)
	}
	return nil
}

func readList[T any](f *bufio.Reader, readElement func(f *bufio.Reader) (T, error)) ([]T, error) {
	if err := expectChar(f, '['); err != nil {
		return nil, fmt.Errorf("list open: %w", err)
	}
	b, err := f.ReadByte()
	if err != nil {
		return nil, fmt.Errorf("list first char: %w", err)
	}
	if b == ']' {
		// Empty list.
		return nil, nil
	}
	if err := f.UnreadByte(); err != nil {
		return nil, fmt.Errorf("unreading: %w", err)
	}
	var l []T
loop:
	for {
		t, err := readElement(f)
		if err != nil {
			return nil, fmt.Errorf("list element: %w", err)
		}
		l = append(l, t)

		b, err := f.ReadByte()
		switch {
		case err != nil:
			return nil, fmt.Errorf("reading list element or end of list: %w", err)
		case b == ']':
			break loop
		case b == ',':
			continue
		default:
			return nil, fmt.Errorf("expecting list element or end of list: got '%c'", b)
		}
	}
	return l, nil
}

func readStrings(f *bufio.Reader) ([]string, error) {
	return readList(f, readString)
}

func readString(f *bufio.Reader) (string, error) {
	if err := expectChar(f, '"'); err != nil {
		return "", fmt.Errorf("expected string opening '\"': %w", err)
	}
	var s strings.Builder
loop:
	for {
		b, err := f.ReadByte()
		if err != nil {
			return "", fmt.Errorf("reading string: %w", err)
		}
		switch b {
		case '"':
			break loop
		case '\\':
			b, err := f.ReadByte()
			if err != nil {
				return "", fmt.Errorf("reading string: %w", err)
			}
			switch b {
			case 'n':
				s.WriteByte('\n')
			case 'r':
				s.WriteByte('\r')
			case 't':
				s.WriteByte('\t')
			default:
				s.WriteByte(b)
			}
			continue
		}
		s.WriteByte(b)
	}
	return s.String(), nil
}

func writeTuple(sb *strings.Builder, ss ...string) {
	sb.WriteByte('(')
	if len(ss) >= 1 {
		writeString(sb, ss[0])
	}
	for _, s := range ss[1:] {
		sb.WriteByte(',')
		writeString(sb, s)
	}
	sb.WriteByte(')')
}

func writeList[T any](sb *strings.Builder, elements []T, writeElement func(*strings.Builder, T)) {
	sb.WriteByte('[')
	if len(elements) >= 1 {
		writeElement(sb, elements[0])
	}
	for _, e := range elements[1:] {
		sb.WriteByte(',')
		writeElement(sb, e)
	}
	sb.WriteByte(']')
}

var (
	escaperReplacer = strings.NewReplacer("\"", "\\\"",
		"\\", "\\\\",
		"\n", "\\n",
		"\r", "\\r",
		"\t", "\\t")
)

func escapeString(s string) string {
	return "\"" + escaperReplacer.Replace(s) + "\""
}

func writeUnescapedString(sb *strings.Builder, s string) {
	sb.WriteString(escapeString(s))
}

func writeString(sb *strings.Builder, s string) {
	sb.WriteString(s)
}

func (d *Derivation) String() string {
	var sb strings.Builder
	sb.WriteString("Derive(")

	var outputKeys []string
	for key := range d.Outputs {
		outputKeys = append(outputKeys, key)
	}
	sort.Strings(outputKeys)
	writeList(&sb, outputKeys, func(sb *strings.Builder, k string) {
		v := d.Outputs[k]
		writeTuple(sb, escapeString(k), escapeString(v.Path), escapeString(v.HashAlgorithm), escapeString(v.Hash))
	})
	sb.WriteByte(',')

	writeList(&sb, d.InputDerivations, func(sb *strings.Builder, drv InputDerivation) {
		var outputsBuilder strings.Builder
		writeList(&outputsBuilder, drv.Outputs, writeUnescapedString)
		writeTuple(sb, escapeString(drv.Path), outputsBuilder.String())
	})
	sb.WriteByte(',')

	var inputSrcs []string
	for is := range d.InputSrcs {
		inputSrcs = append(inputSrcs, is)
	}
	sort.Strings(inputSrcs)
	writeList(&sb, inputSrcs, writeUnescapedString)
	sb.WriteByte(',')

	writeUnescapedString(&sb, d.Platform)
	sb.WriteByte(',')

	writeUnescapedString(&sb, d.Builder)
	sb.WriteByte(',')

	writeList(&sb, d.Args, writeUnescapedString)
	sb.WriteByte(',')

	var env []string
	for k := range d.Env {
		env = append(env, k)
	}
	sort.Strings(env)
	writeList(&sb, env, func(sb *strings.Builder, k string) {
		v := d.Env[k]
		writeTuple(sb, escapeString(k), escapeString(v))
	})

	sb.WriteByte(')') // Derive(
	return sb.String()
}

func Load(f *bufio.Reader) (*Derivation, error) {
	drv := &Derivation{}

	if err := expect(f, "Derive("); err != nil {
		return nil, err
	}

	type outputWithName struct {
		Name string
		Output
	}
	outputList, err := readList(f, func(f *bufio.Reader) (outputWithName, error) {
		var (
			empT, o outputWithName
			err     error
		)
		if err := expectChar(f, '('); err != nil {
			return empT, fmt.Errorf("output open: %w", err)
		}

		if o.Name, err = readString(f); err != nil {
			return empT, fmt.Errorf("output name: %w", err)
		}
		if err := expectChar(f, ','); err != nil {
			return empT, fmt.Errorf(", after output name: %w", err)
		}

		if o.Path, err = readString(f); err != nil {
			return empT, fmt.Errorf("output path: %w", err)
		}
		if err := expectChar(f, ','); err != nil {
			return empT, fmt.Errorf(", after output path: %w", err)
		}

		if o.HashAlgorithm, err = readString(f); err != nil {
			return empT, fmt.Errorf("output hash algorithm: %w", err)
		}
		if err := expectChar(f, ','); err != nil {
			return empT, fmt.Errorf(", after output hash algorithm: %w", err)
		}

		if o.Hash, err = readString(f); err != nil {
			return empT, fmt.Errorf("output hash: %w", err)
		}
		if err := expectChar(f, ')'); err != nil {
			return empT, fmt.Errorf("output close: %w", err)
		}
		return o, nil
	})
	if err != nil {
		return nil, fmt.Errorf("reading outputs: %w", err)
	}
	drv.Outputs = make(map[string]Output)
	for _, own := range outputList {
		drv.Outputs[own.Name] = own.Output
	}
	if err := expectChar(f, ','); err != nil {
		return nil, fmt.Errorf(", after outputs list: %w", err)
	}

	drv.InputDerivations, err = readList(f, func(f *bufio.Reader) (InputDerivation, error) {
		var (
			empT, id InputDerivation
			err      error
		)
		if err := expectChar(f, '('); err != nil {
			return empT, fmt.Errorf("input derivation open: %w", err)
		}

		if id.Path, err = readString(f); err != nil {
			return empT, fmt.Errorf("input derivation path: %w", err)
		}
		if err := expectChar(f, ','); err != nil {
			return empT, fmt.Errorf(", after input derivation path: %w", err)
		}

		if id.Outputs, err = readStrings(f); err != nil {
			return empT, fmt.Errorf("input derivation outputs: %w", err)
		}
		if err := expectChar(f, ')'); err != nil {
			return empT, fmt.Errorf("input derivation close: %w", err)
		}
		return id, nil
	})
	if err != nil {
		return nil, err
	}
	if err := expectChar(f, ','); err != nil {
		return nil, fmt.Errorf(", after input derivations list: %w", err)
	}

	inputSrcsList, err := readStrings(f)
	if err != nil {
		return nil, fmt.Errorf("reading input srcs list: %w", err)
	}
	drv.InputSrcs = make(map[string]bool)
	for _, k := range inputSrcsList {
		drv.InputSrcs[k] = true
	}
	if err := expectChar(f, ','); err != nil {
		return nil, fmt.Errorf(", after input srcs list: %w", err)
	}

	if drv.Platform, err = readString(f); err != nil {
		return nil, fmt.Errorf("reading platform: %w", err)
	}
	if err := expectChar(f, ','); err != nil {
		return nil, fmt.Errorf(", after platform: %w", err)
	}

	if drv.Builder, err = readString(f); err != nil {
		return nil, fmt.Errorf("reading builder: %w", err)
	}
	if err := expectChar(f, ','); err != nil {
		return nil, fmt.Errorf(", after builder: %w", err)
	}

	if drv.Args, err = readStrings(f); err != nil {
		return nil, fmt.Errorf("reading builder args: %w", err)
	}
	if err := expectChar(f, ','); err != nil {
		return nil, fmt.Errorf(", after builder args: %w", err)
	}

	type keyValue struct {
		key, value string
	}
	envList, err := readList(f, func(f *bufio.Reader) (keyValue, error) {
		var (
			empT, kv keyValue
			err      error
		)
		if err := expectChar(f, '('); err != nil {
			return empT, fmt.Errorf("key value open: %w", err)
		}

		if kv.key, err = readString(f); err != nil {
			return empT, fmt.Errorf("key: %w", err)
		}
		if err := expectChar(f, ','); err != nil {
			return empT, fmt.Errorf(", after key: %w", err)
		}

		if kv.value, err = readString(f); err != nil {
			return empT, fmt.Errorf("value: %w", err)
		}
		if err := expectChar(f, ')'); err != nil {
			return empT, fmt.Errorf("key value close: %w", err)
		}
		return kv, nil
	})
	if err != nil {
		return nil, err
	}
	drv.Env = make(map[string]string)
	for _, kv := range envList {
		drv.Env[kv.key] = kv.value
	}

	if err := expectChar(f, ')'); err != nil {
		return nil, fmt.Errorf("closing brace: %w", err)
	}

	return drv, nil
}