diff --git a/imageproxy.go b/imageproxy.go index 8afe6cd..d51ffc0 100644 --- a/imageproxy.go +++ b/imageproxy.go @@ -573,7 +573,14 @@ func (t *TransformingTransport) RoundTrip(req *http.Request) (*http.Response, er if should304(req, resp) { // bare 304 response, full response will be used from cache - return &http.Response{StatusCode: http.StatusNotModified}, nil + return &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Status: fmt.Sprintf("%d %s", http.StatusNotModified, http.StatusText(http.StatusNotModified)), + StatusCode: http.StatusNotModified, + Body: http.NoBody, + }, nil } b, err := io.ReadAll(resp.Body) diff --git a/imageproxy_test.go b/imageproxy_test.go index 6e05ddd..7909486 100644 --- a/imageproxy_test.go +++ b/imageproxy_test.go @@ -23,6 +23,10 @@ import ( "strings" "testing" "time" + + "github.com/die-net/lrucache" + "github.com/google/uuid" + "github.com/gregjones/httpcache" ) func TestPeekContentType(t *testing.T) { @@ -332,9 +336,11 @@ func TestShould304(t *testing.T) { // testTransport is an http.RoundTripper that returns certained canned // responses for particular requests. -type testTransport struct{} +type testTransport struct { + replyNotModified bool +} -func (t testTransport) RoundTrip(req *http.Request) (*http.Response, error) { +func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { var raw string switch req.URL.Path { @@ -352,6 +358,19 @@ func (t testTransport) RoundTrip(req *http.Request) (*http.Response, error) { _ = png.Encode(img, m) raw = fmt.Sprintf("HTTP/1.1 200 OK\nContent-Length: %d\nContent-Type: image/png\n\n%s", len(img.Bytes()), img.Bytes()) + case "/redirect-to-notmodified": + parts := []string{ + "HTTP/1.1 303\nLocation: http://notmodified.test/notmodified?X-Security-Token=", + uuid.NewString(), + "#_=_\nCache-Control: no-store\n\n", + } + raw = strings.Join(parts, "") + case "/notmodified": + if t.replyNotModified { + raw = "HTTP/1.1 304 Not modified\nEtag: \"abcdef\"\n\n" + } else { + raw = "HTTP/1.1 200 OK\nEtag: \"abcdef\"\n\nOriginal response\n" + } default: redirectRegexp := regexp.MustCompile(`/redirects-(\d+)`) if redirectRegexp.MatchString(req.URL.Path) { @@ -526,7 +545,7 @@ func TestProxy_UpdateCacheHeaders(t *testing.T) { func TestProxy_ServeHTTP(t *testing.T) { p := &Proxy{ Client: &http.Client{ - Transport: testTransport{}, + Transport: &testTransport{}, }, AllowHosts: []string{"good.test"}, ContentTypes: []string{"image/*"}, @@ -564,7 +583,7 @@ func TestProxy_ServeHTTP(t *testing.T) { func TestProxy_ServeHTTP_is304(t *testing.T) { p := &Proxy{ Client: &http.Client{ - Transport: testTransport{}, + Transport: &testTransport{}, }, } @@ -581,10 +600,57 @@ func TestProxy_ServeHTTP_is304(t *testing.T) { } } +func TestProxy_ServeHTTP_cached304(t *testing.T) { + cache := lrucache.New(1024*1024*8, 0) + client := new(http.Client) + tt := testTransport{} + client.Transport = &httpcache.Transport{ + Transport: &TransformingTransport{ + Transport: &tt, + CachingClient: client, + }, + Cache: cache, + MarkCachedResponses: true, + } + + p := &Proxy{ + Client: client, + FollowRedirects: true, + } + + // prime the cache + req := httptest.NewRequest("GET", "http://localhost//http://good.test/redirect-to-notmodified", nil) + recorder := httptest.NewRecorder() + p.ServeHTTP(recorder, req) + + resp := recorder.Result() + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("ServeHTTP(%v) returned status %d, want %d", req, got, want) + } + if _, found := cache.Get("http://good.test/redirect-to-notmodified#0x0"); !found { + t.Errorf("Response to http://good.test/redirect-to-notmodified#0x0 should be cached") + } + + // now make the same request again, but this time make sure the server responds with a 304 + tt.replyNotModified = true + req = httptest.NewRequest("GET", "http://localhost//http://good.test/redirect-to-notmodified", nil) + recorder = httptest.NewRecorder() + p.ServeHTTP(recorder, req) + + resp = recorder.Result() + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("ServeHTTP(%v) returned status %d, want %d", req, got, want) + } + + if recorder.Body.String() != "Original response\n" { + t.Errorf("Response isn't what we expected: %v", recorder.Body.String()) + } +} + func TestProxy_ServeHTTP_maxRedirects(t *testing.T) { p := &Proxy{ Client: &http.Client{ - Transport: testTransport{}, + Transport: &testTransport{}, }, FollowRedirects: true, } @@ -658,7 +724,7 @@ func TestProxy_log_default(t *testing.T) { func TestTransformingTransport(t *testing.T) { client := new(http.Client) tr := &TransformingTransport{ - Transport: testTransport{}, + Transport: &testTransport{}, CachingClient: client, } client.Transport = tr