diff --git a/caddyhttp/httpserver/replacer.go b/caddyhttp/httpserver/replacer.go index 892dc4ae..c77230e8 100644 --- a/caddyhttp/httpserver/replacer.go +++ b/caddyhttp/httpserver/replacer.go @@ -40,10 +40,10 @@ type Replacer interface { // they will be used to overwrite other replacements // if there is a name conflict. type replacer struct { - replacements map[string]func() string - customReplacements map[string]func() string + customReplacements map[string]string emptyValue string responseRecorder *ResponseRecorder + request *http.Request } // NewReplacer makes a new replacer based on r and rr which @@ -55,90 +55,15 @@ type replacer struct { // of empty string (can still be empty string). func NewReplacer(r *http.Request, rr *ResponseRecorder, emptyValue string) Replacer { rep := &replacer{ + request: r, responseRecorder: rr, - customReplacements: make(map[string]func() string), - replacements: map[string]func() string{ - "{method}": func() string { return r.Method }, - "{scheme}": func() string { - if r.TLS != nil { - return "https" - } - return "http" - }, - "{hostname}": func() string { - name, err := os.Hostname() - if err != nil { - return "" - } - return name - }, - "{host}": func() string { return r.Host }, - "{hostonly}": func() string { - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - return r.Host - } - return host - }, - "{path}": func() string { return r.URL.Path }, - "{path_escaped}": func() string { return url.QueryEscape(r.URL.Path) }, - "{query}": func() string { return r.URL.RawQuery }, - "{query_escaped}": func() string { return url.QueryEscape(r.URL.RawQuery) }, - "{fragment}": func() string { return r.URL.Fragment }, - "{proto}": func() string { return r.Proto }, - "{remote}": func() string { - host, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return r.RemoteAddr - } - return host - }, - "{port}": func() string { - _, port, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return "" - } - return port - }, - "{uri}": func() string { return r.URL.RequestURI() }, - "{uri_escaped}": func() string { return url.QueryEscape(r.URL.RequestURI()) }, - "{when}": func() string { return time.Now().Format(timeFormat) }, - "{file}": func() string { - _, file := path.Split(r.URL.Path) - return file - }, - "{dir}": func() string { - dir, _ := path.Split(r.URL.Path) - return dir - }, - "{request}": func() string { - dump, err := httputil.DumpRequest(r, false) - if err != nil { - return "" - } - - return requestReplacer.Replace(string(dump)) - }, - "{request_body}": func() string { - if !canLogRequest(r) { - return "" - } - - body, err := readRequestBody(r, maxLogBodySize) - if err != nil { - return "" - } - - return requestReplacer.Replace(string(body)) - }, - }, - emptyValue: emptyValue, + customReplacements: make(map[string]string), + emptyValue: emptyValue, } // Header placeholders (case-insensitive) for header, values := range r.Header { - values := values - rep.replacements[headerReplacer+strings.ToLower(header)+"}"] = func() string { return strings.Join(values, ",") } + rep.customReplacements["{>"+strings.ToLower(header)+"}"] = strings.Join(values, ",") } return rep @@ -185,54 +110,37 @@ func (r *replacer) Replace(s string) string { return s } - // Make response placeholders now - if r.responseRecorder != nil { - r.replacements["{status}"] = func() string { return strconv.Itoa(r.responseRecorder.status) } - r.replacements["{size}"] = func() string { return strconv.Itoa(r.responseRecorder.size) } - r.replacements["{latency}"] = func() string { - dur := time.Since(r.responseRecorder.start) - return roundDuration(dur).String() - } - } - - // Include custom placeholders, overwriting existing ones if necessary - for key, val := range r.customReplacements { - r.replacements[key] = val - } - - // Header replacements - these are case-insensitive, so we can't just use strings.Replace() - for strings.Contains(s, headerReplacer) { - idxStart := strings.Index(s, headerReplacer) - endOffset := idxStart + len(headerReplacer) - idxEnd := strings.Index(s[endOffset:], "}") - if idxEnd > -1 { - placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1]) - replacement := "" - if getReplacement, ok := r.replacements[placeholder]; ok { - replacement = getReplacement() - } - if replacement == "" { - replacement = r.emptyValue - } - s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:] - } else { + result := "" + for { + idxStart := strings.Index(s, "{") + if idxStart == -1 { + // no placeholder anymore break } + idxEnd := strings.Index(s[idxStart:], "}") + if idxEnd == -1 { + // unpaired placeholder + break + } + idxEnd += idxStart + + // get a replacement + placeholder := s[idxStart : idxEnd+1] + // Header replacements - they are case-insensitive + if placeholder[1] == '>' { + placeholder = strings.ToLower(placeholder) + } + replacement := r.getSubstitution(placeholder) + + // append prefix + replacement + result += s[:idxStart] + replacement + + // strip out scanned parts + s = s[idxEnd+1:] } - // Regular replacements - these are easier because they're case-sensitive - for placeholder, getReplacement := range r.replacements { - if !strings.Contains(s, placeholder) { - continue - } - replacement := getReplacement() - if replacement == "" { - replacement = r.emptyValue - } - s = strings.Replace(s, placeholder, replacement, -1) - } - - return s + // append unscanned parts + return result + s } func roundDuration(d time.Duration) time.Duration { @@ -265,14 +173,114 @@ func round(d, r time.Duration) time.Duration { return d } +// getSubstitution retrieves value from corresponding key +func (r *replacer) getSubstitution(key string) string { + // search custom replacements first + if value, ok := r.customReplacements[key]; ok { + return value + } + + // search default replacements then + switch key { + case "{method}": + return r.request.Method + case "{scheme}": + if r.request.TLS != nil { + return "https" + } + return "http" + case "{hostname}": + name, err := os.Hostname() + if err != nil { + return r.emptyValue + } + return name + case "{host}": + return r.request.Host + case "{hostonly}": + host, _, err := net.SplitHostPort(r.request.Host) + if err != nil { + return r.request.Host + } + return host + case "{path}": + return r.request.URL.Path + case "{path_escaped}": + return url.QueryEscape(r.request.URL.Path) + case "{query}": + return r.request.URL.RawQuery + case "{query_escaped}": + return url.QueryEscape(r.request.URL.RawQuery) + case "{fragment}": + return r.request.URL.Fragment + case "{proto}": + return r.request.Proto + case "{remote}": + host, _, err := net.SplitHostPort(r.request.RemoteAddr) + if err != nil { + return r.request.RemoteAddr + } + return host + case "{port}": + _, port, err := net.SplitHostPort(r.request.RemoteAddr) + if err != nil { + return r.emptyValue + } + return port + case "{uri}": + return r.request.URL.RequestURI() + case "{uri_escaped}": + return url.QueryEscape(r.request.URL.RequestURI()) + case "{when}": + return time.Now().Format(timeFormat) + case "{file}": + _, file := path.Split(r.request.URL.Path) + return file + case "{dir}": + dir, _ := path.Split(r.request.URL.Path) + return dir + case "{request}": + dump, err := httputil.DumpRequest(r.request, false) + if err != nil { + return r.emptyValue + } + return requestReplacer.Replace(string(dump)) + case "{request_body}": + if !canLogRequest(r.request) { + return r.emptyValue + } + body, err := readRequestBody(r.request, maxLogBodySize) + if err != nil { + return r.emptyValue + } + return requestReplacer.Replace(string(body)) + case "{status}": + if r.responseRecorder == nil { + return r.emptyValue + } + return strconv.Itoa(r.responseRecorder.status) + case "{size}": + if r.responseRecorder == nil { + return r.emptyValue + } + return strconv.Itoa(r.responseRecorder.size) + case "{latency}": + if r.responseRecorder == nil { + return r.emptyValue + } + return roundDuration(time.Since(r.responseRecorder.start)).String() + } + + return r.emptyValue +} + // Set sets key to value in the r.customReplacements map. func (r *replacer) Set(key, value string) { - r.customReplacements["{"+key+"}"] = func() string { return value } + r.customReplacements["{"+key+"}"] = value } const ( timeFormat = "02/Jan/2006:15:04:05 -0700" - headerReplacer = "{>" headerContentType = "Content-Type" contentTypeJSON = "application/json" contentTypeXML = "application/xml" diff --git a/caddyhttp/httpserver/replacer_test.go b/caddyhttp/httpserver/replacer_test.go index cfd52fb2..25d39d07 100644 --- a/caddyhttp/httpserver/replacer_test.go +++ b/caddyhttp/httpserver/replacer_test.go @@ -24,28 +24,12 @@ func TestNewReplacer(t *testing.T) { switch v := rep.(type) { case *replacer: - if v.replacements["{host}"]() != "localhost" { + if v.getSubstitution("{host}") != "localhost" { t.Error("Expected host to be localhost") } - if v.replacements["{method}"]() != "POST" { + if v.getSubstitution("{method}") != "POST" { t.Error("Expected request method to be POST") } - - // Response placeholders should only be set after call to Replace() - got, want := "", "" - if getReplacement, ok := v.replacements["{status}"]; ok { - got = getReplacement() - } - if want := ""; got != want { - t.Errorf("Expected status to NOT be set before Replace() is called; was: %s", got) - } - rep.Replace("{foobar}") - if getReplacement, ok := v.replacements["{status}"]; ok { - got = getReplacement() - } - if want = "200"; got != want { - t.Errorf("Expected status to be %s, was: %s", want, got) - } default: t.Fatalf("Expected *replacer underlying Replacer type, got: %#v", rep) } @@ -94,19 +78,21 @@ func TestReplace(t *testing.T) { complexCases := []struct { template string - replacements map[string]func() string + replacements map[string]string expect string }{ - {"/a{1}/{2}", - map[string]func() string{ - "{1}": func() string { return "12" }, - "{2}": func() string { return "" }}, + { + "/a{1}/{2}", + map[string]string{ + "{1}": "12", + "{2}": "", + }, "/a12/"}, } for _, c := range complexCases { repl := &replacer{ - replacements: c.replacements, + customReplacements: c.replacements, } if expected, actual := c.expect, repl.Replace(c.template); expected != actual { t.Errorf("for template '%s', expected '%s', got '%s'", c.template, expected, actual)