diff --git a/config/setup/gzip.go b/config/setup/gzip.go index 714f8198..ea93a128 100644 --- a/config/setup/gzip.go +++ b/config/setup/gzip.go @@ -28,7 +28,8 @@ func gzipParse(c *Controller) ([]gzip.Config, error) { config := gzip.Config{} pathFilter := gzip.PathFilter{make(gzip.Set)} - extFilter := gzip.DefaultExtFilter() + mimeFilter := gzip.MIMEFilter{make(gzip.Set)} + extFilter := gzip.ExtFilter{make(gzip.Set)} // no extra args expected if len(c.RemainingArgs()) > 0 { @@ -37,6 +38,17 @@ func gzipParse(c *Controller) ([]gzip.Config, error) { for c.NextBlock() { switch c.Val() { + case "mimes": + mimes := c.RemainingArgs() + if len(mimes) == 0 { + return configs, c.ArgErr() + } + for _, m := range mimes { + if !gzip.ValidMIME(m) { + return configs, fmt.Errorf("Invalid MIME %v.", m) + } + mimeFilter.Types.Add(m) + } case "ext": exts := c.RemainingArgs() if len(exts) == 0 { @@ -74,8 +86,25 @@ func gzipParse(c *Controller) ([]gzip.Config, error) { } } - // put pathFilter in front to filter with path first - config.Filters = []gzip.Filter{pathFilter, extFilter} + config.Filters = []gzip.Filter{} + + // if ignored paths are specified, put in front to filter with path first + if len(pathFilter.IgnoredPaths) > 0 { + config.Filters = []gzip.Filter{pathFilter} + } + + // if mime types are specified, use it and ignore extensions + if len(mimeFilter.Types) > 0 { + config.Filters = append(config.Filters, mimeFilter) + + // if extensions are specified, use it + } else if len(extFilter.Exts) > 0 { + config.Filters = append(config.Filters, extFilter) + + // neither is specified, use default mime types + } else { + config.Filters = append(config.Filters, gzip.DefaultMIMEFilter()) + } configs = append(configs, config) } diff --git a/config/setup/gzip_test.go b/config/setup/gzip_test.go index accede6a..b228dbcc 100644 --- a/config/setup/gzip_test.go +++ b/config/setup/gzip_test.go @@ -59,14 +59,35 @@ func TestGzip(t *testing.T) { level 3 } `, false}, + {`gzip { mimes text/html + }`, false}, + {`gzip { mimes text/html application/json + }`, false}, + {`gzip { mimes text/html application/ + }`, true}, + {`gzip { mimes text/html /json + }`, true}, + {`gzip { mimes /json text/html + }`, true}, + {`gzip { not /file + ext .html + level 1 + mimes text/html text/plain + } + gzip { not /file1 + ext .htm + level 3 + mimes text/html text/css + } + `, false}, } for i, test := range tests { c := newTestController(test.input) _, err := gzipParse(c) if test.shouldErr && err == nil { - t.Errorf("Text %v: Expected error but found nil", i) + t.Errorf("Test %v: Expected error but found nil", i) } else if !test.shouldErr && err != nil { - t.Errorf("Text %v: Expected no error but found error: ", i, err) + t.Errorf("Test %v: Expected no error but found error: %v", i, err) } } } diff --git a/middleware/gzip/filter.go b/middleware/gzip/filter.go index 517f8858..a945fed9 100644 --- a/middleware/gzip/filter.go +++ b/middleware/gzip/filter.go @@ -3,13 +3,14 @@ package gzip import ( "net/http" "path" + "strings" "github.com/mholt/caddy/middleware" ) // Filter determines if a request should be gzipped. type Filter interface { - // ShouldCompress tells if compression gzip compression + // ShouldCompress tells if gzip compression // should be done on the request. ShouldCompress(*http.Request) bool } @@ -20,24 +21,12 @@ type ExtFilter struct { Exts Set } -// textExts is a list of extensions for text related files. -var textExts = []string{ - ".html", ".htm", ".css", ".json", ".php", ".js", ".txt", ".md", ".xml", -} - // extWildCard is the wildcard for extensions. const extWildCard = "*" -// DefaultExtFilter creates a default ExtFilter with -// file extensions for text types. -func DefaultExtFilter() ExtFilter { - e := ExtFilter{make(Set)} - for _, ext := range textExts { - e.Exts.Add(ext) - } - return e -} - +// ShouldCompress checks if the request file extension matches any +// of the registered extensions. It returns true if the extension is +// found and false otherwise. func (e ExtFilter) ShouldCompress(r *http.Request) bool { ext := path.Ext(r.URL.Path) return e.Exts.Contains(extWildCard) || e.Exts.Contains(ext) @@ -50,7 +39,7 @@ type PathFilter struct { } // ShouldCompress checks if the request path matches any of the -// registered paths to ignore. If returns false if an ignored path +// registered paths to ignore. It returns false if an ignored path // is found and true otherwise. func (p PathFilter) ShouldCompress(r *http.Request) bool { return !p.IgnoredPaths.ContainsFunc(func(value string) bool { @@ -58,6 +47,39 @@ func (p PathFilter) ShouldCompress(r *http.Request) bool { }) } +// MIMEFilter is Filter for request content types. +type MIMEFilter struct { + // Types is the MIME types to accept. + Types Set +} + +// defaultMIMETypes is the list of default MIME types to use. +var defaultMIMETypes = []string{ + "text/plain", "text/html", "text/css", "application/json", "application/javascript", + "text/x-markdown", "text/xml", "application/xml", +} + +// DefaultMIMEFilter creates a MIMEFilter with default types. +func DefaultMIMEFilter() MIMEFilter { + m := MIMEFilter{Types: make(Set)} + for _, mime := range defaultMIMETypes { + m.Types.Add(mime) + } + return m +} + +// ShouldCompress checks if the content type of the request +// matches any of the registered ones. It returns true if +// found and false otherwise. +func (m MIMEFilter) ShouldCompress(r *http.Request) bool { + return m.Types.Contains(r.Header.Get("Content-Type")) +} + +func ValidMIME(mime string) bool { + s := strings.Split(mime, "/") + return len(s) == 2 && strings.TrimSpace(s[0]) != "" && strings.TrimSpace(s[1]) != "" +} + // Set stores distinct strings. type Set map[string]struct{} diff --git a/middleware/gzip/filter_test.go b/middleware/gzip/filter_test.go index 56d054cf..c3664cdd 100644 --- a/middleware/gzip/filter_test.go +++ b/middleware/gzip/filter_test.go @@ -47,13 +47,13 @@ func TestSet(t *testing.T) { } func TestExtFilter(t *testing.T) { - var filter Filter = DefaultExtFilter() - _ = filter.(ExtFilter) - for i, e := range textExts { - r := urlRequest("file" + e) - if !filter.ShouldCompress(r) { - t.Errorf("Test %v: Should be valid filter", i) - } + var filter Filter = ExtFilter{make(Set)} + for _, e := range []string{".txt", ".html", ".css", ".md"} { + filter.(ExtFilter).Exts.Add(e) + } + r := urlRequest("file.txt") + if !filter.ShouldCompress(r) { + t.Errorf("Should be valid filter") } var exts = []string{ ".html", ".css", ".md", @@ -100,6 +100,32 @@ func TestPathFilter(t *testing.T) { } } +func TestMIMEFilter(t *testing.T) { + var filter Filter = DefaultMIMEFilter() + _ = filter.(MIMEFilter) + var mimes = []string{ + "text/html", "text/css", "application/json", + } + for i, m := range mimes { + r := urlRequest("file" + m) + r.Header.Set("Content-Type", m) + if !filter.ShouldCompress(r) { + t.Errorf("Test %v: Should be valid filter", i) + } + } + mimes = []string{ + "image/jpeg", "image/png", + } + filter = DefaultMIMEFilter() + for i, m := range mimes { + r := urlRequest("file" + m) + r.Header.Set("Content-Type", m) + if filter.ShouldCompress(r) { + t.Errorf("Test %v: Should not be valid filter", i) + } + } +} + func urlRequest(url string) *http.Request { r, _ := http.NewRequest("GET", url, nil) return r diff --git a/middleware/gzip/gzip_test.go b/middleware/gzip/gzip_test.go index 7015a5b7..ae0e300c 100644 --- a/middleware/gzip/gzip_test.go +++ b/middleware/gzip/gzip_test.go @@ -16,13 +16,20 @@ func TestGzipHandler(t *testing.T) { for _, p := range badPaths { pathFilter.IgnoredPaths.Add(p) } + extFilter := ExtFilter{make(Set)} + for _, e := range []string{".txt", ".html", ".css", ".md"} { + extFilter.Exts.Add(e) + } gz := Gzip{Configs: []Config{ - Config{Filters: []Filter{DefaultExtFilter(), pathFilter}}, + Config{Filters: []Filter{pathFilter, extFilter}}, }} w := httptest.NewRecorder() gz.Next = nextFunc(true) - for _, e := range textExts { + var exts = []string{ + ".html", ".css", ".md", + } + for _, e := range exts { url := "/file" + e r, err := http.NewRequest("GET", url, nil) if err != nil { @@ -38,7 +45,7 @@ func TestGzipHandler(t *testing.T) { w = httptest.NewRecorder() gz.Next = nextFunc(false) for _, p := range badPaths { - for _, e := range textExts { + for _, e := range exts { url := p + "/file" + e r, err := http.NewRequest("GET", url, nil) if err != nil { @@ -54,7 +61,7 @@ func TestGzipHandler(t *testing.T) { w = httptest.NewRecorder() gz.Next = nextFunc(false) - exts := []string{ + exts = []string{ ".htm1", ".abc", ".mdx", } for _, e := range exts { @@ -70,6 +77,45 @@ func TestGzipHandler(t *testing.T) { } } + gz.Configs[0].Filters[1] = DefaultMIMEFilter() + w = httptest.NewRecorder() + gz.Next = nextFunc(true) + var mimes = []string{ + "text/html", "text/css", "application/json", + } + for _, m := range mimes { + url := "/file" + r, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Error(err) + } + r.Header.Set("Content-Type", m) + r.Header.Set("Accept-Encoding", "gzip") + _, err = gz.ServeHTTP(w, r) + if err != nil { + t.Error(err) + } + } + + w = httptest.NewRecorder() + gz.Next = nextFunc(false) + mimes = []string{ + "image/jpeg", "image/png", + } + for _, m := range mimes { + url := "/file" + r, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Error(err) + } + r.Header.Set("Content-Type", m) + r.Header.Set("Accept-Encoding", "gzip") + _, err = gz.ServeHTTP(w, r) + if err != nil { + t.Error(err) + } + } + } func nextFunc(shouldGzip bool) middleware.Handler {