diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index 1503eccd..7ffdb775 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -954,6 +954,90 @@ func TestReverseProxyRetry(t *testing.T) { } } +func TestReverseProxyLargeBody(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + // set up proxy + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + })) + defer backend.Close() + + su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`proxy / `+backend.URL))) + if err != nil { + t.Fatal(err) + } + + p := &Proxy{ + Next: httpserver.EmptyNext, // prevents panic in some cases when test fails + Upstreams: su, + } + + // middle is required to simulate closable downstream request body + middle := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err = p.ServeHTTP(w, r) + if err != nil { + t.Error(err) + } + })) + defer middle.Close() + + // Our request body will be 100MB + bodySize := uint64(100 * 1000 * 1000) + + // We want to see how much memory the proxy module requires for this request. + // So lets record the mem stats before we start it. + begMemstats := &runtime.MemStats{} + runtime.ReadMemStats(begMemstats) + + r, err := http.NewRequest("POST", middle.URL, &noopReader{len: bodySize}) + if err != nil { + t.Fatal(err) + } + resp, err := http.DefaultTransport.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + + // Finally we need the mem stats after the request is done... + endMemstats := &runtime.MemStats{} + runtime.ReadMemStats(endMemstats) + + // ...to calculate the total amount of allocated memory during the request. + totalAlloc := endMemstats.TotalAlloc - begMemstats.TotalAlloc + + // If that's as much as the size of the body itself it's a serious sign that the + // request was not "streamed" to the upstream without buffering it first. + if totalAlloc >= bodySize { + t.Fatalf("proxy allocated too much memory: %d bytes", totalAlloc) + } +} + +type noopReader struct { + len uint64 + pos uint64 +} + +var _ io.Reader = &noopReader{} + +func (r *noopReader) Read(b []byte) (int, error) { + if r.pos >= r.len { + return 0, io.EOF + } + n := int(r.len - r.pos) + if n > len(b) { + n = len(b) + } + for i := range b[:n] { + b[i] = 0 + } + r.pos += uint64(n) + return n, nil +} + func newFakeUpstream(name string, insecure bool) *fakeUpstream { uri, _ := url.Parse(name) u := &fakeUpstream{ @@ -998,6 +1082,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost { func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true } func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second } func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } +func (u *fakeUpstream) GetHostCount() int { return 1 } // newWebSocketTestProxy returns a test proxy that will // redirect to the specified backendAddr. The function @@ -1049,6 +1134,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost { func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true } func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second } func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } +func (u *fakeWsUpstream) GetHostCount() int { return 1 } // recorderHijacker is a ResponseRecorder that can // be hijacked.