diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index 89fa21ae..fc2d727f 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -13,23 +13,32 @@ import ( "github.com/mholt/caddy/caddyhttp/httpserver" ) -var errUnreachable = errors.New("unreachable backend") - // Proxy represents a middleware instance that can proxy requests. type Proxy struct { Next httpserver.Handler Upstreams []Upstream } -// Upstream manages a pool of proxy upstream hosts. Select should return a -// suitable upstream host, or nil if no such hosts are available. +// Upstream manages a pool of proxy upstream hosts. type Upstream interface { // The path this upstream host should be routed on From() string - // Selects an upstream host to be routed to. + + // Selects an upstream host to be routed to. It + // should return a suitable upstream host, or nil + // if no such hosts are available. Select(*http.Request) *UpstreamHost + // Checks if subpath is not an ignored path AllowedPath(string) bool + + // Gets how long to try selecting upstream hosts + // in the case of cascading failures. + GetTryDuration() time.Duration + + // Gets how long to wait between selecting upstream + // hosts in the case of cascading failures. + GetTryInterval() time.Duration } // UpstreamHostDownFunc can be used to customize how Down behaves. @@ -71,10 +80,6 @@ func (uh *UpstreamHost) Available() bool { return !uh.Down() && !uh.Full() } -// tryDuration is how long to try upstream hosts; failures result in -// immediate retries until this duration ends or we get a nil host. -var tryDuration = 60 * time.Second - // ServeHTTP satisfies the httpserver.Handler interface. func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { // start by selecting most specific matching upstream config @@ -89,13 +94,33 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { // outreq is the request that makes a roundtrip to the backend outreq := createUpstreamRequest(r) - // since Select() should give us "up" hosts, keep retrying - // hosts until timeout (or until we get a nil host). + // The keepRetrying function will return true if we should + // loop and try to select another host, or false if we + // should break and stop retrying. start := time.Now() - for time.Now().Sub(start) < tryDuration { + keepRetrying := func() bool { + // if we've tried long enough, break + if time.Since(start) >= upstream.GetTryDuration() { + return false + } + // otherwise, wait and try the next available host + time.Sleep(upstream.GetTryInterval()) + return true + } + + var backendErr error + for { + // since Select() should give us "up" hosts, keep retrying + // hosts until timeout (or until we get a nil host). host := upstream.Select(r) if host == nil { - return http.StatusBadGateway, errUnreachable + if backendErr == nil { + backendErr = errors.New("no hosts available upstream") + } + if !keepRetrying() { + break + } + continue } if rr, ok := w.(*httpserver.ResponseRecorder); ok && rr.Replacer != nil { rr.Replacer.Set("upstream", host.Name) @@ -141,29 +166,35 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { // tell the proxy to serve the request atomic.AddInt64(&host.Conns, 1) - backendErr := proxy.ServeHTTP(w, outreq, downHeaderUpdateFn) + backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn) atomic.AddInt64(&host.Conns, -1) - // if no errors, we're done here; otherwise failover + // if no errors, we're done here if backendErr == nil { return 0, nil } + + // failover; remember this failure for some time if + // request failure counting is enabled timeout := host.FailTimeout - if timeout == 0 { - timeout = 10 * time.Second + if timeout > 0 { + atomic.AddInt32(&host.Fails, 1) + go func(host *UpstreamHost, timeout time.Duration) { + time.Sleep(timeout) + atomic.AddInt32(&host.Fails, -1) + }(host, timeout) + } + + // if we've tried long enough, break + if !keepRetrying() { + break } - atomic.AddInt32(&host.Fails, 1) - go func(host *UpstreamHost, timeout time.Duration) { - time.Sleep(timeout) - atomic.AddInt32(&host.Fails, -1) - }(host, timeout) } - return http.StatusBadGateway, errUnreachable + return http.StatusBadGateway, backendErr } -// match finds the best match for a proxy config based -// on r. +// match finds the best match for a proxy config based on r. func (p Proxy) match(r *http.Request) Upstream { var u Upstream var longestMatch int diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index e96898bd..a1fd889d 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -25,10 +25,6 @@ import ( "golang.org/x/net/websocket" ) -func init() { - tryDuration = 50 * time.Millisecond // prevent tests from hanging -} - func TestReverseProxy(t *testing.T) { log.SetOutput(ioutil.Discard) defer log.SetOutput(os.Stderr) @@ -792,9 +788,9 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost { return u.host } -func (u *fakeUpstream) AllowedPath(requestPath string) bool { - return true -} +func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true } +func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second } +func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } // newWebSocketTestProxy returns a test proxy that will // redirect to the specified backendAddr. The function @@ -834,9 +830,9 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost { } } -func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { - return true -} +func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true } +func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second } +func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } // recorderHijacker is a ResponseRecorder that can // be hijacked. diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go index 2a0f1a77..c5ca77f0 100644 --- a/caddyhttp/proxy/upstream.go +++ b/caddyhttp/proxy/upstream.go @@ -31,6 +31,8 @@ type staticUpstream struct { FailTimeout time.Duration MaxFails int32 + TryDuration time.Duration + TryInterval time.Duration MaxConns int64 HealthCheck struct { Client http.Client @@ -53,8 +55,8 @@ func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) { downstreamHeaders: make(http.Header), Hosts: nil, Policy: &Random{}, - FailTimeout: 10 * time.Second, MaxFails: 1, + TryInterval: 250 * time.Millisecond, MaxConns: 0, KeepAlive: http.DefaultMaxIdleConnsPerHost, } @@ -114,11 +116,6 @@ func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) { return upstreams, nil } -// RegisterPolicy adds a custom policy to the proxy. -func RegisterPolicy(name string, policy func() Policy) { - supportedPolicies[name] = policy -} - func (u *staticUpstream) From() string { return u.from } @@ -141,8 +138,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { if uh.Unhealthy { return true } - if uh.Fails >= u.MaxFails && - u.MaxFails != 0 { + if uh.Fails >= u.MaxFails { return true } return false @@ -237,7 +233,28 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { if err != nil { return err } + if n < 1 { + return c.Err("max_fails must be at least 1") + } u.MaxFails = int32(n) + case "try_duration": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + u.TryDuration = dur + case "try_interval": + if !c.NextArg() { + return c.ArgErr() + } + interval, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + u.TryInterval = interval case "max_conns": if !c.NextArg() { return c.ArgErr() @@ -397,3 +414,18 @@ func (u *staticUpstream) AllowedPath(requestPath string) bool { } return true } + +// GetTryDuration returns u.TryDuration. +func (u *staticUpstream) GetTryDuration() time.Duration { + return u.TryDuration +} + +// GetTryInterval returns u.TryInterval. +func (u *staticUpstream) GetTryInterval() time.Duration { + return u.TryInterval +} + +// RegisterPolicy adds a custom policy to the proxy. +func RegisterPolicy(name string, policy func() Policy) { + supportedPolicies[name] = policy +}