diff --git a/middleware/errors/errors.go b/middleware/errors/errors.go index 33a15269..ccd7e6af 100644 --- a/middleware/errors/errors.go +++ b/middleware/errors/errors.go @@ -43,7 +43,9 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er } if status >= 400 { - h.errorPage(w, r, status) + if w.Header().Get("Content-Length") == "" { + h.errorPage(w, r, status) + } return 0, err } diff --git a/middleware/errors/errors_test.go b/middleware/errors/errors_test.go index 8afa6bff..c0cf6325 100644 --- a/middleware/errors/errors_test.go +++ b/middleware/errors/errors_test.go @@ -9,6 +9,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "strconv" "strings" "testing" @@ -78,6 +79,13 @@ func TestErrors(t *testing.T) { expectedLog: "", expectedErr: nil, }, + { + next: genErrorHandler(http.StatusNotFound, nil, "normal"), + expectedCode: 0, + expectedBody: "normal", + expectedLog: "", + expectedErr: nil, + }, { next: genErrorHandler(http.StatusForbidden, nil, ""), expectedCode: 0, @@ -158,6 +166,9 @@ func TestVisibleErrorWithPanic(t *testing.T) { func genErrorHandler(status int, err error, body string) middleware.Handler { return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + if len(body) > 0 { + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + } fmt.Fprint(w, body) return status, err }) diff --git a/middleware/fastcgi/fastcgi.go b/middleware/fastcgi/fastcgi.go index c4ca935e..fa9a6c46 100644 --- a/middleware/fastcgi/fastcgi.go +++ b/middleware/fastcgi/fastcgi.go @@ -107,7 +107,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) } var responseBody io.Reader = resp.Body - if r.Header.Get("Content-Length") == "" { + if resp.Header.Get("Content-Length") == "" { // If the upstream app didn't set a Content-Length (shame on them), // we need to do it to prevent error messages being appended to // an already-written response, and other problematic behavior. @@ -137,6 +137,11 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n")) } + // Normally we should only return a status >= 400 if no response + // body is written yet, however, upstream apps don't know about + // this contract and we still want the correct code logged, so error + // handling code in our stack needs to check Content-Length before + // writing an error message... oh well. return resp.StatusCode, err } } diff --git a/middleware/log/log.go b/middleware/log/log.go index feb6182a..acb695c5 100644 --- a/middleware/log/log.go +++ b/middleware/log/log.go @@ -26,7 +26,7 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { // The error must be handled here so the log entry will record the response size. if l.ErrorFunc != nil { l.ErrorFunc(responseRecorder, r, status) - } else { + } else if responseRecorder.Header().Get("Content-Length") == "" { // ensure no body written since proxy backends may write an error page // Default failover error handler responseRecorder.WriteHeader(status) fmt.Fprintf(responseRecorder, "%d %s", status, http.StatusText(status)) diff --git a/server/server.go b/server/server.go index 3a336f3b..2df0deac 100644 --- a/server/server.go +++ b/server/server.go @@ -319,7 +319,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { status, _ := vh.stack.ServeHTTP(w, r) // Fallback error response in case error handling wasn't chained in - if status >= 400 { + if status >= 400 && w.Header().Get("Content-Length") == "" { DefaultErrorFunc(w, r, status) } } else { @@ -417,36 +417,6 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) { return ln.TCPListener.File() } -// copied from net/http/transport.go -/* - TODO - remove - not necessary? -func cloneTLSConfig(cfg *tls.Config) *tls.Config { - if cfg == nil { - return &tls.Config{} - } - return &tls.Config{ - Rand: cfg.Rand, - Time: cfg.Time, - Certificates: cfg.Certificates, - NameToCertificate: cfg.NameToCertificate, - GetCertificate: cfg.GetCertificate, - RootCAs: cfg.RootCAs, - NextProtos: cfg.NextProtos, - ServerName: cfg.ServerName, - ClientAuth: cfg.ClientAuth, - ClientCAs: cfg.ClientCAs, - InsecureSkipVerify: cfg.InsecureSkipVerify, - CipherSuites: cfg.CipherSuites, - PreferServerCipherSuites: cfg.PreferServerCipherSuites, - SessionTicketsDisabled: cfg.SessionTicketsDisabled, - SessionTicketKey: cfg.SessionTicketKey, - ClientSessionCache: cfg.ClientSessionCache, - MinVersion: cfg.MinVersion, - MaxVersion: cfg.MaxVersion, - CurvePreferences: cfg.CurvePreferences, - } -}*/ - // ShutdownCallbacks executes all the shutdown callbacks // for all the virtualhosts in servers, and returns all the // errors generated during their execution. In other words,