diff --git a/config/setup/rewrite.go b/config/setup/rewrite.go index 897f597ec..b510a237b 100644 --- a/config/setup/rewrite.go +++ b/config/setup/rewrite.go @@ -60,6 +60,7 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) { return nil, c.ArgErr() } } + // ensure pattern and to are specified if pattern == "" || to == "" { return nil, c.ArgErr() } diff --git a/config/setup/rewrite_test.go b/config/setup/rewrite_test.go index 17a0e97b5..9ff294ef0 100644 --- a/config/setup/rewrite_test.go +++ b/config/setup/rewrite_test.go @@ -3,7 +3,9 @@ package setup import ( "testing" + "fmt" "github.com/mholt/caddy/middleware/rewrite" + "regexp" ) func TestRewrite(t *testing.T) { @@ -33,27 +35,27 @@ func TestRewrite(t *testing.T) { } func TestRewriteParse(t *testing.T) { - tests := []struct { + simpleTests := []struct { input string shouldErr bool expected []rewrite.Rule }{ {`rewrite /from /to`, false, []rewrite.Rule{ - {From: "/from", To: "/to"}, + rewrite.SimpleRule{"/from", "/to"}, }}, {`rewrite /from /to rewrite a b`, false, []rewrite.Rule{ - {From: "/from", To: "/to"}, - {From: "a", To: "b"}, + rewrite.SimpleRule{"/from", "/to"}, + rewrite.SimpleRule{"a", "b"}, }}, {`rewrite a`, true, []rewrite.Rule{}}, {`rewrite`, true, []rewrite.Rule{}}, {`rewrite a b c`, true, []rewrite.Rule{ - {From: "a", To: "b"}, + rewrite.SimpleRule{"a", "b"}, }}, } - for i, test := range tests { + for i, test := range simpleTests { c := newTestController(test.input) actual, err := rewriteParse(c) @@ -61,6 +63,8 @@ func TestRewriteParse(t *testing.T) { t.Errorf("Test %d didn't error, but it should have", i) } else if err != nil && !test.shouldErr { t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } else if err != nil && test.shouldErr { + continue } if len(actual) != len(test.expected) { @@ -68,8 +72,9 @@ func TestRewriteParse(t *testing.T) { i, len(test.expected), len(actual)) } - for j, expectedRule := range test.expected { - actualRule := actual[j] + for j, e := range test.expected { + actualRule := actual[j].(rewrite.SimpleRule) + expectedRule := e.(rewrite.SimpleRule) if actualRule.From != expectedRule.From { t.Errorf("Test %d, rule %d: Expected From=%s, got %s", @@ -82,4 +87,98 @@ func TestRewriteParse(t *testing.T) { } } } + + regexpTests := []struct { + input string + shouldErr bool + expected []rewrite.Rule + }{ + {`rewrite { + r .* + to /to + }`, false, []rewrite.Rule{ + &rewrite.RegexpRule{"/", "/to", nil, regexp.MustCompile(".*")}, + }}, + {`rewrite { + regexp .* + to /to + ext / html txt + }`, false, []rewrite.Rule{ + &rewrite.RegexpRule{"/", "/to", []string{"/", "html", "txt"}, regexp.MustCompile(".*")}, + }}, + {`rewrite /path { + r rr + to /dest + } + rewrite / { + regexp [a-z]+ + to /to + } + `, false, []rewrite.Rule{ + &rewrite.RegexpRule{"/path", "/dest", nil, regexp.MustCompile("rr")}, + &rewrite.RegexpRule{"/", "/to", nil, regexp.MustCompile("[a-z]+")}, + }}, + {`rewrite { + to /to + }`, true, []rewrite.Rule{ + &rewrite.RegexpRule{}, + }}, + {`rewrite { + r .* + }`, true, []rewrite.Rule{ + &rewrite.RegexpRule{}, + }}, + {`rewrite { + + }`, true, []rewrite.Rule{ + &rewrite.RegexpRule{}, + }}, + {`rewrite /`, true, []rewrite.Rule{ + &rewrite.RegexpRule{}, + }}, + } + + for i, test := range regexpTests { + c := newTestController(test.input) + actual, err := rewriteParse(c) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } else if err != nil && test.shouldErr { + continue + } + + if len(actual) != len(test.expected) { + t.Fatalf("Test %d expected %d rules, but got %d", + i, len(test.expected), len(actual)) + } + + for j, e := range test.expected { + actualRule := actual[j].(*rewrite.RegexpRule) + expectedRule := e.(*rewrite.RegexpRule) + + if actualRule.Base != expectedRule.Base { + t.Errorf("Test %d, rule %d: Expected Base=%s, got %s", + i, j, expectedRule.Base, actualRule.Base) + } + + if actualRule.To != expectedRule.To { + t.Errorf("Test %d, rule %d: Expected To=%s, got %s", + i, j, expectedRule.To, actualRule.To) + } + + if fmt.Sprint(actualRule.Exts) != fmt.Sprint(expectedRule.Exts) { + t.Errorf("Test %d, rule %d: Expected Ext=%v, got %v", + i, j, expectedRule.To, actualRule.To) + } + + if actualRule.String() != expectedRule.String() { + t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", + i, j, expectedRule.String(), actualRule.String()) + } + } + } + } diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go index cd51622ef..90b6b902a 100644 --- a/middleware/rewrite/rewrite.go +++ b/middleware/rewrite/rewrite.go @@ -30,31 +30,47 @@ func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) return rw.Next.ServeHTTP(w, r) } -// A Rule describes an internal location rewrite rule. +// Rule describes an internal location rewrite rule. type Rule interface { + // Rewrite rewrites the internal location of the current request. Rewrite(*http.Request) bool } -type SimpleRule [2]string +// SimpleRule is a simple rewrite rule. +type SimpleRule struct { + From, To string +} +// NewSimpleRule creates a new Simple Rule func NewSimpleRule(from, to string) SimpleRule { return SimpleRule{from, to} } +// Rewrite rewrites the internal location of the current request. func (s SimpleRule) Rewrite(r *http.Request) bool { - if s[0] == r.URL.Path { - r.URL.Path = s[1] + if s.From == r.URL.Path { + r.URL.Path = s.To return true } return false } +// RegexpRule is a rewrite rule based on a regular expression type RegexpRule struct { - base, to string - ext []string + // Path base. Request to this path and subpaths will be rewritten + Base string + + // Path to rewrite to + To string + + // Extensions to filter by + Exts []string + *regexp.Regexp } +// NewRegexpRule creates a new RegexpRule. It returns an error if regexp +// pattern (pattern) or extensions (ext) are invalid. func NewRegexpRule(base, pattern, to string, ext []string) (*RegexpRule, error) { r, err := regexp.Compile(pattern) if err != nil { @@ -64,7 +80,7 @@ func NewRegexpRule(base, pattern, to string, ext []string) (*RegexpRule, error) // validate extensions for _, v := range ext { if len(v) < 2 || (len(v) < 3 && v[0] == '!') { - // check if it is no extension + // check if no extension is specified if v != "/" && v != "!/" { return nil, fmt.Errorf("Invalid extension %v", v) } @@ -79,48 +95,62 @@ func NewRegexpRule(base, pattern, to string, ext []string) (*RegexpRule, error) }, nil } +// regexpVars are variables that can be used for To (rewrite destination path). var regexpVars []string = []string{ "{path}", "{query}", } +// Rewrite rewrites the internal location of the current request. func (r *RegexpRule) Rewrite(req *http.Request) bool { rPath := req.URL.Path - if strings.Index(rPath, r.base) != 0 { + + // validate base + if !strings.HasPrefix(rPath, r.Base) { return false } + + // validate extensions if !r.matchExt(rPath) { return false } - if !r.MatchString(rPath[len(r.base):]) { + + // validate regexp + if !r.MatchString(rPath[len(r.Base):]) { return false } - to := r.to + to := r.To // check variables for _, v := range regexpVars { - if strings.Contains(r.to, v) { + if strings.Contains(r.To, v) { switch v { - case regexpVars[0]: + case "{path}": to = strings.Replace(to, v, req.URL.Path[1:], -1) - case regexpVars[1]: + case "{query}": to = strings.Replace(to, v, req.URL.RawQuery, -1) } } } + // validate resulting path url, err := url.Parse(to) if err != nil { return false } + // perform rewrite req.URL.Path = url.Path - req.URL.RawQuery = url.RawQuery - + if url.RawQuery != "" { + // overwrite query string if present + req.URL.RawQuery = url.RawQuery + } return true } +// matchExt matches rPath against registered file extensions. +// Returns true if a match is found and false otherwise. func (r *RegexpRule) matchExt(rPath string) bool { f := filepath.Base(rPath) ext := path.Ext(f) @@ -129,7 +159,7 @@ func (r *RegexpRule) matchExt(rPath string) bool { } mustUse := false - for _, v := range r.ext { + for _, v := range r.Exts { use := true if v[0] == '!' { use = false diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go index e9793ac13..7dcd67685 100644 --- a/middleware/rewrite/rewrite_test.go +++ b/middleware/rewrite/rewrite_test.go @@ -7,16 +7,38 @@ import ( "testing" "github.com/mholt/caddy/middleware" + "strings" ) func TestRewrite(t *testing.T) { rw := Rewrite{ Next: middleware.HandlerFunc(urlPrinter), Rules: []Rule{ - {From: "/from", To: "/to"}, - {From: "/a", To: "/b"}, + NewSimpleRule("/from", "/to"), + NewSimpleRule("/a", "/b"), }, } + + regexpRules := [][]string{ + []string{"/reg/", ".*", "/to", ""}, + []string{"/r/", "[a-z]+", "/toaz", "!.html|"}, + []string{"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""}, + []string{"/ab/", "ab", "/ab?{query}", ".txt|"}, + []string{"/ab/", "ab", "/ab?type=html&{query}", ".html|"}, + } + + for _, regexpRule := range regexpRules { + var ext []string + if s := strings.Split(regexpRule[3], "|"); len(s) > 1 { + ext = s[:len(s)-1] + } + rule, err := NewRegexpRule(regexpRule[0], regexpRule[1], regexpRule[2], ext) + if err != nil { + t.Fatal(err) + } + rw.Rules = append(rw.Rules, rule) + } + tests := []struct { from string expectedTo string @@ -29,6 +51,25 @@ func TestRewrite(t *testing.T) { {"/asdf?foo=bar", "/asdf?foo=bar"}, {"/foo#bar", "/foo#bar"}, {"/a#foo", "/b#foo"}, + {"/reg/foo", "/to"}, + {"/re", "/re"}, + {"/r/", "/r/"}, + {"/r/123", "/r/123"}, + {"/r/a123", "/toaz"}, + {"/r/abcz", "/toaz"}, + {"/r/z", "/toaz"}, + {"/r/z.html", "/r/z.html"}, + {"/r/z.js", "/toaz"}, + {"/url/asAB", "/to/url/asAB"}, + {"/url/aBsAB", "/url/aBsAB"}, + {"/url/a00sAB", "/to/url/a00sAB"}, + {"/url/a0z0sAB", "/to/url/a0z0sAB"}, + {"/ab/aa", "/ab/aa"}, + {"/ab/ab", "/ab/ab"}, + {"/ab/ab.txt", "/ab"}, + {"/ab/ab.txt?name=name", "/ab?name=name"}, + {"/ab/ab.html?name=name", "/ab?type=html&name=name"}, + {"/ab/ab.html", "/ab?type=html&"}, } for i, test := range tests {