diff --git a/web/fup/fuphttp/fuphttp.go b/web/fup/fuphttp/fuphttp.go index 148ef739e0..f19c7fdd66 100644 --- a/web/fup/fuphttp/fuphttp.go +++ b/web/fup/fuphttp/fuphttp.go @@ -59,6 +59,13 @@ type Config struct { // UseDirectDownload decides whether the "pretty" wrapped page or the direct download page is the most appropriate for a given set of parameters. UseDirectDownload func(fileExtension string, mimeType string) bool + + // AuthMiddleware is a Gorilla middleware to provide authentication information. + // + // It runs on every handler, including public ones, and will be provided with information: + // - if the page is an upload page (either the homepage, the paste textbox page, or the upload handler), the IsMutate will return true. + // - if the page is an API handler, then IsAPIRequest will return true. + AuthMiddleware mux.MiddlewareFunc } type Application struct { @@ -80,6 +87,8 @@ type Application struct { filenameGenerator fngen.FilenameGenerator useDirectDownload func(fileExtension string, mimeType string) bool + + authMiddleware mux.MiddlewareFunc } func DefaultUseDirectDownload(fileExtension, mimeType string) bool { @@ -90,6 +99,39 @@ func DefaultUseDirectDownload(fileExtension, mimeType string) bool { return !strings.HasPrefix(mimeType, "text/") } +func isAPIRequest(r *http.Request) bool { + return r.Header.Get("Accept") == "application/json" +} + +type contextKey string + +const ( + ctxKeyIsMutate = contextKey("isMutate") + ctxKeyIsAPIRequest = contextKey("isAPIRequest") +) + +// IsAPIRequest determines whether a request was made from an API client rather than from a browser. +func IsAPIRequest(ctx context.Context) bool { + return ctx.Value(ctxKeyIsAPIRequest).(bool) +} + +// IsMutate determines whether a request for the given context is for a "mutate". +// This includes things like the HTML for the home page, which in itself is not a mutate but is useless if you're not allowed to auth. +func IsMutate(ctx context.Context) bool { + return ctx.Value(ctxKeyIsMutate).(bool) +} + +func contextPopulateMiddleware(isMutate bool) mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx = context.WithValue(ctx, ctxKeyIsMutate, isMutate) + ctx = context.WithValue(ctx, ctxKeyIsAPIRequest, isAPIRequest(r)) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) + } +} + func (a *Application) Handler() http.Handler { r := mux.NewRouter() @@ -102,12 +144,23 @@ func (a *Application) Handler() http.Handler { } r.NotFoundHandler = http.HandlerFunc(a.notFound) - r.HandleFunc("/", renderTemplate(a.indexTmpl)) - r.HandleFunc("/paste", renderTemplate(a.pasteTmpl)) - r.HandleFunc("/raw/{filename}", a.rawDownload) - r.HandleFunc("/upload", a.upload).Methods("POST", "PUT") - r.HandleFunc("/upload/{filename}", a.upload).Methods("PUT") - r.HandleFunc("/{filename}", a.view) + + authR := r.PathPrefix("/").Subrouter() + authR.HandleFunc("/", renderTemplate(a.indexTmpl)) + authR.HandleFunc("/paste", renderTemplate(a.pasteTmpl)) + authR.HandleFunc("/upload", a.upload).Methods("POST", "PUT") + authR.HandleFunc("/upload/{filename}", a.upload).Methods("PUT") + + publicR := r.PathPrefix("/").Subrouter() + publicR.HandleFunc("/raw/{filename}", a.rawDownload) + publicR.HandleFunc("/{filename}", a.view) + + if a.authMiddleware != nil { + authR.Use(contextPopulateMiddleware(true)) + authR.Use(a.authMiddleware) + publicR.Use(contextPopulateMiddleware(false)) + publicR.Use(a.authMiddleware) + } return r } @@ -168,6 +221,7 @@ func New(ctx context.Context, cfg *Config) (*Application, error) { useDirectDownload: cfg.UseDirectDownload, appRoot: cfg.AppRoot, highlighter: cfg.Highlighter, + authMiddleware: cfg.AuthMiddleware, } if a.redirectExpiry == 0 { a.redirectExpiry = defaultRedirectExpiry diff --git a/web/fup/fuphttp/fuphttp_test.go b/web/fup/fuphttp/fuphttp_test.go index 586ed7f9fb..15beadecef 100644 --- a/web/fup/fuphttp/fuphttp_test.go +++ b/web/fup/fuphttp/fuphttp_test.go @@ -11,6 +11,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "hg.lukegb.com/lukegb/depot/web/fup/fuphttp" @@ -214,3 +215,91 @@ func TestRetrieve(t *testing.T) { }) } } + +func TestAuthMiddleware(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + var gotReq, lastReqWasMutate, lastReqWasAPI bool + + ccfg := *cfg + ccfg.AuthMiddleware = func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + gotReq = true + lastReqWasMutate = fuphttp.IsMutate(r.Context()) + lastReqWasAPI = fuphttp.IsAPIRequest(r.Context()) + next.ServeHTTP(rw, r) + }) + } + + a, err := fuphttp.New(ctx, &ccfg) + if err != nil { + t.Fatalf("fuphttp.New: %v", err) + } + s := httptest.NewServer(a.Handler()) + t.Cleanup(s.Close) + + tcs := []struct { + name string + request func(url string) (*http.Request, error) + wantIsMutate bool + wantIsAPI bool + }{{ + name: "/ request", + request: func(url string) (*http.Request, error) { return http.NewRequest("GET", url, nil) }, + wantIsMutate: true, + wantIsAPI: false, + }, { + name: "/upload request", + request: func(url string) (*http.Request, error) { + return http.NewRequest("PUT", url+"/upload/foo.txt", strings.NewReader("slartibartfast\n")) + }, + wantIsMutate: true, + wantIsAPI: false, + }, { + name: "/upload request, application/json", + request: func(url string) (*http.Request, error) { + req, err := http.NewRequest("PUT", url+"/upload/foo.txt", strings.NewReader("slartibartfast\n")) + if err != nil { + return nil, err + } + req.Header.Add("Accept", "application/json") + return req, nil + }, + wantIsMutate: true, + wantIsAPI: true, + }, { + name: "/foo.txt request", + request: func(url string) (*http.Request, error) { return http.NewRequest("GET", url+"/foo.txt", nil) }, + wantIsMutate: false, + wantIsAPI: false, + }, { + name: "/raw/foo.txt request", + request: func(url string) (*http.Request, error) { return http.NewRequest("GET", url+"/raw/foo.txt", nil) }, + wantIsMutate: false, + wantIsAPI: false, + }} + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + gotReq = false + + req, err := tc.request(s.URL) + resp, err := s.Client().Do(req) + if err != nil { + t.Fatalf("Do: %v", err) + } + defer resp.Body.Close() + + if !gotReq { + t.Fatalf("gotReq = %v; want true", gotReq) + } + if lastReqWasMutate != tc.wantIsMutate { + t.Errorf("lastReqWasMutate = %v; want %v", lastReqWasMutate, tc.wantIsMutate) + } + if lastReqWasAPI != tc.wantIsAPI { + t.Errorf("lastReqWasAPI = %v; want %v", lastReqWasAPI, tc.wantIsAPI) + } + }) + } +}