// SPDX-FileCopyrightText: 2021 Luke Granger-Brown <depot@lukegb.com>
//
// SPDX-License-Identifier: Apache-2.0

package fuphttp_test

import (
	"context"
	"fmt"
	"io"
	"net/http"
	"net/http/httptest"
	"net/url"
	"testing"

	"hg.lukegb.com/lukegb/depot/web/fup/fuphttp"
	"hg.lukegb.com/lukegb/depot/web/fup/fupstatic"

	"gocloud.dev/blob"
	"gocloud.dev/blob/fileblob"
	"gocloud.dev/blob/memblob"
)

var cfg = &fuphttp.Config{
	Templates:  fupstatic.Templates,
	StorageURL: "mem://",
}

func TestNotFound(t *testing.T) {
	ctx := context.Background()

	a, err := fuphttp.New(ctx, cfg)
	if err != nil {
		t.Fatalf("fuphttp.New: %v", err)
	}
	s := httptest.NewServer(a.Handler())
	defer s.Close()

	resp, err := s.Client().Get(fmt.Sprintf("%s/not-found", s.URL))
	if err != nil {
		t.Fatalf("Get: %v", err)
	}
	defer resp.Body.Close()
	if resp.StatusCode != http.StatusNotFound {
		t.Errorf("response status was %v; want %v", resp.StatusCode, http.StatusNotFound)
	}
}

func mustURL(s string) *url.URL {
	u, err := url.Parse(s)
	if err != nil {
		panic(fmt.Sprintf("parsing %q: %v", s, err))
	}
	return u
}

func mustBucket(b *blob.Bucket, err error) *blob.Bucket {
	if err != nil {
		panic(fmt.Sprintf("opening bucket: %v", err))
	}
	return b
}

func TestRetrieve(t *testing.T) {
	signer := fileblob.NewURLSignerHMAC(mustURL("http://example.com"), []byte("not secret"))
	tcs := []struct {
		name               string
		backend            *blob.Bucket
		wantSignedRedirect string
		supportsRange      bool
	}{{
		name:          "LocalServe",
		backend:       memblob.OpenBucket(nil),
		supportsRange: false,
	}, {
		name:          "ReadSeeker",
		backend:       mustBucket(fileblob.OpenBucket(t.TempDir(), nil)),
		supportsRange: true,
	}, {
		name: "SignedURL",
		backend: mustBucket(fileblob.OpenBucket(t.TempDir(), &fileblob.Options{
			URLSigner: signer,
		})),
		wantSignedRedirect: "http://example.com",
	}}

	for _, tc := range tcs {
		t.Run(tc.name, func(t *testing.T) {
			ctx, cancel := context.WithCancel(context.Background())
			t.Cleanup(cancel)

			ccfg := *cfg
			ccfg.StorageURL = ""
			ccfg.StorageBackend = tc.backend
			ccfg.RedirectToBlobstore = true

			if err := tc.backend.WriteAll(ctx, "hello.txt", []byte("hello world\n"), nil); err != nil {
				t.Fatalf("WriteAll: %v", err)
			}
			if err := tc.backend.WriteAll(ctx, "expired.txt", []byte("hello world\n"), &blob.WriterOptions{
				Metadata: map[string]string{
					// This unix timestamp is very much in the past.
					"expires-at": "1000",
				},
			}); err != nil {
				t.Fatalf("WriteAll: %v", err)
			}

			a, err := fuphttp.New(ctx, &ccfg)
			if err != nil {
				t.Fatalf("fuphttp.New: %v", err)
			}
			s := httptest.NewServer(a.Handler())
			defer s.Close()

			client := s.Client()
			client.CheckRedirect = func(r *http.Request, via []*http.Request) error {
				return http.ErrUseLastResponse
			}

			t.Run("get all", func(t *testing.T) {
				resp, err := s.Client().Get(fmt.Sprintf("%s/raw/hello.txt", s.URL))
				if err != nil {
					t.Fatalf("Get: %v", err)
				}
				defer resp.Body.Close()

				if tc.wantSignedRedirect != "" {
					if resp.StatusCode != http.StatusFound {
						t.Errorf("response status was %v; want %v", resp.StatusCode, http.StatusFound)
					}

					gotLoc := resp.Header.Get("Location")
					gotLocURL, err := url.Parse(gotLoc)
					if err != nil {
						t.Errorf("parsing Location %q: %v", gotLoc, err)
					} else {
						obj, err := signer.KeyFromURL(ctx, gotLocURL)
						if err != nil {
							t.Errorf("KeyFromURL(%q): %v", gotLoc, err)
						} else if obj != "hello.txt" {
							t.Errorf("KeyFromURL(%q) = %v; want %v", gotLoc, obj, "hello.txt")
						}
					}
				} else {
					if resp.StatusCode != http.StatusOK {
						t.Errorf("response status was %v; want %v", resp.StatusCode, http.StatusOK)
					}

					data, err := io.ReadAll(resp.Body)
					if err != nil {
						t.Errorf("ReadAll: %v", err)
					}
					if got, want := string(data), "hello world\n"; got != want {
						t.Errorf("read data was %q; want %q", got, want)
					}
				}
			})

			t.Run("get range", func(t *testing.T) {
				if !tc.supportsRange {
					t.Skip("range unsupported by this backend")
				}

				req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/raw/hello.txt", s.URL), nil)
				if err != nil {
					t.Fatalf("NewRequestWithContext: %v", err)
				}
				req.Header.Set("Range", "bytes=1-3")

				resp, err := s.Client().Do(req)
				if err != nil {
					t.Fatalf("Client.Do: %v", err)
				}

				if resp.StatusCode != http.StatusPartialContent {
					t.Fatalf("response status was %v; want %v", resp.StatusCode, http.StatusPartialContent)
				}

				data, err := io.ReadAll(resp.Body)
				if err != nil {
					t.Errorf("ReadAll: %v", err)
				}
				if got, want := string(data), "ell"; got != want {
					t.Errorf("read data was %q; want %q", got, want)
				}
			})

			t.Run("expired file", func(t *testing.T) {
				// Verify the file is there.
				if ok, err := tc.backend.Exists(ctx, "expired.txt"); err != nil {
					t.Fatalf("Exists(%q) before test: %v", "expired.txt", err)
				} else if !ok {
					t.Fatalf("Exists(%q) before test returned false", "expired.txt")
				}

				resp, err := s.Client().Get(fmt.Sprintf("%s/raw/expired.txt", s.URL))
				if err != nil {
					t.Fatalf("Get: %v", err)
				}
				defer resp.Body.Close()

				if resp.StatusCode != http.StatusNotFound {
					t.Errorf("response status was %v; want %v", resp.StatusCode, http.StatusNotFound)
				}

				// Verify the file is NOT there.
				if ok, err := tc.backend.Exists(ctx, "expired.txt"); err != nil {
					t.Fatalf("Exists(%q) after retrieving: %v", "expired.txt", err)
				} else if ok {
					t.Fatalf("Exists(%q) after retrieving returned true", "expired.txt")
				}
			})
		})
	}
}