diff --git a/imageproxy.go b/imageproxy.go index fdab7f6..9a603c4 100644 --- a/imageproxy.go +++ b/imageproxy.go @@ -88,10 +88,10 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if !p.allowed(req.URL) { - msg := fmt.Sprintf("remote URL is not for an allowed host: %v", req.URL) + if !p.allowed(req) { + msg := fmt.Sprintf("request does not contain an allowed host") glog.Error(msg) - http.Error(w, msg, http.StatusBadRequest) + http.Error(w, msg, http.StatusForbidden) return } @@ -135,13 +135,26 @@ func copyHeader(w http.ResponseWriter, r *http.Response, header string) { } } -// allowed returns whether the specified URL is on the whitelist of remote hosts. -func (p *Proxy) allowed(u *url.URL) bool { +// allowed returns whether the specified request is allowed because it matches +// a host in the proxy whitelist. +func (p *Proxy) allowed(r *Request) bool { if len(p.Whitelist) == 0 { - return true + return true // no whitelist, all requests accepted } - for _, host := range p.Whitelist { + if len(p.Whitelist) > 0 { + if validHost(p.Whitelist, r.URL) { + return true + } + glog.Infof("remote URL is not for an allowed host: %v", r.URL) + } + + return false +} + +// validHost returns whether the host in u matches one of hosts. +func validHost(hosts []string, u *url.URL) bool { + for _, host := range hosts { if u.Host == host { return true } diff --git a/imageproxy_test.go b/imageproxy_test.go index 3a9d9a4..ece73a1 100644 --- a/imageproxy_test.go +++ b/imageproxy_test.go @@ -15,7 +15,7 @@ import ( ) func TestAllowed(t *testing.T) { - whitelist := []string{"a.test", "*.b.test", "*c.test"} + whitelist := []string{"good.test"} tests := []struct { url string @@ -25,16 +25,8 @@ func TestAllowed(t *testing.T) { {"http://foo/image", nil, true}, {"http://foo/image", []string{}, true}, - {"http://a.test/image", whitelist, true}, - {"http://x.a.test/image", whitelist, false}, - - {"http://b.test/image", whitelist, true}, - {"http://x.b.test/image", whitelist, true}, - {"http://x.y.b.test/image", whitelist, true}, - - {"http://c.test/image", whitelist, false}, - {"http://xc.test/image", whitelist, false}, - {"/image", whitelist, false}, + {"http://good.test/image", whitelist, true}, + {"http://bad.test/image", whitelist, false}, } for _, tt := range tests { @@ -45,12 +37,43 @@ func TestAllowed(t *testing.T) { if err != nil { t.Errorf("error parsing url %q: %v", tt.url, err) } - if got, want := p.allowed(u), tt.allowed; got != want { + req := &Request{u, emptyOptions} + if got, want := p.allowed(req), tt.allowed; got != want { t.Errorf("allowed(%q) returned %v, want %v", u, got, want) } } } +func TestValidHost(t *testing.T) { + whitelist := []string{"a.test", "*.b.test", "*c.test"} + + tests := []struct { + url string + valid bool + }{ + {"http://a.test/image", true}, + {"http://x.a.test/image", false}, + + {"http://b.test/image", true}, + {"http://x.b.test/image", true}, + {"http://x.y.b.test/image", true}, + + {"http://c.test/image", false}, + {"http://xc.test/image", false}, + {"/image", false}, + } + + for _, tt := range tests { + u, err := url.Parse(tt.url) + if err != nil { + t.Errorf("error parsing url %q: %v", tt.url, err) + } + if got, want := validHost(whitelist, u), tt.valid; got != want { + t.Errorf("validHost(%v, %q) returned %v, want %v", whitelist, u, got, want) + } + } +} + func TestCheck304(t *testing.T) { tests := []struct { req, resp string @@ -168,7 +191,7 @@ func TestProxy_ServeHTTP(t *testing.T) { }{ {"/favicon.ico", http.StatusOK}, {"//foo", http.StatusBadRequest}, // invalid request URL - {"/http://bad.test/", http.StatusBadRequest}, // Disallowed host + {"/http://bad.test/", http.StatusForbidden}, // Disallowed host {"/http://good.test/error", http.StatusInternalServerError}, // HTTP protocol error {"/http://good.test/nocontent", http.StatusNoContent}, // non-OK response