From 078c991574647c27a3401e35ae667a166990165a Mon Sep 17 00:00:00 2001 From: Martin Redmond Date: Wed, 28 Jun 2017 17:54:29 -0400 Subject: [PATCH] proxy: custom upstream health check by body string, closes #324 (#1691) --- caddyhttp/proxy/upstream.go | 62 +++++++++++++++++++++----------- caddyhttp/proxy/upstream_test.go | 53 +++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 20 deletions(-) diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go index 25a42d7e..a33ffcdf 100644 --- a/caddyhttp/proxy/upstream.go +++ b/caddyhttp/proxy/upstream.go @@ -1,6 +1,7 @@ package proxy import ( + "bytes" "fmt" "io" "io/ioutil" @@ -38,12 +39,13 @@ type staticUpstream struct { TryInterval time.Duration MaxConns int64 HealthCheck struct { - Client http.Client - Path string - Interval time.Duration - Timeout time.Duration - Host string - Port string + Client http.Client + Path string + Interval time.Duration + Timeout time.Duration + Host string + Port string + ContentString string } WithoutPathPrefix string IgnoredSubPaths []string @@ -337,6 +339,11 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { return c.Errf("invalid health_check_port '%s'", port) } u.HealthCheck.Port = port + case "health_check_contains": + if !c.NextArg() { + return c.ArgErr() + } + u.HealthCheck.ContentString = c.Val() case "header_upstream": var header, value string if !c.Args(&header, &value) { @@ -402,27 +409,42 @@ func (u *staticUpstream) healthCheck() { } hostURL += u.HealthCheck.Path - var unhealthy bool - - // set up request, needed to be able to modify headers - // possible errors are bad HTTP methods or un-parsable urls - req, err := http.NewRequest("GET", hostURL, nil) - if err != nil { - unhealthy = true - } else { + unhealthy := func() bool { + // set up request, needed to be able to modify headers + // possible errors are bad HTTP methods or un-parsable urls + req, err := http.NewRequest("GET", hostURL, nil) + if err != nil { + return true + } // set host for request going upstream if u.HealthCheck.Host != "" { req.Host = u.HealthCheck.Host } - - if r, err := u.HealthCheck.Client.Do(req); err == nil { + r, err := u.HealthCheck.Client.Do(req) + if err != nil { + return true + } + defer func() { io.Copy(ioutil.Discard, r.Body) r.Body.Close() - unhealthy = r.StatusCode < 200 || r.StatusCode >= 400 - } else { - unhealthy = true + }() + if r.StatusCode < 200 || r.StatusCode >= 400 { + return true } - } + if u.HealthCheck.ContentString == "" { // don't check for content string + return false + } + // TODO ReadAll will be replaced if deemed necessary + // See https://github.com/mholt/caddy/pull/1691 + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return true + } + if bytes.Contains(buf, []byte(u.HealthCheck.ContentString)) { + return false + } + return true + }() if unhealthy { atomic.StoreInt32(&host.Unhealthy, 1) } else { diff --git a/caddyhttp/proxy/upstream_test.go b/caddyhttp/proxy/upstream_test.go index be42359d..c19547bb 100644 --- a/caddyhttp/proxy/upstream_test.go +++ b/caddyhttp/proxy/upstream_test.go @@ -448,3 +448,56 @@ func TestHealthCheckPort(t *testing.T) { }) } + +func TestHealthCheckContentString(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "blablabla good blablabla") + r.Body.Close() + })) + _, port, err := net.SplitHostPort(server.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + tests := []struct { + config string + shouldContain bool + }{ + {"proxy / localhost:" + port + + " { health_check /testhealth " + + " health_check_contains good\n}", + true, + }, + {"proxy / localhost:" + port + " {\n health_check /testhealth health_check_port " + port + + " \n health_check_contains bad\n}", + false, + }, + } + for i, test := range tests { + u, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "") + if err != nil { + t.Error("Expected no error. Test %d Got:", i, err.Error()) + } + for _, upstream := range u { + staticUpstream, ok := upstream.(*staticUpstream) + if !ok { + t.Errorf("Type mismatch: %#v", upstream) + continue + } + staticUpstream.healthCheck() + for _, host := range staticUpstream.Hosts { + if test.shouldContain && atomic.LoadInt32(&host.Unhealthy) == 0 { + // healthcheck url was hit and the required test string was found + continue + } + if !test.shouldContain && atomic.LoadInt32(&host.Unhealthy) != 0 { + // healthcheck url was hit and the required string was not found + continue + } + t.Errorf("Health check bad response") + } + upstream.Stop() + } + } +}