diff --git a/caddyhttp/httpserver/pathcleaner.go b/caddyhttp/httpserver/pathcleaner.go new file mode 100644 index 00000000..cf5f1aa0 --- /dev/null +++ b/caddyhttp/httpserver/pathcleaner.go @@ -0,0 +1,76 @@ +package httpserver + +import ( + "math/rand" + "path" + "strings" + "time" +) + +// CleanMaskedPath prevents one or more of the path cleanup operations: +// - collapse multiple slashes into one +// - eliminate "/." (current directory) +// - eliminate "/.." +// by masking certain patterns in the path with a temporary random string. +// This could be helpful when certain patterns in the path are desired to be preserved +// that would otherwise be changed by path.Clean(). +// One such use case is the presence of the double slashes as protocol separator +// (e.g., /api/endpoint/http://example.com). +// This is a common pattern in many applications to allow passing URIs as path argument. +func CleanMaskedPath(reqPath string, masks ...string) string { + var replacerVal string + maskMap := make(map[string]string) + + // Iterate over supplied masks and create temporary replacement strings + // only for the masks that are present in the path, then replace all occurrences + for _, mask := range masks { + if strings.Index(reqPath, mask) >= 0 { + replacerVal = "/_caddy" + generateRandomString() + "__" + maskMap[mask] = replacerVal + reqPath = strings.Replace(reqPath, mask, replacerVal, -1) + } + } + + reqPath = path.Clean(reqPath) + + // Revert the replaced masks after path cleanup + for mask, replacerVal := range maskMap { + reqPath = strings.Replace(reqPath, replacerVal, mask, -1) + } + return reqPath +} + +// CleanPath calls CleanMaskedPath() with the default mask of "://" +// to preserve double slashes of protocols +// such as "http://", "https://", and "ftp://" etc. +func CleanPath(reqPath string) string { + return CleanMaskedPath(reqPath, "://") +} + +// An efficient and fast method for random string generation. +// Inspired by http://stackoverflow.com/a/31832326. +const randomStringLength = 4 +const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +const ( + letterIdxBits = 6 + letterIdxMask = 1<= 0; { + if remain == 0 { + cache, remain = src.Int63(), letterIdxMax + } + if idx := int(cache & letterIdxMask); idx < len(letterBytes) { + b[i] = letterBytes[idx] + i-- + } + cache >>= letterIdxBits + remain-- + } + return string(b) +} diff --git a/caddyhttp/httpserver/pathcleaner_test.go b/caddyhttp/httpserver/pathcleaner_test.go new file mode 100644 index 00000000..34f39590 --- /dev/null +++ b/caddyhttp/httpserver/pathcleaner_test.go @@ -0,0 +1,120 @@ +package httpserver + +import ( + "path" + "testing" +) + +var paths = map[string]map[string]string{ + "/../a/b/../././/c": { + "preserve_all": "/../a/b/../././/c", + "preserve_protocol": "/a/c", + "preserve_slashes": "/a//c", + "preserve_dots": "/../a/b/../././c", + "clean_all": "/a/c", + }, + "/path/https://www.google.com": { + "preserve_all": "/path/https://www.google.com", + "preserve_protocol": "/path/https://www.google.com", + "preserve_slashes": "/path/https://www.google.com", + "preserve_dots": "/path/https:/www.google.com", + "clean_all": "/path/https:/www.google.com", + }, + "/a/b/../././/c/http://example.com/foo//bar/../blah": { + "preserve_all": "/a/b/../././/c/http://example.com/foo//bar/../blah", + "preserve_protocol": "/a/c/http://example.com/foo/blah", + "preserve_slashes": "/a//c/http://example.com/foo/blah", + "preserve_dots": "/a/b/../././c/http:/example.com/foo/bar/../blah", + "clean_all": "/a/c/http:/example.com/foo/blah", + }, +} + +func assertEqual(t *testing.T, expected, received string) { + if expected != received { + t.Errorf("\tExpected: %s\n\t\t\tRecieved: %s", expected, received) + } +} + +func maskedTestRunner(t *testing.T, variation string, masks ...string) { + for reqPath, transformation := range paths { + assertEqual(t, transformation[variation], CleanMaskedPath(reqPath, masks...)) + } +} + +// No need to test the built-in path.Clean() function. +// However, it could be useful to cross-examine the test dataset. +func TestPathClean(t *testing.T) { + for reqPath, transformation := range paths { + assertEqual(t, transformation["clean_all"], path.Clean(reqPath)) + } +} + +func TestCleanAll(t *testing.T) { + maskedTestRunner(t, "clean_all") +} + +func TestPreserveAll(t *testing.T) { + maskedTestRunner(t, "preserve_all", "//", "/..", "/.") +} + +func TestPreserveProtocol(t *testing.T) { + maskedTestRunner(t, "preserve_protocol", "://") +} + +func TestPreserveSlashes(t *testing.T) { + maskedTestRunner(t, "preserve_slashes", "//") +} + +func TestPreserveDots(t *testing.T) { + maskedTestRunner(t, "preserve_dots", "/..", "/.") +} + +func TestDefaultMask(t *testing.T) { + for reqPath, transformation := range paths { + assertEqual(t, transformation["preserve_protocol"], CleanPath(reqPath)) + } +} + +func maskedBenchmarkRunner(b *testing.B, masks ...string) { + for n := 0; n < b.N; n++ { + for reqPath := range paths { + CleanMaskedPath(reqPath, masks...) + } + } +} + +func BenchmarkPathClean(b *testing.B) { + for n := 0; n < b.N; n++ { + for reqPath := range paths { + path.Clean(reqPath) + } + } +} + +func BenchmarkCleanAll(b *testing.B) { + maskedBenchmarkRunner(b) +} + +func BenchmarkPreserveAll(b *testing.B) { + maskedBenchmarkRunner(b, "//", "/..", "/.") +} + +func BenchmarkPreserveProtocol(b *testing.B) { + maskedBenchmarkRunner(b, "://") +} + +func BenchmarkPreserveSlashes(b *testing.B) { + maskedBenchmarkRunner(b, "//") +} + +func BenchmarkPreserveDots(b *testing.B) { + maskedBenchmarkRunner(b, "/..", "/.") +} + +func BenchmarkDefaultMask(b *testing.B) { + for n := 0; n < b.N; n++ { + for reqPath := range paths { + CleanPath(reqPath) + } + } +} diff --git a/caddyhttp/httpserver/server.go b/caddyhttp/httpserver/server.go index 49956ab3..7cc36040 100644 --- a/caddyhttp/httpserver/server.go +++ b/caddyhttp/httpserver/server.go @@ -9,7 +9,6 @@ import ( "net" "net/http" "os" - "path" "runtime" "strings" "sync" @@ -351,7 +350,7 @@ func sanitizePath(r *http.Request) { if r.URL.Path == "/" { return } - cleanedPath := path.Clean(r.URL.Path) + cleanedPath := CleanPath(r.URL.Path) if cleanedPath == "." { r.URL.Path = "/" } else { diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index fe959791..0f48a61f 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -247,12 +247,28 @@ func createUpstreamRequest(r *http.Request) *http.Request { outreq.URL.Opaque = outreq.URL.RawPath } + // We are modifying the same underlying map from req (shallow + // copied above) so we only copy it if necessary. + copiedHeaders := false + + // Remove hop-by-hop headers listed in the "Connection" header. + // See RFC 2616, section 14.10. + if c := outreq.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + if !copiedHeaders { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, r.Header) + copiedHeaders = true + } + outreq.Header.Del(f) + } + } + } + // Remove hop-by-hop headers to the backend. Especially // important is "Connection" because we want a persistent - // connection, regardless of what the client sent to us. This - // is modifying the same underlying map from r (shallow - // copied above) so we only copy it if necessary. - var copiedHeaders bool + // connection, regardless of what the client sent to us. for _, h := range hopHeaders { if outreq.Header.Get(h) != "" { if !copiedHeaders { diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index 6359596c..686a79c5 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -42,10 +42,32 @@ func TestReverseProxy(t *testing.T) { log.SetOutput(ioutil.Discard) defer log.SetOutput(os.Stderr) + verifyHeaders := func(headers http.Header, trailers http.Header) { + if headers.Get("X-Header") != "header-value" { + t.Error("Expected header 'X-Header' to be proxied properly") + } + + if trailers == nil { + t.Error("Expected to receive trailers") + } + if trailers.Get("X-Trailer") != "trailer-value" { + t.Error("Expected header 'X-Trailer' to be proxied properly") + } + } + var requestReceived bool backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // read the body (even if it's empty) to make Go parse trailers + io.Copy(ioutil.Discard, r.Body) + verifyHeaders(r.Header, r.Trailer) + requestReceived = true + + w.Header().Set("Trailer", "X-Trailer") + w.Header().Set("X-Header", "header-value") + w.WriteHeader(http.StatusOK) w.Write([]byte("Hello, client")) + w.Header().Set("X-Trailer", "trailer-value") })) defer backend.Close() @@ -59,12 +81,21 @@ func TestReverseProxy(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() + r.ContentLength = -1 // force chunked encoding (required for trailers) + r.Header.Set("X-Header", "header-value") + r.Trailer = map[string][]string{ + "X-Trailer": {"trailer-value"}, + } + p.ServeHTTP(w, r) if !requestReceived { t.Error("Expected backend to receive request, but it didn't") } + res := w.Result() + verifyHeaders(res.Header, res.Trailer) + // Make sure {upstream} placeholder is set rr := httpserver.NewResponseRecorder(httptest.NewRecorder()) rr.Replacer = httpserver.NewReplacer(r, rr, "-") @@ -123,7 +154,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) { defer wsNop.Close() // Get proxy to use for the test - p := newWebSocketTestProxy(wsNop.URL) + p := newWebSocketTestProxy(wsNop.URL, false) // Create client request r := httptest.NewRequest("GET", "/", nil) @@ -148,7 +179,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { defer wsNop.Close() // Get proxy to use for the test - p := newWebSocketTestProxy(wsNop.URL) + p := newWebSocketTestProxy(wsNop.URL, false) // Create client request r := httptest.NewRequest("GET", "/", nil) @@ -189,7 +220,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) { defer wsEcho.Close() // Get proxy to use for the test - p := newWebSocketTestProxy(wsEcho.URL) + p := newWebSocketTestProxy(wsEcho.URL, false) // This is a full end-end test, so the proxy handler // has to be part of a server listening on a port. Our @@ -228,6 +259,52 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) { } } +func TestWebSocketReverseProxyFromWSSClient(t *testing.T) { + wsEcho := newTLSServer(websocket.Handler(func(ws *websocket.Conn) { + io.Copy(ws, ws) + })) + defer wsEcho.Close() + + p := newWebSocketTestProxy(wsEcho.URL, true) + + echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + defer echoProxy.Close() + + // Set up WebSocket client + url := strings.Replace(echoProxy.URL, "https://", "wss://", 1) + wsCfg, err := websocket.NewConfig(url, echoProxy.URL) + if err != nil { + t.Fatal(err) + } + wsCfg.TlsConfig = &tls.Config{InsecureSkipVerify: true} + ws, err := websocket.DialConfig(wsCfg) + + if err != nil { + t.Fatal(err) + } + defer ws.Close() + + // Send test message + trialMsg := "Is it working?" + + if sendErr := websocket.Message.Send(ws, trialMsg); sendErr != nil { + t.Fatal(sendErr) + } + + // It should be echoed back to us + var actualMsg string + + if rcvErr := websocket.Message.Receive(ws, &actualMsg); rcvErr != nil { + t.Fatal(rcvErr) + } + + if actualMsg != trialMsg { + t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) + } +} + func TestUnixSocketProxy(t *testing.T) { if runtime.GOOS == "windows" { return @@ -264,7 +341,7 @@ func TestUnixSocketProxy(t *testing.T) { defer ts.Close() url := strings.Replace(ts.URL, "http://", "unix:", 1) - p := newWebSocketTestProxy(url) + p := newWebSocketTestProxy(url, false) echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { p.ServeHTTP(w, r) @@ -982,10 +1059,14 @@ func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time. // redirect to the specified backendAddr. The function // also sets up the rules/environment for testing WebSocket // proxy. -func newWebSocketTestProxy(backendAddr string) *Proxy { +func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy { return &Proxy{ - Next: httpserver.EmptyNext, // prevents panic in some cases when test fails - Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}}, + Next: httpserver.EmptyNext, // prevents panic in some cases when test fails + Upstreams: []Upstream{&fakeWsUpstream{ + name: backendAddr, + without: "", + insecure: insecure, + }}, } } @@ -997,8 +1078,9 @@ func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy { } type fakeWsUpstream struct { - name string - without string + name string + without string + insecure bool } func (u *fakeWsUpstream) From() string { @@ -1007,13 +1089,17 @@ func (u *fakeWsUpstream) From() string { func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost { uri, _ := url.Parse(u.name) - return &UpstreamHost{ + host := &UpstreamHost{ Name: u.name, ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost), UpstreamHeaders: http.Header{ "Connection": {"{>Connection}"}, "Upgrade": {"{>Upgrade}"}}, } + if u.insecure { + host.ReverseProxy.UseInsecureTransport() + } + return host } func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true } diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index cfb466c7..a59f4bc8 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -27,10 +27,28 @@ import ( "github.com/mholt/caddy/caddyhttp/httpserver" ) -var bufferPool = sync.Pool{New: createBuffer} +var ( + defaultDialer = &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + + bufferPool = sync.Pool{New: createBuffer} +) func createBuffer() interface{} { - return make([]byte, 32*1024) + return make([]byte, 0, 32*1024) +} + +func pooledIoCopy(dst io.Writer, src io.Reader) { + buf := bufferPool.Get().([]byte) + defer bufferPool.Put(buf) + + // CopyBuffer only uses buf up to its length and panics if it's 0. + // Due to that we extend buf's length to its capacity here and + // ensure it's always non-zero. + bufCap := cap(buf) + io.CopyBuffer(dst, src, buf[0:bufCap:bufCap]) } // onExitFlushLoop is a callback set by tests to detect the state of the @@ -135,11 +153,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * // just use default transport, to avoid creating // a brand new transport transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, + Proxy: http.ProxyFromEnvironment, + Dial: defaultDialer.Dial, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } @@ -148,7 +163,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * } else { transport.MaxIdleConnsPerHost = keepalive } - http2.ConfigureTransport(transport) + if httpserver.HTTP2 { + http2.ConfigureTransport(transport) + } rp.Transport = transport } return rp @@ -160,18 +177,20 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * func (rp *ReverseProxy) UseInsecureTransport() { if rp.Transport == nil { transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, + Proxy: http.ProxyFromEnvironment, + Dial: defaultDialer.Dial, TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } - http2.ConfigureTransport(transport) + if httpserver.HTTP2 { + http2.ConfigureTransport(transport) + } rp.Transport = transport } else if transport, ok := rp.Transport.(*http.Transport); ok { transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + // No http2.ConfigureTransport() here. + // For now this is only added in places where + // an http.Transport is actually created. } } @@ -186,20 +205,33 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, } rp.Director(outreq) - outreq.Proto = "HTTP/1.1" - outreq.ProtoMajor = 1 - outreq.ProtoMinor = 1 - outreq.Close = false res, err := transport.RoundTrip(outreq) if err != nil { return err } + isWebsocket := res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" + + // Remove hop-by-hop headers listed in the + // "Connection" header of the response. + if c := res.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + res.Header.Del(f) + } + } + } + + for _, h := range hopHeaders { + res.Header.Del(h) + } + if respUpdateFn != nil { respUpdateFn(res) } - if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" { + + if isWebsocket { res.Body.Close() hj, ok := rw.(http.Hijacker) if !ok { @@ -228,27 +260,39 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, } defer backendConn.Close() - go func() { - io.Copy(backendConn, conn) // write tcp stream to backend. - }() - io.Copy(conn, backendConn) // read tcp stream from backend. + go pooledIoCopy(backendConn, conn) // write tcp stream to backend + pooledIoCopy(conn, backendConn) // read tcp stream from backend } else { - defer res.Body.Close() - for _, h := range hopHeaders { - res.Header.Del(h) - } copyHeader(rw.Header(), res.Header) + + // The "Trailer" header isn't included in the Transport's response, + // at least for *http.Transport. Build it up from Trailer. + if len(res.Trailer) > 0 { + trailerKeys := make([]string, 0, len(res.Trailer)) + for k := range res.Trailer { + trailerKeys = append(trailerKeys, k) + } + rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + } + rw.WriteHeader(res.StatusCode) + if len(res.Trailer) > 0 { + // Force chunking if we saw a response trailer. + // This prevents net/http from calculating the length for short + // bodies and adding a Content-Length. + if fl, ok := rw.(http.Flusher); ok { + fl.Flush() + } + } rp.copyResponse(rw, res.Body) + res.Body.Close() // close now, instead of defer, to populate res.Trailer + copyHeader(rw.Header(), res.Trailer) } return nil } func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { - buf := bufferPool.Get() - defer bufferPool.Put(buf) - if rp.FlushInterval != 0 { if wf, ok := dst.(writeFlusher); ok { mlw := &maxLatencyWriter{ @@ -261,7 +305,7 @@ func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { dst = mlw } } - io.CopyBuffer(dst, src, buf.([]byte)) + pooledIoCopy(dst, src) } // skip these headers if they already exist. @@ -295,16 +339,17 @@ func copyHeader(dst, src http.Header) { // Hop-by-hop headers. These are removed when sent to the backend. // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html var hopHeaders = []string{ + "Alt-Svc", + "Alternate-Protocol", "Connection", "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization", - "Te", // canonicalized version of "TE" - "Trailers", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522 "Transfer-Encoding", "Upgrade", - "Alternate-Protocol", - "Alt-Svc", } type respUpdateFn func(resp *http.Response) @@ -331,51 +376,169 @@ type connHijackerTransport struct { } func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { - transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, - TLSHandshakeTimeout: 10 * time.Second, + t := &http.Transport{ MaxIdleConnsPerHost: -1, } - if base != nil { - if baseTransport, ok := base.(*http.Transport); ok { - transport.Proxy = baseTransport.Proxy - transport.TLSClientConfig = baseTransport.TLSClientConfig - transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout - transport.Dial = baseTransport.Dial - transport.DialTLS = baseTransport.DialTLS - transport.MaxIdleConnsPerHost = -1 + if b, _ := base.(*http.Transport); b != nil { + tlsClientConfig := b.TLSClientConfig + if tlsClientConfig.NextProtos != nil { + tlsClientConfig = cloneTLSClientConfig(tlsClientConfig) + tlsClientConfig.NextProtos = nil } + + t.Proxy = b.Proxy + t.TLSClientConfig = tlsClientConfig + t.TLSHandshakeTimeout = b.TLSHandshakeTimeout + t.Dial = b.Dial + t.DialTLS = b.DialTLS + } else { + t.Proxy = http.ProxyFromEnvironment + t.TLSHandshakeTimeout = 10 * time.Second } - hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]} - oldDial := transport.Dial - oldDialTLS := transport.DialTLS - if oldDial == nil { - oldDial = (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial + hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]} + + dial := getTransportDial(t) + dialTLS := getTransportDialTLS(t) + t.Dial = func(network, addr string) (net.Conn, error) { + c, err := dial(network, addr) + hj.Conn = c + return &hijackedConn{c, hj}, err } - hjTransport.Dial = func(network, addr string) (net.Conn, error) { - c, err := oldDial(network, addr) - hjTransport.Conn = c - return &hijackedConn{c, hjTransport}, err + t.DialTLS = func(network, addr string) (net.Conn, error) { + c, err := dialTLS(network, addr) + hj.Conn = c + return &hijackedConn{c, hj}, err } - if oldDialTLS != nil { - hjTransport.DialTLS = func(network, addr string) (net.Conn, error) { - c, err := oldDialTLS(network, addr) - hjTransport.Conn = c - return &hijackedConn{c, hjTransport}, err + + return hj +} + +// getTransportDial always returns a plain Dialer +// and defaults to the existing t.Dial. +func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) { + if t.Dial != nil { + return t.Dial + } + return defaultDialer.Dial +} + +// getTransportDial always returns a TLS Dialer +// and defaults to the existing t.DialTLS. +func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) { + if t.DialTLS != nil { + return t.DialTLS + } + + // newConnHijackerTransport will modify t.Dial after calling this method + // => Create a backup reference. + plainDial := getTransportDial(t) + + // The following DialTLS implementation stems from the Go stdlib and + // is identical to what happens if DialTLS is not provided. + // Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051 + return func(network, addr string) (net.Conn, error) { + plainConn, err := plainDial(network, addr) + if err != nil { + return nil, err } + + tlsClientConfig := t.TLSClientConfig + if tlsClientConfig == nil { + tlsClientConfig = &tls.Config{} + } + if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" { + tlsClientConfig.ServerName = stripPort(addr) + } + + tlsConn := tls.Client(plainConn, tlsClientConfig) + errc := make(chan error, 2) + var timer *time.Timer + if d := t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + err := tlsConn.Handshake() + if timer != nil { + timer.Stop() + } + errc <- err + }() + if err := <-errc; err != nil { + plainConn.Close() + return nil, err + } + if !tlsClientConfig.InsecureSkipVerify { + hostname := tlsClientConfig.ServerName + if hostname == "" { + hostname = stripPort(addr) + } + if err := tlsConn.VerifyHostname(hostname); err != nil { + plainConn.Close() + return nil, err + } + } + + return tlsConn, nil + } +} + +// stripPort returns address without its port if it has one and +// works with IP addresses as well as hostnames formatted as host:port. +// +// IPv6 addresses (excluding the port) must be enclosed in +// square brackets similar to the requirements of Go's stdlib. +func stripPort(address string) string { + // Keep in mind that the address might be a IPv6 address + // and thus contain a colon, but not have a port. + portIdx := strings.LastIndex(address, ":") + ipv6Idx := strings.LastIndex(address, "]") + if portIdx > ipv6Idx { + address = address[:portIdx] + } + return address +} + +type tlsHandshakeTimeoutError struct{} + +func (tlsHandshakeTimeoutError) Timeout() bool { return true } +func (tlsHandshakeTimeoutError) Temporary() bool { return true } +func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } + +// cloneTLSClientConfig is like cloneTLSConfig but omits +// the fields SessionTicketsDisabled and SessionTicketKey. +// This makes it safe to call cloneTLSClientConfig on a config +// in active use by a server. +func cloneTLSClientConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return &tls.Config{ + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled, + Renegotiation: cfg.Renegotiation, } - return hjTransport } func requestIsWebsocket(req *http.Request) bool { - return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")) + return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") } type writeFlusher interface {