// Copyright 2015 Light Code Labs, LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package proxy import ( "context" "errors" "fmt" "net" "net/http" "net/http/httptest" "net/url" "reflect" "strconv" "strings" "sync/atomic" "testing" "time" "github.com/lucas-clemente/quic-go/h2quic" "github.com/mholt/caddy/caddyfile" ) func TestNewHost(t *testing.T) { upstream := &staticUpstream{ FailTimeout: 10 * time.Second, MaxConns: 1, MaxFails: 1, } uh, err := upstream.NewHost("example.com") if err != nil { t.Error("Expected no error") } if uh.Name != "http://example.com" { t.Error("Expected default schema to be added to Name.") } if uh.FailTimeout != upstream.FailTimeout { t.Error("Expected default FailTimeout to be set.") } if uh.MaxConns != upstream.MaxConns { t.Error("Expected default MaxConns to be set.") } if uh.CheckDown == nil { t.Error("Expected default CheckDown to be set.") } if uh.CheckDown(uh) { t.Error("Expected new host not to be down.") } // mark Unhealthy uh.Unhealthy = 1 if !uh.CheckDown(uh) { t.Error("Expected unhealthy host to be down.") } // mark with Fails uh.Unhealthy = 0 uh.Fails = 1 if !uh.CheckDown(uh) { t.Error("Expected failed host to be down.") } } func TestHealthCheck(t *testing.T) { upstream := &staticUpstream{ from: "", Hosts: testPool(), Policy: &Random{}, FailTimeout: 10 * time.Second, MaxFails: 1, } upstream.healthCheck() if upstream.Hosts[0].Down() { t.Error("Expected first host in testpool to not fail healthcheck.") } if !upstream.Hosts[1].Down() { t.Error("Expected second host in testpool to fail healthcheck.") } } func TestSelect(t *testing.T) { upstream := &staticUpstream{ from: "", Hosts: testPool()[:3], Policy: &Random{}, FailTimeout: 10 * time.Second, MaxFails: 1, } r, _ := http.NewRequest("GET", "/", nil) upstream.Hosts[0].Unhealthy = 1 upstream.Hosts[1].Unhealthy = 1 upstream.Hosts[2].Unhealthy = 1 if h := upstream.Select(r); h != nil { t.Error("Expected select to return nil as all host are down") } upstream.Hosts[2].Unhealthy = 0 if h := upstream.Select(r); h == nil { t.Error("Expected select to not return nil") } upstream.Hosts[0].Conns = 1 upstream.Hosts[0].MaxConns = 1 upstream.Hosts[1].Conns = 1 upstream.Hosts[1].MaxConns = 1 upstream.Hosts[2].Conns = 1 upstream.Hosts[2].MaxConns = 1 if h := upstream.Select(r); h != nil { t.Error("Expected select to return nil as all hosts are full") } upstream.Hosts[2].Conns = 0 if h := upstream.Select(r); h == nil { t.Error("Expected select to not return nil") } } func TestRegisterPolicy(t *testing.T) { name := "custom" customPolicy := &customPolicy{} RegisterPolicy(name, func(string) Policy { return customPolicy }) if _, ok := supportedPolicies[name]; !ok { t.Error("Expected supportedPolicies to have a custom policy.") } } func TestAllowedPaths(t *testing.T) { upstream := &staticUpstream{ from: "/proxy", IgnoredSubPaths: []string{"/download", "/static"}, } tests := []struct { url string expected bool }{ {"/proxy", true}, {"/proxy/dl", true}, {"/proxy/download", false}, {"/proxy/download/static", false}, {"/proxy/static", false}, {"/proxy/static/download", false}, {"/proxy/something/download", true}, {"/proxy/something/static", true}, {"/proxy//static", false}, {"/proxy//static//download", false}, {"/proxy//download", false}, } for i, test := range tests { allowed := upstream.AllowedPath(test.url) if test.expected != allowed { t.Errorf("Test %d: expected %v found %v", i+1, test.expected, allowed) } } } func TestParseBlockHealthCheck(t *testing.T) { tests := []struct { config string interval string timeout string }{ // Test #1: Both options set correct time {"health_check /health\n health_check_interval 10s\n health_check_timeout 20s", "10s", "20s"}, // Test #2: Health check options flipped around. Making sure health_check doesn't overwrite it {"health_check_interval 10s\n health_check_timeout 20s\n health_check /health", "10s", "20s"}, // Test #3: No health_check options. So default. {"health_check /health", "30s", "1m0s"}, // Test #4: Interval sets it to 15s and timeout defaults {"health_check /health\n health_check_interval 15s", "15s", "1m0s"}, // Test #5: Timeout sets it to 15s and interval defaults {"health_check /health\n health_check_timeout 15s", "30s", "15s"}, // Test #6: Some funky spelling to make sure it still defaults {"health_check /health health_check_time 15s", "30s", "1m0s"}, } for i, test := range tests { u := staticUpstream{} c := caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)) for c.Next() { parseBlock(&c, &u, false) } if u.HealthCheck.Interval.String() != test.interval { t.Errorf( "Test %d: HealthCheck interval not the same from config. Got %v. Expected: %v", i+1, u.HealthCheck.Interval, test.interval, ) } if u.HealthCheck.Timeout.String() != test.timeout { t.Errorf( "Test %d: HealthCheck timeout not the same from config. Got %v. Expected: %v", i+1, u.HealthCheck.Timeout, test.timeout, ) } } } func TestStop(t *testing.T) { config := "proxy / %s {\n health_check /healthcheck \nhealth_check_interval %dms \n}" tests := []struct { name string intervalInMilliseconds int numHealthcheckIntervals int }{ { "No Healthchecks After Stop - 5ms, 1 intervals", 5, 1, }, { "No Healthchecks After Stop - 5ms, 2 intervals", 5, 2, }, { "No Healthchecks After Stop - 5ms, 3 intervals", 5, 3, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { // Set up proxy. var counter int64 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.Body.Close() atomic.AddInt64(&counter, 1) })) defer backend.Close() upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(fmt.Sprintf(config, backend.URL, test.intervalInMilliseconds))), "") if err != nil { t.Error("Expected no error. Got:", err.Error()) } // Give some time for healthchecks to hit the server. time.Sleep(time.Duration(test.intervalInMilliseconds*test.numHealthcheckIntervals) * time.Millisecond) for _, upstream := range upstreams { if err := upstream.Stop(); err != nil { t.Error("Expected no error stopping upstream. Got: ", err.Error()) } } counterValueAfterShutdown := atomic.LoadInt64(&counter) // Give some time to see if healthchecks are still hitting the server. time.Sleep(time.Duration(test.intervalInMilliseconds*test.numHealthcheckIntervals) * time.Millisecond) if counterValueAfterShutdown == 0 { t.Error("Expected healthchecks to hit test server. Got no healthchecks.") } counterValueAfterWaiting := atomic.LoadInt64(&counter) if counterValueAfterWaiting != counterValueAfterShutdown { t.Errorf("Expected no more healthchecks after shutdown. Got: %d healthchecks after shutdown", counterValueAfterWaiting-counterValueAfterShutdown) } }) } } func TestParseBlock(t *testing.T) { r, _ := http.NewRequest("GET", "/", nil) tests := []struct { config string }{ // Test #1: transparent preset {"proxy / localhost:8080 {\n transparent \n}"}, // Test #2: transparent preset with another param {"proxy / localhost:8080 {\n transparent \nheader_upstream X-Test Tester \n}"}, // Test #3: transparent preset on multiple sites {"proxy / localhost:8080 {\n transparent \n} \nproxy /api localhost:8081 { \ntransparent \n}"}, } for i, test := range tests { upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "") if err != nil { t.Errorf("Expected no error. Got: %s", err.Error()) } for _, upstream := range upstreams { headers := upstream.Select(r).UpstreamHeaders if _, ok := headers["Host"]; !ok { t.Errorf("Test %d: Could not find the Host header", i+1) } if _, ok := headers["X-Real-Ip"]; !ok { t.Errorf("Test %d: Could not find the X-Real-Ip header", i+1) } if _, ok := headers["X-Forwarded-Proto"]; !ok { t.Errorf("Test %d: Could not find the X-Forwarded-Proto header", i+1) } } } } func TestHealthSetUp(t *testing.T) { // tests for insecure skip verify tests := []struct { config string flag bool }{ // Test #1: without flag {"proxy / localhost:8080 {\n health_check / \n}", false}, // Test #2: with flag {"proxy / localhost:8080 {\n health_check / \n insecure_skip_verify \n}", true}, } for i, test := range tests { upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "") if err != nil { t.Errorf("Expected no error. Got: %s", err.Error()) } for _, upstream := range upstreams { staticUpstream, ok := upstream.(*staticUpstream) if !ok { t.Errorf("Type mismatch: %#v", upstream) continue } transport, ok := staticUpstream.HealthCheck.Client.Transport.(*http.Transport) if !ok { t.Errorf("Type mismatch: %#v", staticUpstream.HealthCheck.Client.Transport) continue } if test.flag != transport.TLSClientConfig.InsecureSkipVerify { t.Errorf("Test %d: expected transport.TLSClientCnfig.InsecureSkipVerify=%v, got %v", i, test.flag, transport.TLSClientConfig.InsecureSkipVerify) } } } } func TestHealthCheckHost(t *testing.T) { // tests for upstream host on health checks tests := []struct { config string flag bool host string }{ // Test #1: without upstream header {"proxy / localhost:8080 {\n health_check / \n}", false, "example.com"}, // Test #2: without upstream header, missing host {"proxy / localhost:8080 {\n health_check / \n}", true, ""}, // Test #3: with upstream header (via transparent preset) {"proxy / localhost:8080 {\n health_check / \n transparent \n}", true, "foo.example.com"}, // Test #4: with upstream header (explicit header) {"proxy / localhost:8080 {\n health_check / \n header_upstream Host {host} \n}", true, "example.com"}, // Test #5: with upstream header, missing host {"proxy / localhost:8080 {\n health_check / \n transparent \n}", true, ""}, } for i, test := range tests { upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), test.host) if err != nil { t.Errorf("Expected no error. Got: %s", err.Error()) } for _, upstream := range upstreams { staticUpstream, ok := upstream.(*staticUpstream) if !ok { t.Errorf("Type mismatch: %#v", upstream) continue } if test.flag != (staticUpstream.HealthCheck.Host == test.host) { t.Errorf("Test %d: expected staticUpstream.HealthCheck.Host=%v, got %v", i, test.host, staticUpstream.HealthCheck.Host) } } } } func TestHealthCheckPort(t *testing.T) { var counter int64 healthCounter := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.Body.Close() atomic.AddInt64(&counter, 1) })) _, healthPort, err := net.SplitHostPort(healthCounter.Listener.Addr().String()) if err != nil { t.Fatal(err) } defer healthCounter.Close() tests := []struct { config string }{ // Test #1: upstream with port {"proxy / localhost:8080 {\n health_check / health_check_port " + healthPort + "\n}"}, // Test #2: upstream without port (default to 80) {"proxy / localhost {\n health_check / health_check_port " + healthPort + "\n}"}, } for i, test := range tests { counterValueAtStart := atomic.LoadInt64(&counter) upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "") if err != nil { t.Error("Expected no error. Got:", err.Error()) } // Give some time for healthchecks to hit the server. time.Sleep(500 * time.Millisecond) for _, upstream := range upstreams { if err := upstream.Stop(); err != nil { t.Errorf("Test %d: Expected no error stopping upstream. Got: %v", i, err.Error()) } } counterValueAfterShutdown := atomic.LoadInt64(&counter) if counterValueAfterShutdown == counterValueAtStart { t.Errorf("Test %d: Expected healthchecks to hit test server. Got no healthchecks.", i) } } t.Run("valid_port", func(t *testing.T) { tests := []struct { config string }{ // Test #1: invalid port (nil) {"proxy / localhost {\n health_check / health_check_port\n}"}, // Test #2: invalid port (string) {"proxy / localhost {\n health_check / health_check_port abc\n}"}, // Test #3: invalid port (negative) {"proxy / localhost {\n health_check / health_check_port -1\n}"}, } for i, test := range tests { _, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "") if err == nil { t.Errorf("Test %d accepted invalid config", i) } } }) } func TestHealthCheckContentString(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "blablabla good blablabla") r.Body.Close() })) _, port, err := net.SplitHostPort(server.Listener.Addr().String()) if err != nil { t.Fatal(err) } defer server.Close() tests := []struct { config string shouldContain bool }{ {"proxy / localhost:" + port + " { health_check /testhealth " + " health_check_contains good\n}", true, }, {"proxy / localhost:" + port + " {\n health_check /testhealth health_check_port " + port + " \n health_check_contains bad\n}", false, }, } for i, test := range tests { u, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "") if err != nil { t.Errorf("Expected no error. Test %d Got: %s", i, err.Error()) } for _, upstream := range u { staticUpstream, ok := upstream.(*staticUpstream) if !ok { t.Errorf("Type mismatch: %#v", upstream) continue } staticUpstream.healthCheck() for _, host := range staticUpstream.Hosts { if test.shouldContain && atomic.LoadInt32(&host.Unhealthy) == 0 { // healthcheck url was hit and the required test string was found continue } if !test.shouldContain && atomic.LoadInt32(&host.Unhealthy) != 0 { // healthcheck url was hit and the required string was not found continue } t.Errorf("Health check bad response") } upstream.Stop() } } } func TestQuicHost(t *testing.T) { // tests for QUIC proxy tests := []struct { config string flag bool }{ // Test #1: without flag {"proxy / quic://localhost:8080", false}, // Test #2: with flag {"proxy / quic://localhost:8080 {\n insecure_skip_verify \n}", true}, } for _, test := range tests { upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "") if err != nil { t.Errorf("Expected no error. Got: %s", err.Error()) } for _, upstream := range upstreams { staticUpstream, ok := upstream.(*staticUpstream) if !ok { t.Errorf("Type mismatch: %#v", upstream) continue } for _, host := range staticUpstream.Hosts { _, ok := host.ReverseProxy.Transport.(*h2quic.RoundTripper) if !ok { t.Errorf("Type mismatch: %#v", host.ReverseProxy.Transport) continue } } } } } func TestParseSRVBlock(t *testing.T) { tests := []struct { config string shouldErr bool }{ {"proxy / srv://bogus.service", false}, {"proxy / srv://bogus.service:80", true}, {"proxy / srv://bogus.service srv://bogus.service.fallback", true}, {"proxy / srv://bogus.service http://bogus.service.fallback", true}, {"proxy / http://bogus.service srv://bogus.service.fallback", true}, {"proxy / srv://bogus.service bogus.service.fallback", true}, {`proxy / srv://bogus.service { upstream srv://bogus.service }`, true}, {"proxy / srv+https://bogus.service", false}, {"proxy / srv+https://bogus.service:80", true}, {"proxy / srv+https://bogus.service srv://bogus.service.fallback", true}, {"proxy / srv+https://bogus.service http://bogus.service.fallback", true}, {"proxy / http://bogus.service srv+https://bogus.service.fallback", true}, {"proxy / srv+https://bogus.service bogus.service.fallback", true}, {`proxy / srv+https://bogus.service { upstream srv://bogus.service }`, true}, {`proxy / srv+https://bogus.service { health_check_port 96 }`, true}, } for i, test := range tests { _, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "") if err == nil && test.shouldErr { t.Errorf("Case %d - Expected an error. got nothing", i) } if err != nil && !test.shouldErr { t.Errorf("Case %d - Expected no error. got %s", i, err.Error()) } } } type testResolver struct { errOn string result []*net.SRV } func (r testResolver) LookupSRV(ctx context.Context, _, _, service string) (string, []*net.SRV, error) { if service == r.errOn { return "", nil, errors.New("an error occurred") } return "", r.result, nil } func TestResolveHost(t *testing.T) { upstream := &staticUpstream{ resolver: testResolver{ errOn: "srv://problematic.service.name", result: []*net.SRV{ {Target: "target-1.fqdn", Port: 85, Priority: 1, Weight: 1}, {Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1}, {Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1}, }, }, } tests := []struct { host string expect []string isSrv bool shouldErr bool }{ // Static DNS records {"http://subdomain.domain.service", []string{"http://subdomain.domain.service"}, false, false}, {"https://subdomain.domain.service", []string{"https://subdomain.domain.service"}, false, false}, {"http://subdomain.domain.service:76", []string{"http://subdomain.domain.service:76"}, false, false}, {"https://subdomain.domain.service:65", []string{"https://subdomain.domain.service:65"}, false, false}, // SRV lookups {"srv://service.name", []string{ "http://target-1.fqdn:85", "http://target-2.fqdn:33", "http://target-3.fqdn:94", }, true, false}, {"srv+https://service.name", []string{ "https://target-1.fqdn:85", "https://target-2.fqdn:33", "https://target-3.fqdn:94", }, true, false}, {"srv://problematic.service.name", []string{}, true, true}, } for i, test := range tests { results, isSrv, err := upstream.resolveHost(test.host) if err == nil && test.shouldErr { t.Errorf("Test %d - expected an error, got none", i) } if err != nil && !test.shouldErr { t.Errorf("Test %d - unexpected error %s", i, err.Error()) } if test.isSrv && !isSrv { t.Errorf("Test %d - expecting resolution to be SRV lookup but it isn't", i) } if isSrv && !test.isSrv { t.Errorf("Test %d - expecting resolution to be normal lookup, got SRV", i) } if !reflect.DeepEqual(results, test.expect) { t.Errorf("Test %d - resolution result %#v does not match expected value %#v", i, results, test.expect) } } } func TestSRVHealthCheck(t *testing.T) { serverURL, err := url.Parse(workableServer.URL) if err != nil { t.Errorf("Failed to parse test server URL: %s", err.Error()) } pp, err := strconv.Atoi(serverURL.Port()) if err != nil { t.Errorf("Failed to parse test server port [%s]: %s", serverURL.Port(), err.Error()) } port := uint16(pp) allGoodResolver := testResolver{ result: []*net.SRV{ {Target: serverURL.Hostname(), Port: port, Priority: 1, Weight: 1}, }, } partialFailureResolver := testResolver{ result: []*net.SRV{ {Target: serverURL.Hostname(), Port: port, Priority: 1, Weight: 1}, {Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1}, {Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1}, }, } fullFailureResolver := testResolver{ result: []*net.SRV{ {Target: "target-1.fqdn", Port: 876, Priority: 1, Weight: 1}, {Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1}, {Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1}, }, } resolutionErrorResolver := testResolver{ errOn: "srv://tag.service.consul", result: []*net.SRV{}, } upstream := &staticUpstream{ Hosts: []*UpstreamHost{ {Name: "srv://tag.service.consul"}, }, FailTimeout: 10 * time.Second, MaxFails: 1, } tests := []struct { resolver testResolver shouldFail bool shouldErr bool }{ {allGoodResolver, false, false}, {partialFailureResolver, false, false}, {fullFailureResolver, true, false}, {resolutionErrorResolver, true, true}, } for i, test := range tests { upstream.resolver = test.resolver upstream.healthCheck() if upstream.Hosts[0].Down() && !test.shouldFail { t.Errorf("Test %d - expected all healthchecks to pass, all failing", i) } if test.shouldFail && !upstream.Hosts[0].Down() { t.Errorf("Test %d - expected all healthchecks to fail, all passing", i) } status := fmt.Sprintf("%s", upstream.Hosts[0].HealthCheckResult.Load()) if test.shouldFail && !test.shouldErr && status != "Failed" { t.Errorf("Test %d - Expected health check result to be 'Failed', got '%s'", i, status) } if !test.shouldFail && status != "OK" { t.Errorf("Test %d - Expected health check result to be 'OK', got '%s'", i, status) } if test.shouldErr && status != "an error occurred" { t.Errorf("Test %d - Expected health check result to be 'an error occured', got '%s'", i, status) } } }