diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 45222570..14d4bd7a 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -28,6 +28,7 @@ package proxy import ( "context" "crypto/tls" + "crypto/x509" "fmt" "io" "net" @@ -310,6 +311,25 @@ func (rp *ReverseProxy) UseInsecureTransport() { } } +// UseOwnCertificate is used to facilitate HTTPS proxying +// with locally provided certificate. +func (rp *ReverseProxy) UseOwnCACertificates(CaCertPool *x509.CertPool) { + if transport, ok := rp.Transport.(*http.Transport); ok { + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{} + } + transport.TLSClientConfig.RootCAs = CaCertPool + // No http2.ConfigureTransport() here. + // For now this is only added in places where + // an http.Transport is actually created. + } else if transport, ok := rp.Transport.(*h2quic.RoundTripper); ok { + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{} + } + transport.TLSClientConfig.RootCAs = CaCertPool + } +} + // ServeHTTP serves the proxied request to the upstream by performing a roundtrip. // It is designed to handle websocket connection upgrades as well. func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error { diff --git a/caddyhttp/proxy/reverseproxy_test.go b/caddyhttp/proxy/reverseproxy_test.go index 57c335b0..5509af67 100644 --- a/caddyhttp/proxy/reverseproxy_test.go +++ b/caddyhttp/proxy/reverseproxy_test.go @@ -22,6 +22,8 @@ import ( "strconv" "testing" "time" + + "github.com/lucas-clemente/quic-go/h2quic" ) const ( @@ -30,6 +32,20 @@ const ( ) var upstreamHost *httptest.Server +var upstreamHostTLS *httptest.Server + +func setupTLSServer() { + upstreamHostTLS = httptest.NewTLSServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/test-path" { + w.WriteHeader(expectedStatus) + w.Write([]byte(expectedResponse)) + } else { + w.WriteHeader(404) + w.Write([]byte("Not found")) + } + })) +} func setupTest() { upstreamHost = httptest.NewServer(http.HandlerFunc( @@ -44,10 +60,76 @@ func setupTest() { })) } +func tearDownTLSServer() { + upstreamHostTLS.Close() +} + func tearDownTest() { upstreamHost.Close() } +func TestReverseProxyWithOwnCACertificates(t *testing.T) { + setupTLSServer() + defer tearDownTLSServer() + + // get http client from tls server + cl := upstreamHostTLS.Client() + + // add certs from httptest tls server to reverse proxy + var transport *http.Transport + if tr, ok := cl.Transport.(*http.Transport); ok { + transport = tr + } else { + t.Error("could not parse transport from upstreamHostTLS") + } + + pool := transport.TLSClientConfig.RootCAs + + u := staticUpstream{} + u.CaCertPool = pool + + upstreamURL, err := url.Parse(upstreamHostTLS.URL) + if err != nil { + t.Errorf("Failed to parse test server URL [%s]. %s", upstreamHost.URL, err.Error()) + } + + // setup host for reverse proxy + ups, err := u.NewHost(upstreamURL.String()) + if err != nil { + t.Errorf("Creating new host failed. %v", err) + } + + // UseOwnCACertificates called in NewHost sets the RootCAs based if the cert pool is set + if transport, ok := ups.ReverseProxy.Transport.(*http.Transport); ok { + if transport.TLSClientConfig.RootCAs == nil { + t.Errorf("RootCAs not set on TLSClientConfig.") + } + } else if transport, ok := ups.ReverseProxy.Transport.(*h2quic.RoundTripper); ok { + if transport.TLSClientConfig.RootCAs == nil { + t.Errorf("RootCAs not set on TLSClientConfig.") + } + } + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "https://test.host/test-path", nil) + if err != nil { + t.Errorf("Failed to create new request. %s", err.Error()) + } + + err = ups.ReverseProxy.ServeHTTP(resp, req, nil) + if err != nil { + t.Errorf("Failed to perform reverse proxy to upstream host. %s", err.Error()) + } + + rBody := resp.Body.String() + if rBody != expectedResponse { + t.Errorf("Unexpected proxy response received. Expected: '%s', Got: '%s'", expectedResponse, resp.Body.String()) + } + + if resp.Code != expectedStatus { + t.Errorf("Unexpected proxy status. Expected: '%d', Got: '%d'", expectedStatus, resp.Code) + } +} func TestSingleSRVHostReverseProxy(t *testing.T) { setupTest() defer tearDownTest() diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go index f9a78e62..a89fdfb0 100644 --- a/caddyhttp/proxy/upstream.go +++ b/caddyhttp/proxy/upstream.go @@ -17,6 +17,7 @@ package proxy import ( "bytes" "context" + "crypto/x509" "fmt" "io" "io/ioutil" @@ -69,6 +70,7 @@ type staticUpstream struct { insecureSkipVerify bool MaxFails int32 resolver srvResolver + CaCertPool *x509.CertPool } type srvResolver interface { @@ -233,6 +235,10 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { uh.ReverseProxy.UseInsecureTransport() } + if u.CaCertPool != nil { + uh.ReverseProxy.UseOwnCACertificates(u.CaCertPool) + } + return uh, nil } @@ -465,6 +471,34 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error { u.IgnoredSubPaths = ignoredPaths case "insecure_skip_verify": u.insecureSkipVerify = true + case "ca_certificates": + caCertificates := c.RemainingArgs() + if len(caCertificates) == 0 { + return c.ArgErr() + } + + pool := x509.NewCertPool() + caCertificatesAdded := make(map[string]struct{}) + for _, caFile := range caCertificates { + // don't add cert to pool more than once + if _, ok := caCertificatesAdded[caFile]; ok { + continue + } + caCertificatesAdded[caFile] = struct{}{} + + // any client with a certificate from this CA will be allowed to connect + caCrt, err := ioutil.ReadFile(caFile) + if err != nil { + return c.Err(err.Error()) + } + + // attempt to parse pem and append to cert pool + if ok := pool.AppendCertsFromPEM(caCrt); !ok { + return c.Errf("loading CA certificate '%s': no certificates were successfully parsed", caFile) + } + } + + u.CaCertPool = pool case "keepalive": if !c.NextArg() { return c.ArgErr() @@ -489,6 +523,13 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error { default: return c.Errf("unknown property '%s'", c.Val()) } + + // these settings are at odds with one another. insecure_skip_verify disables security features over HTTPS + // which is what we are trying to achieve with ca_certificates + if u.insecureSkipVerify && u.CaCertPool != nil { + return c.Errf("both insecure_skip_verify and ca_certificates cannot be set in the proxy directive") + } + return nil } diff --git a/caddyhttp/proxy/upstream_test.go b/caddyhttp/proxy/upstream_test.go index 1e7196f6..2c72f97a 100644 --- a/caddyhttp/proxy/upstream_test.go +++ b/caddyhttp/proxy/upstream_test.go @@ -163,6 +163,61 @@ func TestAllowedPaths(t *testing.T) { } } +func TestParseBlockCACertificates(t *testing.T) { + tests := []struct { + config string + shouldPass bool + subjectLength int + }{ + // Test #1: ca_certificates set but invalid file path provided + {"ca_certificates ./test.pem\n", false, 0}, + + // Test #2: ca_certificates set but no arguments provided + {"ca_certificates \n", false, 0}, + + // Test #3 valid ca_certificate (fullchain) and invalid public cert passed (privkey). CACertPool should not be set + {"ca_certificates ./testdata/fullchain.pem ./testdata/privkey.pem", false, 0}, + + // Test #4 valid ca_certificate section + {"ca_certificates ./testdata/fullchain.pem", true, 2}, + + // Test #5 ca_certificates and insecure_skip_verify cannot both be set + {"ca_certificates ./testdata/fullchain.pem\ninsecure_skip_verify", false, 0}, + } + + for i, test := range tests { + u := staticUpstream{} + c := caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)) + for c.Next() { + err := parseBlock(&c, &u, false) + if err != nil && test.shouldPass { + t.Errorf( + "Test %d: Could not parse CACertificates. %v.", + i+1, + err, + ) + } + } + + if test.shouldPass && u.CaCertPool == nil { + t.Errorf( + "Test %d: CACertificates not parsed correctly. CaCertPool %v. Expected value to be set.", + i+1, + u.CaCertPool, + ) + } + + if test.shouldPass && test.subjectLength != len(u.CaCertPool.Subjects()) { + t.Errorf( + "Test %d: CACertPool subject length incorrect. Got %v. Expected %v.", + i+1, + len(u.CaCertPool.Subjects()), + test.subjectLength, + ) + } + } +} + func TestParseBlockHealthCheck(t *testing.T) { tests := []struct { config string