diff --git a/caddyhttp/proxy/policy_test.go b/caddyhttp/proxy/policy_test.go index 2a8dfe61..9b277197 100644 --- a/caddyhttp/proxy/policy_test.go +++ b/caddyhttp/proxy/policy_test.go @@ -60,13 +60,13 @@ func TestRoundRobinPolicy(t *testing.T) { t.Error("Expected third round robin host to be first host in the pool.") } // mark host as down - pool[1].Unhealthy = true + pool[1].Unhealthy = 1 h = rrPolicy.Select(pool, request) if h != pool[2] { t.Error("Expected to skip down host.") } // mark host as up - pool[1].Unhealthy = false + pool[1].Unhealthy = 0 h = rrPolicy.Select(pool, request) if h == pool[2] { @@ -161,7 +161,7 @@ func TestIPHashPolicy(t *testing.T) { // we should get a healthy host if the original host is unhealthy and a // healthy host is available request.RemoteAddr = "172.0.0.1" - pool[1].Unhealthy = true + pool[1].Unhealthy = 1 h = ipHash.Select(pool, request) if h != pool[2] { t.Error("Expected ip hash policy host to be the third host.") @@ -172,10 +172,10 @@ func TestIPHashPolicy(t *testing.T) { if h != pool[2] { t.Error("Expected ip hash policy host to be the third host.") } - pool[1].Unhealthy = false + pool[1].Unhealthy = 0 request.RemoteAddr = "172.0.0.3" - pool[2].Unhealthy = true + pool[2].Unhealthy = 1 h = ipHash.Select(pool, request) if h != pool[0] { t.Error("Expected ip hash policy host to be the first host.") @@ -219,8 +219,8 @@ func TestIPHashPolicy(t *testing.T) { } // We should get nil when there are no healthy hosts - pool[0].Unhealthy = true - pool[1].Unhealthy = true + pool[0].Unhealthy = 1 + pool[1].Unhealthy = 1 h = ipHash.Select(pool, request) if h != nil { t.Error("Expected ip hash policy host to be nil.") diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index c0c2bb4b..c2d05b49 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -49,6 +49,8 @@ type UpstreamHostDownFunc func(*UpstreamHost) bool // UpstreamHost represents a single proxy upstream type UpstreamHost struct { + // This field is read & written to concurrently, so all access must use + // atomic operations. Conns int64 // must be first field to be 64-bit aligned on 32-bit systems MaxConns int64 Name string // hostname of this upstream host @@ -59,7 +61,10 @@ type UpstreamHost struct { WithoutPathPrefix string ReverseProxy *ReverseProxy Fails int32 - Unhealthy bool + // This is an int32 so that we can use atomic operations to do concurrent + // reads & writes to this value. The default value of 0 indicates that it + // is healthy and any non-zero value indicates unhealthy. + Unhealthy int32 } // Down checks whether the upstream host is down or not. @@ -68,14 +73,14 @@ type UpstreamHost struct { func (uh *UpstreamHost) Down() bool { if uh.CheckDown == nil { // Default settings - return uh.Unhealthy || uh.Fails > 0 + return atomic.LoadInt32(&uh.Unhealthy) != 0 || atomic.LoadInt32(&uh.Fails) > 0 } return uh.CheckDown(uh) } // Full checks whether the upstream host has reached its maximum connections func (uh *UpstreamHost) Full() bool { - return uh.MaxConns > 0 && uh.Conns >= uh.MaxConns + return uh.MaxConns > 0 && atomic.LoadInt64(&uh.Conns) >= uh.MaxConns } // Available checks whether the upstream host is available for proxying to diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index 90753ab3..380094e0 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -19,6 +19,7 @@ import ( "reflect" "runtime" "strings" + "sync" "sync/atomic" "testing" "time" @@ -143,6 +144,74 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) { } } +// This test will fail when using the race detector without atomic reads & +// writes of UpstreamHost.Conns and UpstreamHost.Unhealthy. +func TestReverseProxyMaxConnLimit(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + const MaxTestConns = 2 + connReceived := make(chan bool, MaxTestConns) + connContinue := make(chan bool) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + connReceived <- true + <-connContinue + })) + defer backend.Close() + + su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(` + proxy / `+backend.URL+` { + max_conns `+fmt.Sprint(MaxTestConns)+` + } + `))) + if err != nil { + t.Fatal(err) + } + + // set up proxy + p := &Proxy{ + Next: httpserver.EmptyNext, // prevents panic in some cases when test fails + Upstreams: su, + } + + var jobs sync.WaitGroup + + for i := 0; i < MaxTestConns; i++ { + jobs.Add(1) + go func(i int) { + defer jobs.Done() + w := httptest.NewRecorder() + code, err := p.ServeHTTP(w, httptest.NewRequest("GET", "/", nil)) + if err != nil { + t.Errorf("Request %d failed: %v", i, err) + } else if code != 0 { + t.Errorf("Bad return code for request %d: %d", i, code) + } else if w.Code != 200 { + t.Errorf("Bad statuc code for request %d: %d", i, w.Code) + } + }(i) + } + // Wait for all the requests to hit the backend. + for i := 0; i < MaxTestConns; i++ { + <-connReceived + } + + // Now we should have MaxTestConns requests connected and sitting on the backend + // server. Verify that the next request is rejected. + w := httptest.NewRecorder() + code, err := p.ServeHTTP(w, httptest.NewRequest("GET", "/", nil)) + if code != http.StatusBadGateway { + t.Errorf("Expected request to be rejected, but got: %d [%v]\nStatus code: %d", + code, err, w.Code) + } + + // Now let all the requests complete and verify the status codes for those: + close(connContinue) + + // Wait for the initial requests to finish and check their results. + jobs.Wait() +} + func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) { // Capture the expected panic defer func() { diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go index 5742eff0..0e831124 100644 --- a/caddyhttp/proxy/upstream.go +++ b/caddyhttp/proxy/upstream.go @@ -9,6 +9,7 @@ import ( "path" "strconv" "strings" + "sync/atomic" "time" "github.com/mholt/caddy/caddyfile" @@ -128,15 +129,15 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { Conns: 0, Fails: 0, FailTimeout: u.FailTimeout, - Unhealthy: false, + Unhealthy: 0, UpstreamHeaders: u.upstreamHeaders, DownstreamHeaders: u.downstreamHeaders, CheckDown: func(u *staticUpstream) UpstreamHostDownFunc { return func(uh *UpstreamHost) bool { - if uh.Unhealthy { + if atomic.LoadInt32(&uh.Unhealthy) != 0 { return true } - if uh.Fails >= u.MaxFails { + if atomic.LoadInt32(&uh.Fails) >= u.MaxFails { return true } return false @@ -355,12 +356,18 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { func (u *staticUpstream) healthCheck() { for _, host := range u.Hosts { hostURL := host.Name + u.HealthCheck.Path + var unhealthy bool if r, err := u.HealthCheck.Client.Get(hostURL); err == nil { io.Copy(ioutil.Discard, r.Body) r.Body.Close() - host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400 + unhealthy = r.StatusCode < 200 || r.StatusCode >= 400 } else { - host.Unhealthy = true + unhealthy = true + } + if unhealthy { + atomic.StoreInt32(&host.Unhealthy, 1) + } else { + atomic.StoreInt32(&host.Unhealthy, 0) } } } diff --git a/caddyhttp/proxy/upstream_test.go b/caddyhttp/proxy/upstream_test.go index 2d7828eb..1163fffe 100644 --- a/caddyhttp/proxy/upstream_test.go +++ b/caddyhttp/proxy/upstream_test.go @@ -36,12 +36,12 @@ func TestNewHost(t *testing.T) { t.Error("Expected new host not to be down.") } // mark Unhealthy - uh.Unhealthy = true + uh.Unhealthy = 1 if !uh.CheckDown(uh) { t.Error("Expected unhealthy host to be down.") } // mark with Fails - uh.Unhealthy = false + uh.Unhealthy = 0 uh.Fails = 1 if !uh.CheckDown(uh) { t.Error("Expected failed host to be down.") @@ -74,13 +74,13 @@ func TestSelect(t *testing.T) { MaxFails: 1, } r, _ := http.NewRequest("GET", "/", nil) - upstream.Hosts[0].Unhealthy = true - upstream.Hosts[1].Unhealthy = true - upstream.Hosts[2].Unhealthy = true + upstream.Hosts[0].Unhealthy = 1 + upstream.Hosts[1].Unhealthy = 1 + upstream.Hosts[2].Unhealthy = 1 if h := upstream.Select(r); h != nil { t.Error("Expected select to return nil as all host are down") } - upstream.Hosts[2].Unhealthy = false + upstream.Hosts[2].Unhealthy = 0 if h := upstream.Select(r); h == nil { t.Error("Expected select to not return nil") }