0
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2024-12-23 22:27:38 -05:00

Merge branch 'master' of github.com:mholt/caddy

This commit is contained in:
Matthew Holt 2017-01-01 10:27:58 -07:00
commit a1a8d0f655
6 changed files with 545 additions and 85 deletions

View file

@ -0,0 +1,76 @@
package httpserver
import (
"math/rand"
"path"
"strings"
"time"
)
// CleanMaskedPath prevents one or more of the path cleanup operations:
// - collapse multiple slashes into one
// - eliminate "/." (current directory)
// - eliminate "<parent_directory>/.."
// by masking certain patterns in the path with a temporary random string.
// This could be helpful when certain patterns in the path are desired to be preserved
// that would otherwise be changed by path.Clean().
// One such use case is the presence of the double slashes as protocol separator
// (e.g., /api/endpoint/http://example.com).
// This is a common pattern in many applications to allow passing URIs as path argument.
func CleanMaskedPath(reqPath string, masks ...string) string {
var replacerVal string
maskMap := make(map[string]string)
// Iterate over supplied masks and create temporary replacement strings
// only for the masks that are present in the path, then replace all occurrences
for _, mask := range masks {
if strings.Index(reqPath, mask) >= 0 {
replacerVal = "/_caddy" + generateRandomString() + "__"
maskMap[mask] = replacerVal
reqPath = strings.Replace(reqPath, mask, replacerVal, -1)
}
}
reqPath = path.Clean(reqPath)
// Revert the replaced masks after path cleanup
for mask, replacerVal := range maskMap {
reqPath = strings.Replace(reqPath, replacerVal, mask, -1)
}
return reqPath
}
// CleanPath calls CleanMaskedPath() with the default mask of "://"
// to preserve double slashes of protocols
// such as "http://", "https://", and "ftp://" etc.
func CleanPath(reqPath string) string {
return CleanMaskedPath(reqPath, "://")
}
// An efficient and fast method for random string generation.
// Inspired by http://stackoverflow.com/a/31832326.
const randomStringLength = 4
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const (
letterIdxBits = 6
letterIdxMask = 1<<letterIdxBits - 1
letterIdxMax = 63 / letterIdxBits
)
var src = rand.NewSource(time.Now().UnixNano())
func generateRandomString() string {
b := make([]byte, randomStringLength)
for i, cache, remain := randomStringLength-1, src.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = src.Int63(), letterIdxMax
}
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
b[i] = letterBytes[idx]
i--
}
cache >>= letterIdxBits
remain--
}
return string(b)
}

View file

@ -0,0 +1,120 @@
package httpserver
import (
"path"
"testing"
)
var paths = map[string]map[string]string{
"/../a/b/../././/c": {
"preserve_all": "/../a/b/../././/c",
"preserve_protocol": "/a/c",
"preserve_slashes": "/a//c",
"preserve_dots": "/../a/b/../././c",
"clean_all": "/a/c",
},
"/path/https://www.google.com": {
"preserve_all": "/path/https://www.google.com",
"preserve_protocol": "/path/https://www.google.com",
"preserve_slashes": "/path/https://www.google.com",
"preserve_dots": "/path/https:/www.google.com",
"clean_all": "/path/https:/www.google.com",
},
"/a/b/../././/c/http://example.com/foo//bar/../blah": {
"preserve_all": "/a/b/../././/c/http://example.com/foo//bar/../blah",
"preserve_protocol": "/a/c/http://example.com/foo/blah",
"preserve_slashes": "/a//c/http://example.com/foo/blah",
"preserve_dots": "/a/b/../././c/http:/example.com/foo/bar/../blah",
"clean_all": "/a/c/http:/example.com/foo/blah",
},
}
func assertEqual(t *testing.T, expected, received string) {
if expected != received {
t.Errorf("\tExpected: %s\n\t\t\tRecieved: %s", expected, received)
}
}
func maskedTestRunner(t *testing.T, variation string, masks ...string) {
for reqPath, transformation := range paths {
assertEqual(t, transformation[variation], CleanMaskedPath(reqPath, masks...))
}
}
// No need to test the built-in path.Clean() function.
// However, it could be useful to cross-examine the test dataset.
func TestPathClean(t *testing.T) {
for reqPath, transformation := range paths {
assertEqual(t, transformation["clean_all"], path.Clean(reqPath))
}
}
func TestCleanAll(t *testing.T) {
maskedTestRunner(t, "clean_all")
}
func TestPreserveAll(t *testing.T) {
maskedTestRunner(t, "preserve_all", "//", "/..", "/.")
}
func TestPreserveProtocol(t *testing.T) {
maskedTestRunner(t, "preserve_protocol", "://")
}
func TestPreserveSlashes(t *testing.T) {
maskedTestRunner(t, "preserve_slashes", "//")
}
func TestPreserveDots(t *testing.T) {
maskedTestRunner(t, "preserve_dots", "/..", "/.")
}
func TestDefaultMask(t *testing.T) {
for reqPath, transformation := range paths {
assertEqual(t, transformation["preserve_protocol"], CleanPath(reqPath))
}
}
func maskedBenchmarkRunner(b *testing.B, masks ...string) {
for n := 0; n < b.N; n++ {
for reqPath := range paths {
CleanMaskedPath(reqPath, masks...)
}
}
}
func BenchmarkPathClean(b *testing.B) {
for n := 0; n < b.N; n++ {
for reqPath := range paths {
path.Clean(reqPath)
}
}
}
func BenchmarkCleanAll(b *testing.B) {
maskedBenchmarkRunner(b)
}
func BenchmarkPreserveAll(b *testing.B) {
maskedBenchmarkRunner(b, "//", "/..", "/.")
}
func BenchmarkPreserveProtocol(b *testing.B) {
maskedBenchmarkRunner(b, "://")
}
func BenchmarkPreserveSlashes(b *testing.B) {
maskedBenchmarkRunner(b, "//")
}
func BenchmarkPreserveDots(b *testing.B) {
maskedBenchmarkRunner(b, "/..", "/.")
}
func BenchmarkDefaultMask(b *testing.B) {
for n := 0; n < b.N; n++ {
for reqPath := range paths {
CleanPath(reqPath)
}
}
}

View file

@ -9,7 +9,6 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"path"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@ -351,7 +350,7 @@ func sanitizePath(r *http.Request) {
if r.URL.Path == "/" { if r.URL.Path == "/" {
return return
} }
cleanedPath := path.Clean(r.URL.Path) cleanedPath := CleanPath(r.URL.Path)
if cleanedPath == "." { if cleanedPath == "." {
r.URL.Path = "/" r.URL.Path = "/"
} else { } else {

View file

@ -247,12 +247,28 @@ func createUpstreamRequest(r *http.Request) *http.Request {
outreq.URL.Opaque = outreq.URL.RawPath outreq.URL.Opaque = outreq.URL.RawPath
} }
// We are modifying the same underlying map from req (shallow
// copied above) so we only copy it if necessary.
copiedHeaders := false
// Remove hop-by-hop headers listed in the "Connection" header.
// See RFC 2616, section 14.10.
if c := outreq.Header.Get("Connection"); c != "" {
for _, f := range strings.Split(c, ",") {
if f = strings.TrimSpace(f); f != "" {
if !copiedHeaders {
outreq.Header = make(http.Header)
copyHeader(outreq.Header, r.Header)
copiedHeaders = true
}
outreq.Header.Del(f)
}
}
}
// Remove hop-by-hop headers to the backend. Especially // Remove hop-by-hop headers to the backend. Especially
// important is "Connection" because we want a persistent // important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us. This // connection, regardless of what the client sent to us.
// is modifying the same underlying map from r (shallow
// copied above) so we only copy it if necessary.
var copiedHeaders bool
for _, h := range hopHeaders { for _, h := range hopHeaders {
if outreq.Header.Get(h) != "" { if outreq.Header.Get(h) != "" {
if !copiedHeaders { if !copiedHeaders {

View file

@ -42,10 +42,32 @@ func TestReverseProxy(t *testing.T) {
log.SetOutput(ioutil.Discard) log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr) defer log.SetOutput(os.Stderr)
verifyHeaders := func(headers http.Header, trailers http.Header) {
if headers.Get("X-Header") != "header-value" {
t.Error("Expected header 'X-Header' to be proxied properly")
}
if trailers == nil {
t.Error("Expected to receive trailers")
}
if trailers.Get("X-Trailer") != "trailer-value" {
t.Error("Expected header 'X-Trailer' to be proxied properly")
}
}
var requestReceived bool var requestReceived bool
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// read the body (even if it's empty) to make Go parse trailers
io.Copy(ioutil.Discard, r.Body)
verifyHeaders(r.Header, r.Trailer)
requestReceived = true requestReceived = true
w.Header().Set("Trailer", "X-Trailer")
w.Header().Set("X-Header", "header-value")
w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello, client")) w.Write([]byte("Hello, client"))
w.Header().Set("X-Trailer", "trailer-value")
})) }))
defer backend.Close() defer backend.Close()
@ -59,12 +81,21 @@ func TestReverseProxy(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
r.ContentLength = -1 // force chunked encoding (required for trailers)
r.Header.Set("X-Header", "header-value")
r.Trailer = map[string][]string{
"X-Trailer": {"trailer-value"},
}
p.ServeHTTP(w, r) p.ServeHTTP(w, r)
if !requestReceived { if !requestReceived {
t.Error("Expected backend to receive request, but it didn't") t.Error("Expected backend to receive request, but it didn't")
} }
res := w.Result()
verifyHeaders(res.Header, res.Trailer)
// Make sure {upstream} placeholder is set // Make sure {upstream} placeholder is set
rr := httpserver.NewResponseRecorder(httptest.NewRecorder()) rr := httpserver.NewResponseRecorder(httptest.NewRecorder())
rr.Replacer = httpserver.NewReplacer(r, rr, "-") rr.Replacer = httpserver.NewReplacer(r, rr, "-")
@ -123,7 +154,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
defer wsNop.Close() defer wsNop.Close()
// Get proxy to use for the test // Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL) p := newWebSocketTestProxy(wsNop.URL, false)
// Create client request // Create client request
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
@ -148,7 +179,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
defer wsNop.Close() defer wsNop.Close()
// Get proxy to use for the test // Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL) p := newWebSocketTestProxy(wsNop.URL, false)
// Create client request // Create client request
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
@ -189,7 +220,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
defer wsEcho.Close() defer wsEcho.Close()
// Get proxy to use for the test // Get proxy to use for the test
p := newWebSocketTestProxy(wsEcho.URL) p := newWebSocketTestProxy(wsEcho.URL, false)
// This is a full end-end test, so the proxy handler // This is a full end-end test, so the proxy handler
// has to be part of a server listening on a port. Our // has to be part of a server listening on a port. Our
@ -228,6 +259,52 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
} }
} }
func TestWebSocketReverseProxyFromWSSClient(t *testing.T) {
wsEcho := newTLSServer(websocket.Handler(func(ws *websocket.Conn) {
io.Copy(ws, ws)
}))
defer wsEcho.Close()
p := newWebSocketTestProxy(wsEcho.URL, true)
echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer echoProxy.Close()
// Set up WebSocket client
url := strings.Replace(echoProxy.URL, "https://", "wss://", 1)
wsCfg, err := websocket.NewConfig(url, echoProxy.URL)
if err != nil {
t.Fatal(err)
}
wsCfg.TlsConfig = &tls.Config{InsecureSkipVerify: true}
ws, err := websocket.DialConfig(wsCfg)
if err != nil {
t.Fatal(err)
}
defer ws.Close()
// Send test message
trialMsg := "Is it working?"
if sendErr := websocket.Message.Send(ws, trialMsg); sendErr != nil {
t.Fatal(sendErr)
}
// It should be echoed back to us
var actualMsg string
if rcvErr := websocket.Message.Receive(ws, &actualMsg); rcvErr != nil {
t.Fatal(rcvErr)
}
if actualMsg != trialMsg {
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
}
}
func TestUnixSocketProxy(t *testing.T) { func TestUnixSocketProxy(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
return return
@ -264,7 +341,7 @@ func TestUnixSocketProxy(t *testing.T) {
defer ts.Close() defer ts.Close()
url := strings.Replace(ts.URL, "http://", "unix:", 1) url := strings.Replace(ts.URL, "http://", "unix:", 1)
p := newWebSocketTestProxy(url) p := newWebSocketTestProxy(url, false)
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r) p.ServeHTTP(w, r)
@ -982,10 +1059,14 @@ func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.
// redirect to the specified backendAddr. The function // redirect to the specified backendAddr. The function
// also sets up the rules/environment for testing WebSocket // also sets up the rules/environment for testing WebSocket
// proxy. // proxy.
func newWebSocketTestProxy(backendAddr string) *Proxy { func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
return &Proxy{ return &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}}, Upstreams: []Upstream{&fakeWsUpstream{
name: backendAddr,
without: "",
insecure: insecure,
}},
} }
} }
@ -999,6 +1080,7 @@ func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
type fakeWsUpstream struct { type fakeWsUpstream struct {
name string name string
without string without string
insecure bool
} }
func (u *fakeWsUpstream) From() string { func (u *fakeWsUpstream) From() string {
@ -1007,13 +1089,17 @@ func (u *fakeWsUpstream) From() string {
func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost { func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
uri, _ := url.Parse(u.name) uri, _ := url.Parse(u.name)
return &UpstreamHost{ host := &UpstreamHost{
Name: u.name, Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost), ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
UpstreamHeaders: http.Header{ UpstreamHeaders: http.Header{
"Connection": {"{>Connection}"}, "Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"}}, "Upgrade": {"{>Upgrade}"}},
} }
if u.insecure {
host.ReverseProxy.UseInsecureTransport()
}
return host
} }
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true } func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }

View file

@ -27,10 +27,28 @@ import (
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
) )
var bufferPool = sync.Pool{New: createBuffer} var (
defaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
bufferPool = sync.Pool{New: createBuffer}
)
func createBuffer() interface{} { func createBuffer() interface{} {
return make([]byte, 32*1024) return make([]byte, 0, 32*1024)
}
func pooledIoCopy(dst io.Writer, src io.Reader) {
buf := bufferPool.Get().([]byte)
defer bufferPool.Put(buf)
// CopyBuffer only uses buf up to its length and panics if it's 0.
// Due to that we extend buf's length to its capacity here and
// ensure it's always non-zero.
bufCap := cap(buf)
io.CopyBuffer(dst, src, buf[0:bufCap:bufCap])
} }
// onExitFlushLoop is a callback set by tests to detect the state of the // onExitFlushLoop is a callback set by tests to detect the state of the
@ -136,10 +154,7 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
// a brand new transport // a brand new transport
transport := &http.Transport{ transport := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{ Dial: defaultDialer.Dial,
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
} }
@ -148,7 +163,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
} else { } else {
transport.MaxIdleConnsPerHost = keepalive transport.MaxIdleConnsPerHost = keepalive
} }
if httpserver.HTTP2 {
http2.ConfigureTransport(transport) http2.ConfigureTransport(transport)
}
rp.Transport = transport rp.Transport = transport
} }
return rp return rp
@ -161,17 +178,19 @@ func (rp *ReverseProxy) UseInsecureTransport() {
if rp.Transport == nil { if rp.Transport == nil {
transport := &http.Transport{ transport := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{ Dial: defaultDialer.Dial,
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
} }
if httpserver.HTTP2 {
http2.ConfigureTransport(transport) http2.ConfigureTransport(transport)
}
rp.Transport = transport rp.Transport = transport
} else if transport, ok := rp.Transport.(*http.Transport); ok { } else if transport, ok := rp.Transport.(*http.Transport); ok {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
// No http2.ConfigureTransport() here.
// For now this is only added in places where
// an http.Transport is actually created.
} }
} }
@ -186,20 +205,33 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
} }
rp.Director(outreq) rp.Director(outreq)
outreq.Proto = "HTTP/1.1"
outreq.ProtoMajor = 1
outreq.ProtoMinor = 1
outreq.Close = false
res, err := transport.RoundTrip(outreq) res, err := transport.RoundTrip(outreq)
if err != nil { if err != nil {
return err return err
} }
isWebsocket := res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket"
// Remove hop-by-hop headers listed in the
// "Connection" header of the response.
if c := res.Header.Get("Connection"); c != "" {
for _, f := range strings.Split(c, ",") {
if f = strings.TrimSpace(f); f != "" {
res.Header.Del(f)
}
}
}
for _, h := range hopHeaders {
res.Header.Del(h)
}
if respUpdateFn != nil { if respUpdateFn != nil {
respUpdateFn(res) respUpdateFn(res)
} }
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
if isWebsocket {
res.Body.Close() res.Body.Close()
hj, ok := rw.(http.Hijacker) hj, ok := rw.(http.Hijacker)
if !ok { if !ok {
@ -228,27 +260,39 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
} }
defer backendConn.Close() defer backendConn.Close()
go func() { go pooledIoCopy(backendConn, conn) // write tcp stream to backend
io.Copy(backendConn, conn) // write tcp stream to backend. pooledIoCopy(conn, backendConn) // read tcp stream from backend
}()
io.Copy(conn, backendConn) // read tcp stream from backend.
} else { } else {
defer res.Body.Close()
for _, h := range hopHeaders {
res.Header.Del(h)
}
copyHeader(rw.Header(), res.Header) copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
if len(res.Trailer) > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
rw.WriteHeader(res.StatusCode) rw.WriteHeader(res.StatusCode)
if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
if fl, ok := rw.(http.Flusher); ok {
fl.Flush()
}
}
rp.copyResponse(rw, res.Body) rp.copyResponse(rw, res.Body)
res.Body.Close() // close now, instead of defer, to populate res.Trailer
copyHeader(rw.Header(), res.Trailer)
} }
return nil return nil
} }
func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
buf := bufferPool.Get()
defer bufferPool.Put(buf)
if rp.FlushInterval != 0 { if rp.FlushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok { if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{ mlw := &maxLatencyWriter{
@ -261,7 +305,7 @@ func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
dst = mlw dst = mlw
} }
} }
io.CopyBuffer(dst, src, buf.([]byte)) pooledIoCopy(dst, src)
} }
// skip these headers if they already exist. // skip these headers if they already exist.
@ -295,16 +339,17 @@ func copyHeader(dst, src http.Header) {
// Hop-by-hop headers. These are removed when sent to the backend. // Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{ var hopHeaders = []string{
"Alt-Svc",
"Alternate-Protocol",
"Connection", "Connection",
"Keep-Alive", "Keep-Alive",
"Proxy-Authenticate", "Proxy-Authenticate",
"Proxy-Authorization", "Proxy-Authorization",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Te", // canonicalized version of "TE" "Te", // canonicalized version of "TE"
"Trailers", "Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding", "Transfer-Encoding",
"Upgrade", "Upgrade",
"Alternate-Protocol",
"Alt-Svc",
} }
type respUpdateFn func(resp *http.Response) type respUpdateFn func(resp *http.Response)
@ -331,51 +376,169 @@ type connHijackerTransport struct {
} }
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
transport := &http.Transport{ t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
MaxIdleConnsPerHost: -1, MaxIdleConnsPerHost: -1,
} }
if base != nil { if b, _ := base.(*http.Transport); b != nil {
if baseTransport, ok := base.(*http.Transport); ok { tlsClientConfig := b.TLSClientConfig
transport.Proxy = baseTransport.Proxy if tlsClientConfig.NextProtos != nil {
transport.TLSClientConfig = baseTransport.TLSClientConfig tlsClientConfig = cloneTLSClientConfig(tlsClientConfig)
transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout tlsClientConfig.NextProtos = nil
transport.Dial = baseTransport.Dial }
transport.DialTLS = baseTransport.DialTLS
transport.MaxIdleConnsPerHost = -1 t.Proxy = b.Proxy
t.TLSClientConfig = tlsClientConfig
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
t.Dial = b.Dial
t.DialTLS = b.DialTLS
} else {
t.Proxy = http.ProxyFromEnvironment
t.TLSHandshakeTimeout = 10 * time.Second
}
hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]}
dial := getTransportDial(t)
dialTLS := getTransportDialTLS(t)
t.Dial = func(network, addr string) (net.Conn, error) {
c, err := dial(network, addr)
hj.Conn = c
return &hijackedConn{c, hj}, err
}
t.DialTLS = func(network, addr string) (net.Conn, error) {
c, err := dialTLS(network, addr)
hj.Conn = c
return &hijackedConn{c, hj}, err
}
return hj
}
// getTransportDial always returns a plain Dialer
// and defaults to the existing t.Dial.
func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) {
if t.Dial != nil {
return t.Dial
}
return defaultDialer.Dial
}
// getTransportDial always returns a TLS Dialer
// and defaults to the existing t.DialTLS.
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
if t.DialTLS != nil {
return t.DialTLS
}
// newConnHijackerTransport will modify t.Dial after calling this method
// => Create a backup reference.
plainDial := getTransportDial(t)
// The following DialTLS implementation stems from the Go stdlib and
// is identical to what happens if DialTLS is not provided.
// Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
return func(network, addr string) (net.Conn, error) {
plainConn, err := plainDial(network, addr)
if err != nil {
return nil, err
}
tlsClientConfig := t.TLSClientConfig
if tlsClientConfig == nil {
tlsClientConfig = &tls.Config{}
}
if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" {
tlsClientConfig.ServerName = stripPort(addr)
}
tlsConn := tls.Client(plainConn, tlsClientConfig)
errc := make(chan error, 2)
var timer *time.Timer
if d := t.TLSHandshakeTimeout; d != 0 {
timer = time.AfterFunc(d, func() {
errc <- tlsHandshakeTimeoutError{}
})
}
go func() {
err := tlsConn.Handshake()
if timer != nil {
timer.Stop()
}
errc <- err
}()
if err := <-errc; err != nil {
plainConn.Close()
return nil, err
}
if !tlsClientConfig.InsecureSkipVerify {
hostname := tlsClientConfig.ServerName
if hostname == "" {
hostname = stripPort(addr)
}
if err := tlsConn.VerifyHostname(hostname); err != nil {
plainConn.Close()
return nil, err
} }
} }
hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]}
oldDial := transport.Dial return tlsConn, nil
oldDialTLS := transport.DialTLS
if oldDial == nil {
oldDial = (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial
}
hjTransport.Dial = func(network, addr string) (net.Conn, error) {
c, err := oldDial(network, addr)
hjTransport.Conn = c
return &hijackedConn{c, hjTransport}, err
}
if oldDialTLS != nil {
hjTransport.DialTLS = func(network, addr string) (net.Conn, error) {
c, err := oldDialTLS(network, addr)
hjTransport.Conn = c
return &hijackedConn{c, hjTransport}, err
} }
} }
return hjTransport
// stripPort returns address without its port if it has one and
// works with IP addresses as well as hostnames formatted as host:port.
//
// IPv6 addresses (excluding the port) must be enclosed in
// square brackets similar to the requirements of Go's stdlib.
func stripPort(address string) string {
// Keep in mind that the address might be a IPv6 address
// and thus contain a colon, but not have a port.
portIdx := strings.LastIndex(address, ":")
ipv6Idx := strings.LastIndex(address, "]")
if portIdx > ipv6Idx {
address = address[:portIdx]
}
return address
}
type tlsHandshakeTimeoutError struct{}
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
// cloneTLSClientConfig is like cloneTLSConfig but omits
// the fields SessionTicketsDisabled and SessionTicketKey.
// This makes it safe to call cloneTLSClientConfig on a config
// in active use by a server.
func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
Renegotiation: cfg.Renegotiation,
}
} }
func requestIsWebsocket(req *http.Request) bool { func requestIsWebsocket(req *http.Request) bool {
return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")) return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
} }
type writeFlusher interface { type writeFlusher interface {