// SPDX-FileCopyrightText: 2020 Luke Granger-Brown <depot@lukegb.com>
//
// SPDX-License-Identifier: Apache-2.0

// Package narinfo implements a simple narinfo read/writer.
package narinfo

import (
	"bufio"
	"crypto/md5"
	"crypto/sha1"
	"crypto/sha256"
	"crypto/sha512"
	"encoding/base64"
	"encoding/hex"
	"fmt"
	"hash"
	"io"
	"regexp"
	"sort"
	"strconv"
	"strings"

	"github.com/numtide/go-nix/nixbase32"
)

// CompressionMethod lists the supported compression methods.
type CompressionMethod int

const (
	CompressionUnknown CompressionMethod = iota
	CompressionNone
	CompressionXz
	CompressionBzip2
	CompressionGzip
	CompressionBrotli
)

func (c CompressionMethod) String() string {
	switch c {
	case CompressionNone:
		return "none"
	case CompressionXz:
		return "xz"
	case CompressionBzip2:
		return "bzip2"
	case CompressionGzip:
		return "gzip"
	case CompressionBrotli:
		return "br"
	}
	return "!!unknown!!"
}

func (c CompressionMethod) FileSuffix() (string, error) {
	switch c {
	case CompressionNone:
		return "", nil
	case CompressionXz:
		return ".xz", nil
	case CompressionBzip2:
		return ".bz2", nil
	case CompressionGzip:
		return ".gz", nil
	case CompressionBrotli:
		return ".br", nil
	}
	return "", fmt.Errorf("bad compression method %d", int(c))
}

func compressionFromString(s string) CompressionMethod {
	switch s {
	case "none", "":
		return CompressionNone
	case "xz":
		return CompressionXz
	case "bzip2":
		return CompressionBzip2
	case "gzip":
		return CompressionGzip
	case "br":
		return CompressionBrotli
	default:
		return CompressionUnknown
	}
}

// HashAlgorithm lists the supported hashing algorithms.
type HashAlgorithm int

const (
	HashUnknown HashAlgorithm = iota
	HashMd5
	HashSha1
	HashSha256
	HashSha512
)

func (a HashAlgorithm) String() string {
	switch a {
	case HashMd5:
		return "md5"
	case HashSha1:
		return "sha1"
	case HashSha256:
		return "sha256"
	case HashSha512:
		return "sha512"
	}
	return "!!unknown!!"
}

func (a HashAlgorithm) Hash() hash.Hash {
	switch a {
	case HashMd5:
		return md5.New()
	case HashSha1:
		return sha1.New()
	case HashSha256:
		return sha256.New()
	case HashSha512:
		return sha512.New()
	}
	panic("unknown hash algorithm")
}

func (a HashAlgorithm) sizes() sizes {
	sz := a.Hash().Size()
	return sizes{
		rawLen:    sz,
		base16Len: hex.EncodedLen(sz),
		base32Len: nixbase32.EncodedLen(sz),
		base64Len: base64.StdEncoding.EncodedLen(sz),
	}
}

type sizes struct {
	rawLen    int
	base16Len int
	base32Len int
	base64Len int
}

// Hash represents a hash and its algorithm.
type Hash struct {
	Algorithm HashAlgorithm
	Hash      []byte
	IsSRI     bool
}

// Valid returns if this Hash contains a valid hash.
func (h Hash) Valid() bool {
	return h.Algorithm != HashUnknown && len(h.Hash) == h.Algorithm.sizes().rawLen
}

func (h Hash) String() string {
	if h.IsSRI {
		return fmt.Sprintf("%s-%s", h.Algorithm, base64.StdEncoding.EncodeToString(h.Hash))
	}
	return fmt.Sprintf("%s:%s", h.Algorithm, nixbase32.EncodeToString(h.Hash))
}

func HashFromString(s string) (Hash, error) {
	var h Hash
	idx := strings.IndexAny(s, "-:")
	if idx == -1 {
		return h, fmt.Errorf("no hashtype separator")
	}
	h.IsSRI = s[idx] == '-'
	hashtype, hashdata := s[:idx], s[idx+1:]
	switch hashtype {
	case "md5":
		h.Algorithm = HashMd5
	case "sha1":
		h.Algorithm = HashSha1
	case "sha256":
		h.Algorithm = HashSha256
	case "sha512":
		h.Algorithm = HashSha512
	default:
		return h, fmt.Errorf("unknown hash type %v", hashtype)
	}
	var hashbytes []byte
	sz := h.Algorithm.sizes()
	// We support the same things as Nix, which is base16 (hex), base32 (why), and base64 hashes.
	// For SRI, base64 is required.
	switch {
	case !h.IsSRI && len(hashdata) == sz.base16Len:
		var err error
		hashbytes, err = hex.DecodeString(hashdata)
		if err != nil {
			return h, fmt.Errorf("decoding hash %v as hex", hashdata)
		}
	case !h.IsSRI && len(hashdata) == sz.base32Len:
		var err error
		hashbytes, err = nixbase32.DecodeString(hashdata)
		if err != nil {
			return h, fmt.Errorf("decoding hash %v as nix-base32", hashdata)
		}
	case len(hashdata) == sz.base64Len:
		var err error
		hashbytes, err = base64.StdEncoding.DecodeString(hashdata)
		if err != nil {
			return h, fmt.Errorf("decoding hash %v as base64", hashdata)
		}
	default:
		return h, fmt.Errorf("unknown hash type for length %d", len(hashdata))
	}
	h.Hash = hashbytes
	return h, nil
}

// NarInfo represents the content of a .narinfo file.
type NarInfo struct {
	StorePath   string
	URL         string
	Compression CompressionMethod
	FileHash    Hash
	FileSize    uint64
	NarHash     Hash
	NarSize     uint64
	References  []string
	Deriver     string
	System      string
	Sig         map[string][]byte
	CA          string
}

func (ni NarInfo) Sigs() []string {
	sigKeys := make([]string, 0, len(ni.Sig))
	for k := range ni.Sig {
		sigKeys = append(sigKeys, k)
	}
	sort.Strings(sigKeys)
	sigs := make([]string, 0, len(sigKeys))
	for _, k := range sigKeys {
		v := ni.Sig[k]
		sigs = append(sigs, fmt.Sprintf("%s:%s", k, base64.StdEncoding.EncodeToString(v)))
	}
	return sigs
}

func (ni NarInfo) WriteTo(w io.Writer) error {
	if ni.StorePath != "" {
		if _, err := fmt.Fprintf(w, "StorePath: %s\n", ni.StorePath); err != nil {
			return err
		}
	}
	if ni.URL != "" {
		if _, err := fmt.Fprintf(w, "URL: %s\n", ni.URL); err != nil {
			return err
		}
	}
	if ni.Compression != CompressionUnknown {
		if _, err := fmt.Fprintf(w, "Compression: %s\n", ni.Compression); err != nil {
			return err
		}
	}
	if ni.FileHash.Valid() {
		if _, err := fmt.Fprintf(w, "FileHash: %s\n", ni.FileHash); err != nil {
			return err
		}
	}
	if ni.FileSize != 0 {
		if _, err := fmt.Fprintf(w, "FileSize: %d\n", ni.FileSize); err != nil {
			return err
		}
	}
	if ni.NarHash.Valid() {
		if _, err := fmt.Fprintf(w, "NarHash: %s\n", ni.NarHash); err != nil {
			return err
		}
	}
	if ni.NarSize != 0 {
		if _, err := fmt.Fprintf(w, "NarSize: %d\n", ni.NarSize); err != nil {
			return err
		}
	}
	if len(ni.References) != 0 {
		if _, err := fmt.Fprintf(w, "References: %s\n", strings.Join(ni.References, " ")); err != nil {
			return err
		}
	}
	if ni.Deriver != "" {
		if _, err := fmt.Fprintf(w, "Deriver: %s\n", ni.Deriver); err != nil {
			return err
		}
	}
	if ni.System != "" {
		if _, err := fmt.Fprintf(w, "System: %s\n", ni.System); err != nil {
			return err
		}
	}
	if len(ni.Sig) != 0 {
		sigs := ni.Sigs()
		for _, sig := range sigs {
			if _, err := fmt.Fprintf(w, "Sig: %s\n", sig); err != nil {
				return err
			}
		}
	}
	if ni.CA != "" {
		if _, err := fmt.Fprintf(w, "CA: %s\n", ni.CA); err != nil {
			return err
		}
	}
	return nil
}

func (ni NarInfo) String() string {
	var b strings.Builder
	ni.WriteTo(&b)
	return b.String()
}

var narLineRegexp = regexp.MustCompile(`^([A-Za-z]+): (.*)$`)

// ParseNarInfo parses NarInfo data from a Reader.
func ParseNarInfo(r io.Reader) (*NarInfo, error) {
	ni := NarInfo{
		Compression: CompressionBzip2,
	}
	s := bufio.NewScanner(r)
	for s.Scan() {
		submatches := narLineRegexp.FindStringSubmatch(s.Text())
		switch {
		case len(submatches) == 0:
			continue
		case len(submatches) != 3:
			return nil, fmt.Errorf("failed to parse line %q: wrong number of submatches", s.Text())
		}
		name, value := submatches[1], submatches[2]
		switch name {
		case "StorePath":
			ni.StorePath = value
		case "URL":
			ni.URL = value
		case "Compression":
			ni.Compression = compressionFromString(value)
			if ni.Compression == CompressionUnknown {
				return nil, fmt.Errorf("unknown compression method %q", value)
			}
		case "FileHash":
			h, err := HashFromString(value)
			if err != nil {
				return nil, fmt.Errorf("parsing %q as FileHash: %w", value, err)
			}
			ni.FileHash = h
		case "FileSize":
			var err error
			ni.FileSize, err = strconv.ParseUint(value, 10, 64)
			if err != nil {
				return nil, fmt.Errorf("parsing %q as FileSize: %w", value, err)
			}
		case "NarHash":
			h, err := HashFromString(value)
			if err != nil {
				return nil, fmt.Errorf("parsing %q as NarHash: %w", value, err)
			}
			ni.NarHash = h
		case "NarSize":
			var err error
			ni.NarSize, err = strconv.ParseUint(value, 10, 64)
			if err != nil {
				return nil, fmt.Errorf("parsing %q as NarSize: %w", value, err)
			}
		case "References":
			ni.References = strings.Split(value, " ")
			if len(ni.References) == 1 && ni.References[0] == "" {
				ni.References = nil
			}
		case "Deriver":
			if value != "unknown-deriver" {
				ni.Deriver = value
			}
		case "System":
			ni.System = value
		case "Sig":
			if ni.Sig == nil {
				ni.Sig = make(map[string][]byte)
			}
			bits := strings.SplitN(value, ":", 2)
			if len(bits) != 2 {
				return nil, fmt.Errorf("parsing %q as Sig: not enough pieces", value)
			}
			rawVal, err := base64.StdEncoding.DecodeString(bits[1])
			if err != nil {
				return nil, fmt.Errorf("parsing %q as Sig: %w", value, err)
			}
			ni.Sig[bits[0]] = rawVal
		case "CA":
			ni.CA = value
		}
	}
	if err := s.Err(); err != nil {
		return nil, err
	}
	switch {
	case ni.StorePath == "":
		return nil, fmt.Errorf("StorePath missing")
	case !ni.NarHash.Valid():
		return nil, fmt.Errorf("NarHash missing or invalid")
	case ni.URL == "":
		return nil, fmt.Errorf("URL empty")
	case ni.NarSize == 0:
		return nil, fmt.Errorf("NAR is 0-size")
	}
	return &ni, nil
}