diff --git a/web/fup/fuphttp/fuphttp.go b/web/fup/fuphttp/fuphttp.go index 39841584d1..9baff24d8a 100644 --- a/web/fup/fuphttp/fuphttp.go +++ b/web/fup/fuphttp/fuphttp.go @@ -98,7 +98,6 @@ func (a *Application) rawDownload(rw http.ResponseWriter, r *http.Request) { case gcerrors.OK: // OK } - _ = attrs // Do things with attributes? if a.redirectToBlobstore { @@ -157,10 +156,19 @@ func (a *Application) rawDownloadDirect(ctx context.Context, rw http.ResponseWri if v := attrs.ETag; v != "" { rw.Header().Set("ETag", v) } + + // If we're serving from local disk, use http.ServeContent, using this somewhat opaque method. + var ior io.Reader + if rdr.As(&ior) { + if iors, ok := ior.(io.ReadSeeker); ok { + http.ServeContent(rw, r, filename, attrs.ModTime, iors) + return + } + } + if v := attrs.ModTime; !v.IsZero() { rw.Header().Set("Last-Modified", v.UTC().Format(http.TimeFormat)) } - if v := attrs.Size; v != 0 { rw.Header().Set("Content-Length", fmt.Sprintf("%d", v)) } diff --git a/web/fup/fuphttp/fuphttp_test.go b/web/fup/fuphttp/fuphttp_test.go index 9944f3f594..6f367ca7ee 100644 --- a/web/fup/fuphttp/fuphttp_test.go +++ b/web/fup/fuphttp/fuphttp_test.go @@ -7,14 +7,18 @@ package fuphttp_test import ( "context" "fmt" + "io" "net/http" "net/http/httptest" + "net/url" "testing" "hg.lukegb.com/lukegb/depot/web/fup/fuphttp" "hg.lukegb.com/lukegb/depot/web/fup/fupstatic" - _ "gocloud.dev/blob/memblob" + "gocloud.dev/blob" + "gocloud.dev/blob/fileblob" + "gocloud.dev/blob/memblob" ) var cfg = &fuphttp.Config{ @@ -36,7 +40,143 @@ func TestNotFound(t *testing.T) { if err != nil { t.Fatalf("Get: %v", err) } + defer resp.Body.Close() if resp.StatusCode != http.StatusNotFound { t.Errorf("response status was %v; want %v", resp.StatusCode, http.StatusNotFound) } } + +func mustURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(fmt.Sprintf("parsing %q: %v", s, err)) + } + return u +} + +func mustBucket(b *blob.Bucket, err error) *blob.Bucket { + if err != nil { + panic(fmt.Sprintf("opening bucket: %v", err)) + } + return b +} + +func TestRetrieve(t *testing.T) { + signer := fileblob.NewURLSignerHMAC(mustURL("http://example.com"), []byte("not secret")) + tcs := []struct { + name string + backend *blob.Bucket + wantSignedRedirect string + supportsRange bool + }{{ + name: "LocalServe", + backend: memblob.OpenBucket(nil), + supportsRange: false, + }, { + name: "ReadSeeker", + backend: mustBucket(fileblob.OpenBucket(t.TempDir(), nil)), + supportsRange: true, + }, { + name: "SignedURL", + backend: mustBucket(fileblob.OpenBucket(t.TempDir(), &fileblob.Options{ + URLSigner: signer, + })), + wantSignedRedirect: "http://example.com", + }} + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + ccfg := *cfg + ccfg.StorageURL = "" + ccfg.StorageBackend = tc.backend + ccfg.RedirectToBlobstore = true + + if err := tc.backend.WriteAll(ctx, "hello.txt", []byte("hello world\n"), nil); err != nil { + t.Fatalf("WriteAll: %v", err) + } + + a, err := fuphttp.New(ctx, &ccfg) + if err != nil { + t.Fatalf("fuphttp.New: %v", err) + } + s := httptest.NewServer(a.Handler()) + defer s.Close() + + client := s.Client() + client.CheckRedirect = func(r *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + + t.Run("get all", func(t *testing.T) { + resp, err := s.Client().Get(fmt.Sprintf("%s/raw/hello.txt", s.URL)) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer resp.Body.Close() + + if tc.wantSignedRedirect != "" { + if resp.StatusCode != http.StatusFound { + t.Errorf("response status was %v; want %v", resp.StatusCode, http.StatusFound) + } + + gotLoc := resp.Header.Get("Location") + gotLocURL, err := url.Parse(gotLoc) + if err != nil { + t.Errorf("parsing Location %q: %v", gotLoc, err) + } else { + obj, err := signer.KeyFromURL(ctx, gotLocURL) + if err != nil { + t.Errorf("KeyFromURL(%q): %v", gotLoc, err) + } else if obj != "hello.txt" { + t.Errorf("KeyFromURL(%q) = %v; want %v", gotLoc, obj, "hello.txt") + } + } + } else { + if resp.StatusCode != http.StatusOK { + t.Errorf("response status was %v; want %v", resp.StatusCode, http.StatusOK) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + t.Errorf("ReadAll: %v", err) + } + if got, want := string(data), "hello world\n"; got != want { + t.Errorf("read data was %q; want %q", got, want) + } + } + }) + + t.Run("get range", func(t *testing.T) { + if !tc.supportsRange { + t.Skip("range unsupported by this backend") + } + + req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/raw/hello.txt", s.URL), nil) + if err != nil { + t.Fatalf("NewRequestWithContext: %v", err) + } + req.Header.Set("Range", "bytes=1-3") + + resp, err := s.Client().Do(req) + if err != nil { + t.Fatalf("Client.Do: %v", err) + } + + if resp.StatusCode != http.StatusPartialContent { + t.Fatalf("response status was %v; want %v", resp.StatusCode, http.StatusPartialContent) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + t.Errorf("ReadAll: %v", err) + } + if got, want := string(data), "ell"; got != want { + t.Errorf("read data was %q; want %q", got, want) + } + }) + }) + } +}