package nixstore

import (
	"encoding/base64"
	"fmt"
	"io"
	"net"
	"path"
	"strings"
	"sync"

	"hg.lukegb.com/lukegb/depot/go/nix/nar/narinfo"
	"hg.lukegb.com/lukegb/depot/go/nix/nixdrv"
	"hg.lukegb.com/lukegb/depot/go/nix/nixwire"
)

const (
	DaemonSock      = "/nix/var/nix/daemon-socket/socket"
	WorkerMagic1    = 0x6e697863
	WorkerMagic2    = 0x6478696f
	ProtocolVersion = 0x115

	WopBuildDerivation            = 36
	WopEnsurePath                 = 10
	WopQueryPathInfo              = 26
	WopQuerySubstitutablePathInfo = 21

	MaxBuf = 1024 * 1024 // 1 MB
)

type Daemon struct {
	conn io.Closer
	w    *nixwire.Serializer
	r    *nixwire.Deserializer
	mu   sync.Mutex
	err  error
}

type PathNotValidError struct{ Path string }

func (err PathNotValidError) Error() string {
	return fmt.Sprintf("path %q is not valid", err.Path)
}

func (d *Daemon) NARInfo(storePath string) (*narinfo.NarInfo, error) {
	d.mu.Lock()
	defer d.mu.Unlock()

	if _, err := d.w.WriteUint64(WopQueryPathInfo); err != nil {
		return nil, fmt.Errorf("writing worker op WopQueryPathInfo: %w", err)
	}
	if _, err := d.w.WriteString(storePath); err != nil {
		return nil, fmt.Errorf("writing store path query %v: %w", storePath, err)
	}

	if err := d.processStderr(nil, nil, nil); err != nil {
		return nil, fmt.Errorf("reading stderr from WopQueryPathInfo: %w", err)
	}

	validInt, err := d.r.ReadUint64()
	if err != nil {
		return nil, fmt.Errorf("reading path validity: %w", err)
	}
	valid := validInt == uint64(1)
	if !valid {
		return nil, PathNotValidError{storePath}
	}

	ni := &narinfo.NarInfo{
		StorePath: storePath,
	}
	if ni.Deriver, err = d.r.ReadString(); err != nil {
		return nil, fmt.Errorf("reading deriver: %w", err)
	}
	ni.Deriver = path.Base(ni.Deriver)
	if ni.Deriver == "." {
		ni.Deriver = ""
	}

	hashStr, err := d.r.ReadString()
	if err != nil {
		return nil, fmt.Errorf("reading NAR hash: %w", err)
	}
	if ni.NarHash, err = narinfo.HashFromString("sha256:" + hashStr); err != nil {
		return nil, fmt.Errorf("parsing NAR hash %q: %w", hashStr, err)
	}

	refs, err := d.r.ReadStrings()
	if err != nil {
		return nil, fmt.Errorf("reading referrers: %w", err)
	}
	for n, ref := range refs {
		refs[n] = path.Base(ref)
	}
	ni.References = refs

	if _, err := d.r.ReadUint64(); err != nil {
		return nil, fmt.Errorf("reading registration time: %w", err)
	}
	if ni.NarSize, err = d.r.ReadUint64(); err != nil {
		return nil, fmt.Errorf("reading narsize: %w", err)
	}
	if _, err := d.r.ReadUint64(); err != nil {
		return nil, fmt.Errorf("reading ultimate: %w", err)
	}
	sigs, err := d.r.ReadStrings()
	if err != nil {
		return nil, fmt.Errorf("reading sigs: %w", err)
	}
	if _, err := d.r.ReadUint64(); err != nil {
		return nil, fmt.Errorf("reading CA: %w", err)
	}

	ni.Sig = make(map[string][]byte)
	for _, sigStr := range sigs {
		sigBits := strings.Split(sigStr, ":")
		sigHash, err := base64.StdEncoding.DecodeString(sigBits[1])
		if err != nil {
			return nil, fmt.Errorf("decoding signature %q: %w", sigStr, err)
		}
		ni.Sig[sigBits[0]] = sigHash
	}

	return ni, nil
}

func (d *Daemon) Close() error {
	d.mu.Lock()
	defer d.mu.Unlock()

	return d.conn.Close()
}

func (d *Daemon) readFields() ([]any, error) {
	fieldsSz, err := d.r.ReadUint64()
	if err != nil {
		return nil, fmt.Errorf("reading fields count: %w", err)
	}
	var fields []any
	for n := 0; n < int(fieldsSz); n++ {
		fieldTyp, err := d.r.ReadUint64()
		if err != nil {
			return nil, fmt.Errorf("reading field %d type: %w", n, err)
		}
		switch fieldTyp {
		case 0: // tInt
			fieldVal, err := d.r.ReadUint64()
			if err != nil {
				return nil, fmt.Errorf("reading field %d int: %w", n, err)
			}
			fields = append(fields, fieldVal)
		case 1: // tString
			fieldVal, err := d.r.ReadString()
			if err != nil {
				return nil, fmt.Errorf("reading field %d str: %w", n, err)
			}
			fields = append(fields, fieldVal)
		default:
			return nil, fmt.Errorf("reading field %d: unsupported field type %x", n, fieldTyp)
		}
	}
	return fields, nil
}

type NixError struct {
	Code    uint64
	Message string
}

func (ne NixError) Error() string {
	return fmt.Sprintf("nix error %d: %v", ne.Code, ne.Message)
}

func (d *Daemon) processStderr(al *ActivityLogger, stdout io.Writer, stdin io.Reader) error {
	for {
		msg, err := d.r.ReadUint64()
		if err != nil {
			return fmt.Errorf("read stderr msg code: %w", err)
		}
		switch msg {
		case 0x64617416: // STDERR_WRITE
			if stdout == nil {
				return fmt.Errorf("STDERR_WRITE requested")
			}
			bs, err := d.r.ReadBytes()
			if err != nil {
				return fmt.Errorf("read bytes: %w", err)
			}
			_, err = stdout.Write(bs)
			if err != nil {
				return fmt.Errorf("write bytes into stdout: %w", err)
			}
		case 0x64617461: // STDERR_READ
			if stdin == nil {
				return fmt.Errorf("STDERR_READ requested")
			}
			readSizeU64, err := d.r.ReadUint64()
			if err != nil {
				return fmt.Errorf("STDERR_READ reading uint64: %w", err)
			}
			readSize := int(readSizeU64)
			if readSize > MaxBuf {
				readSize = MaxBuf
			}
			buf := make([]byte, readSize)
			readSize, err = stdin.Read(buf)
			if err != nil {
				return fmt.Errorf("STDERR_READ reading from stdin: %w", err)
			}
			if _, err := d.w.WriteBytes(buf[:readSize]); err != nil {
				return fmt.Errorf("STDERR_READ writing stdin to socket: %w", err)
			}
		case 0x63787470: // STDERR_ERROR
			errStr, err := d.r.ReadString()
			if err != nil {
				return fmt.Errorf("STDERR_ERROR reading error string: %w", err)
			}
			status, err := d.r.ReadUint64()
			if err != nil {
				return fmt.Errorf("STDERR_ERROR reading uint64: %w", err)
			}
			al.AddError(status, errStr)
			return NixError{status, errStr}
		case 0x6f6c6d67: // STDERR_NEXT
			msg, err := d.r.ReadString()
			if err != nil {
				return fmt.Errorf("STDERR_NEXT reading log string: %w", err)
			}
			al.AddLog(msg)
		case 0x53545254: // STDERR_START_ACTIVITY
			activity, err := d.r.ReadUint64()
			if err != nil {
				return fmt.Errorf("STDERR_START_ACTIVITY reading ActivityId: %w", err)
			}
			lvl, err := d.r.ReadUint64()
			if err != nil {
				return fmt.Errorf("STDERR_START_ACTIVITY reading level: %w", err)
			}
			actTypeU64, err := d.r.ReadUint64()
			if err != nil {
				return fmt.Errorf("STDERR_START_ACTIVITY reading activity type: %w", err)
			}
			actType := ActivityType(actTypeU64)
			s, err := d.r.ReadString()
			if err != nil {
				return fmt.Errorf("STDERR_START_ACTIVITY reading s: %w", err)
			}
			fields, err := d.readFields()
			if err != nil {
				return fmt.Errorf("STDERR_START_ACTIVITY fields: %w", err)
			}
			parentActivity, err := d.r.ReadUint64()
			if err != nil {
				return fmt.Errorf("STDERR_START_ACTIVITY reading parent activity: %w", err)
			}
			al.StartActivity(actType, ActivityMeta{
				ActivityID:       activity,
				Level:            lvl,
				String:           s,
				Fields:           fields,
				ParentActivityID: parentActivity,
			})
		case 0x53544f50: // STDERR_STOP_ACTIVITY
			activity, err := d.r.ReadUint64()
			if err != nil {
				return fmt.Errorf("STDERR_STOP_ACTIVITY reading ActivityId: %w", err)
			}
			al.EndActivity(al.Activity(activity))
		case 0x52534c54: // STDERR_RESULT
			activity, err := d.r.ReadUint64()
			if err != nil {
				return fmt.Errorf("STDERR_RESULT reading ActivityId: %w", err)
			}
			resultTypU64, err := d.r.ReadUint64()
			if err != nil {
				return fmt.Errorf("STDERR_RESULT reading result: %w", err)
			}
			resultTyp := ResultType(resultTypU64)
			fields, err := d.readFields()
			if err != nil {
				return fmt.Errorf("STDERR_RESULT fields: %w", err)
			}
			al.ActivityResult(al.Activity(activity), resultTyp, fields)
		case 0x616c7473: // STDERR_LAST
			return nil
		}
	}
}

func (d *Daemon) hello() error {
	_, err := d.w.WriteUint64(WorkerMagic1)
	if err != nil {
		return fmt.Errorf("writing magic 1: %w", err)
	}

	magic2, err := d.r.ReadUint64()
	if err != nil {
		return fmt.Errorf("reading magic 2: %w", err)
	}
	if magic2 != WorkerMagic2 {
		return fmt.Errorf("magic 2 mismatch: got %x, wanted %x", magic2, WorkerMagic2)
	}

	daemonVersion, err := d.r.ReadUint64()
	if err != nil {
		return fmt.Errorf("reading daemon version: %w", err)
	}
	majorVersion := int((daemonVersion >> 8) & 0xff)
	minorVersion := int(daemonVersion & 0xff)
	if majorVersion != 1 {
		return fmt.Errorf("daemon major version mismatch: got %d, want 1", majorVersion)
	}
	if minorVersion < 21 {
		return fmt.Errorf("daemon minor version too old: got %d, want at least 21", minorVersion)
	}

	if _, err := d.w.WriteUint64(ProtocolVersion); err != nil {
		return fmt.Errorf("writing protocol version: %w", err)
	}

	if _, err := d.w.WriteUint64(0); err != nil {
		return fmt.Errorf("writing obsolete CPU affinity: %w", err)
	}
	if _, err := d.w.WriteUint64(0); err != nil {
		return fmt.Errorf("writing obsolete reserveSpace: %w", err)
	}

	return d.processStderr(nil, nil, nil)
}

func (d *Daemon) EnsurePath(at *ActivityTracker, storePath string) error {
	al := at.StartAction(fmt.Sprintf("EnsurePath(%q)", storePath))
	defer al.Close()

	d.mu.Lock()
	defer d.mu.Unlock()

	if _, err := d.w.WriteUint64(WopEnsurePath); err != nil {
		return fmt.Errorf("writing worker op WopEnsurePath: %w", err)
	}
	if _, err := d.w.WriteString(storePath); err != nil {
		return fmt.Errorf("writing store path %v: %w", storePath, err)
	}

	if err := d.processStderr(al, nil, nil); err != nil {
		return fmt.Errorf("reading stderr from WopEnsurePath: %w", err)
	}

	validInt, err := d.r.ReadUint64()
	if err != nil {
		return fmt.Errorf("reading path validity: %w", err)
	}
	if validInt != 1 {
		return fmt.Errorf("path validity was %d; wanted 1", validInt)
	}
	return nil
}

type BuildMode uint64

const (
	BMNormal BuildMode = iota
	BMRepair
	BMCheck
)

func (bm BuildMode) String() string {
	return map[BuildMode]string{
		BMNormal: "normal",
		BMRepair: "repair",
		BMCheck:  "check",
	}[bm]
}

type BuildResultStatus uint64

const (
	BRSBuilt BuildResultStatus = iota
	BRSSubstituted
	BRSAlreadyValid
	BRSPermanentFailure
	BRSInputRejected
	BRSOutputRejected
	BRSTransientFailure
	BRSTimedOut
	BRSMiscFailure
	BRSDependencyFailed
	BRSLogLimitExceeded
	BRSNotDeterministic
)

func (brs BuildResultStatus) String() string {
	return map[BuildResultStatus]string{
		BRSBuilt:            "built",
		BRSSubstituted:      "substituted",
		BRSAlreadyValid:     "already valid",
		BRSPermanentFailure: "permanent failure",
		BRSInputRejected:    "input rejected",
		BRSOutputRejected:   "output rejected",
		BRSTransientFailure: "transient failure",
		BRSTimedOut:         "timed out",
		BRSMiscFailure:      "misc failure",
		BRSDependencyFailed: "dependency failed",
		BRSLogLimitExceeded: "log limit exceeded",
		BRSNotDeterministic: "not deterministic",
	}[brs]
}

func (brs BuildResultStatus) IsSuccess() bool {
	switch brs {
	case BRSBuilt, BRSSubstituted, BRSAlreadyValid:
		return true
	default:
		return false
	}
}

func (d *Daemon) BuildDerivation(at *ActivityTracker, derivationPath string, derivation *nixdrv.BasicDerivation, buildMode BuildMode) (BuildResultStatus, error) {
	al := at.StartAction(fmt.Sprintf("BuildDerivation(%q, %q)", derivationPath, buildMode))
	defer al.Close()

	d.mu.Lock()
	defer d.mu.Unlock()

	if _, err := d.w.WriteUint64(WopBuildDerivation); err != nil {
		return BRSMiscFailure, fmt.Errorf("writing worker op WopBuildDerivation: %w", err)
	}
	if _, err := d.w.WriteString(derivationPath); err != nil {
		return BRSMiscFailure, fmt.Errorf("writing derivation store path %v: %w", derivationPath, err)
	}
	if _, err := d.w.WriteDerivation(derivation); err != nil {
		return BRSMiscFailure, fmt.Errorf("writing derivation content of %v: %w", derivationPath, err)
	}
	if _, err := d.w.WriteUint64(uint64(buildMode)); err != nil {
		return BRSMiscFailure, fmt.Errorf("writing build mode %v: %w", buildMode, err)
	}

	if err := d.processStderr(al, nil, nil); err != nil {
		return BRSMiscFailure, fmt.Errorf("reading stderr from WopBuildDerivation: %w", err)
	}

	buildStatusU64, err := d.r.ReadUint64()
	if err != nil {
		return BRSMiscFailure, fmt.Errorf("reading build status code: %w", err)
	}
	buildStatus := BuildResultStatus(buildStatusU64)
	buildErrorMsg, err := d.r.ReadString()
	if err != nil {
		return BRSMiscFailure, fmt.Errorf("reading build error message: %w", err)
	}

	if !buildStatus.IsSuccess() {
		return buildStatus, fmt.Errorf("%s: %s", buildStatus, buildErrorMsg)
	}

	return buildStatus, nil
}

func OpenDaemon(path string) (*Daemon, error) {
	conn, err := net.Dial("unix", path)
	if err != nil {
		return nil, fmt.Errorf("dialing %v: %w", path, err)
	}

	d := &Daemon{conn: conn, w: &nixwire.Serializer{conn}, r: &nixwire.Deserializer{conn}}
	if err := d.hello(); err != nil {
		d.Close()
		return nil, fmt.Errorf("sending hello to %v: %w", path, err)
	}

	return d, nil
}

func OpenDaemonWithIOs(r io.Reader, w io.Writer, c io.Closer) (*Daemon, error) {
	d := &Daemon{conn: c, w: &nixwire.Serializer{w}, r: &nixwire.Deserializer{r}}
	if err := d.hello(); err != nil {
		d.Close()
		return nil, fmt.Errorf("sending hello: %w", err)
	}

	return d, nil
}