package nixwire

import (
	"encoding/binary"
	"fmt"
	"io"
)

type Serializer struct {
	io.Writer
}

func (w Serializer) WritePadding(n int64) (int64, error) {
	if n%8 > 0 {
		n, err := w.Write(make([]byte, 8-(n%8)))
		return int64(n), err
	}
	return 0, nil
}

func (w Serializer) WriteUint64(n uint64) (int64, error) {
	buf := make([]byte, 8)
	binary.LittleEndian.PutUint64(buf, n)
	wrote, err := w.Write(buf)
	return int64(wrote), err
}

func (w Serializer) WriteString(s string) (int64, error) {
	nSize, err := w.WriteUint64(uint64(len(s)))
	if err != nil {
		return int64(nSize), err
	}

	nData, err := w.Write([]byte(s))
	if err != nil {
		return int64(nSize) + int64(nData), err
	}

	nPad, err := w.WritePadding(int64(len(s)))
	return int64(nSize) + int64(nData) + int64(nPad), err
}

type Deserializer struct {
	io.Reader
}

func (r Deserializer) ReadPadding(n uint64) error {
	if n%8 == 0 {
		return nil
	}

	buf := make([]byte, 8-(int(n)%8))
	if _, err := io.ReadFull(r, buf); err != nil {
		return fmt.Errorf("reading padding: %w", err)
	}
	for _, n := range buf {
		if n != 0 {
			return fmt.Errorf("padding was not all 0")
		}
	}
	return nil
}

func (r Deserializer) ReadUint64() (uint64, error) {
	buf := make([]byte, 8)
	if _, err := io.ReadFull(r, buf); err != nil {
		return 0, fmt.Errorf("reading uint64: %w", err)
	}

	return binary.LittleEndian.Uint64(buf), nil
}

func (r Deserializer) ReadString() (string, error) {
	strLen, err := r.ReadUint64()
	if err != nil {
		return "", fmt.Errorf("reading length: %w", err)
	}

	strBuf := make([]byte, int(strLen))
	if _, err := io.ReadFull(r, strBuf); err != nil {
		return "", fmt.Errorf("reading string buffer: %w", err)
	}

	if err := r.ReadPadding(strLen); err != nil {
		return "", fmt.Errorf("reading string padding: %w", err)
	}

	return string(strBuf), nil
}

func (r Deserializer) ReadStrings() ([]string, error) {
	setLen, err := r.ReadUint64()
	if err != nil {
		return nil, fmt.Errorf("reading string set length: %w", err)
	}

	xs := make([]string, int(setLen))
	for n := uint64(0); n < setLen; n++ {
		xs[n], err = r.ReadString()
		if err != nil {
			return nil, fmt.Errorf("reading string %d of %d: %w", n, setLen, err)
		}
	}
	return xs, nil
}