From 6352c9054ab7a72460f6b75c0be797eaff9d0a67 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Mon, 26 Dec 2016 20:40:44 +0100 Subject: [PATCH 1/9] Fixed proxy not respecting the -http2 flag --- caddyhttp/proxy/reverseproxy.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index cfb466c7..050e7fb9 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -148,7 +148,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * } else { transport.MaxIdleConnsPerHost = keepalive } - http2.ConfigureTransport(transport) + if httpserver.HTTP2 { + http2.ConfigureTransport(transport) + } rp.Transport = transport } return rp @@ -168,10 +170,15 @@ func (rp *ReverseProxy) UseInsecureTransport() { TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } - http2.ConfigureTransport(transport) + if httpserver.HTTP2 { + http2.ConfigureTransport(transport) + } rp.Transport = transport } else if transport, ok := rp.Transport.(*http.Transport); ok { transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + // No http2.ConfigureTransport() here. + // For now this is only added in places where + // an http.Transport is actually created. } } From 53635ba538fb41bbfd38b70282cd6d59e693c0b3 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Mon, 26 Dec 2016 20:42:00 +0100 Subject: [PATCH 2/9] Fixed panic due to 0-length buffers being passed to io.CopyBuffer --- caddyhttp/proxy/reverseproxy.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 050e7fb9..49df1603 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -253,7 +253,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, } func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { - buf := bufferPool.Get() + buf := bufferPool.Get().([]byte) defer bufferPool.Put(buf) if rp.FlushInterval != 0 { @@ -268,7 +268,10 @@ func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { dst = mlw } } - io.CopyBuffer(dst, src, buf.([]byte)) + + // `CopyBuffer` only uses `buf` up to it's length and + // panics if it's 0 => Extend it's length up to it's capacity. + io.CopyBuffer(dst, src, buf[:cap(buf)]) } // skip these headers if they already exist. From 9f9ad21aaa526768b638105fbc6675c80fdfa0ce Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Mon, 26 Dec 2016 20:52:36 +0100 Subject: [PATCH 3/9] Fixed #1292: Failure to proxy WebSockets over HTTPS This issue was caused by connHijackerTransport trying to record HTTP response headers by "hijacking" the Read() method of the plain net.Conn. This does not simply work over TLS though since this will record the TLS handshake and encrypted data instead of the actual content. This commit fixes the problem by providing an alternative transport.DialTLS which correctly hijacks the overlying tls.Conn instead. --- caddyhttp/proxy/reverseproxy.go | 184 ++++++++++++++++++++++++-------- 1 file changed, 140 insertions(+), 44 deletions(-) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 49df1603..c980e905 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -27,6 +27,11 @@ import ( "github.com/mholt/caddy/caddyhttp/httpserver" ) +var defaultDialer = &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, +} + var bufferPool = sync.Pool{New: createBuffer} func createBuffer() interface{} { @@ -135,11 +140,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * // just use default transport, to avoid creating // a brand new transport transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, + Proxy: http.ProxyFromEnvironment, + Dial: defaultDialer.Dial, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } @@ -162,11 +164,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * func (rp *ReverseProxy) UseInsecureTransport() { if rp.Transport == nil { transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, + Proxy: http.ProxyFromEnvironment, + Dial: defaultDialer.Dial, TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } @@ -341,51 +340,148 @@ type connHijackerTransport struct { } func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { - transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, - TLSHandshakeTimeout: 10 * time.Second, + t := &http.Transport{ MaxIdleConnsPerHost: -1, } - if base != nil { - if baseTransport, ok := base.(*http.Transport); ok { - transport.Proxy = baseTransport.Proxy - transport.TLSClientConfig = baseTransport.TLSClientConfig - transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout - transport.Dial = baseTransport.Dial - transport.DialTLS = baseTransport.DialTLS - transport.MaxIdleConnsPerHost = -1 + if b, _ := base.(*http.Transport); b != nil { + t.Proxy = b.Proxy + t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig) + t.TLSClientConfig.NextProtos = nil + t.TLSHandshakeTimeout = b.TLSHandshakeTimeout + t.Dial = b.Dial + t.DialTLS = b.DialTLS + } else { + t.Proxy = http.ProxyFromEnvironment + t.TLSHandshakeTimeout = 10 * time.Second + } + hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]} + + dial := getTransportDial(t) + dialTLS := getTransportDialTLS(t) + + t.Dial = func(network, addr string) (net.Conn, error) { + c, err := dial(network, addr) + hj.Conn = c + return &hijackedConn{c, hj}, err + } + + if dialTLS != nil { + t.DialTLS = func(network, addr string) (net.Conn, error) { + c, err := dialTLS(network, addr) + hj.Conn = c + return &hijackedConn{c, hj}, err } } - hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]} - oldDial := transport.Dial - oldDialTLS := transport.DialTLS - if oldDial == nil { - oldDial = (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial + + return hj +} + +// getTransportDial always returns a plain Dialer +// and defaults to the existing t.Dial. +func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) { + if t.Dial != nil { + return t.Dial } - hjTransport.Dial = func(network, addr string) (net.Conn, error) { - c, err := oldDial(network, addr) - hjTransport.Conn = c - return &hijackedConn{c, hjTransport}, err + return defaultDialer.Dial +} + +// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil +// and defaults to the existing t.DialTLS. +func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) { + if t.DialTLS != nil { + return t.DialTLS } - if oldDialTLS != nil { - hjTransport.DialTLS = func(network, addr string) (net.Conn, error) { - c, err := oldDialTLS(network, addr) - hjTransport.Conn = c - return &hijackedConn{c, hjTransport}, err + if t.TLSClientConfig == nil { + return nil + } + + // newConnHijackerTransport will modify t.Dial after calling this method + // => Create a backup reference. + plainDial := getTransportDial(t) + + return func(network, addr string) (net.Conn, error) { + plainConn, err := plainDial(network, addr) + if err != nil { + return nil, err } + + tlsConn := tls.Client(plainConn, t.TLSClientConfig) + errc := make(chan error, 2) + var timer *time.Timer + if d := t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + err := tlsConn.Handshake() + if timer != nil { + timer.Stop() + } + errc <- err + }() + if err := <-errc; err != nil { + plainConn.Close() + return nil, err + } + if !t.TLSClientConfig.InsecureSkipVerify { + serverName := t.TLSClientConfig.ServerName + if serverName == "" { + serverName = addr + idx := strings.LastIndex(serverName, ":") + if idx != -1 { + serverName = serverName[:idx] + } + } + if err := tlsConn.VerifyHostname(serverName); err != nil { + plainConn.Close() + return nil, err + } + } + + return tlsConn, nil + } +} + +type tlsHandshakeTimeoutError struct{} + +func (tlsHandshakeTimeoutError) Timeout() bool { return true } +func (tlsHandshakeTimeoutError) Temporary() bool { return true } +func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } + +// cloneTLSClientConfig is like cloneTLSConfig but omits +// the fields SessionTicketsDisabled and SessionTicketKey. +// This makes it safe to call cloneTLSClientConfig on a config +// in active use by a server. +func cloneTLSClientConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return &tls.Config{ + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled, + Renegotiation: cfg.Renegotiation, } - return hjTransport } func requestIsWebsocket(req *http.Request) bool { - return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")) + return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") } type writeFlusher interface { From 20483c23f85961f7b540f77aa3bbfc7228bca235 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Mon, 26 Dec 2016 20:53:18 +0100 Subject: [PATCH 4/9] Added end-to-end test case for #1292 --- caddyhttp/proxy/proxy_test.go | 75 ++++++++++++++++++++++++++++++----- 1 file changed, 65 insertions(+), 10 deletions(-) diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index 6359596c..85897c00 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -123,7 +123,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) { defer wsNop.Close() // Get proxy to use for the test - p := newWebSocketTestProxy(wsNop.URL) + p := newWebSocketTestProxy(wsNop.URL, false) // Create client request r := httptest.NewRequest("GET", "/", nil) @@ -148,7 +148,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { defer wsNop.Close() // Get proxy to use for the test - p := newWebSocketTestProxy(wsNop.URL) + p := newWebSocketTestProxy(wsNop.URL, false) // Create client request r := httptest.NewRequest("GET", "/", nil) @@ -189,7 +189,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) { defer wsEcho.Close() // Get proxy to use for the test - p := newWebSocketTestProxy(wsEcho.URL) + p := newWebSocketTestProxy(wsEcho.URL, false) // This is a full end-end test, so the proxy handler // has to be part of a server listening on a port. Our @@ -228,6 +228,52 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) { } } +func TestWebSocketReverseProxyFromWSSClient(t *testing.T) { + wsEcho := newTLSServer(websocket.Handler(func(ws *websocket.Conn) { + io.Copy(ws, ws) + })) + defer wsEcho.Close() + + p := newWebSocketTestProxy(wsEcho.URL, true) + + echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + defer echoProxy.Close() + + // Set up WebSocket client + url := strings.Replace(echoProxy.URL, "https://", "wss://", 1) + wsCfg, err := websocket.NewConfig(url, echoProxy.URL) + if err != nil { + t.Fatal(err) + } + wsCfg.TlsConfig = &tls.Config{InsecureSkipVerify: true} + ws, err := websocket.DialConfig(wsCfg) + + if err != nil { + t.Fatal(err) + } + defer ws.Close() + + // Send test message + trialMsg := "Is it working?" + + if sendErr := websocket.Message.Send(ws, trialMsg); sendErr != nil { + t.Fatal(sendErr) + } + + // It should be echoed back to us + var actualMsg string + + if rcvErr := websocket.Message.Receive(ws, &actualMsg); rcvErr != nil { + t.Fatal(rcvErr) + } + + if actualMsg != trialMsg { + t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) + } +} + func TestUnixSocketProxy(t *testing.T) { if runtime.GOOS == "windows" { return @@ -264,7 +310,7 @@ func TestUnixSocketProxy(t *testing.T) { defer ts.Close() url := strings.Replace(ts.URL, "http://", "unix:", 1) - p := newWebSocketTestProxy(url) + p := newWebSocketTestProxy(url, false) echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { p.ServeHTTP(w, r) @@ -982,10 +1028,14 @@ func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time. // redirect to the specified backendAddr. The function // also sets up the rules/environment for testing WebSocket // proxy. -func newWebSocketTestProxy(backendAddr string) *Proxy { +func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy { return &Proxy{ - Next: httpserver.EmptyNext, // prevents panic in some cases when test fails - Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}}, + Next: httpserver.EmptyNext, // prevents panic in some cases when test fails + Upstreams: []Upstream{&fakeWsUpstream{ + name: backendAddr, + without: "", + insecure: insecure, + }}, } } @@ -997,8 +1047,9 @@ func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy { } type fakeWsUpstream struct { - name string - without string + name string + without string + insecure bool } func (u *fakeWsUpstream) From() string { @@ -1007,13 +1058,17 @@ func (u *fakeWsUpstream) From() string { func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost { uri, _ := url.Parse(u.name) - return &UpstreamHost{ + host := &UpstreamHost{ Name: u.name, ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost), UpstreamHeaders: http.Header{ "Connection": {"{>Connection}"}, "Upgrade": {"{>Upgrade}"}}, } + if u.insecure { + host.ReverseProxy.UseInsecureTransport() + } + return host } func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true } From 153d4a5ac62ffcef77df8ad29ae13e7d68b188b6 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Wed, 28 Dec 2016 17:17:52 +0100 Subject: [PATCH 5/9] proxy: Improved handling of bufferPool --- caddyhttp/proxy/reverseproxy.go | 39 +++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index c980e905..d32be7cd 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -27,15 +27,28 @@ import ( "github.com/mholt/caddy/caddyhttp/httpserver" ) -var defaultDialer = &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, -} +var ( + defaultDialer = &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } -var bufferPool = sync.Pool{New: createBuffer} + bufferPool = sync.Pool{New: createBuffer} +) func createBuffer() interface{} { - return make([]byte, 32*1024) + return make([]byte, 0, 32*1024) +} + +func pooledIoCopy(dst io.Writer, src io.Reader) { + buf := bufferPool.Get().([]byte) + defer bufferPool.Put(buf) + + // CopyBuffer only uses buf up to its length and panics if it's 0. + // Due to that we extend buf's length to its capacity here and + // ensure it's always non-zero. + bufCap := cap(buf) + io.CopyBuffer(dst, src, buf[0:bufCap:bufCap]) } // onExitFlushLoop is a callback set by tests to detect the state of the @@ -234,10 +247,8 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, } defer backendConn.Close() - go func() { - io.Copy(backendConn, conn) // write tcp stream to backend. - }() - io.Copy(conn, backendConn) // read tcp stream from backend. + go pooledIoCopy(backendConn, conn) // write tcp stream to backend + pooledIoCopy(conn, backendConn) // read tcp stream from backend } else { defer res.Body.Close() for _, h := range hopHeaders { @@ -252,9 +263,6 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, } func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { - buf := bufferPool.Get().([]byte) - defer bufferPool.Put(buf) - if rp.FlushInterval != 0 { if wf, ok := dst.(writeFlusher); ok { mlw := &maxLatencyWriter{ @@ -267,10 +275,7 @@ func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { dst = mlw } } - - // `CopyBuffer` only uses `buf` up to it's length and - // panics if it's 0 => Extend it's length up to it's capacity. - io.CopyBuffer(dst, src, buf[:cap(buf)]) + pooledIoCopy(dst, src) } // skip these headers if they already exist. From b857265f9c759fde2c3a10580ea85a9b2f83d968 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Wed, 28 Dec 2016 17:20:31 +0100 Subject: [PATCH 6/9] proxy: Fixed support for TLS verification of WebSocket connections --- caddyhttp/proxy/reverseproxy.go | 69 +++++++++++++++++++++------------ 1 file changed, 45 insertions(+), 24 deletions(-) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index d32be7cd..021025a0 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -349,9 +349,14 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { MaxIdleConnsPerHost: -1, } if b, _ := base.(*http.Transport); b != nil { + tlsClientConfig := b.TLSClientConfig + if tlsClientConfig.NextProtos != nil { + tlsClientConfig = cloneTLSClientConfig(tlsClientConfig) + tlsClientConfig.NextProtos = nil + } + t.Proxy = b.Proxy - t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig) - t.TLSClientConfig.NextProtos = nil + t.TLSClientConfig = tlsClientConfig t.TLSHandshakeTimeout = b.TLSHandshakeTimeout t.Dial = b.Dial t.DialTLS = b.DialTLS @@ -363,19 +368,15 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { dial := getTransportDial(t) dialTLS := getTransportDialTLS(t) - t.Dial = func(network, addr string) (net.Conn, error) { c, err := dial(network, addr) hj.Conn = c return &hijackedConn{c, hj}, err } - - if dialTLS != nil { - t.DialTLS = func(network, addr string) (net.Conn, error) { - c, err := dialTLS(network, addr) - hj.Conn = c - return &hijackedConn{c, hj}, err - } + t.DialTLS = func(network, addr string) (net.Conn, error) { + c, err := dialTLS(network, addr) + hj.Conn = c + return &hijackedConn{c, hj}, err } return hj @@ -390,27 +391,35 @@ func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, e return defaultDialer.Dial } -// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil +// getTransportDial always returns a TLS Dialer // and defaults to the existing t.DialTLS. func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) { if t.DialTLS != nil { return t.DialTLS } - if t.TLSClientConfig == nil { - return nil - } // newConnHijackerTransport will modify t.Dial after calling this method // => Create a backup reference. plainDial := getTransportDial(t) + // The following DialTLS implementation stems from the Go stdlib and + // is identical to what happens if DialTLS is not provided. + // Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051 return func(network, addr string) (net.Conn, error) { plainConn, err := plainDial(network, addr) if err != nil { return nil, err } - tlsConn := tls.Client(plainConn, t.TLSClientConfig) + tlsClientConfig := t.TLSClientConfig + if tlsClientConfig == nil { + tlsClientConfig = &tls.Config{} + } + if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" { + tlsClientConfig.ServerName = stripPort(addr) + } + + tlsConn := tls.Client(plainConn, tlsClientConfig) errc := make(chan error, 2) var timer *time.Timer if d := t.TLSHandshakeTimeout; d != 0 { @@ -429,16 +438,12 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn plainConn.Close() return nil, err } - if !t.TLSClientConfig.InsecureSkipVerify { - serverName := t.TLSClientConfig.ServerName - if serverName == "" { - serverName = addr - idx := strings.LastIndex(serverName, ":") - if idx != -1 { - serverName = serverName[:idx] - } + if !tlsClientConfig.InsecureSkipVerify { + hostname := tlsClientConfig.ServerName + if hostname == "" { + hostname = stripPort(addr) } - if err := tlsConn.VerifyHostname(serverName); err != nil { + if err := tlsConn.VerifyHostname(hostname); err != nil { plainConn.Close() return nil, err } @@ -448,6 +453,22 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn } } +// stripPort returns address without its port if it has one and +// works with IP addresses as well as hostnames formatted as host:port. +// +// IPv6 addresses (excluding the port) must be enclosed in +// square brackets similar to the requirements of Go's stdlib. +func stripPort(address string) string { + // Keep in mind that the address might be a IPv6 address + // and thus contain a colon, but not have a port. + portIdx := strings.LastIndex(address, ":") + ipv6Idx := strings.LastIndex(address, "]") + if portIdx > ipv6Idx { + address = address[:portIdx] + } + return address +} + type tlsHandshakeTimeoutError struct{} func (tlsHandshakeTimeoutError) Timeout() bool { return true } From 533039e6d8cfe9c2a6c379f614bb7aa0645aca80 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Thu, 29 Dec 2016 16:07:22 +0100 Subject: [PATCH 7/9] proxy: Removed leftover restriction to HTTP/1.1 --- caddyhttp/proxy/reverseproxy.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 021025a0..552c1ab9 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -205,10 +205,6 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, } rp.Director(outreq) - outreq.Proto = "HTTP/1.1" - outreq.ProtoMajor = 1 - outreq.ProtoMinor = 1 - outreq.Close = false res, err := transport.RoundTrip(outreq) if err != nil { From 4babe4b201eef3e9794851b74f734a03338f7a20 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Fri, 30 Dec 2016 18:13:14 +0100 Subject: [PATCH 8/9] proxy: Added support for HTTP trailers --- caddyhttp/proxy/proxy.go | 24 ++++++++++++--- caddyhttp/proxy/proxy_test.go | 31 +++++++++++++++++++ caddyhttp/proxy/reverseproxy.go | 53 +++++++++++++++++++++++++++------ 3 files changed, 95 insertions(+), 13 deletions(-) diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index fe959791..0f48a61f 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -247,12 +247,28 @@ func createUpstreamRequest(r *http.Request) *http.Request { outreq.URL.Opaque = outreq.URL.RawPath } + // We are modifying the same underlying map from req (shallow + // copied above) so we only copy it if necessary. + copiedHeaders := false + + // Remove hop-by-hop headers listed in the "Connection" header. + // See RFC 2616, section 14.10. + if c := outreq.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + if !copiedHeaders { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, r.Header) + copiedHeaders = true + } + outreq.Header.Del(f) + } + } + } + // Remove hop-by-hop headers to the backend. Especially // important is "Connection" because we want a persistent - // connection, regardless of what the client sent to us. This - // is modifying the same underlying map from r (shallow - // copied above) so we only copy it if necessary. - var copiedHeaders bool + // connection, regardless of what the client sent to us. for _, h := range hopHeaders { if outreq.Header.Get(h) != "" { if !copiedHeaders { diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index 85897c00..686a79c5 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -42,10 +42,32 @@ func TestReverseProxy(t *testing.T) { log.SetOutput(ioutil.Discard) defer log.SetOutput(os.Stderr) + verifyHeaders := func(headers http.Header, trailers http.Header) { + if headers.Get("X-Header") != "header-value" { + t.Error("Expected header 'X-Header' to be proxied properly") + } + + if trailers == nil { + t.Error("Expected to receive trailers") + } + if trailers.Get("X-Trailer") != "trailer-value" { + t.Error("Expected header 'X-Trailer' to be proxied properly") + } + } + var requestReceived bool backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // read the body (even if it's empty) to make Go parse trailers + io.Copy(ioutil.Discard, r.Body) + verifyHeaders(r.Header, r.Trailer) + requestReceived = true + + w.Header().Set("Trailer", "X-Trailer") + w.Header().Set("X-Header", "header-value") + w.WriteHeader(http.StatusOK) w.Write([]byte("Hello, client")) + w.Header().Set("X-Trailer", "trailer-value") })) defer backend.Close() @@ -59,12 +81,21 @@ func TestReverseProxy(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() + r.ContentLength = -1 // force chunked encoding (required for trailers) + r.Header.Set("X-Header", "header-value") + r.Trailer = map[string][]string{ + "X-Trailer": {"trailer-value"}, + } + p.ServeHTTP(w, r) if !requestReceived { t.Error("Expected backend to receive request, but it didn't") } + res := w.Result() + verifyHeaders(res.Header, res.Trailer) + // Make sure {upstream} placeholder is set rr := httpserver.NewResponseRecorder(httptest.NewRecorder()) rr.Replacer = httpserver.NewReplacer(r, rr, "-") diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 552c1ab9..a59f4bc8 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -211,10 +211,27 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, return err } + isWebsocket := res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" + + // Remove hop-by-hop headers listed in the + // "Connection" header of the response. + if c := res.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + res.Header.Del(f) + } + } + } + + for _, h := range hopHeaders { + res.Header.Del(h) + } + if respUpdateFn != nil { respUpdateFn(res) } - if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" { + + if isWebsocket { res.Body.Close() hj, ok := rw.(http.Hijacker) if !ok { @@ -246,13 +263,30 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, go pooledIoCopy(backendConn, conn) // write tcp stream to backend pooledIoCopy(conn, backendConn) // read tcp stream from backend } else { - defer res.Body.Close() - for _, h := range hopHeaders { - res.Header.Del(h) - } copyHeader(rw.Header(), res.Header) + + // The "Trailer" header isn't included in the Transport's response, + // at least for *http.Transport. Build it up from Trailer. + if len(res.Trailer) > 0 { + trailerKeys := make([]string, 0, len(res.Trailer)) + for k := range res.Trailer { + trailerKeys = append(trailerKeys, k) + } + rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + } + rw.WriteHeader(res.StatusCode) + if len(res.Trailer) > 0 { + // Force chunking if we saw a response trailer. + // This prevents net/http from calculating the length for short + // bodies and adding a Content-Length. + if fl, ok := rw.(http.Flusher); ok { + fl.Flush() + } + } rp.copyResponse(rw, res.Body) + res.Body.Close() // close now, instead of defer, to populate res.Trailer + copyHeader(rw.Header(), res.Trailer) } return nil @@ -305,16 +339,17 @@ func copyHeader(dst, src http.Header) { // Hop-by-hop headers. These are removed when sent to the backend. // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html var hopHeaders = []string{ + "Alt-Svc", + "Alternate-Protocol", "Connection", "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization", - "Te", // canonicalized version of "TE" - "Trailers", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522 "Transfer-Encoding", "Upgrade", - "Alternate-Protocol", - "Alt-Svc", } type respUpdateFn func(resp *http.Response) From 04bee0f36d218edf0593b2c2254974f35d00276e Mon Sep 17 00:00:00 2001 From: Sawood Alam Date: Sat, 31 Dec 2016 22:29:14 -0500 Subject: [PATCH 9/9] Implementing custom PathClean function to allow masking, closes #1298 (#1317) * Added path cleanup functions with masking to preserve certain patterns + unit tests, #1298 * Use custom PathClean function instead of path.Clean to apply masks to preserve protocol separator in the path * Indentation corrected in the test data map to pass the lint * Fixing ineffassign of a temporary string variable * Improved variable naming and documentation * Improved variable naming * Added benchmarks and improved variable naming in tests * Removed unnecessary value capture when iterating over a map for keys * A typo correction --- caddyhttp/httpserver/pathcleaner.go | 76 ++++++++++++++ caddyhttp/httpserver/pathcleaner_test.go | 120 +++++++++++++++++++++++ caddyhttp/httpserver/server.go | 3 +- 3 files changed, 197 insertions(+), 2 deletions(-) create mode 100644 caddyhttp/httpserver/pathcleaner.go create mode 100644 caddyhttp/httpserver/pathcleaner_test.go diff --git a/caddyhttp/httpserver/pathcleaner.go b/caddyhttp/httpserver/pathcleaner.go new file mode 100644 index 00000000..cf5f1aa0 --- /dev/null +++ b/caddyhttp/httpserver/pathcleaner.go @@ -0,0 +1,76 @@ +package httpserver + +import ( + "math/rand" + "path" + "strings" + "time" +) + +// CleanMaskedPath prevents one or more of the path cleanup operations: +// - collapse multiple slashes into one +// - eliminate "/." (current directory) +// - eliminate "/.." +// by masking certain patterns in the path with a temporary random string. +// This could be helpful when certain patterns in the path are desired to be preserved +// that would otherwise be changed by path.Clean(). +// One such use case is the presence of the double slashes as protocol separator +// (e.g., /api/endpoint/http://example.com). +// This is a common pattern in many applications to allow passing URIs as path argument. +func CleanMaskedPath(reqPath string, masks ...string) string { + var replacerVal string + maskMap := make(map[string]string) + + // Iterate over supplied masks and create temporary replacement strings + // only for the masks that are present in the path, then replace all occurrences + for _, mask := range masks { + if strings.Index(reqPath, mask) >= 0 { + replacerVal = "/_caddy" + generateRandomString() + "__" + maskMap[mask] = replacerVal + reqPath = strings.Replace(reqPath, mask, replacerVal, -1) + } + } + + reqPath = path.Clean(reqPath) + + // Revert the replaced masks after path cleanup + for mask, replacerVal := range maskMap { + reqPath = strings.Replace(reqPath, replacerVal, mask, -1) + } + return reqPath +} + +// CleanPath calls CleanMaskedPath() with the default mask of "://" +// to preserve double slashes of protocols +// such as "http://", "https://", and "ftp://" etc. +func CleanPath(reqPath string) string { + return CleanMaskedPath(reqPath, "://") +} + +// An efficient and fast method for random string generation. +// Inspired by http://stackoverflow.com/a/31832326. +const randomStringLength = 4 +const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +const ( + letterIdxBits = 6 + letterIdxMask = 1<= 0; { + if remain == 0 { + cache, remain = src.Int63(), letterIdxMax + } + if idx := int(cache & letterIdxMask); idx < len(letterBytes) { + b[i] = letterBytes[idx] + i-- + } + cache >>= letterIdxBits + remain-- + } + return string(b) +} diff --git a/caddyhttp/httpserver/pathcleaner_test.go b/caddyhttp/httpserver/pathcleaner_test.go new file mode 100644 index 00000000..34f39590 --- /dev/null +++ b/caddyhttp/httpserver/pathcleaner_test.go @@ -0,0 +1,120 @@ +package httpserver + +import ( + "path" + "testing" +) + +var paths = map[string]map[string]string{ + "/../a/b/../././/c": { + "preserve_all": "/../a/b/../././/c", + "preserve_protocol": "/a/c", + "preserve_slashes": "/a//c", + "preserve_dots": "/../a/b/../././c", + "clean_all": "/a/c", + }, + "/path/https://www.google.com": { + "preserve_all": "/path/https://www.google.com", + "preserve_protocol": "/path/https://www.google.com", + "preserve_slashes": "/path/https://www.google.com", + "preserve_dots": "/path/https:/www.google.com", + "clean_all": "/path/https:/www.google.com", + }, + "/a/b/../././/c/http://example.com/foo//bar/../blah": { + "preserve_all": "/a/b/../././/c/http://example.com/foo//bar/../blah", + "preserve_protocol": "/a/c/http://example.com/foo/blah", + "preserve_slashes": "/a//c/http://example.com/foo/blah", + "preserve_dots": "/a/b/../././c/http:/example.com/foo/bar/../blah", + "clean_all": "/a/c/http:/example.com/foo/blah", + }, +} + +func assertEqual(t *testing.T, expected, received string) { + if expected != received { + t.Errorf("\tExpected: %s\n\t\t\tRecieved: %s", expected, received) + } +} + +func maskedTestRunner(t *testing.T, variation string, masks ...string) { + for reqPath, transformation := range paths { + assertEqual(t, transformation[variation], CleanMaskedPath(reqPath, masks...)) + } +} + +// No need to test the built-in path.Clean() function. +// However, it could be useful to cross-examine the test dataset. +func TestPathClean(t *testing.T) { + for reqPath, transformation := range paths { + assertEqual(t, transformation["clean_all"], path.Clean(reqPath)) + } +} + +func TestCleanAll(t *testing.T) { + maskedTestRunner(t, "clean_all") +} + +func TestPreserveAll(t *testing.T) { + maskedTestRunner(t, "preserve_all", "//", "/..", "/.") +} + +func TestPreserveProtocol(t *testing.T) { + maskedTestRunner(t, "preserve_protocol", "://") +} + +func TestPreserveSlashes(t *testing.T) { + maskedTestRunner(t, "preserve_slashes", "//") +} + +func TestPreserveDots(t *testing.T) { + maskedTestRunner(t, "preserve_dots", "/..", "/.") +} + +func TestDefaultMask(t *testing.T) { + for reqPath, transformation := range paths { + assertEqual(t, transformation["preserve_protocol"], CleanPath(reqPath)) + } +} + +func maskedBenchmarkRunner(b *testing.B, masks ...string) { + for n := 0; n < b.N; n++ { + for reqPath := range paths { + CleanMaskedPath(reqPath, masks...) + } + } +} + +func BenchmarkPathClean(b *testing.B) { + for n := 0; n < b.N; n++ { + for reqPath := range paths { + path.Clean(reqPath) + } + } +} + +func BenchmarkCleanAll(b *testing.B) { + maskedBenchmarkRunner(b) +} + +func BenchmarkPreserveAll(b *testing.B) { + maskedBenchmarkRunner(b, "//", "/..", "/.") +} + +func BenchmarkPreserveProtocol(b *testing.B) { + maskedBenchmarkRunner(b, "://") +} + +func BenchmarkPreserveSlashes(b *testing.B) { + maskedBenchmarkRunner(b, "//") +} + +func BenchmarkPreserveDots(b *testing.B) { + maskedBenchmarkRunner(b, "/..", "/.") +} + +func BenchmarkDefaultMask(b *testing.B) { + for n := 0; n < b.N; n++ { + for reqPath := range paths { + CleanPath(reqPath) + } + } +} diff --git a/caddyhttp/httpserver/server.go b/caddyhttp/httpserver/server.go index 49956ab3..7cc36040 100644 --- a/caddyhttp/httpserver/server.go +++ b/caddyhttp/httpserver/server.go @@ -9,7 +9,6 @@ import ( "net" "net/http" "os" - "path" "runtime" "strings" "sync" @@ -351,7 +350,7 @@ func sanitizePath(r *http.Request) { if r.URL.Path == "/" { return } - cleanedPath := path.Clean(r.URL.Path) + cleanedPath := CleanPath(r.URL.Path) if cleanedPath == "." { r.URL.Path = "/" } else {