diff --git a/caddyhttp/gzip/gzip.go b/caddyhttp/gzip/gzip.go index 3d34ca27..48172c72 100644 --- a/caddyhttp/gzip/gzip.go +++ b/caddyhttp/gzip/gzip.go @@ -9,8 +9,6 @@ import ( "net/http" "strings" - "errors" - "github.com/mholt/caddy" "github.com/mholt/caddy/caddyhttp/httpserver" ) @@ -155,7 +153,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { return pusher.Push(target, opts) } - return errors.New("push is unavailable (probably chained http.ResponseWriter does not implement http.Pusher)") + return httpserver.NonFlusherError{Underlying: w.ResponseWriter} } // Interface guards diff --git a/caddyhttp/header/header.go b/caddyhttp/header/header.go index c2c3fad8..a2685fbb 100644 --- a/caddyhttp/header/header.go +++ b/caddyhttp/header/header.go @@ -9,7 +9,6 @@ import ( "net/http" "strings" - "errors" "github.com/mholt/caddy/caddyhttp/httpserver" ) @@ -141,7 +140,7 @@ func (rww *responseWriterWrapper) Push(target string, opts *http.PushOptions) er return pusher.Push(target, opts) } - return errors.New("push is unavailable (probably chained http.ResponseWriter does not implement http.Pusher)") + return httpserver.NonPusherError{Underlying: rww.ResponseWriter} } // Interface guards diff --git a/caddyhttp/httpserver/error.go b/caddyhttp/httpserver/error.go index 2cfa530e..2fbd486c 100644 --- a/caddyhttp/httpserver/error.go +++ b/caddyhttp/httpserver/error.go @@ -8,6 +8,7 @@ var ( _ error = NonHijackerError{} _ error = NonFlusherError{} _ error = NonCloseNotifierError{} + _ error = NonPusherError{} ) // NonHijackerError is more descriptive error caused by a non hijacker @@ -42,3 +43,14 @@ type NonCloseNotifierError struct { func (c NonCloseNotifierError) Error() string { return fmt.Sprintf("%T is not a closeNotifier", c.Underlying) } + +// NonPusherError is more descriptive error caused by a non pusher +type NonPusherError struct { + // underlying type which doesn't implement pusher + Underlying interface{} +} + +// Implement Error +func (c NonPusherError) Error() string { + return fmt.Sprintf("%T is not a pusher", c.Underlying) +} diff --git a/caddyhttp/httpserver/recorder.go b/caddyhttp/httpserver/recorder.go index f9f70bcc..2bb919fc 100644 --- a/caddyhttp/httpserver/recorder.go +++ b/caddyhttp/httpserver/recorder.go @@ -2,7 +2,6 @@ package httpserver import ( "bufio" - "errors" "net" "net/http" "time" @@ -103,7 +102,7 @@ func (r *ResponseRecorder) Push(target string, opts *http.PushOptions) error { return pusher.Push(target, opts) } - return errors.New("push is unavailable (probably chained http.ResponseWriter does not implement http.Pusher)") + return NonPusherError{Underlying: r.ResponseWriter} } // Interface guards diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index a7922d4a..044a3127 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -2,6 +2,7 @@ package proxy import ( + "context" "errors" "net" "net/http" @@ -103,7 +104,8 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { replacer := httpserver.NewReplacer(r, nil, "") // outreq is the request that makes a roundtrip to the backend - outreq := createUpstreamRequest(r) + outreq, cancel := createUpstreamRequest(w, r) + defer cancel() // If we have more than one upstream host defined and if retrying is enabled // by setting try_duration to a non-zero value, caddy will try to @@ -131,7 +133,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { // loop and try to select another host, or false if we // should break and stop retrying. start := time.Now() - keepRetrying := func() bool { + keepRetrying := func(backendErr error) bool { + // if downstream has canceled the request, break + if backendErr == context.Canceled { + return false + } // if we've tried long enough, break if time.Since(start) >= upstream.GetTryDuration() { return false @@ -150,7 +156,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { if backendErr == nil { backendErr = errors.New("no hosts available upstream") } - if !keepRetrying() { + if !keepRetrying(backendErr) { break } continue @@ -238,7 +244,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { } // if we've tried long enough, break - if !keepRetrying() { + if !keepRetrying(backendErr) { break } } @@ -267,9 +273,23 @@ func (p Proxy) match(r *http.Request) Upstream { // that can be sent upstream. // // Derived from reverseproxy.go in the standard Go httputil package. -func createUpstreamRequest(r *http.Request) *http.Request { - outreq := new(http.Request) - *outreq = *r // includes shallow copies of maps, but okay +func createUpstreamRequest(rw http.ResponseWriter, r *http.Request) (*http.Request, context.CancelFunc) { + // Original incoming server request may be canceled by the + // user or by std lib(e.g. too many idle connections). + ctx, cancel := context.WithCancel(r.Context()) + if cn, ok := rw.(http.CloseNotifier); ok { + notifyChan := cn.CloseNotify() + go func() { + select { + case <-notifyChan: + cancel() + case <-ctx.Done(): + } + }() + } + + outreq := r.WithContext(ctx) // includes shallow copies of maps, but okay + // We should set body to nil explicitly if request body is empty. // For server requests the Request Body is always non-nil. if r.ContentLength == 0 { @@ -319,7 +339,7 @@ func createUpstreamRequest(r *http.Request) *http.Request { outreq.Header.Set("X-Forwarded-For", clientIP) } - return outreq + return outreq, cancel } func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn { diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index 6ba05000..d81c411b 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -12,7 +12,6 @@ import ( "net" "net/http" "net/http/httptest" - "net/http/httptrace" "net/url" "os" "path/filepath" @@ -101,7 +100,7 @@ func TestReverseProxy(t *testing.T) { // Make sure {upstream} placeholder is set r.Body = ioutil.NopCloser(strings.NewReader("test")) - rr := httpserver.NewResponseRecorder(httptest.NewRecorder()) + rr := httpserver.NewResponseRecorder(testResponseRecorder{httptest.NewRecorder()}) rr.Replacer = httpserver.NewReplacer(r, rr, "-") p.ServeHTTP(rr, r) @@ -1123,7 +1122,18 @@ func TestReverseProxyLargeBody(t *testing.T) { } func TestCancelRequest(t *testing.T) { + reqInFlight := make(chan struct{}) backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(reqInFlight) // cause the client to cancel its request + + select { + case <-time.After(10 * time.Second): + t.Error("Handler never saw CloseNotify") + return + case <-w.(http.CloseNotifier).CloseNotify(): + } + + w.WriteHeader(http.StatusOK) w.Write([]byte("Hello, client")) })) defer backend.Close() @@ -1140,26 +1150,21 @@ func TestCancelRequest(t *testing.T) { defer cancel() req = req.WithContext(ctx) - // add GotConn hook to cancel the request - gotC := make(chan struct{}) - defer close(gotC) - trace := &httptrace.ClientTrace{ - GotConn: func(connInfo httptrace.GotConnInfo) { - gotC <- struct{}{} - }, - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - // wait for canceling the request go func() { - <-gotC + <-reqInFlight cancel() }() - status, err := p.ServeHTTP(httptest.NewRecorder(), req) - if status != 0 || err != nil { - t.Errorf("expect proxy handle normally, but not, status:%d, err:%q", - status, err) + rec := httptest.NewRecorder() + status, err := p.ServeHTTP(rec, req) + expectedStatus, expectErr := http.StatusBadGateway, context.Canceled + if status != expectedStatus || err != expectErr { + t.Errorf("expect proxy handle return status[%d] with error[%v], but got status[%d] with error[%v]", + expectedStatus, expectErr, status, err) + } + if body := rec.Body.String(); body != "" { + t.Errorf("expect a blank response, but got %q", body) } } @@ -1310,6 +1315,28 @@ func (c *fakeConn) Close() error { return nil } func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) } func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) } +// testResponseRecorder wraps `httptest.ResponseRecorder`, +// also implements `http.CloseNotifier`, `http.Hijacker` and `http.Pusher`. +type testResponseRecorder struct { + *httptest.ResponseRecorder +} + +func (testResponseRecorder) CloseNotify() <-chan bool { return nil } +func (t testResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, httpserver.NonHijackerError{Underlying: t} +} +func (t testResponseRecorder) Push(target string, opts *http.PushOptions) error { + return httpserver.NonPusherError{Underlying: t} +} + +// Interface guards +var ( + _ http.Pusher = testResponseRecorder{} + _ http.Flusher = testResponseRecorder{} + _ http.CloseNotifier = testResponseRecorder{} + _ http.Hijacker = testResponseRecorder{} +) + func BenchmarkProxy(b *testing.B) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello, client")) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index eb0f63e8..c2bdbaff 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -12,7 +12,6 @@ package proxy import ( - "context" "crypto/tls" "io" "net" @@ -252,14 +251,6 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, rp.Director(outreq) - // Original incoming server request may be canceled by the - // user or by std lib(e.g. too many idle connections). - // Now we issue the new outgoing client request which - // doesn't depend on the original one. (issue 1345) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - outreq = outreq.WithContext(ctx) - res, err := transport.RoundTrip(outreq) if err != nil { return err