mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-23 22:27:38 -05:00
Fix deleted Content-Length header bug.
This commit is contained in:
parent
8631f33940
commit
23631cfaca
4 changed files with 65 additions and 22 deletions
|
@ -3,6 +3,7 @@
|
||||||
package gzip
|
package gzip
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -47,9 +48,13 @@ outer:
|
||||||
// Delete this header so gzipping is not repeated later in the chain
|
// Delete this header so gzipping is not repeated later in the chain
|
||||||
r.Header.Del("Accept-Encoding")
|
r.Header.Del("Accept-Encoding")
|
||||||
|
|
||||||
w.Header().Set("Content-Encoding", "gzip")
|
// gzipWriter modifies underlying writer at init,
|
||||||
w.Header().Set("Vary", "Accept-Encoding")
|
// use a buffer instead to leave ResponseWriter in
|
||||||
gzipWriter, err := newWriter(c, w)
|
// original form.
|
||||||
|
var buf = &bytes.Buffer{}
|
||||||
|
defer buf.Reset()
|
||||||
|
|
||||||
|
gzipWriter, err := newWriter(c, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// should not happen
|
// should not happen
|
||||||
return http.StatusInternalServerError, err
|
return http.StatusInternalServerError, err
|
||||||
|
@ -60,6 +65,8 @@ outer:
|
||||||
var rw http.ResponseWriter
|
var rw http.ResponseWriter
|
||||||
// if no response filter is used
|
// if no response filter is used
|
||||||
if len(c.ResponseFilters) == 0 {
|
if len(c.ResponseFilters) == 0 {
|
||||||
|
// replace buffer with ResponseWriter
|
||||||
|
gzipWriter.Reset(w)
|
||||||
rw = gz
|
rw = gz
|
||||||
} else {
|
} else {
|
||||||
// wrap gzip writer with ResponseFilterWriter
|
// wrap gzip writer with ResponseFilterWriter
|
||||||
|
@ -88,7 +95,7 @@ outer:
|
||||||
// newWriter create a new Gzip Writer based on the compression level.
|
// newWriter create a new Gzip Writer based on the compression level.
|
||||||
// If the level is valid (i.e. between 1 and 9), it uses the level.
|
// If the level is valid (i.e. between 1 and 9), it uses the level.
|
||||||
// Otherwise, it uses default compression level.
|
// Otherwise, it uses default compression level.
|
||||||
func newWriter(c Config, w http.ResponseWriter) (*gzip.Writer, error) {
|
func newWriter(c Config, w io.Writer) (*gzip.Writer, error) {
|
||||||
if c.Level >= gzip.BestSpeed && c.Level <= gzip.BestCompression {
|
if c.Level >= gzip.BestSpeed && c.Level <= gzip.BestCompression {
|
||||||
return gzip.NewWriterLevel(w, c.Level)
|
return gzip.NewWriterLevel(w, c.Level)
|
||||||
}
|
}
|
||||||
|
@ -108,6 +115,8 @@ type gzipResponseWriter struct {
|
||||||
// be wrong because it doesn't know it's being gzipped.
|
// be wrong because it doesn't know it's being gzipped.
|
||||||
func (w gzipResponseWriter) WriteHeader(code int) {
|
func (w gzipResponseWriter) WriteHeader(code int) {
|
||||||
w.Header().Del("Content-Length")
|
w.Header().Del("Content-Length")
|
||||||
|
w.Header().Set("Content-Encoding", "gzip")
|
||||||
|
w.Header().Set("Vary", "Accept-Encoding")
|
||||||
w.ResponseWriter.WriteHeader(code)
|
w.ResponseWriter.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -80,6 +80,8 @@ func TestGzipHandler(t *testing.T) {
|
||||||
|
|
||||||
func nextFunc(shouldGzip bool) middleware.Handler {
|
func nextFunc(shouldGzip bool) middleware.Handler {
|
||||||
return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte("test"))
|
||||||
if shouldGzip {
|
if shouldGzip {
|
||||||
if r.Header.Get("Accept-Encoding") != "" {
|
if r.Header.Get("Accept-Encoding") != "" {
|
||||||
return 0, fmt.Errorf("Accept-Encoding header not expected")
|
return 0, fmt.Errorf("Accept-Encoding header not expected")
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package gzip
|
package gzip
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"compress/gzip"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
@ -29,7 +30,6 @@ func (l LengthFilter) ShouldCompress(w http.ResponseWriter) bool {
|
||||||
// uncompressed data otherwise.
|
// uncompressed data otherwise.
|
||||||
type ResponseFilterWriter struct {
|
type ResponseFilterWriter struct {
|
||||||
filters []ResponseFilter
|
filters []ResponseFilter
|
||||||
validated bool
|
|
||||||
shouldCompress bool
|
shouldCompress bool
|
||||||
gzipResponseWriter
|
gzipResponseWriter
|
||||||
}
|
}
|
||||||
|
@ -40,21 +40,33 @@ func NewResponseFilterWriter(filters []ResponseFilter, gz gzipResponseWriter) *R
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write wraps underlying Write method and compresses if filters
|
// Write wraps underlying Write method and compresses if filters
|
||||||
// are satisfied
|
// are satisfied.
|
||||||
func (r *ResponseFilterWriter) Write(b []byte) (int, error) {
|
func (r *ResponseFilterWriter) WriteHeader(code int) {
|
||||||
// One time validation to determine if compression should
|
// Determine if compression should be used or not.
|
||||||
// be used or not.
|
r.shouldCompress = true
|
||||||
if !r.validated {
|
for _, filter := range r.filters {
|
||||||
r.shouldCompress = true
|
if !filter.ShouldCompress(r) {
|
||||||
for _, filter := range r.filters {
|
r.shouldCompress = false
|
||||||
if !filter.ShouldCompress(r) {
|
break
|
||||||
r.shouldCompress = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
r.validated = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.shouldCompress {
|
||||||
|
// replace buffer with ResponseWriter
|
||||||
|
if gzWriter, ok := r.gzipResponseWriter.Writer.(*gzip.Writer); ok {
|
||||||
|
gzWriter.Reset(r.ResponseWriter)
|
||||||
|
}
|
||||||
|
// use gzip WriteHeader to include and delete
|
||||||
|
// necessary headers
|
||||||
|
r.gzipResponseWriter.WriteHeader(code)
|
||||||
|
} else {
|
||||||
|
r.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write wraps underlying Write method and compresses if filters
|
||||||
|
// are satisfied
|
||||||
|
func (r *ResponseFilterWriter) Write(b []byte) (int, error) {
|
||||||
if r.shouldCompress {
|
if r.shouldCompress {
|
||||||
return r.gzipResponseWriter.Write(b)
|
return r.gzipResponseWriter.Write(b)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,8 +3,11 @@ package gzip
|
||||||
import (
|
import (
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mholt/caddy/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLengthFilter(t *testing.T) {
|
func TestLengthFilter(t *testing.T) {
|
||||||
|
@ -30,7 +33,8 @@ func TestLengthFilter(t *testing.T) {
|
||||||
for j, filter := range filters {
|
for j, filter := range filters {
|
||||||
r := httptest.NewRecorder()
|
r := httptest.NewRecorder()
|
||||||
r.Header().Set("Content-Length", fmt.Sprint(ts.length))
|
r.Header().Set("Content-Length", fmt.Sprint(ts.length))
|
||||||
if filter.ShouldCompress(r) != ts.shouldCompress[j] {
|
wWriter := NewResponseFilterWriter([]ResponseFilter{filter}, gzipResponseWriter{gzip.NewWriter(r), r})
|
||||||
|
if filter.ShouldCompress(wWriter) != ts.shouldCompress[j] {
|
||||||
t.Errorf("Test %v: Expected %v found %v", i, ts.shouldCompress[j], filter.ShouldCompress(r))
|
t.Errorf("Test %v: Expected %v found %v", i, ts.shouldCompress[j], filter.ShouldCompress(r))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,16 +51,32 @@ func TestResponseFilterWriter(t *testing.T) {
|
||||||
{"Hello \t\t\nfrom gzip", true},
|
{"Hello \t\t\nfrom gzip", true},
|
||||||
{"Hello gzip\n", false},
|
{"Hello gzip\n", false},
|
||||||
}
|
}
|
||||||
|
|
||||||
filters := []ResponseFilter{
|
filters := []ResponseFilter{
|
||||||
LengthFilter(15),
|
LengthFilter(15),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server := Gzip{Configs: []Config{
|
||||||
|
{ResponseFilters: filters},
|
||||||
|
}}
|
||||||
|
|
||||||
for i, ts := range tests {
|
for i, ts := range tests {
|
||||||
|
server.Next = middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
|
w.Header().Set("Content-Length", fmt.Sprint(len(ts.body)))
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(ts.body))
|
||||||
|
return 200, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
r := urlRequest("/")
|
||||||
|
r.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
w.Header().Set("Content-Length", fmt.Sprint(len(ts.body)))
|
|
||||||
gz := gzipResponseWriter{gzip.NewWriter(w), w}
|
server.ServeHTTP(w, r)
|
||||||
rw := NewResponseFilterWriter(filters, gz)
|
|
||||||
rw.Write([]byte(ts.body))
|
|
||||||
resp := w.Body.String()
|
resp := w.Body.String()
|
||||||
|
|
||||||
if !ts.shouldCompress {
|
if !ts.shouldCompress {
|
||||||
if resp != ts.body {
|
if resp != ts.body {
|
||||||
t.Errorf("Test %v: No compression expected, found %v", i, resp)
|
t.Errorf("Test %v: No compression expected, found %v", i, resp)
|
||||||
|
|
Loading…
Reference in a new issue