diff --git a/caddyhttp/fastcgi/dialer.go b/caddyhttp/fastcgi/dialer.go index be3cfe56..135908e3 100644 --- a/caddyhttp/fastcgi/dialer.go +++ b/caddyhttp/fastcgi/dialer.go @@ -19,8 +19,11 @@ type basicDialer struct { timeout time.Duration } -func (b basicDialer) Dial() (Client, error) { return Dial(b.network, b.address, b.timeout) } -func (b basicDialer) Close(c Client) error { return c.Close() } +func (b basicDialer) Dial() (Client, error) { + return DialTimeout(b.network, b.address, b.timeout) +} + +func (b basicDialer) Close(c Client) error { return c.Close() } // persistentDialer keeps a pool of fcgi connections. // connections are not closed after use, rather added back to the pool for reuse. @@ -47,7 +50,7 @@ func (p *persistentDialer) Dial() (Client, error) { p.Unlock() // no connection available, create new one - return Dial(p.network, p.address, p.timeout) + return DialTimeout(p.network, p.address, p.timeout) } func (p *persistentDialer) Close(client Client) error { diff --git a/caddyhttp/fastcgi/fastcgi.go b/caddyhttp/fastcgi/fastcgi.go index 74999592..c4dd9b32 100644 --- a/caddyhttp/fastcgi/fastcgi.go +++ b/caddyhttp/fastcgi/fastcgi.go @@ -6,6 +6,7 @@ package fastcgi import ( "errors" "io" + "net" "net/http" "os" "path" @@ -80,9 +81,14 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) // Connect to FastCGI gateway fcgiBackend, err := rule.dialer.Dial() if err != nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + return http.StatusGatewayTimeout, err + } return http.StatusBadGateway, err } + defer fcgiBackend.Close() fcgiBackend.SetReadTimeout(rule.ReadTimeout) + fcgiBackend.SetSendTimeout(rule.SendTimeout) var resp *http.Response contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length")) @@ -97,8 +103,12 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength) } - if err != nil && err != io.EOF { - return http.StatusBadGateway, err + if err != nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + return http.StatusGatewayTimeout, err + } else if err != io.EOF { + return http.StatusBadGateway, err + } } // Write response header @@ -110,8 +120,6 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) return http.StatusBadGateway, err } - defer rule.dialer.Close(fcgiBackend) - // Log any stderr output from upstream if stderr := fcgiBackend.StdErr(); stderr.Len() != 0 { // Remove trailing newline, error logger already does this. @@ -306,6 +314,9 @@ type Rule struct { // The duration used to set a deadline when reading from the FastCGI server. ReadTimeout time.Duration + // The duration used to set a deadline when sending to the FastCGI server. + SendTimeout time.Duration + // FCGI dialer dialer dialer } diff --git a/caddyhttp/fastcgi/fastcgi_test.go b/caddyhttp/fastcgi/fastcgi_test.go index b84a78c5..fdaf78bb 100644 --- a/caddyhttp/fastcgi/fastcgi_test.go +++ b/caddyhttp/fastcgi/fastcgi_test.go @@ -327,38 +327,124 @@ func TestBuildEnv(t *testing.T) { } func TestReadTimeout(t *testing.T) { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("Unable to create listener for test: %v", err) + tests := []struct { + sleep time.Duration + readTimeout time.Duration + shouldErr bool + }{ + {75 * time.Millisecond, 50 * time.Millisecond, true}, + {0, -1 * time.Second, true}, + {0, time.Minute, false}, } - defer listener.Close() - network, address := parseAddress(listener.Addr().String()) - handler := Handler{ - Next: nil, - Rules: []Rule{ - { - Path: "/", - Address: listener.Addr().String(), - dialer: basicDialer{network: network, address: address}, - ReadTimeout: time.Millisecond * 100, + var wg sync.WaitGroup + + for i, test := range tests { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Test %d: Unable to create listener for test: %v", i, err) + } + defer listener.Close() + + network, address := parseAddress(listener.Addr().String()) + handler := Handler{ + Next: nil, + Rules: []Rule{ + { + Path: "/", + Address: listener.Addr().String(), + dialer: basicDialer{network: network, address: address}, + ReadTimeout: test.readTimeout, + }, }, - }, - } - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatalf("Unable to create request: %v", err) - } - w := httptest.NewRecorder() + } + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Test %d: Unable to create request: %v", i, err) + } + w := httptest.NewRecorder() - go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(time.Millisecond * 130) - })) + wg.Add(1) + go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(test.sleep) + w.WriteHeader(http.StatusOK) + wg.Done() + })) - _, err = handler.ServeHTTP(w, r) - if err == nil { - t.Error("Expected i/o timeout error but had none") - } else if err, ok := err.(net.Error); !ok || !err.Timeout() { - t.Errorf("Expected i/o timeout error, got: '%s'", err.Error()) + got, err := handler.ServeHTTP(w, r) + if test.shouldErr { + if err == nil { + t.Errorf("Test %d: Expected i/o timeout error but had none", i) + } else if err, ok := err.(net.Error); !ok || !err.Timeout() { + t.Errorf("Test %d: Expected i/o timeout error, got: '%s'", i, err.Error()) + } + + want := http.StatusGatewayTimeout + if got != want { + t.Errorf("Test %d: Expected returned status code to be %d, got: %d", + i, want, got) + } + } else if err != nil { + t.Errorf("Test %d: Expected nil error, got: %v", i, err) + } + + wg.Wait() + } +} + +func TestSendTimeout(t *testing.T) { + tests := []struct { + sendTimeout time.Duration + shouldErr bool + }{ + {-1 * time.Second, true}, + {time.Minute, false}, + } + + for i, test := range tests { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Test %d: Unable to create listener for test: %v", i, err) + } + defer listener.Close() + + network, address := parseAddress(listener.Addr().String()) + handler := Handler{ + Next: nil, + Rules: []Rule{ + { + Path: "/", + Address: listener.Addr().String(), + dialer: basicDialer{network: network, address: address}, + SendTimeout: test.sendTimeout, + }, + }, + } + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Test %d: Unable to create request: %v", i, err) + } + w := httptest.NewRecorder() + + go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + got, err := handler.ServeHTTP(w, r) + if test.shouldErr { + if err == nil { + t.Errorf("Test %d: Expected i/o timeout error but had none", i) + } else if err, ok := err.(net.Error); !ok || !err.Timeout() { + t.Errorf("Test %d: Expected i/o timeout error, got: '%s'", i, err.Error()) + } + + want := http.StatusGatewayTimeout + if got != want { + t.Errorf("Test %d: Expected returned status code to be %d, got: %d", + i, want, got) + } + } else if err != nil { + t.Errorf("Test %d: Expected nil error, got: %v", i, err) + } } } diff --git a/caddyhttp/fastcgi/fcgiclient.go b/caddyhttp/fastcgi/fcgiclient.go index 1160be6b..d7db291c 100644 --- a/caddyhttp/fastcgi/fcgiclient.go +++ b/caddyhttp/fastcgi/fcgiclient.go @@ -15,7 +15,6 @@ import ( "bytes" "encoding/binary" "errors" - "fmt" "io" "io/ioutil" "mime/multipart" @@ -116,8 +115,8 @@ type Client interface { Post(pairs map[string]string, method string, bodyType string, body io.Reader, contentLength int) (response *http.Response, err error) Close() error StdErr() bytes.Buffer - ReadTimeout() time.Duration SetReadTimeout(time.Duration) error + SetSendTimeout(time.Duration) error } type header struct { @@ -174,57 +173,32 @@ func (rec *record) read(r io.Reader) (buf []byte, err error) { // interfacing external applications with Web servers. type FCGIClient struct { mutex sync.Mutex - rwc io.ReadWriteCloser + conn net.Conn h header buf bytes.Buffer stderr bytes.Buffer keepAlive bool reqID uint16 readTimeout time.Duration + sendTimeout time.Duration } -// DialWithDialer connects to the fcgi responder at the specified network address, using custom net.Dialer. +// DialTimeout connects to the fcgi responder at the specified network address, using default net.Dialer. // See func net.Dial for a description of the network and address parameters. -func DialWithDialer(network, address string, dialer net.Dialer) (fcgi *FCGIClient, err error) { - var conn net.Conn - conn, err = dialer.Dial(network, address) +func DialTimeout(network string, address string, timeout time.Duration) (fcgi *FCGIClient, err error) { + conn, err := net.DialTimeout(network, address, timeout) if err != nil { return } - fcgi = &FCGIClient{ - rwc: conn, - keepAlive: false, - reqID: 1, - } + fcgi = &FCGIClient{conn: conn, keepAlive: false, reqID: 1} - return -} - -// Dial connects to the fcgi responder at the specified network address, using default net.Dialer. -// See func net.Dial for a description of the network and address parameters. -func Dial(network string, address string, timeout time.Duration) (fcgi *FCGIClient, err error) { - return DialWithDialer(network, address, net.Dialer{Timeout: timeout}) + return fcgi, nil } // Close closes fcgi connnection. func (c *FCGIClient) Close() error { - return c.rwc.Close() -} - -// setReadDeadline sets a read deadline on FCGIClient based on the configured -// readTimeout. A zero value for readTimeout means no deadline will be set. -func (c *FCGIClient) setReadDeadline() error { - if c.readTimeout > 0 { - conn, ok := c.rwc.(net.Conn) - if ok { - conn.SetReadDeadline(time.Now().Add(c.readTimeout)) - } else { - return fmt.Errorf("Could not set Client ReadTimeout") - } - } - - return nil + return c.conn.Close() } func (c *FCGIClient) writeRecord(recType uint8, content []byte) error { @@ -245,7 +219,13 @@ func (c *FCGIClient) writeRecord(recType uint8, content []byte) error { return err } - if _, err := c.rwc.Write(c.buf.Bytes()); err != nil { + if c.sendTimeout != 0 { + if err := c.conn.SetWriteDeadline(time.Now().Add(c.sendTimeout)); err != nil { + return err + } + } + + if _, err := c.conn.Write(c.buf.Bytes()); err != nil { return err } @@ -369,7 +349,7 @@ func (w *streamReader) Read(p []byte) (n int, err error) { for { rec := &record{} var buf []byte - buf, err = rec.read(w.c.rwc) + buf, err = rec.read(w.c.conn) if err == errInvalidHeaderVersion { continue } else if err != nil { @@ -436,7 +416,6 @@ func (c clientCloser) Close() error { return c.f.Close() } // Request returns a HTTP Response with Header and Body // from fcgi responder func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Response, err error) { - r, err := c.Do(p, req) if err != nil { return @@ -446,8 +425,10 @@ func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Res tp := textproto.NewReader(rb) resp = new(http.Response) - if err = c.setReadDeadline(); err != nil { - return + if c.readTimeout != 0 { + if err = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil { + return + } } // Parse the response headers. @@ -582,10 +563,6 @@ func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[str return c.Post(p, "POST", bodyType, buf, buf.Len()) } -// ReadTimeout returns the read timeout for future calls that read from the -// fcgi responder. -func (c *FCGIClient) ReadTimeout() time.Duration { return c.readTimeout } - // SetReadTimeout sets the read timeout for future calls that read from the // fcgi responder. A zero value for t means no timeout will be set. func (c *FCGIClient) SetReadTimeout(t time.Duration) error { @@ -593,6 +570,13 @@ func (c *FCGIClient) SetReadTimeout(t time.Duration) error { return nil } +// SetSendTimeout sets the read timeout for future calls that send data to +// the fcgi responder. A zero value for t means no timeout will be set. +func (c *FCGIClient) SetSendTimeout(t time.Duration) error { + c.sendTimeout = t + return nil +} + // Checks whether chunked is part of the encodings stack func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } diff --git a/caddyhttp/fastcgi/fcgiclient_test.go b/caddyhttp/fastcgi/fcgiclient_test.go index bc2a2244..59879332 100644 --- a/caddyhttp/fastcgi/fcgiclient_test.go +++ b/caddyhttp/fastcgi/fcgiclient_test.go @@ -103,7 +103,7 @@ func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) { } func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) { - fcgi, err := Dial("tcp", ipPort, 0) + fcgi, err := DialTimeout("tcp", ipPort, 0) if err != nil { log.Println("err:", err) return diff --git a/caddyhttp/fastcgi/setup.go b/caddyhttp/fastcgi/setup.go index 5382ff67..e3615afe 100644 --- a/caddyhttp/fastcgi/setup.go +++ b/caddyhttp/fastcgi/setup.go @@ -59,7 +59,7 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { return rules, c.ArgErr() } - rule := Rule{Path: args[0], ReadTimeout: 60 * time.Second} + rule := Rule{Path: args[0], ReadTimeout: 60 * time.Second, SendTimeout: 60 * time.Second} upstreams := []string{args[1]} if len(args) == 3 { @@ -144,6 +144,15 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { return rules, err } rule.ReadTimeout = readTimeout + case "send_timeout": + if !c.NextArg() { + return rules, c.ArgErr() + } + sendTimeout, err := time.ParseDuration(c.Val()) + if err != nil { + return rules, err + } + rule.SendTimeout = sendTimeout } } diff --git a/caddyhttp/fastcgi/setup_test.go b/caddyhttp/fastcgi/setup_test.go index c5e1b681..488011b6 100644 --- a/caddyhttp/fastcgi/setup_test.go +++ b/caddyhttp/fastcgi/setup_test.go @@ -80,6 +80,7 @@ func TestFastcgiParse(t *testing.T) { dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000", timeout: 60 * time.Second}}}, IndexFiles: []string{"index.php"}, ReadTimeout: 60 * time.Second, + SendTimeout: 60 * time.Second, }}}, {`fastcgi /blog 127.0.0.1:9000 php { upstream 127.0.0.1:9001 @@ -92,6 +93,7 @@ func TestFastcgiParse(t *testing.T) { dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000", timeout: 60 * time.Second}, basicDialer{network: "tcp", address: "127.0.0.1:9001", timeout: 60 * time.Second}}}, IndexFiles: []string{"index.php"}, ReadTimeout: 60 * time.Second, + SendTimeout: 60 * time.Second, }}}, {`fastcgi /blog 127.0.0.1:9000 { upstream 127.0.0.1:9001 @@ -104,6 +106,7 @@ func TestFastcgiParse(t *testing.T) { dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000", timeout: 60 * time.Second}, basicDialer{network: "tcp", address: "127.0.0.1:9001", timeout: 60 * time.Second}}}, IndexFiles: []string{}, ReadTimeout: 60 * time.Second, + SendTimeout: 60 * time.Second, }}}, {`fastcgi / ` + defaultAddress + ` { split .html @@ -116,6 +119,7 @@ func TestFastcgiParse(t *testing.T) { dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 60 * time.Second}}}, IndexFiles: []string{}, ReadTimeout: 60 * time.Second, + SendTimeout: 60 * time.Second, }}}, {`fastcgi / ` + defaultAddress + ` { split .html @@ -130,6 +134,7 @@ func TestFastcgiParse(t *testing.T) { IndexFiles: []string{}, IgnoredSubPaths: []string{"/admin", "/user"}, ReadTimeout: 60 * time.Second, + SendTimeout: 60 * time.Second, }}}, {`fastcgi / ` + defaultAddress + ` { pool 0 @@ -142,6 +147,7 @@ func TestFastcgiParse(t *testing.T) { dialer: &loadBalancingDialer{dialers: []dialer{&persistentDialer{size: 0, network: network, address: address, timeout: 60 * time.Second}}}, IndexFiles: []string{}, ReadTimeout: 60 * time.Second, + SendTimeout: 60 * time.Second, }}}, {`fastcgi / 127.0.0.1:8080 { upstream 127.0.0.1:9000 @@ -155,6 +161,7 @@ func TestFastcgiParse(t *testing.T) { dialer: &loadBalancingDialer{dialers: []dialer{&persistentDialer{size: 5, network: "tcp", address: "127.0.0.1:8080", timeout: 60 * time.Second}, &persistentDialer{size: 5, network: "tcp", address: "127.0.0.1:9000", timeout: 60 * time.Second}}}, IndexFiles: []string{}, ReadTimeout: 60 * time.Second, + SendTimeout: 60 * time.Second, }}}, {`fastcgi / ` + defaultAddress + ` { split .php @@ -167,6 +174,7 @@ func TestFastcgiParse(t *testing.T) { dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 60 * time.Second}}}, IndexFiles: []string{}, ReadTimeout: 60 * time.Second, + SendTimeout: 60 * time.Second, }}}, {`fastcgi / ` + defaultAddress + ` { connect_timeout 5s @@ -179,7 +187,13 @@ func TestFastcgiParse(t *testing.T) { dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 5 * time.Second}}}, IndexFiles: []string{}, ReadTimeout: 60 * time.Second, + SendTimeout: 60 * time.Second, }}}, + { + `fastcgi / ` + defaultAddress + ` { connect_timeout BADVALUE }`, + true, + []Rule{}, + }, {`fastcgi / ` + defaultAddress + ` { read_timeout 5s }`, @@ -191,7 +205,31 @@ func TestFastcgiParse(t *testing.T) { dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 60 * time.Second}}}, IndexFiles: []string{}, ReadTimeout: 5 * time.Second, + SendTimeout: 60 * time.Second, }}}, + { + `fastcgi / ` + defaultAddress + ` { read_timeout BADVALUE }`, + true, + []Rule{}, + }, + {`fastcgi / ` + defaultAddress + ` { + send_timeout 5s + }`, + false, []Rule{{ + Path: "/", + Address: defaultAddress, + Ext: "", + SplitPath: "", + dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 60 * time.Second}}}, + IndexFiles: []string{}, + ReadTimeout: 60 * time.Second, + SendTimeout: 5 * time.Second, + }}}, + { + `fastcgi / ` + defaultAddress + ` { send_timeout BADVALUE }`, + true, + []Rule{}, + }, {`fastcgi / { }`, @@ -251,6 +289,16 @@ func TestFastcgiParse(t *testing.T) { t.Errorf("Test %d expected %dth FastCGI IgnoredSubPaths to be %s , but got %s", i, j, test.expectedFastcgiConfig[j].IgnoredSubPaths, actualFastcgiConfig.IgnoredSubPaths) } + + if fmt.Sprint(actualFastcgiConfig.ReadTimeout) != fmt.Sprint(test.expectedFastcgiConfig[j].ReadTimeout) { + t.Errorf("Test %d expected %dth FastCGI ReadTimeout to be %s , but got %s", + i, j, test.expectedFastcgiConfig[j].ReadTimeout, actualFastcgiConfig.ReadTimeout) + } + + if fmt.Sprint(actualFastcgiConfig.SendTimeout) != fmt.Sprint(test.expectedFastcgiConfig[j].SendTimeout) { + t.Errorf("Test %d expected %dth FastCGI SendTimeout to be %s , but got %s", + i, j, test.expectedFastcgiConfig[j].SendTimeout, actualFastcgiConfig.SendTimeout) + } } } }