From 737c7c437204922822d94bf41d6f374ccb13b935 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Wed, 24 Feb 2016 16:41:45 -0700 Subject: [PATCH] fastcgi: Only perform extra copy if necessary; added tests --- middleware/fastcgi/fastcgi.go | 44 ++++++++++--------- middleware/fastcgi/fastcgi_test.go | 68 ++++++++++++++++++++++++------ 2 files changed, 79 insertions(+), 33 deletions(-) diff --git a/middleware/fastcgi/fastcgi.go b/middleware/fastcgi/fastcgi.go index 3d01c416..c4ca935e 100644 --- a/middleware/fastcgi/fastcgi.go +++ b/middleware/fastcgi/fastcgi.go @@ -72,7 +72,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) // Connect to FastCGI gateway network, address := rule.parseAddress() - fcgi, err := Dial(network, address) + fcgiBackend, err := Dial(network, address) if err != nil { return http.StatusBadGateway, err } @@ -81,19 +81,19 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length")) switch r.Method { case "HEAD": - resp, err = fcgi.Head(env) + resp, err = fcgiBackend.Head(env) case "GET": - resp, err = fcgi.Get(env) + resp, err = fcgiBackend.Get(env) case "OPTIONS": - resp, err = fcgi.Options(env) + resp, err = fcgiBackend.Options(env) case "POST": - resp, err = fcgi.Post(env, r.Header.Get("Content-Type"), r.Body, contentLength) + resp, err = fcgiBackend.Post(env, r.Header.Get("Content-Type"), r.Body, contentLength) case "PUT": - resp, err = fcgi.Put(env, r.Header.Get("Content-Type"), r.Body, contentLength) + resp, err = fcgiBackend.Put(env, r.Header.Get("Content-Type"), r.Body, contentLength) case "PATCH": - resp, err = fcgi.Patch(env, r.Header.Get("Content-Type"), r.Body, contentLength) + resp, err = fcgiBackend.Patch(env, r.Header.Get("Content-Type"), r.Body, contentLength) case "DELETE": - resp, err = fcgi.Delete(env, r.Header.Get("Content-Type"), r.Body, contentLength) + resp, err = fcgiBackend.Delete(env, r.Header.Get("Content-Type"), r.Body, contentLength) default: return http.StatusMethodNotAllowed, nil } @@ -106,29 +106,35 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) return http.StatusBadGateway, err } - // Write the response body to a buffer - // To explicitly set Content-Length - // For FastCGI app that don't set it - var buf bytes.Buffer - io.Copy(&buf, resp.Body) + var responseBody io.Reader = resp.Body if r.Header.Get("Content-Length") == "" { + // If the upstream app didn't set a Content-Length (shame on them), + // we need to do it to prevent error messages being appended to + // an already-written response, and other problematic behavior. + // So we copy it to a buffer and read its size before flushing + // the response out to the client. See issues #567 and #614. + buf := new(bytes.Buffer) + _, err := io.Copy(buf, resp.Body) + if err != nil { + return http.StatusBadGateway, err + } w.Header().Set("Content-Length", strconv.Itoa(buf.Len())) + responseBody = buf } + + // Write the status code and header fields writeHeader(w, resp) // Write the response body - // TODO: If this has an error, the response will already be - // partly written. We should copy out of resp.Body into a buffer - // first, then write it to the response... - _, err = io.Copy(w, &buf) + _, err = io.Copy(w, responseBody) if err != nil { return http.StatusBadGateway, err } // FastCGI stderr outputs - if fcgi.stderr.Len() != 0 { + if fcgiBackend.stderr.Len() != 0 { // Remove trailing newline, error logger already does this. - err = LogError(strings.TrimSuffix(fcgi.stderr.String(), "\n")) + err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n")) } return resp.StatusCode, err diff --git a/middleware/fastcgi/fastcgi_test.go b/middleware/fastcgi/fastcgi_test.go index c33f47af..5fbba23f 100644 --- a/middleware/fastcgi/fastcgi_test.go +++ b/middleware/fastcgi/fastcgi_test.go @@ -1,13 +1,61 @@ package fastcgi import ( + "net" "net/http" + "net/http/fcgi" + "net/http/httptest" "net/url" + "strconv" "testing" ) -func TestRuleParseAddress(t *testing.T) { +func TestServeHTTPContentLength(t *testing.T) { + testWithBackend := func(body string, setContentLength bool) { + bodyLenStr := strconv.Itoa(len(body)) + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("BackendSetsContentLength=%v: Unable to create listener for test: %v", setContentLength, err) + } + defer listener.Close() + go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if setContentLength { + w.Header().Set("Content-Length", bodyLenStr) + } + w.Write([]byte(body)) + })) + handler := Handler{ + Next: nil, + Rules: []Rule{{Path: "/", Address: listener.Addr().String()}}, + } + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("BackendSetsContentLength=%v: Unable to create request: %v", setContentLength, err) + } + w := httptest.NewRecorder() + + status, err := handler.ServeHTTP(w, r) + + if got, want := status, http.StatusOK; got != want { + t.Errorf("BackendSetsContentLength=%v: Expected returned status code to be %d, got %d", setContentLength, want, got) + } + if err != nil { + t.Errorf("BackendSetsContentLength=%v: Expected nil error, got: %v", setContentLength, err) + } + if got, want := w.Header().Get("Content-Length"), bodyLenStr; got != want { + t.Errorf("BackendSetsContentLength=%v: Expected Content-Length to be '%s', got: '%s'", setContentLength, want, got) + } + if got, want := w.Body.String(), body; got != want { + t.Errorf("BackendSetsContentLength=%v: Expected response body to be '%s', got: '%s'", setContentLength, want, got) + } + } + + testWithBackend("Backend does NOT set Content-Length", false) + testWithBackend("Backend sets Content-Length", true) +} + +func TestRuleParseAddress(t *testing.T) { getClientTestTable := []struct { rule *Rule expectednetwork string @@ -27,28 +75,21 @@ func TestRuleParseAddress(t *testing.T) { if _, actualaddress := entry.rule.parseAddress(); actualaddress != entry.expectedaddress { t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", entry.rule.Address, actualaddress, entry.expectedaddress) } - } - } func TestBuildEnv(t *testing.T) { - - buildEnvSingle := func(r *http.Request, rule Rule, fpath string, envExpected map[string]string, t *testing.T) { - - h := Handler{} - + testBuildEnv := func(r *http.Request, rule Rule, fpath string, envExpected map[string]string) { + var h Handler env, err := h.buildEnv(r, rule, fpath) if err != nil { t.Error("Unexpected error:", err.Error()) } - for k, v := range envExpected { if env[k] != v { t.Errorf("Unexpected %v. Got %v, expected %v", k, env[k], v) } } - } rule := Rule{} @@ -80,16 +121,15 @@ func TestBuildEnv(t *testing.T) { } // 1. Test for full canonical IPv6 address - buildEnvSingle(&r, rule, fpath, envExpected, t) + testBuildEnv(&r, rule, fpath, envExpected) // 2. Test for shorthand notation of IPv6 address r.RemoteAddr = "[::1]:51688" envExpected["REMOTE_ADDR"] = "[::1]" - buildEnvSingle(&r, rule, fpath, envExpected, t) + testBuildEnv(&r, rule, fpath, envExpected) // 3. Test for IPv4 address r.RemoteAddr = "192.168.0.10:51688" envExpected["REMOTE_ADDR"] = "192.168.0.10" - buildEnvSingle(&r, rule, fpath, envExpected, t) - + testBuildEnv(&r, rule, fpath, envExpected) }