mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-23 22:27:38 -05:00
fastcgi: Only perform extra copy if necessary; added tests
This commit is contained in:
parent
367397dbd6
commit
737c7c4372
2 changed files with 79 additions and 33 deletions
|
@ -72,7 +72,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
|||
|
||||
// Connect to FastCGI gateway
|
||||
network, address := rule.parseAddress()
|
||||
fcgi, err := Dial(network, address)
|
||||
fcgiBackend, err := Dial(network, address)
|
||||
if err != nil {
|
||||
return http.StatusBadGateway, err
|
||||
}
|
||||
|
@ -81,19 +81,19 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
|||
contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length"))
|
||||
switch r.Method {
|
||||
case "HEAD":
|
||||
resp, err = fcgi.Head(env)
|
||||
resp, err = fcgiBackend.Head(env)
|
||||
case "GET":
|
||||
resp, err = fcgi.Get(env)
|
||||
resp, err = fcgiBackend.Get(env)
|
||||
case "OPTIONS":
|
||||
resp, err = fcgi.Options(env)
|
||||
resp, err = fcgiBackend.Options(env)
|
||||
case "POST":
|
||||
resp, err = fcgi.Post(env, r.Header.Get("Content-Type"), r.Body, contentLength)
|
||||
resp, err = fcgiBackend.Post(env, r.Header.Get("Content-Type"), r.Body, contentLength)
|
||||
case "PUT":
|
||||
resp, err = fcgi.Put(env, r.Header.Get("Content-Type"), r.Body, contentLength)
|
||||
resp, err = fcgiBackend.Put(env, r.Header.Get("Content-Type"), r.Body, contentLength)
|
||||
case "PATCH":
|
||||
resp, err = fcgi.Patch(env, r.Header.Get("Content-Type"), r.Body, contentLength)
|
||||
resp, err = fcgiBackend.Patch(env, r.Header.Get("Content-Type"), r.Body, contentLength)
|
||||
case "DELETE":
|
||||
resp, err = fcgi.Delete(env, r.Header.Get("Content-Type"), r.Body, contentLength)
|
||||
resp, err = fcgiBackend.Delete(env, r.Header.Get("Content-Type"), r.Body, contentLength)
|
||||
default:
|
||||
return http.StatusMethodNotAllowed, nil
|
||||
}
|
||||
|
@ -106,29 +106,35 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
|||
return http.StatusBadGateway, err
|
||||
}
|
||||
|
||||
// Write the response body to a buffer
|
||||
// To explicitly set Content-Length
|
||||
// For FastCGI app that don't set it
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, resp.Body)
|
||||
var responseBody io.Reader = resp.Body
|
||||
if r.Header.Get("Content-Length") == "" {
|
||||
// If the upstream app didn't set a Content-Length (shame on them),
|
||||
// we need to do it to prevent error messages being appended to
|
||||
// an already-written response, and other problematic behavior.
|
||||
// So we copy it to a buffer and read its size before flushing
|
||||
// the response out to the client. See issues #567 and #614.
|
||||
buf := new(bytes.Buffer)
|
||||
_, err := io.Copy(buf, resp.Body)
|
||||
if err != nil {
|
||||
return http.StatusBadGateway, err
|
||||
}
|
||||
w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
|
||||
responseBody = buf
|
||||
}
|
||||
|
||||
// Write the status code and header fields
|
||||
writeHeader(w, resp)
|
||||
|
||||
// Write the response body
|
||||
// TODO: If this has an error, the response will already be
|
||||
// partly written. We should copy out of resp.Body into a buffer
|
||||
// first, then write it to the response...
|
||||
_, err = io.Copy(w, &buf)
|
||||
_, err = io.Copy(w, responseBody)
|
||||
if err != nil {
|
||||
return http.StatusBadGateway, err
|
||||
}
|
||||
|
||||
// FastCGI stderr outputs
|
||||
if fcgi.stderr.Len() != 0 {
|
||||
if fcgiBackend.stderr.Len() != 0 {
|
||||
// Remove trailing newline, error logger already does this.
|
||||
err = LogError(strings.TrimSuffix(fcgi.stderr.String(), "\n"))
|
||||
err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
|
||||
}
|
||||
|
||||
return resp.StatusCode, err
|
||||
|
|
|
@ -1,13 +1,61 @@
|
|||
package fastcgi
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/fcgi"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRuleParseAddress(t *testing.T) {
|
||||
func TestServeHTTPContentLength(t *testing.T) {
|
||||
testWithBackend := func(body string, setContentLength bool) {
|
||||
bodyLenStr := strconv.Itoa(len(body))
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("BackendSetsContentLength=%v: Unable to create listener for test: %v", setContentLength, err)
|
||||
}
|
||||
defer listener.Close()
|
||||
go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if setContentLength {
|
||||
w.Header().Set("Content-Length", bodyLenStr)
|
||||
}
|
||||
w.Write([]byte(body))
|
||||
}))
|
||||
|
||||
handler := Handler{
|
||||
Next: nil,
|
||||
Rules: []Rule{{Path: "/", Address: listener.Addr().String()}},
|
||||
}
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("BackendSetsContentLength=%v: Unable to create request: %v", setContentLength, err)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
status, err := handler.ServeHTTP(w, r)
|
||||
|
||||
if got, want := status, http.StatusOK; got != want {
|
||||
t.Errorf("BackendSetsContentLength=%v: Expected returned status code to be %d, got %d", setContentLength, want, got)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("BackendSetsContentLength=%v: Expected nil error, got: %v", setContentLength, err)
|
||||
}
|
||||
if got, want := w.Header().Get("Content-Length"), bodyLenStr; got != want {
|
||||
t.Errorf("BackendSetsContentLength=%v: Expected Content-Length to be '%s', got: '%s'", setContentLength, want, got)
|
||||
}
|
||||
if got, want := w.Body.String(), body; got != want {
|
||||
t.Errorf("BackendSetsContentLength=%v: Expected response body to be '%s', got: '%s'", setContentLength, want, got)
|
||||
}
|
||||
}
|
||||
|
||||
testWithBackend("Backend does NOT set Content-Length", false)
|
||||
testWithBackend("Backend sets Content-Length", true)
|
||||
}
|
||||
|
||||
func TestRuleParseAddress(t *testing.T) {
|
||||
getClientTestTable := []struct {
|
||||
rule *Rule
|
||||
expectednetwork string
|
||||
|
@ -27,28 +75,21 @@ func TestRuleParseAddress(t *testing.T) {
|
|||
if _, actualaddress := entry.rule.parseAddress(); actualaddress != entry.expectedaddress {
|
||||
t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", entry.rule.Address, actualaddress, entry.expectedaddress)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestBuildEnv(t *testing.T) {
|
||||
|
||||
buildEnvSingle := func(r *http.Request, rule Rule, fpath string, envExpected map[string]string, t *testing.T) {
|
||||
|
||||
h := Handler{}
|
||||
|
||||
testBuildEnv := func(r *http.Request, rule Rule, fpath string, envExpected map[string]string) {
|
||||
var h Handler
|
||||
env, err := h.buildEnv(r, rule, fpath)
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err.Error())
|
||||
}
|
||||
|
||||
for k, v := range envExpected {
|
||||
if env[k] != v {
|
||||
t.Errorf("Unexpected %v. Got %v, expected %v", k, env[k], v)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
rule := Rule{}
|
||||
|
@ -80,16 +121,15 @@ func TestBuildEnv(t *testing.T) {
|
|||
}
|
||||
|
||||
// 1. Test for full canonical IPv6 address
|
||||
buildEnvSingle(&r, rule, fpath, envExpected, t)
|
||||
testBuildEnv(&r, rule, fpath, envExpected)
|
||||
|
||||
// 2. Test for shorthand notation of IPv6 address
|
||||
r.RemoteAddr = "[::1]:51688"
|
||||
envExpected["REMOTE_ADDR"] = "[::1]"
|
||||
buildEnvSingle(&r, rule, fpath, envExpected, t)
|
||||
testBuildEnv(&r, rule, fpath, envExpected)
|
||||
|
||||
// 3. Test for IPv4 address
|
||||
r.RemoteAddr = "192.168.0.10:51688"
|
||||
envExpected["REMOTE_ADDR"] = "192.168.0.10"
|
||||
buildEnvSingle(&r, rule, fpath, envExpected, t)
|
||||
|
||||
testBuildEnv(&r, rule, fpath, envExpected)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue