diff --git a/imageproxy.go b/imageproxy.go index 3bfc262..752974d 100644 --- a/imageproxy.go +++ b/imageproxy.go @@ -5,12 +5,14 @@ import ( "fmt" "log" "net/http" + "strings" "github.com/willnorris/go-imageproxy/cache" "github.com/willnorris/go-imageproxy/proxy" ) var port = flag.Int("port", 8080, "port to listen on") +var whitelist = flag.String("whitelist", "", "comma separated list of allowed remote hosts") func main() { flag.Parse() @@ -19,6 +21,9 @@ func main() { p := proxy.NewProxy(nil) p.Cache = cache.NewMemoryCache() + if *whitelist != "" { + p.Whitelist = strings.Split(*whitelist, ",") + } server := &http.Server{ Addr: fmt.Sprintf(":%d", *port), Handler: p, diff --git a/proxy/proxy.go b/proxy/proxy.go index 3c4cd5c..754728c 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -66,6 +66,9 @@ func NewRequest(r *http.Request) (*data.Request, error) { type Proxy struct { Client *http.Client // client used to fetch remote URLs Cache cache.Cache + + // Whitelist specifies a list of remote hosts that images can be proxied from. An empty list means all hosts are allowed. + Whitelist []string } // NewProxy constructs a new proxy. The provided http Client will be used to @@ -88,6 +91,11 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { u := req.URL.String() glog.Infof("request for image: %v", u) + if !p.allowed(req.URL) { + http.Error(w, fmt.Sprintf("remote URL is not for an allowed host: %v", req.URL.Host), http.StatusForbidden) + return + } + image, ok := p.Cache.Get(u) if !ok { glog.Infof("image not cached") @@ -153,6 +161,21 @@ func (p *Proxy) fetchRemoteImage(u string, cached *data.Image) (*data.Image, err }, nil } +// allowed returns whether the specified URL is on the whitelist of remote hosts. +func (p *Proxy) allowed(u *url.URL) bool { + if len(p.Whitelist) == 0 { + return true + } + + for _, host := range p.Whitelist { + if u.Host == host { + return true + } + } + + return false +} + func parseExpires(resp *http.Response) time.Time { exp := resp.Header.Get("Expires") if exp == "" {