From f04ff063ed62b8ce0b2c3cafdf6bcb90fffd4a4b Mon Sep 17 00:00:00 2001 From: Abiola Ibrahim Date: Fri, 18 Dec 2015 20:58:23 +0100 Subject: [PATCH] Gzip: Fix missing gzip encoding headers. --- middleware/gzip/gzip.go | 11 ++++++++--- middleware/gzip/gzip_test.go | 4 ++-- middleware/gzip/response_filter.go | 4 ++-- middleware/gzip/response_filter_test.go | 2 +- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/middleware/gzip/gzip.go b/middleware/gzip/gzip.go index 147d739f..99903ff0 100644 --- a/middleware/gzip/gzip.go +++ b/middleware/gzip/gzip.go @@ -57,7 +57,7 @@ outer: return http.StatusInternalServerError, err } defer gzipWriter.Close() - gz := gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w} + gz := &gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w} var rw http.ResponseWriter // if no response filter is used @@ -104,21 +104,26 @@ func newWriter(c Config, w io.Writer) (*gzip.Writer, error) { type gzipResponseWriter struct { io.Writer http.ResponseWriter + statusCodeWritten bool } // WriteHeader wraps the underlying WriteHeader method to prevent // problems with conflicting headers from proxied backends. For // example, a backend system that calculates Content-Length would // be wrong because it doesn't know it's being gzipped. -func (w gzipResponseWriter) WriteHeader(code int) { +func (w *gzipResponseWriter) WriteHeader(code int) { w.Header().Del("Content-Length") w.Header().Set("Content-Encoding", "gzip") w.Header().Set("Vary", "Accept-Encoding") w.ResponseWriter.WriteHeader(code) + w.statusCodeWritten = true } // Write wraps the underlying Write method to do compression. -func (w gzipResponseWriter) Write(b []byte) (int, error) { +func (w *gzipResponseWriter) Write(b []byte) (int, error) { + if !w.statusCodeWritten { + w.WriteHeader(http.StatusOK) + } if w.Header().Get("Content-Type") == "" { w.Header().Set("Content-Type", http.DetectContentType(b)) } diff --git a/middleware/gzip/gzip_test.go b/middleware/gzip/gzip_test.go index c35c99c6..b9bafc6d 100644 --- a/middleware/gzip/gzip_test.go +++ b/middleware/gzip/gzip_test.go @@ -92,7 +92,7 @@ func nextFunc(shouldGzip bool) middleware.Handler { if w.Header().Get("Vary") != "Accept-Encoding" { return 0, fmt.Errorf("Vary must be Accept-Encoding, found %v", r.Header.Get("Vary")) } - if _, ok := w.(gzipResponseWriter); !ok { + if _, ok := w.(*gzipResponseWriter); !ok { return 0, fmt.Errorf("ResponseWriter should be gzipResponseWriter, found %T", w) } return 0, nil @@ -103,7 +103,7 @@ func nextFunc(shouldGzip bool) middleware.Handler { if w.Header().Get("Content-Encoding") == "gzip" { return 0, fmt.Errorf("Content-Encoding must not be gzip, found gzip") } - if _, ok := w.(gzipResponseWriter); ok { + if _, ok := w.(*gzipResponseWriter); ok { return 0, fmt.Errorf("ResponseWriter should not be gzipResponseWriter") } return 0, nil diff --git a/middleware/gzip/response_filter.go b/middleware/gzip/response_filter.go index c599b3e1..87a34e60 100644 --- a/middleware/gzip/response_filter.go +++ b/middleware/gzip/response_filter.go @@ -31,11 +31,11 @@ func (l LengthFilter) ShouldCompress(w http.ResponseWriter) bool { type ResponseFilterWriter struct { filters []ResponseFilter shouldCompress bool - gzipResponseWriter + *gzipResponseWriter } // NewResponseFilterWriter creates and initializes a new ResponseFilterWriter. -func NewResponseFilterWriter(filters []ResponseFilter, gz gzipResponseWriter) *ResponseFilterWriter { +func NewResponseFilterWriter(filters []ResponseFilter, gz *gzipResponseWriter) *ResponseFilterWriter { return &ResponseFilterWriter{filters: filters, gzipResponseWriter: gz} } diff --git a/middleware/gzip/response_filter_test.go b/middleware/gzip/response_filter_test.go index cd7c7191..73867e7f 100644 --- a/middleware/gzip/response_filter_test.go +++ b/middleware/gzip/response_filter_test.go @@ -33,7 +33,7 @@ func TestLengthFilter(t *testing.T) { for j, filter := range filters { r := httptest.NewRecorder() r.Header().Set("Content-Length", fmt.Sprint(ts.length)) - wWriter := NewResponseFilterWriter([]ResponseFilter{filter}, gzipResponseWriter{gzip.NewWriter(r), r}) + wWriter := NewResponseFilterWriter([]ResponseFilter{filter}, &gzipResponseWriter{gzip.NewWriter(r), r, false}) if filter.ShouldCompress(wWriter) != ts.shouldCompress[j] { t.Errorf("Test %v: Expected %v found %v", i, ts.shouldCompress[j], filter.ShouldCompress(r)) }