// SPDX-FileCopyrightText: 2023 Luke Granger-Brown <depot@lukegb.com>
//
// SPDX-License-Identifier: Apache-2.0

package nixpool

import (
	"context"
	"fmt"
	"io"
	"net/url"
	"os"

	"golang.org/x/crypto/ssh"
	"hg.lukegb.com/lukegb/depot/go/nix/nixstore"
)

// DaemonFactory is the shape of a factory function.
type DaemonFactory func() (*nixstore.Daemon, error)

// DaemonDialer creates a factory function from the provided remote.
func DaemonDialer(ctx context.Context, remote string) (DaemonFactory, error) {
	u, err := url.Parse(remote)
	if err != nil {
		return nil, fmt.Errorf("parsing remote %q as URL: %w", remote, err)
	}

	switch u.Scheme {
	case "unix":
		if u.Path == "" {
			u.Path = nixstore.DaemonSock
		}
		return func() (*nixstore.Daemon, error) {
			d, err := nixstore.OpenDaemon(u.Path)
			if err != nil {
				return nil, err
			}
			return d, nil
		}, nil
	case "ssh-ng":
		// Construct a ClientConfig from the URL.
		cfg := &ssh.ClientConfig{}
		cfg.Ciphers = []string{"chacha20-poly1305@openssh.com"}
		if u.Query().Has("privkey") {
			var keys []ssh.Signer
			for _, privkeyPath := range u.Query()["privkey"] {
				privkeyF, err := os.Open(privkeyPath)
				if err != nil {
					return nil, fmt.Errorf("opening privkey %q: %w", privkeyPath, err)
				}
				defer privkeyF.Close()
				privkeyB, err := io.ReadAll(privkeyF)
				if err != nil {
					return nil, fmt.Errorf("reading privkey %q: %w", privkeyPath, err)
				}

				privkey, err := ssh.ParsePrivateKey(privkeyB)
				if err != nil {
					return nil, fmt.Errorf("parsing privkey %q: %w", privkeyPath, err)
				}
				keys = append(keys, privkey)
			}
			cfg.Auth = append(cfg.Auth, ssh.PublicKeys(keys...))
		}
		if u.User != nil {
			cfg.User = u.User.Username()
			if pw, ok := u.User.Password(); ok {
				cfg.Auth = append(cfg.Auth, ssh.Password(pw))
			}
		}
		switch {
		case u.Query().Has("host-key"):
			hkStr := u.Query().Get("host-key")
			_, _, hk, _, _, err := ssh.ParseKnownHosts(append([]byte("x "), []byte(hkStr)...))
			if err != nil {
				return nil, fmt.Errorf("parsing host-key %q: %w", hkStr, err)
			}
			cfg.HostKeyCallback = ssh.FixedHostKey(hk)
		case u.Query().Has("insecure-allow-any-ssh-host-key"):
			cfg.HostKeyCallback = ssh.InsecureIgnoreHostKey()
		default:
			return nil, fmt.Errorf("some SSH host key configuration is required (?host-key=; ?insecure-allow-any-ssh-host-key)")
		}

		// Work out other misc parameters.
		// ...remote command.
		remoteCmd := "nix-daemon --stdio"
		if u.Query().Has("remote-cmd") {
			remoteCmd = u.Query().Get("remote-cmd")
		}

		// Work out the host:port to connect to.
		remote := u.Hostname()
		if portStr := u.Port(); portStr != "" {
			remote = remote + ":" + portStr
		} else {
			remote = remote + ":22"
		}

		return func() (*nixstore.Daemon, error) {
			conn, err := ssh.Dial("tcp", remote, cfg)
			if err != nil {
				return nil, fmt.Errorf("dialing %v via SSH: %w", remote, err)
			}
			sess, err := conn.NewSession()
			if err != nil {
				conn.Close()
				return nil, fmt.Errorf("opening SSH session to %v: %w", remote, err)
			}
			stdin, err := sess.StdinPipe()
			if err != nil {
				conn.Close()
				return nil, fmt.Errorf("opening stdin pipe: %w", err)
			}
			stdout, err := sess.StdoutPipe()
			if err != nil {
				conn.Close()
				return nil, fmt.Errorf("opening stdout pipe: %w", err)
			}
			if err := sess.Start(remoteCmd); err != nil {
				conn.Close()
				return nil, fmt.Errorf("starting %q: %w", remoteCmd, err)
			}
			d, err := nixstore.OpenDaemonWithIOs(stdout, stdin, conn)
			if err != nil {
				conn.Close()
				return nil, fmt.Errorf("establishing connection to daemon: %w", err)
			}
			return d, nil
		}, nil
	default:
		return nil, fmt.Errorf("unknown remote %q", remote)
	}
}