diff --git a/imageproxy.go b/imageproxy.go index ac6b357..a4ca19e 100644 --- a/imageproxy.go +++ b/imageproxy.go @@ -142,27 +142,32 @@ func (p *Proxy) serveImage(w http.ResponseWriter, r *http.Request) { cached := resp.Header.Get(httpcache.XFromCache) glog.Infof("request: %v (served from cache: %v)", *req, cached == "1") - copyHeader(w, resp, "Cache-Control") - copyHeader(w, resp, "Last-Modified") - copyHeader(w, resp, "Expires") - copyHeader(w, resp, "Etag") - copyHeader(w, resp, "Link") + copyHeader(w.Header(), resp.Header, "Cache-Control", "Last-Modified", "Expires", "Etag", "Link") if should304(r, resp) { w.WriteHeader(http.StatusNotModified) return } - copyHeader(w, resp, "Content-Length") - copyHeader(w, resp, "Content-Type") + copyHeader(w.Header(), resp.Header, "Content-Length", "Content-Type") w.WriteHeader(resp.StatusCode) io.Copy(w, resp.Body) } -func copyHeader(w http.ResponseWriter, r *http.Response, header string) { - key := http.CanonicalHeaderKey(header) - if value, ok := r.Header[key]; ok { - w.Header()[key] = value +// copyHeader copies header values from src to dst, adding to any existing +// values with the same header name. If keys is not empty, only those header +// keys will be copied. +func copyHeader(dst, src http.Header, keys ...string) { + if len(keys) == 0 { + for k, _ := range src { + keys = append(keys, k) + } + } + for _, key := range keys { + k := http.CanonicalHeaderKey(key) + for _, v := range src[k] { + dst.Add(k, v) + } } } diff --git a/imageproxy_test.go b/imageproxy_test.go index 425059e..6a242c7 100644 --- a/imageproxy_test.go +++ b/imageproxy_test.go @@ -24,10 +24,78 @@ import ( "net/http" "net/http/httptest" "net/url" + "reflect" "strings" "testing" ) +func TestCopyHeader(t *testing.T) { + tests := []struct { + dst, src http.Header + keys []string + want http.Header + }{ + // empty + {http.Header{}, http.Header{}, nil, http.Header{}}, + {http.Header{}, http.Header{}, []string{}, http.Header{}}, + {http.Header{}, http.Header{}, []string{"A"}, http.Header{}}, + + // nothing to copy + { + dst: http.Header{"A": []string{"a1"}}, + src: http.Header{}, + keys: nil, + want: http.Header{"A": []string{"a1"}}, + }, + { + dst: http.Header{}, + src: http.Header{"A": []string{"a"}}, + keys: []string{"B"}, + want: http.Header{}, + }, + + // copy headers + { + dst: http.Header{}, + src: http.Header{"A": []string{"a"}}, + keys: nil, + want: http.Header{"A": []string{"a"}}, + }, + { + dst: http.Header{"A": []string{"a"}}, + src: http.Header{"B": []string{"b"}}, + keys: nil, + want: http.Header{"A": []string{"a"}, "B": []string{"b"}}, + }, + { + dst: http.Header{"A": []string{"a"}}, + src: http.Header{"B": []string{"b"}, "C": []string{"c"}}, + keys: []string{"B"}, + want: http.Header{"A": []string{"a"}, "B": []string{"b"}}, + }, + { + dst: http.Header{"A": []string{"a1"}}, + src: http.Header{"A": []string{"a2"}}, + keys: nil, + want: http.Header{"A": []string{"a1", "a2"}}, + }, + } + + for _, tt := range tests { + // copy dst map + got := make(http.Header) + for k, v := range tt.dst { + got[k] = v + } + + copyHeader(got, tt.src, tt.keys...) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("copyHeader(%v, %v, %v) returned %v, want %v", tt.dst, tt.src, tt.keys, got, tt.want) + } + + } +} + func TestAllowed(t *testing.T) { whitelist := []string{"good"} key := []byte("c0ffee")