// SPDX-FileCopyrightText: 2023 Luke Granger-Brown <depot@lukegb.com>
//
// SPDX-License-Identifier: Apache-2.0

package nixbuild

import (
	"context"
	"fmt"
	"io"
	"net/http"
	"os"
	"regexp"

	"github.com/numtide/go-nix/nixbase32"
	"git.lukegb.com/lukegb/depot/go/nix/nar/narinfo"
)

var (
	hashExtractRegexp = regexp.MustCompile(`(^|/)([0-9a-df-np-sv-z]{32})([-.].*)?$`)
)

func hashExtract(s string) string {
	res := hashExtractRegexp.FindStringSubmatch(s)
	if len(res) == 0 {
		return ""
	}
	return res[2]
}

func keyForPath(storePath string) (string, error) {
	fileHash := hashExtract(storePath)
	if fileHash == "" {
		return "", fmt.Errorf("store path %v seems to be invalid: couldn't extract hash", storePath)
	}

	return fmt.Sprintf("%s.narinfo", fileHash), nil
}

type HTTPCacheResolver struct {
	Endpoint   string
	HTTPClient *http.Client
	Priority   int
}

var _ Resolver = ((*HTTPCacheResolver)(nil))

func (r *HTTPCacheResolver) RelativePriority() int { return r.Priority }

func (r *HTTPCacheResolver) FetchNarInfo(ctx context.Context, path string) (*narinfo.NarInfo, error) {
	key, err := keyForPath(path)
	if err != nil {
		return nil, err
	}

	req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%v/%v", r.Endpoint, key), nil)
	if err != nil {
		return nil, fmt.Errorf("constructing request for %v/%v: %v", r.Endpoint, key, err)
	}

	resp, err := http.DefaultClient.Do(req)
	if err != nil {
		return nil, fmt.Errorf("making request for %v/%v: %v", r.Endpoint, key, err)
	}
	defer resp.Body.Close()

	switch resp.StatusCode {
	case http.StatusNotFound:
		return nil, os.ErrNotExist
	case http.StatusOK:
		// continue with function
	default:
		return nil, fmt.Errorf("%v/%v: HTTP %v", r.Endpoint, key, resp.Status)
	}

	return narinfo.ParseNarInfo(resp.Body)
}

func (r *HTTPCacheResolver) FetchNar(ctx context.Context, ni *narinfo.NarInfo) (io.ReadCloser, error) {
	compressionSuffix, err := ni.Compression.FileSuffix()
	if err != nil {
		return nil, fmt.Errorf("determining filename: %w", err)
	}
	key := fmt.Sprintf("nar/%s.nar%s", nixbase32.EncodeToString(ni.FileHash.Hash), compressionSuffix)

	req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%v/%v", r.Endpoint, key), nil)
	if err != nil {
		return nil, fmt.Errorf("constructing request for %v/%v: %v", r.Endpoint, key, err)
	}

	resp, err := http.DefaultClient.Do(req)
	if err != nil {
		return nil, fmt.Errorf("making request for %v/%v: %v", r.Endpoint, key, err)
	}

	switch resp.StatusCode {
	case http.StatusNotFound:
		resp.Body.Close()
		return nil, os.ErrNotExist
	case http.StatusOK:
		return resp.Body, nil
	default:
		resp.Body.Close()
		return nil, fmt.Errorf("downloading %v/%v: HTTP status code %v", r.Endpoint, key, resp.Status)
	}
}