package nixwire

import (
	"encoding/binary"
	"fmt"
	"io"

	"hg.lukegb.com/lukegb/depot/go/nix/nixdrv"
)

type Serializer struct {
	io.Writer
}

func (w Serializer) WritePadding(n int64) (int64, error) {
	if n%8 > 0 {
		n, err := w.Write(make([]byte, 8-(n%8)))
		return int64(n), err
	}
	return 0, nil
}

func (w Serializer) WriteUint64(n uint64) (int64, error) {
	buf := make([]byte, 8)
	binary.LittleEndian.PutUint64(buf, n)
	wrote, err := w.Write(buf)
	return int64(wrote), err
}

func (w Serializer) WriteBytes(bs []byte) (int64, error) {
	nSize, err := w.WriteUint64(uint64(len(bs)))
	if err != nil {
		return int64(nSize), err
	}

	nData, err := w.Write(bs)
	if err != nil {
		return int64(nSize) + int64(nData), err
	}

	nPad, err := w.WritePadding(int64(len(bs)))
	return int64(nSize) + int64(nData) + int64(nPad), err
}

func (w Serializer) WriteString(s string) (int64, error) {
	return w.WriteBytes([]byte(s))
}

func (w Serializer) WriteStrings(ss []string) (int64, error) {
	nLen, err := w.WriteUint64(uint64(len(ss)))
	if err != nil {
		return nLen, err
	}
	var nTotal int64
	for _, s := range ss {
		nStr, err := w.WriteString(s)
		if err != nil {
			return nLen + nTotal, err
		}
		nTotal += nStr
	}
	return nLen + nTotal, nil
}

func (w Serializer) WriteStringSet(s map[string]bool) (int64, error) {
	ss := make([]string, 0, len(s))
	for k := range s {
		ss = append(ss, k)
	}
	return w.WriteStrings(ss)
}

func (w Serializer) WriteDerivation(drv *nixdrv.BasicDerivation) (int64, error) {
	nRunningTotal, err := w.WriteUint64(uint64(len(drv.Outputs)))
	if err != nil {
		return nRunningTotal, err
	}
	for outputName, output := range drv.Outputs {
		nOutputName, err := w.WriteString(outputName)
		nRunningTotal += nOutputName
		if err != nil {
			return nRunningTotal, err
		}
		nOutputPath, err := w.WriteString(output.Path)
		nRunningTotal += nOutputPath
		if err != nil {
			return nRunningTotal, err
		}
		nOutputHashAlgo, err := w.WriteString(output.HashAlgorithm)
		nRunningTotal += nOutputHashAlgo
		if err != nil {
			return nRunningTotal, err
		}
		nOutputHash, err := w.WriteString(output.Hash)
		nRunningTotal += nOutputHash
		if err != nil {
			return nRunningTotal, err
		}
	}
	nInputSrcs, err := w.WriteStringSet(drv.InputSrcs)
	nRunningTotal += nInputSrcs
	if err != nil {
		return nRunningTotal, err
	}
	nPlatform, err := w.WriteString(drv.Platform)
	nRunningTotal += nPlatform
	if err != nil {
		return nRunningTotal, err
	}
	nBuilder, err := w.WriteString(drv.Builder)
	nRunningTotal += nBuilder
	if err != nil {
		return nRunningTotal, err
	}
	nArgs, err := w.WriteStrings(drv.Args)
	nRunningTotal += nArgs
	if err != nil {
		return nRunningTotal, err
	}
	nEnvSize, err := w.WriteUint64(uint64(len(drv.Env)))
	nRunningTotal += nEnvSize
	if err != nil {
		return nRunningTotal, err
	}
	for k, v := range drv.Env {
		nEnvKey, err := w.WriteString(k)
		nRunningTotal += nEnvKey
		if err != nil {
			return nRunningTotal, err
		}
		nEnvValue, err := w.WriteString(v)
		nRunningTotal += nEnvValue
		if err != nil {
			return nRunningTotal, err
		}
	}
	return nRunningTotal, nil
}

type Deserializer struct {
	io.Reader
}

func (r Deserializer) ReadPadding(n uint64) error {
	if n%8 == 0 {
		return nil
	}

	buf := make([]byte, 8-(int(n)%8))
	if _, err := io.ReadFull(r, buf); err != nil {
		return fmt.Errorf("reading padding: %w", err)
	}
	for _, n := range buf {
		if n != 0 {
			return fmt.Errorf("padding was not all 0")
		}
	}
	return nil
}

func (r Deserializer) ReadUint64() (uint64, error) {
	buf := make([]byte, 8)
	if _, err := io.ReadFull(r, buf); err != nil {
		return 0, fmt.Errorf("reading uint64: %w", err)
	}

	return binary.LittleEndian.Uint64(buf), nil
}

func (r Deserializer) ReadBytes() ([]byte, error) {
	strLen, err := r.ReadUint64()
	if err != nil {
		return nil, fmt.Errorf("reading length: %w", err)
	}

	strBuf := make([]byte, int(strLen))
	if _, err := io.ReadFull(r, strBuf); err != nil {
		return nil, fmt.Errorf("reading string buffer: %w", err)
	}

	if err := r.ReadPadding(strLen); err != nil {
		return nil, fmt.Errorf("reading string padding: %w", err)
	}

	return strBuf, nil
}

func (r Deserializer) ReadString() (string, error) {
	bs, err := r.ReadBytes()
	if err != nil {
		return "", err
	}
	return string(bs), nil
}

func (r Deserializer) ReadStrings() ([]string, error) {
	setLen, err := r.ReadUint64()
	if err != nil {
		return nil, fmt.Errorf("reading string set length: %w", err)
	}

	xs := make([]string, int(setLen))
	for n := uint64(0); n < setLen; n++ {
		xs[n], err = r.ReadString()
		if err != nil {
			return nil, fmt.Errorf("reading string %d of %d: %w", n, setLen, err)
		}
	}
	return xs, nil
}

func (r Deserializer) ReadStringSet() (map[string]bool, error) {
	strs, err := r.ReadStrings()
	if err != nil {
		return nil, err
	}

	m := make(map[string]bool, len(strs))
	for _, k := range strs {
		m[k] = true
	}
	return m, nil
}

func (r Deserializer) ReadDerivation() (*nixdrv.BasicDerivation, error) {
	drv := &nixdrv.BasicDerivation{}

	outputsLen, err := r.ReadUint64()
	if err != nil {
		return nil, fmt.Errorf("reading outputs length: %w", err)
	}

	drv.Outputs = make(map[string]nixdrv.Output)
	for n := uint64(0); n < outputsLen; n++ {
		outputName, err := r.ReadString()
		if err != nil {
			return nil, fmt.Errorf("reading output name: %w", err)
		}
		var o nixdrv.Output
		if o.Path, err = r.ReadString(); err != nil {
			return nil, fmt.Errorf("reading output path: %w", err)
		}
		if o.HashAlgorithm, err = r.ReadString(); err != nil {
			return nil, fmt.Errorf("reading output hash algorithm: %w", err)
		}
		if o.Hash, err = r.ReadString(); err != nil {
			return nil, fmt.Errorf("reading output hash: %w", err)
		}

		drv.Outputs[outputName] = o
	}

	if drv.InputSrcs, err = r.ReadStringSet(); err != nil {
		return nil, fmt.Errorf("reading input srcs: %w", err)
	}
	if drv.Platform, err = r.ReadString(); err != nil {
		return nil, fmt.Errorf("reading platform: %w", err)
	}
	if drv.Builder, err = r.ReadString(); err != nil {
		return nil, fmt.Errorf("reading builder: %w", err)
	}
	if drv.Args, err = r.ReadStrings(); err != nil {
		return nil, fmt.Errorf("reading args: %w", err)
	}

	envLen, err := r.ReadUint64()
	if err != nil {
		return nil, fmt.Errorf("reading environment length: %w", err)
	}
	drv.Env = map[string]string{}
	for n := uint64(0); n < envLen; n++ {
		key, err := r.ReadString()
		if err != nil {
			return nil, fmt.Errorf("reading environment variable name: %w", err)
		}
		value, err := r.ReadString()
		if err != nil {
			return nil, fmt.Errorf("reading environment variable value: %w", err)
		}
		drv.Env[key] = value
	}

	return drv, nil
}