Luke Granger-Brown
7bac0aee74
This is a very basic implementation of gating that we only return files which haven't expired. If the file has expired, then we just delete it.
216 lines
5.7 KiB
Go
216 lines
5.7 KiB
Go
// SPDX-FileCopyrightText: 2021 Luke Granger-Brown <depot@lukegb.com>
|
|
//
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package fuphttp_test
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"testing"
|
|
|
|
"hg.lukegb.com/lukegb/depot/web/fup/fuphttp"
|
|
"hg.lukegb.com/lukegb/depot/web/fup/fupstatic"
|
|
|
|
"gocloud.dev/blob"
|
|
"gocloud.dev/blob/fileblob"
|
|
"gocloud.dev/blob/memblob"
|
|
)
|
|
|
|
var cfg = &fuphttp.Config{
|
|
Templates: fupstatic.Templates,
|
|
StorageURL: "mem://",
|
|
}
|
|
|
|
func TestNotFound(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
a, err := fuphttp.New(ctx, cfg)
|
|
if err != nil {
|
|
t.Fatalf("fuphttp.New: %v", err)
|
|
}
|
|
s := httptest.NewServer(a.Handler())
|
|
defer s.Close()
|
|
|
|
resp, err := s.Client().Get(fmt.Sprintf("%s/not-found", s.URL))
|
|
if err != nil {
|
|
t.Fatalf("Get: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusNotFound {
|
|
t.Errorf("response status was %v; want %v", resp.StatusCode, http.StatusNotFound)
|
|
}
|
|
}
|
|
|
|
func mustURL(s string) *url.URL {
|
|
u, err := url.Parse(s)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("parsing %q: %v", s, err))
|
|
}
|
|
return u
|
|
}
|
|
|
|
func mustBucket(b *blob.Bucket, err error) *blob.Bucket {
|
|
if err != nil {
|
|
panic(fmt.Sprintf("opening bucket: %v", err))
|
|
}
|
|
return b
|
|
}
|
|
|
|
func TestRetrieve(t *testing.T) {
|
|
signer := fileblob.NewURLSignerHMAC(mustURL("http://example.com"), []byte("not secret"))
|
|
tcs := []struct {
|
|
name string
|
|
backend *blob.Bucket
|
|
wantSignedRedirect string
|
|
supportsRange bool
|
|
}{{
|
|
name: "LocalServe",
|
|
backend: memblob.OpenBucket(nil),
|
|
supportsRange: false,
|
|
}, {
|
|
name: "ReadSeeker",
|
|
backend: mustBucket(fileblob.OpenBucket(t.TempDir(), nil)),
|
|
supportsRange: true,
|
|
}, {
|
|
name: "SignedURL",
|
|
backend: mustBucket(fileblob.OpenBucket(t.TempDir(), &fileblob.Options{
|
|
URLSigner: signer,
|
|
})),
|
|
wantSignedRedirect: "http://example.com",
|
|
}}
|
|
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
t.Cleanup(cancel)
|
|
|
|
ccfg := *cfg
|
|
ccfg.StorageURL = ""
|
|
ccfg.StorageBackend = tc.backend
|
|
ccfg.RedirectToBlobstore = true
|
|
|
|
if err := tc.backend.WriteAll(ctx, "hello.txt", []byte("hello world\n"), nil); err != nil {
|
|
t.Fatalf("WriteAll: %v", err)
|
|
}
|
|
if err := tc.backend.WriteAll(ctx, "expired.txt", []byte("hello world\n"), &blob.WriterOptions{
|
|
Metadata: map[string]string{
|
|
// This unix timestamp is very much in the past.
|
|
"expires-at": "1000",
|
|
},
|
|
}); err != nil {
|
|
t.Fatalf("WriteAll: %v", err)
|
|
}
|
|
|
|
a, err := fuphttp.New(ctx, &ccfg)
|
|
if err != nil {
|
|
t.Fatalf("fuphttp.New: %v", err)
|
|
}
|
|
s := httptest.NewServer(a.Handler())
|
|
defer s.Close()
|
|
|
|
client := s.Client()
|
|
client.CheckRedirect = func(r *http.Request, via []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
}
|
|
|
|
t.Run("get all", func(t *testing.T) {
|
|
resp, err := s.Client().Get(fmt.Sprintf("%s/raw/hello.txt", s.URL))
|
|
if err != nil {
|
|
t.Fatalf("Get: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if tc.wantSignedRedirect != "" {
|
|
if resp.StatusCode != http.StatusFound {
|
|
t.Errorf("response status was %v; want %v", resp.StatusCode, http.StatusFound)
|
|
}
|
|
|
|
gotLoc := resp.Header.Get("Location")
|
|
gotLocURL, err := url.Parse(gotLoc)
|
|
if err != nil {
|
|
t.Errorf("parsing Location %q: %v", gotLoc, err)
|
|
} else {
|
|
obj, err := signer.KeyFromURL(ctx, gotLocURL)
|
|
if err != nil {
|
|
t.Errorf("KeyFromURL(%q): %v", gotLoc, err)
|
|
} else if obj != "hello.txt" {
|
|
t.Errorf("KeyFromURL(%q) = %v; want %v", gotLoc, obj, "hello.txt")
|
|
}
|
|
}
|
|
} else {
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("response status was %v; want %v", resp.StatusCode, http.StatusOK)
|
|
}
|
|
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Errorf("ReadAll: %v", err)
|
|
}
|
|
if got, want := string(data), "hello world\n"; got != want {
|
|
t.Errorf("read data was %q; want %q", got, want)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("get range", func(t *testing.T) {
|
|
if !tc.supportsRange {
|
|
t.Skip("range unsupported by this backend")
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/raw/hello.txt", s.URL), nil)
|
|
if err != nil {
|
|
t.Fatalf("NewRequestWithContext: %v", err)
|
|
}
|
|
req.Header.Set("Range", "bytes=1-3")
|
|
|
|
resp, err := s.Client().Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Client.Do: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusPartialContent {
|
|
t.Fatalf("response status was %v; want %v", resp.StatusCode, http.StatusPartialContent)
|
|
}
|
|
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Errorf("ReadAll: %v", err)
|
|
}
|
|
if got, want := string(data), "ell"; got != want {
|
|
t.Errorf("read data was %q; want %q", got, want)
|
|
}
|
|
})
|
|
|
|
t.Run("expired file", func(t *testing.T) {
|
|
// Verify the file is there.
|
|
if ok, err := tc.backend.Exists(ctx, "expired.txt"); err != nil {
|
|
t.Fatalf("Exists(%q) before test: %v", "expired.txt", err)
|
|
} else if !ok {
|
|
t.Fatalf("Exists(%q) before test returned false", "expired.txt")
|
|
}
|
|
|
|
resp, err := s.Client().Get(fmt.Sprintf("%s/raw/expired.txt", s.URL))
|
|
if err != nil {
|
|
t.Fatalf("Get: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusNotFound {
|
|
t.Errorf("response status was %v; want %v", resp.StatusCode, http.StatusNotFound)
|
|
}
|
|
|
|
// Verify the file is NOT there.
|
|
if ok, err := tc.backend.Exists(ctx, "expired.txt"); err != nil {
|
|
t.Fatalf("Exists(%q) after retrieving: %v", "expired.txt", err)
|
|
} else if ok {
|
|
t.Fatalf("Exists(%q) after retrieving returned true", "expired.txt")
|
|
}
|
|
})
|
|
})
|
|
}
|
|
}
|