diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..09dd4e8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.cache +imageproxy diff --git a/README.md b/README.md index f7e9cf1..9a5b5f2 100644 --- a/README.md +++ b/README.md @@ -184,6 +184,18 @@ first check an in-memory cache for an image, followed by a gcs bucket: [tiered fashion]: https://godoc.org/github.com/die-net/lrucache/twotier +#### Cache Duration + +By default, images are cached for the duration specified in response headers. +If an image has no cache directives, or an explicit `Cache-Control: no-cache` header, +then the response is not cached. + +To override the response cache directives, set a minimum time that response should be cached for. +This will ignore `no-cache` and `no-store` directives, and will set `max-age` +to the specified value if it is greater than the original `max-age` value. + + imageproxy -cache /tmp/imageproxy -minCacheDuration 5m + ### Allowed Referrer List You can limit images to only be accessible for certain hosts in the HTTP diff --git a/cmd/imageproxy/main.go b/cmd/imageproxy/main.go index de66551..a4cc04d 100644 --- a/cmd/imageproxy/main.go +++ b/cmd/imageproxy/main.go @@ -46,6 +46,7 @@ var verbose = flag.Bool("verbose", false, "print verbose logging messages") var _ = flag.Bool("version", false, "Deprecated: this flag does nothing") var contentTypes = flag.String("contentTypes", "image/*", "comma separated list of allowed content types") var userAgent = flag.String("userAgent", "willnorris/imageproxy", "specify the user-agent used by imageproxy when fetching images from origin website") +var minCacheDuration = flag.Duration("minCacheDuration", 0, "minimum duration to cache remote images") func init() { flag.Var(&cache, "cache", "location to cache images (see https://github.com/willnorris/imageproxy#cache)") @@ -87,6 +88,7 @@ func main() { p.ScaleUp = *scaleUp p.Verbose = *verbose p.UserAgent = *userAgent + p.MinimumCacheDuration = *minCacheDuration server := &http.Server{ Addr: *addr, diff --git a/imageproxy.go b/imageproxy.go index fba2f03..b7f3e68 100644 --- a/imageproxy.go +++ b/imageproxy.go @@ -28,6 +28,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" tphttp "willnorris.com/go/imageproxy/third_party/http" + tphc "willnorris.com/go/imageproxy/third_party/httpcache" ) // Maximum number of redirection-followings allowed. @@ -91,6 +92,10 @@ type Proxy struct { // PassRequestHeaders identifies HTTP headers to pass from inbound // requests to the proxied server. PassRequestHeaders []string + + // MinimumCacheDuration is the minimum duration to cache remote images. + // This will override cache-control instructions from the remote server. + MinimumCacheDuration time.Duration } // NewProxy constructs a new proxy. The provided http RoundTripper will be @@ -118,6 +123,7 @@ func NewProxy(transport http.RoundTripper, cache Cache) *Proxy { proxy.logf(format, v...) } }, + updateCacheHeaders: proxy.updateCacheHeaders, }, Cache: cache, MarkCachedResponses: true, @@ -128,6 +134,39 @@ func NewProxy(transport http.RoundTripper, cache Cache) *Proxy { return proxy } +// updateCacheHeaders updates the cache-control headers in the provided headers. +// It sets the cache-control max-age value to the maximum of the minimum cache +// duration, the expires header, and the max-age header. It also removes the +// expires header. +func (p *Proxy) updateCacheHeaders(hdr http.Header) { + if p.MinimumCacheDuration == 0 { + return + } + cc := tphc.ParseCacheControl(hdr) + + var expiresDuration time.Duration + var maxAgeDuration time.Duration + + if maxAge, ok := cc["max-age"]; ok { + maxAgeDuration, _ = time.ParseDuration(maxAge + "s") + } + if date, err := httpcache.Date(hdr); err == nil { + if expiresHeader := hdr.Get("Expires"); expiresHeader != "" { + if expires, err := time.Parse(time.RFC1123, expiresHeader); err == nil { + expiresDuration = expires.Sub(date) + } + } + } + + maxAge := max(p.MinimumCacheDuration, expiresDuration, maxAgeDuration) + cc["max-age"] = fmt.Sprintf("%d", int(maxAge.Seconds())) + delete(cc, "no-cache") + delete(cc, "no-store") + + hdr.Set("Cache-Control", cc.String()) + hdr.Del("Expires") +} + // ServeHTTP handles incoming requests. func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/favicon.ico" { @@ -475,6 +514,8 @@ type TransformingTransport struct { CachingClient *http.Client log func(format string, v ...any) + + updateCacheHeaders func(hdr http.Header) } // RoundTrip implements the http.RoundTripper interface. @@ -484,7 +525,11 @@ func (t *TransformingTransport) RoundTrip(req *http.Request) (*http.Response, er if t.log != nil { t.log("fetching remote URL: %v", req.URL) } - return t.Transport.RoundTrip(req) + resp, err := t.Transport.RoundTrip(req) + if err == nil && t.updateCacheHeaders != nil { + t.updateCacheHeaders(resp.Header) + } + return resp, err } f := req.URL.Fragment diff --git a/imageproxy_test.go b/imageproxy_test.go index 86bd6a6..f39d000 100644 --- a/imageproxy_test.go +++ b/imageproxy_test.go @@ -12,6 +12,7 @@ import ( "image" "image/png" "log" + "maps" "net/http" "net/http/httptest" "net/url" @@ -21,6 +22,7 @@ import ( "strconv" "strings" "testing" + "time" ) func TestPeekContentType(t *testing.T) { @@ -368,6 +370,108 @@ func (t testTransport) RoundTrip(req *http.Request) (*http.Response, error) { return http.ReadResponse(buf, req) } +func TestProxy_UpdateCacheHeaders(t *testing.T) { + date := "Mon, 02 Jan 2006 15:04:05 MST" + exp := "Mon, 02 Jan 2006 16:04:05 MST" + + tests := []struct { + name string + minDuration time.Duration + headers http.Header + want http.Header + }{ + { + name: "zero", + headers: http.Header{}, + want: http.Header{}, + }, + { + name: "no min duration", + headers: http.Header{ + "Date": {date}, + "Expires": {exp}, + "Cache-Control": {"max-age=600"}, + }, + want: http.Header{ + "Date": {date}, + "Expires": {exp}, + "Cache-Control": {"max-age=600"}, + }, + }, + { + name: "cache control exceeds min duration", + minDuration: 30 * time.Second, + headers: http.Header{ + "Cache-Control": {"max-age=600"}, + }, + want: http.Header{ + "Cache-Control": {"max-age=600"}, + }, + }, + { + name: "cache control exceeds min duration, expires", + minDuration: 30 * time.Second, + headers: http.Header{ + "Date": {date}, + "Expires": {exp}, + "Cache-Control": {"max-age=86400"}, + }, + want: http.Header{ + "Date": {date}, + "Cache-Control": {"max-age=86400"}, + }, + }, + { + name: "min duration exceeds cache control", + minDuration: 1 * time.Hour, + headers: http.Header{ + "Cache-Control": {"max-age=600"}, + }, + want: http.Header{ + "Cache-Control": {"max-age=3600"}, + }, + }, + { + name: "min duration exceeds cache control, expires", + minDuration: 2 * time.Hour, + headers: http.Header{ + "Date": {date}, + "Expires": {exp}, + "Cache-Control": {"max-age=600"}, + }, + want: http.Header{ + "Date": {date}, + "Cache-Control": {"max-age=7200"}, + }, + }, + { + name: "expires exceeds min duration, cache control", + minDuration: 30 * time.Minute, + headers: http.Header{ + "Date": {date}, + "Expires": {exp}, + "Cache-Control": {"max-age=600"}, + }, + want: http.Header{ + "Date": {date}, + "Cache-Control": {"max-age=3600"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Proxy{MinimumCacheDuration: tt.minDuration} + hdr := maps.Clone(tt.headers) + p.updateCacheHeaders(hdr) + + if !reflect.DeepEqual(hdr, tt.want) { + t.Errorf("updateCacheHeaders(%v) returned %v, want %v", tt.headers, hdr, tt.want) + } + }) + } +} + func TestProxy_ServeHTTP(t *testing.T) { p := &Proxy{ Client: &http.Client{ diff --git a/third_party/httpcache/httpcache.go b/third_party/httpcache/httpcache.go index f996a47..654cc9c 100644 --- a/third_party/httpcache/httpcache.go +++ b/third_party/httpcache/httpcache.go @@ -5,10 +5,10 @@ import ( "strings" ) -type cacheControl map[string]string +type CacheControl map[string]string -func parseCacheControl(headers http.Header) cacheControl { - cc := cacheControl{} +func ParseCacheControl(headers http.Header) CacheControl { + cc := CacheControl{} ccHeader := headers.Get("Cache-Control") for _, part := range strings.Split(ccHeader, ",") { part = strings.Trim(part, " ") @@ -24,3 +24,15 @@ func parseCacheControl(headers http.Header) cacheControl { } return cc } + +func (cc CacheControl) String() string { + parts := make([]string, 0, len(cc)) + for k, v := range cc { + if v == "" { + parts = append(parts, k) + } else { + parts = append(parts, k+"="+v) + } + } + return strings.Join(parts, ", ") +}