diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go index 4f823e67..6b99cd6c 100644 --- a/caddyhttp/proxy/upstream.go +++ b/caddyhttp/proxy/upstream.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "net/http" "net/url" "path" @@ -42,6 +43,7 @@ type staticUpstream struct { Interval time.Duration Timeout time.Duration Host string + Port string } WithoutPathPrefix string IgnoredSubPaths []string @@ -321,6 +323,20 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { return err } u.HealthCheck.Timeout = dur + case "health_check_port": + if !c.NextArg() { + return c.ArgErr() + } + port := c.Val() + n, err := strconv.Atoi(port) + if err != nil { + return err + } + + if n < 0 { + return c.Errf("invalid health_check_port '%s'", port) + } + u.HealthCheck.Port = c.Val() case "header_upstream": var header, value string if !c.Args(&header, &value) { @@ -380,7 +396,12 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { func (u *staticUpstream) healthCheck() { for _, host := range u.Hosts { - hostURL := host.Name + u.HealthCheck.Path + hostURL := host.Name + if u.HealthCheck.Port != "" { + hostURL = replacePort(host.Name, u.HealthCheck.Port) + } + hostURL += u.HealthCheck.Path + var unhealthy bool // set up request, needed to be able to modify headers @@ -483,3 +504,19 @@ func (u *staticUpstream) Stop() error { func RegisterPolicy(name string, policy func() Policy) { supportedPolicies[name] = policy } + +func replacePort(originalURL string, newPort string) string { + parsedURL, err := url.Parse(originalURL) + if err != nil { + return originalURL + } + + // handles 'localhost' and 'localhost:8080' + parsedHost, _, err := net.SplitHostPort(parsedURL.Host) + if err != nil { + parsedHost = parsedURL.Host + } + + parsedURL.Host = net.JoinHostPort(parsedHost, newPort) + return parsedURL.String() +} diff --git a/caddyhttp/proxy/upstream_test.go b/caddyhttp/proxy/upstream_test.go index e1361dff..be42359d 100644 --- a/caddyhttp/proxy/upstream_test.go +++ b/caddyhttp/proxy/upstream_test.go @@ -2,6 +2,7 @@ package proxy import ( "fmt" + "net" "net/http" "net/http/httptest" "strings" @@ -375,3 +376,75 @@ func TestHealthCheckHost(t *testing.T) { } } } + +func TestHealthCheckPort(t *testing.T) { + var counter int64 + + healthCounter := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Body.Close() + atomic.AddInt64(&counter, 1) + })) + + _, healthPort, err := net.SplitHostPort(healthCounter.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + + defer healthCounter.Close() + + tests := []struct { + config string + }{ + // Test #1: upstream with port + {"proxy / localhost:8080 {\n health_check / health_check_port " + healthPort + "\n}"}, + + // Test #2: upstream without port (default to 80) + {"proxy / localhost {\n health_check / health_check_port " + healthPort + "\n}"}, + } + + for i, test := range tests { + counterValueAtStart := atomic.LoadInt64(&counter) + upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "") + if err != nil { + t.Error("Expected no error. Got:", err.Error()) + } + + // Give some time for healthchecks to hit the server. + time.Sleep(500 * time.Millisecond) + + for _, upstream := range upstreams { + if err := upstream.Stop(); err != nil { + t.Errorf("Test %d: Expected no error stopping upstream. Got: %v", i, err.Error()) + } + } + + counterValueAfterShutdown := atomic.LoadInt64(&counter) + + if counterValueAfterShutdown == counterValueAtStart { + t.Errorf("Test %d: Expected healthchecks to hit test server. Got no healthchecks.", i) + } + } + + t.Run("valid_port", func(t *testing.T) { + tests := []struct { + config string + }{ + // Test #1: invalid port (nil) + {"proxy / localhost {\n health_check / health_check_port\n}"}, + + // Test #2: invalid port (string) + {"proxy / localhost {\n health_check / health_check_port abc\n}"}, + + // Test #3: invalid port (negative) + {"proxy / localhost {\n health_check / health_check_port -1\n}"}, + } + + for i, test := range tests { + _, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "") + if err == nil { + t.Errorf("Test %d accepted invalid config", i) + } + } + }) + +}