From 6352c9054ab7a72460f6b75c0be797eaff9d0a67 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Mon, 26 Dec 2016 20:40:44 +0100 Subject: [PATCH 1/8] Fixed proxy not respecting the -http2 flag --- caddyhttp/proxy/reverseproxy.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index cfb466c7..050e7fb9 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -148,7 +148,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * } else { transport.MaxIdleConnsPerHost = keepalive } - http2.ConfigureTransport(transport) + if httpserver.HTTP2 { + http2.ConfigureTransport(transport) + } rp.Transport = transport } return rp @@ -168,10 +170,15 @@ func (rp *ReverseProxy) UseInsecureTransport() { TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } - http2.ConfigureTransport(transport) + if httpserver.HTTP2 { + http2.ConfigureTransport(transport) + } rp.Transport = transport } else if transport, ok := rp.Transport.(*http.Transport); ok { transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + // No http2.ConfigureTransport() here. + // For now this is only added in places where + // an http.Transport is actually created. } } From 53635ba538fb41bbfd38b70282cd6d59e693c0b3 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Mon, 26 Dec 2016 20:42:00 +0100 Subject: [PATCH 2/8] Fixed panic due to 0-length buffers being passed to io.CopyBuffer --- caddyhttp/proxy/reverseproxy.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 050e7fb9..49df1603 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -253,7 +253,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, } func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { - buf := bufferPool.Get() + buf := bufferPool.Get().([]byte) defer bufferPool.Put(buf) if rp.FlushInterval != 0 { @@ -268,7 +268,10 @@ func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { dst = mlw } } - io.CopyBuffer(dst, src, buf.([]byte)) + + // `CopyBuffer` only uses `buf` up to it's length and + // panics if it's 0 => Extend it's length up to it's capacity. + io.CopyBuffer(dst, src, buf[:cap(buf)]) } // skip these headers if they already exist. From 9f9ad21aaa526768b638105fbc6675c80fdfa0ce Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Mon, 26 Dec 2016 20:52:36 +0100 Subject: [PATCH 3/8] Fixed #1292: Failure to proxy WebSockets over HTTPS This issue was caused by connHijackerTransport trying to record HTTP response headers by "hijacking" the Read() method of the plain net.Conn. This does not simply work over TLS though since this will record the TLS handshake and encrypted data instead of the actual content. This commit fixes the problem by providing an alternative transport.DialTLS which correctly hijacks the overlying tls.Conn instead. --- caddyhttp/proxy/reverseproxy.go | 184 ++++++++++++++++++++++++-------- 1 file changed, 140 insertions(+), 44 deletions(-) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 49df1603..c980e905 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -27,6 +27,11 @@ import ( "github.com/mholt/caddy/caddyhttp/httpserver" ) +var defaultDialer = &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, +} + var bufferPool = sync.Pool{New: createBuffer} func createBuffer() interface{} { @@ -135,11 +140,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * // just use default transport, to avoid creating // a brand new transport transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, + Proxy: http.ProxyFromEnvironment, + Dial: defaultDialer.Dial, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } @@ -162,11 +164,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * func (rp *ReverseProxy) UseInsecureTransport() { if rp.Transport == nil { transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, + Proxy: http.ProxyFromEnvironment, + Dial: defaultDialer.Dial, TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } @@ -341,51 +340,148 @@ type connHijackerTransport struct { } func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { - transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, - TLSHandshakeTimeout: 10 * time.Second, + t := &http.Transport{ MaxIdleConnsPerHost: -1, } - if base != nil { - if baseTransport, ok := base.(*http.Transport); ok { - transport.Proxy = baseTransport.Proxy - transport.TLSClientConfig = baseTransport.TLSClientConfig - transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout - transport.Dial = baseTransport.Dial - transport.DialTLS = baseTransport.DialTLS - transport.MaxIdleConnsPerHost = -1 + if b, _ := base.(*http.Transport); b != nil { + t.Proxy = b.Proxy + t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig) + t.TLSClientConfig.NextProtos = nil + t.TLSHandshakeTimeout = b.TLSHandshakeTimeout + t.Dial = b.Dial + t.DialTLS = b.DialTLS + } else { + t.Proxy = http.ProxyFromEnvironment + t.TLSHandshakeTimeout = 10 * time.Second + } + hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]} + + dial := getTransportDial(t) + dialTLS := getTransportDialTLS(t) + + t.Dial = func(network, addr string) (net.Conn, error) { + c, err := dial(network, addr) + hj.Conn = c + return &hijackedConn{c, hj}, err + } + + if dialTLS != nil { + t.DialTLS = func(network, addr string) (net.Conn, error) { + c, err := dialTLS(network, addr) + hj.Conn = c + return &hijackedConn{c, hj}, err } } - hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]} - oldDial := transport.Dial - oldDialTLS := transport.DialTLS - if oldDial == nil { - oldDial = (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial + + return hj +} + +// getTransportDial always returns a plain Dialer +// and defaults to the existing t.Dial. +func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) { + if t.Dial != nil { + return t.Dial } - hjTransport.Dial = func(network, addr string) (net.Conn, error) { - c, err := oldDial(network, addr) - hjTransport.Conn = c - return &hijackedConn{c, hjTransport}, err + return defaultDialer.Dial +} + +// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil +// and defaults to the existing t.DialTLS. +func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) { + if t.DialTLS != nil { + return t.DialTLS } - if oldDialTLS != nil { - hjTransport.DialTLS = func(network, addr string) (net.Conn, error) { - c, err := oldDialTLS(network, addr) - hjTransport.Conn = c - return &hijackedConn{c, hjTransport}, err + if t.TLSClientConfig == nil { + return nil + } + + // newConnHijackerTransport will modify t.Dial after calling this method + // => Create a backup reference. + plainDial := getTransportDial(t) + + return func(network, addr string) (net.Conn, error) { + plainConn, err := plainDial(network, addr) + if err != nil { + return nil, err } + + tlsConn := tls.Client(plainConn, t.TLSClientConfig) + errc := make(chan error, 2) + var timer *time.Timer + if d := t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + err := tlsConn.Handshake() + if timer != nil { + timer.Stop() + } + errc <- err + }() + if err := <-errc; err != nil { + plainConn.Close() + return nil, err + } + if !t.TLSClientConfig.InsecureSkipVerify { + serverName := t.TLSClientConfig.ServerName + if serverName == "" { + serverName = addr + idx := strings.LastIndex(serverName, ":") + if idx != -1 { + serverName = serverName[:idx] + } + } + if err := tlsConn.VerifyHostname(serverName); err != nil { + plainConn.Close() + return nil, err + } + } + + return tlsConn, nil + } +} + +type tlsHandshakeTimeoutError struct{} + +func (tlsHandshakeTimeoutError) Timeout() bool { return true } +func (tlsHandshakeTimeoutError) Temporary() bool { return true } +func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } + +// cloneTLSClientConfig is like cloneTLSConfig but omits +// the fields SessionTicketsDisabled and SessionTicketKey. +// This makes it safe to call cloneTLSClientConfig on a config +// in active use by a server. +func cloneTLSClientConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return &tls.Config{ + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled, + Renegotiation: cfg.Renegotiation, } - return hjTransport } func requestIsWebsocket(req *http.Request) bool { - return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")) + return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") } type writeFlusher interface { From 20483c23f85961f7b540f77aa3bbfc7228bca235 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Mon, 26 Dec 2016 20:53:18 +0100 Subject: [PATCH 4/8] Added end-to-end test case for #1292 --- caddyhttp/proxy/proxy_test.go | 75 ++++++++++++++++++++++++++++++----- 1 file changed, 65 insertions(+), 10 deletions(-) diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index 6359596c..85897c00 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -123,7 +123,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) { defer wsNop.Close() // Get proxy to use for the test - p := newWebSocketTestProxy(wsNop.URL) + p := newWebSocketTestProxy(wsNop.URL, false) // Create client request r := httptest.NewRequest("GET", "/", nil) @@ -148,7 +148,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { defer wsNop.Close() // Get proxy to use for the test - p := newWebSocketTestProxy(wsNop.URL) + p := newWebSocketTestProxy(wsNop.URL, false) // Create client request r := httptest.NewRequest("GET", "/", nil) @@ -189,7 +189,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) { defer wsEcho.Close() // Get proxy to use for the test - p := newWebSocketTestProxy(wsEcho.URL) + p := newWebSocketTestProxy(wsEcho.URL, false) // This is a full end-end test, so the proxy handler // has to be part of a server listening on a port. Our @@ -228,6 +228,52 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) { } } +func TestWebSocketReverseProxyFromWSSClient(t *testing.T) { + wsEcho := newTLSServer(websocket.Handler(func(ws *websocket.Conn) { + io.Copy(ws, ws) + })) + defer wsEcho.Close() + + p := newWebSocketTestProxy(wsEcho.URL, true) + + echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + defer echoProxy.Close() + + // Set up WebSocket client + url := strings.Replace(echoProxy.URL, "https://", "wss://", 1) + wsCfg, err := websocket.NewConfig(url, echoProxy.URL) + if err != nil { + t.Fatal(err) + } + wsCfg.TlsConfig = &tls.Config{InsecureSkipVerify: true} + ws, err := websocket.DialConfig(wsCfg) + + if err != nil { + t.Fatal(err) + } + defer ws.Close() + + // Send test message + trialMsg := "Is it working?" + + if sendErr := websocket.Message.Send(ws, trialMsg); sendErr != nil { + t.Fatal(sendErr) + } + + // It should be echoed back to us + var actualMsg string + + if rcvErr := websocket.Message.Receive(ws, &actualMsg); rcvErr != nil { + t.Fatal(rcvErr) + } + + if actualMsg != trialMsg { + t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) + } +} + func TestUnixSocketProxy(t *testing.T) { if runtime.GOOS == "windows" { return @@ -264,7 +310,7 @@ func TestUnixSocketProxy(t *testing.T) { defer ts.Close() url := strings.Replace(ts.URL, "http://", "unix:", 1) - p := newWebSocketTestProxy(url) + p := newWebSocketTestProxy(url, false) echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { p.ServeHTTP(w, r) @@ -982,10 +1028,14 @@ func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time. // redirect to the specified backendAddr. The function // also sets up the rules/environment for testing WebSocket // proxy. -func newWebSocketTestProxy(backendAddr string) *Proxy { +func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy { return &Proxy{ - Next: httpserver.EmptyNext, // prevents panic in some cases when test fails - Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}}, + Next: httpserver.EmptyNext, // prevents panic in some cases when test fails + Upstreams: []Upstream{&fakeWsUpstream{ + name: backendAddr, + without: "", + insecure: insecure, + }}, } } @@ -997,8 +1047,9 @@ func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy { } type fakeWsUpstream struct { - name string - without string + name string + without string + insecure bool } func (u *fakeWsUpstream) From() string { @@ -1007,13 +1058,17 @@ func (u *fakeWsUpstream) From() string { func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost { uri, _ := url.Parse(u.name) - return &UpstreamHost{ + host := &UpstreamHost{ Name: u.name, ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost), UpstreamHeaders: http.Header{ "Connection": {"{>Connection}"}, "Upgrade": {"{>Upgrade}"}}, } + if u.insecure { + host.ReverseProxy.UseInsecureTransport() + } + return host } func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true } From 153d4a5ac62ffcef77df8ad29ae13e7d68b188b6 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Wed, 28 Dec 2016 17:17:52 +0100 Subject: [PATCH 5/8] proxy: Improved handling of bufferPool --- caddyhttp/proxy/reverseproxy.go | 39 +++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index c980e905..d32be7cd 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -27,15 +27,28 @@ import ( "github.com/mholt/caddy/caddyhttp/httpserver" ) -var defaultDialer = &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, -} +var ( + defaultDialer = &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } -var bufferPool = sync.Pool{New: createBuffer} + bufferPool = sync.Pool{New: createBuffer} +) func createBuffer() interface{} { - return make([]byte, 32*1024) + return make([]byte, 0, 32*1024) +} + +func pooledIoCopy(dst io.Writer, src io.Reader) { + buf := bufferPool.Get().([]byte) + defer bufferPool.Put(buf) + + // CopyBuffer only uses buf up to its length and panics if it's 0. + // Due to that we extend buf's length to its capacity here and + // ensure it's always non-zero. + bufCap := cap(buf) + io.CopyBuffer(dst, src, buf[0:bufCap:bufCap]) } // onExitFlushLoop is a callback set by tests to detect the state of the @@ -234,10 +247,8 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, } defer backendConn.Close() - go func() { - io.Copy(backendConn, conn) // write tcp stream to backend. - }() - io.Copy(conn, backendConn) // read tcp stream from backend. + go pooledIoCopy(backendConn, conn) // write tcp stream to backend + pooledIoCopy(conn, backendConn) // read tcp stream from backend } else { defer res.Body.Close() for _, h := range hopHeaders { @@ -252,9 +263,6 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, } func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { - buf := bufferPool.Get().([]byte) - defer bufferPool.Put(buf) - if rp.FlushInterval != 0 { if wf, ok := dst.(writeFlusher); ok { mlw := &maxLatencyWriter{ @@ -267,10 +275,7 @@ func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { dst = mlw } } - - // `CopyBuffer` only uses `buf` up to it's length and - // panics if it's 0 => Extend it's length up to it's capacity. - io.CopyBuffer(dst, src, buf[:cap(buf)]) + pooledIoCopy(dst, src) } // skip these headers if they already exist. From b857265f9c759fde2c3a10580ea85a9b2f83d968 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Wed, 28 Dec 2016 17:20:31 +0100 Subject: [PATCH 6/8] proxy: Fixed support for TLS verification of WebSocket connections --- caddyhttp/proxy/reverseproxy.go | 69 +++++++++++++++++++++------------ 1 file changed, 45 insertions(+), 24 deletions(-) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index d32be7cd..021025a0 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -349,9 +349,14 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { MaxIdleConnsPerHost: -1, } if b, _ := base.(*http.Transport); b != nil { + tlsClientConfig := b.TLSClientConfig + if tlsClientConfig.NextProtos != nil { + tlsClientConfig = cloneTLSClientConfig(tlsClientConfig) + tlsClientConfig.NextProtos = nil + } + t.Proxy = b.Proxy - t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig) - t.TLSClientConfig.NextProtos = nil + t.TLSClientConfig = tlsClientConfig t.TLSHandshakeTimeout = b.TLSHandshakeTimeout t.Dial = b.Dial t.DialTLS = b.DialTLS @@ -363,19 +368,15 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { dial := getTransportDial(t) dialTLS := getTransportDialTLS(t) - t.Dial = func(network, addr string) (net.Conn, error) { c, err := dial(network, addr) hj.Conn = c return &hijackedConn{c, hj}, err } - - if dialTLS != nil { - t.DialTLS = func(network, addr string) (net.Conn, error) { - c, err := dialTLS(network, addr) - hj.Conn = c - return &hijackedConn{c, hj}, err - } + t.DialTLS = func(network, addr string) (net.Conn, error) { + c, err := dialTLS(network, addr) + hj.Conn = c + return &hijackedConn{c, hj}, err } return hj @@ -390,27 +391,35 @@ func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, e return defaultDialer.Dial } -// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil +// getTransportDial always returns a TLS Dialer // and defaults to the existing t.DialTLS. func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) { if t.DialTLS != nil { return t.DialTLS } - if t.TLSClientConfig == nil { - return nil - } // newConnHijackerTransport will modify t.Dial after calling this method // => Create a backup reference. plainDial := getTransportDial(t) + // The following DialTLS implementation stems from the Go stdlib and + // is identical to what happens if DialTLS is not provided. + // Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051 return func(network, addr string) (net.Conn, error) { plainConn, err := plainDial(network, addr) if err != nil { return nil, err } - tlsConn := tls.Client(plainConn, t.TLSClientConfig) + tlsClientConfig := t.TLSClientConfig + if tlsClientConfig == nil { + tlsClientConfig = &tls.Config{} + } + if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" { + tlsClientConfig.ServerName = stripPort(addr) + } + + tlsConn := tls.Client(plainConn, tlsClientConfig) errc := make(chan error, 2) var timer *time.Timer if d := t.TLSHandshakeTimeout; d != 0 { @@ -429,16 +438,12 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn plainConn.Close() return nil, err } - if !t.TLSClientConfig.InsecureSkipVerify { - serverName := t.TLSClientConfig.ServerName - if serverName == "" { - serverName = addr - idx := strings.LastIndex(serverName, ":") - if idx != -1 { - serverName = serverName[:idx] - } + if !tlsClientConfig.InsecureSkipVerify { + hostname := tlsClientConfig.ServerName + if hostname == "" { + hostname = stripPort(addr) } - if err := tlsConn.VerifyHostname(serverName); err != nil { + if err := tlsConn.VerifyHostname(hostname); err != nil { plainConn.Close() return nil, err } @@ -448,6 +453,22 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn } } +// stripPort returns address without its port if it has one and +// works with IP addresses as well as hostnames formatted as host:port. +// +// IPv6 addresses (excluding the port) must be enclosed in +// square brackets similar to the requirements of Go's stdlib. +func stripPort(address string) string { + // Keep in mind that the address might be a IPv6 address + // and thus contain a colon, but not have a port. + portIdx := strings.LastIndex(address, ":") + ipv6Idx := strings.LastIndex(address, "]") + if portIdx > ipv6Idx { + address = address[:portIdx] + } + return address +} + type tlsHandshakeTimeoutError struct{} func (tlsHandshakeTimeoutError) Timeout() bool { return true } From 533039e6d8cfe9c2a6c379f614bb7aa0645aca80 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Thu, 29 Dec 2016 16:07:22 +0100 Subject: [PATCH 7/8] proxy: Removed leftover restriction to HTTP/1.1 --- caddyhttp/proxy/reverseproxy.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 021025a0..552c1ab9 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -205,10 +205,6 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, } rp.Director(outreq) - outreq.Proto = "HTTP/1.1" - outreq.ProtoMajor = 1 - outreq.ProtoMinor = 1 - outreq.Close = false res, err := transport.RoundTrip(outreq) if err != nil { From 4babe4b201eef3e9794851b74f734a03338f7a20 Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Fri, 30 Dec 2016 18:13:14 +0100 Subject: [PATCH 8/8] proxy: Added support for HTTP trailers --- caddyhttp/proxy/proxy.go | 24 ++++++++++++--- caddyhttp/proxy/proxy_test.go | 31 +++++++++++++++++++ caddyhttp/proxy/reverseproxy.go | 53 +++++++++++++++++++++++++++------ 3 files changed, 95 insertions(+), 13 deletions(-) diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index fe959791..0f48a61f 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -247,12 +247,28 @@ func createUpstreamRequest(r *http.Request) *http.Request { outreq.URL.Opaque = outreq.URL.RawPath } + // We are modifying the same underlying map from req (shallow + // copied above) so we only copy it if necessary. + copiedHeaders := false + + // Remove hop-by-hop headers listed in the "Connection" header. + // See RFC 2616, section 14.10. + if c := outreq.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + if !copiedHeaders { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, r.Header) + copiedHeaders = true + } + outreq.Header.Del(f) + } + } + } + // Remove hop-by-hop headers to the backend. Especially // important is "Connection" because we want a persistent - // connection, regardless of what the client sent to us. This - // is modifying the same underlying map from r (shallow - // copied above) so we only copy it if necessary. - var copiedHeaders bool + // connection, regardless of what the client sent to us. for _, h := range hopHeaders { if outreq.Header.Get(h) != "" { if !copiedHeaders { diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index 85897c00..686a79c5 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -42,10 +42,32 @@ func TestReverseProxy(t *testing.T) { log.SetOutput(ioutil.Discard) defer log.SetOutput(os.Stderr) + verifyHeaders := func(headers http.Header, trailers http.Header) { + if headers.Get("X-Header") != "header-value" { + t.Error("Expected header 'X-Header' to be proxied properly") + } + + if trailers == nil { + t.Error("Expected to receive trailers") + } + if trailers.Get("X-Trailer") != "trailer-value" { + t.Error("Expected header 'X-Trailer' to be proxied properly") + } + } + var requestReceived bool backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // read the body (even if it's empty) to make Go parse trailers + io.Copy(ioutil.Discard, r.Body) + verifyHeaders(r.Header, r.Trailer) + requestReceived = true + + w.Header().Set("Trailer", "X-Trailer") + w.Header().Set("X-Header", "header-value") + w.WriteHeader(http.StatusOK) w.Write([]byte("Hello, client")) + w.Header().Set("X-Trailer", "trailer-value") })) defer backend.Close() @@ -59,12 +81,21 @@ func TestReverseProxy(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() + r.ContentLength = -1 // force chunked encoding (required for trailers) + r.Header.Set("X-Header", "header-value") + r.Trailer = map[string][]string{ + "X-Trailer": {"trailer-value"}, + } + p.ServeHTTP(w, r) if !requestReceived { t.Error("Expected backend to receive request, but it didn't") } + res := w.Result() + verifyHeaders(res.Header, res.Trailer) + // Make sure {upstream} placeholder is set rr := httpserver.NewResponseRecorder(httptest.NewRecorder()) rr.Replacer = httpserver.NewReplacer(r, rr, "-") diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 552c1ab9..a59f4bc8 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -211,10 +211,27 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, return err } + isWebsocket := res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" + + // Remove hop-by-hop headers listed in the + // "Connection" header of the response. + if c := res.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + res.Header.Del(f) + } + } + } + + for _, h := range hopHeaders { + res.Header.Del(h) + } + if respUpdateFn != nil { respUpdateFn(res) } - if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" { + + if isWebsocket { res.Body.Close() hj, ok := rw.(http.Hijacker) if !ok { @@ -246,13 +263,30 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, go pooledIoCopy(backendConn, conn) // write tcp stream to backend pooledIoCopy(conn, backendConn) // read tcp stream from backend } else { - defer res.Body.Close() - for _, h := range hopHeaders { - res.Header.Del(h) - } copyHeader(rw.Header(), res.Header) + + // The "Trailer" header isn't included in the Transport's response, + // at least for *http.Transport. Build it up from Trailer. + if len(res.Trailer) > 0 { + trailerKeys := make([]string, 0, len(res.Trailer)) + for k := range res.Trailer { + trailerKeys = append(trailerKeys, k) + } + rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + } + rw.WriteHeader(res.StatusCode) + if len(res.Trailer) > 0 { + // Force chunking if we saw a response trailer. + // This prevents net/http from calculating the length for short + // bodies and adding a Content-Length. + if fl, ok := rw.(http.Flusher); ok { + fl.Flush() + } + } rp.copyResponse(rw, res.Body) + res.Body.Close() // close now, instead of defer, to populate res.Trailer + copyHeader(rw.Header(), res.Trailer) } return nil @@ -305,16 +339,17 @@ func copyHeader(dst, src http.Header) { // Hop-by-hop headers. These are removed when sent to the backend. // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html var hopHeaders = []string{ + "Alt-Svc", + "Alternate-Protocol", "Connection", "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization", - "Te", // canonicalized version of "TE" - "Trailers", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522 "Transfer-Encoding", "Upgrade", - "Alternate-Protocol", - "Alt-Svc", } type respUpdateFn func(resp *http.Response)