mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-23 22:27:38 -05:00
Merge branch 'master' of github.com:mholt/caddy
This commit is contained in:
commit
a1a8d0f655
6 changed files with 545 additions and 85 deletions
76
caddyhttp/httpserver/pathcleaner.go
Normal file
76
caddyhttp/httpserver/pathcleaner.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package httpserver
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CleanMaskedPath prevents one or more of the path cleanup operations:
|
||||
// - collapse multiple slashes into one
|
||||
// - eliminate "/." (current directory)
|
||||
// - eliminate "<parent_directory>/.."
|
||||
// by masking certain patterns in the path with a temporary random string.
|
||||
// This could be helpful when certain patterns in the path are desired to be preserved
|
||||
// that would otherwise be changed by path.Clean().
|
||||
// One such use case is the presence of the double slashes as protocol separator
|
||||
// (e.g., /api/endpoint/http://example.com).
|
||||
// This is a common pattern in many applications to allow passing URIs as path argument.
|
||||
func CleanMaskedPath(reqPath string, masks ...string) string {
|
||||
var replacerVal string
|
||||
maskMap := make(map[string]string)
|
||||
|
||||
// Iterate over supplied masks and create temporary replacement strings
|
||||
// only for the masks that are present in the path, then replace all occurrences
|
||||
for _, mask := range masks {
|
||||
if strings.Index(reqPath, mask) >= 0 {
|
||||
replacerVal = "/_caddy" + generateRandomString() + "__"
|
||||
maskMap[mask] = replacerVal
|
||||
reqPath = strings.Replace(reqPath, mask, replacerVal, -1)
|
||||
}
|
||||
}
|
||||
|
||||
reqPath = path.Clean(reqPath)
|
||||
|
||||
// Revert the replaced masks after path cleanup
|
||||
for mask, replacerVal := range maskMap {
|
||||
reqPath = strings.Replace(reqPath, replacerVal, mask, -1)
|
||||
}
|
||||
return reqPath
|
||||
}
|
||||
|
||||
// CleanPath calls CleanMaskedPath() with the default mask of "://"
|
||||
// to preserve double slashes of protocols
|
||||
// such as "http://", "https://", and "ftp://" etc.
|
||||
func CleanPath(reqPath string) string {
|
||||
return CleanMaskedPath(reqPath, "://")
|
||||
}
|
||||
|
||||
// An efficient and fast method for random string generation.
|
||||
// Inspired by http://stackoverflow.com/a/31832326.
|
||||
const randomStringLength = 4
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
const (
|
||||
letterIdxBits = 6
|
||||
letterIdxMask = 1<<letterIdxBits - 1
|
||||
letterIdxMax = 63 / letterIdxBits
|
||||
)
|
||||
|
||||
var src = rand.NewSource(time.Now().UnixNano())
|
||||
|
||||
func generateRandomString() string {
|
||||
b := make([]byte, randomStringLength)
|
||||
for i, cache, remain := randomStringLength-1, src.Int63(), letterIdxMax; i >= 0; {
|
||||
if remain == 0 {
|
||||
cache, remain = src.Int63(), letterIdxMax
|
||||
}
|
||||
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
|
||||
b[i] = letterBytes[idx]
|
||||
i--
|
||||
}
|
||||
cache >>= letterIdxBits
|
||||
remain--
|
||||
}
|
||||
return string(b)
|
||||
}
|
120
caddyhttp/httpserver/pathcleaner_test.go
Normal file
120
caddyhttp/httpserver/pathcleaner_test.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package httpserver
|
||||
|
||||
import (
|
||||
"path"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var paths = map[string]map[string]string{
|
||||
"/../a/b/../././/c": {
|
||||
"preserve_all": "/../a/b/../././/c",
|
||||
"preserve_protocol": "/a/c",
|
||||
"preserve_slashes": "/a//c",
|
||||
"preserve_dots": "/../a/b/../././c",
|
||||
"clean_all": "/a/c",
|
||||
},
|
||||
"/path/https://www.google.com": {
|
||||
"preserve_all": "/path/https://www.google.com",
|
||||
"preserve_protocol": "/path/https://www.google.com",
|
||||
"preserve_slashes": "/path/https://www.google.com",
|
||||
"preserve_dots": "/path/https:/www.google.com",
|
||||
"clean_all": "/path/https:/www.google.com",
|
||||
},
|
||||
"/a/b/../././/c/http://example.com/foo//bar/../blah": {
|
||||
"preserve_all": "/a/b/../././/c/http://example.com/foo//bar/../blah",
|
||||
"preserve_protocol": "/a/c/http://example.com/foo/blah",
|
||||
"preserve_slashes": "/a//c/http://example.com/foo/blah",
|
||||
"preserve_dots": "/a/b/../././c/http:/example.com/foo/bar/../blah",
|
||||
"clean_all": "/a/c/http:/example.com/foo/blah",
|
||||
},
|
||||
}
|
||||
|
||||
func assertEqual(t *testing.T, expected, received string) {
|
||||
if expected != received {
|
||||
t.Errorf("\tExpected: %s\n\t\t\tRecieved: %s", expected, received)
|
||||
}
|
||||
}
|
||||
|
||||
func maskedTestRunner(t *testing.T, variation string, masks ...string) {
|
||||
for reqPath, transformation := range paths {
|
||||
assertEqual(t, transformation[variation], CleanMaskedPath(reqPath, masks...))
|
||||
}
|
||||
}
|
||||
|
||||
// No need to test the built-in path.Clean() function.
|
||||
// However, it could be useful to cross-examine the test dataset.
|
||||
func TestPathClean(t *testing.T) {
|
||||
for reqPath, transformation := range paths {
|
||||
assertEqual(t, transformation["clean_all"], path.Clean(reqPath))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanAll(t *testing.T) {
|
||||
maskedTestRunner(t, "clean_all")
|
||||
}
|
||||
|
||||
func TestPreserveAll(t *testing.T) {
|
||||
maskedTestRunner(t, "preserve_all", "//", "/..", "/.")
|
||||
}
|
||||
|
||||
func TestPreserveProtocol(t *testing.T) {
|
||||
maskedTestRunner(t, "preserve_protocol", "://")
|
||||
}
|
||||
|
||||
func TestPreserveSlashes(t *testing.T) {
|
||||
maskedTestRunner(t, "preserve_slashes", "//")
|
||||
}
|
||||
|
||||
func TestPreserveDots(t *testing.T) {
|
||||
maskedTestRunner(t, "preserve_dots", "/..", "/.")
|
||||
}
|
||||
|
||||
func TestDefaultMask(t *testing.T) {
|
||||
for reqPath, transformation := range paths {
|
||||
assertEqual(t, transformation["preserve_protocol"], CleanPath(reqPath))
|
||||
}
|
||||
}
|
||||
|
||||
func maskedBenchmarkRunner(b *testing.B, masks ...string) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
for reqPath := range paths {
|
||||
CleanMaskedPath(reqPath, masks...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPathClean(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
for reqPath := range paths {
|
||||
path.Clean(reqPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCleanAll(b *testing.B) {
|
||||
maskedBenchmarkRunner(b)
|
||||
}
|
||||
|
||||
func BenchmarkPreserveAll(b *testing.B) {
|
||||
maskedBenchmarkRunner(b, "//", "/..", "/.")
|
||||
}
|
||||
|
||||
func BenchmarkPreserveProtocol(b *testing.B) {
|
||||
maskedBenchmarkRunner(b, "://")
|
||||
}
|
||||
|
||||
func BenchmarkPreserveSlashes(b *testing.B) {
|
||||
maskedBenchmarkRunner(b, "//")
|
||||
}
|
||||
|
||||
func BenchmarkPreserveDots(b *testing.B) {
|
||||
maskedBenchmarkRunner(b, "/..", "/.")
|
||||
}
|
||||
|
||||
func BenchmarkDefaultMask(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
for reqPath := range paths {
|
||||
CleanPath(reqPath)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -9,7 +9,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -351,7 +350,7 @@ func sanitizePath(r *http.Request) {
|
|||
if r.URL.Path == "/" {
|
||||
return
|
||||
}
|
||||
cleanedPath := path.Clean(r.URL.Path)
|
||||
cleanedPath := CleanPath(r.URL.Path)
|
||||
if cleanedPath == "." {
|
||||
r.URL.Path = "/"
|
||||
} else {
|
||||
|
|
|
@ -247,12 +247,28 @@ func createUpstreamRequest(r *http.Request) *http.Request {
|
|||
outreq.URL.Opaque = outreq.URL.RawPath
|
||||
}
|
||||
|
||||
// We are modifying the same underlying map from req (shallow
|
||||
// copied above) so we only copy it if necessary.
|
||||
copiedHeaders := false
|
||||
|
||||
// Remove hop-by-hop headers listed in the "Connection" header.
|
||||
// See RFC 2616, section 14.10.
|
||||
if c := outreq.Header.Get("Connection"); c != "" {
|
||||
for _, f := range strings.Split(c, ",") {
|
||||
if f = strings.TrimSpace(f); f != "" {
|
||||
if !copiedHeaders {
|
||||
outreq.Header = make(http.Header)
|
||||
copyHeader(outreq.Header, r.Header)
|
||||
copiedHeaders = true
|
||||
}
|
||||
outreq.Header.Del(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove hop-by-hop headers to the backend. Especially
|
||||
// important is "Connection" because we want a persistent
|
||||
// connection, regardless of what the client sent to us. This
|
||||
// is modifying the same underlying map from r (shallow
|
||||
// copied above) so we only copy it if necessary.
|
||||
var copiedHeaders bool
|
||||
// connection, regardless of what the client sent to us.
|
||||
for _, h := range hopHeaders {
|
||||
if outreq.Header.Get(h) != "" {
|
||||
if !copiedHeaders {
|
||||
|
|
|
@ -42,10 +42,32 @@ func TestReverseProxy(t *testing.T) {
|
|||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
verifyHeaders := func(headers http.Header, trailers http.Header) {
|
||||
if headers.Get("X-Header") != "header-value" {
|
||||
t.Error("Expected header 'X-Header' to be proxied properly")
|
||||
}
|
||||
|
||||
if trailers == nil {
|
||||
t.Error("Expected to receive trailers")
|
||||
}
|
||||
if trailers.Get("X-Trailer") != "trailer-value" {
|
||||
t.Error("Expected header 'X-Trailer' to be proxied properly")
|
||||
}
|
||||
}
|
||||
|
||||
var requestReceived bool
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// read the body (even if it's empty) to make Go parse trailers
|
||||
io.Copy(ioutil.Discard, r.Body)
|
||||
verifyHeaders(r.Header, r.Trailer)
|
||||
|
||||
requestReceived = true
|
||||
|
||||
w.Header().Set("Trailer", "X-Trailer")
|
||||
w.Header().Set("X-Header", "header-value")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("Hello, client"))
|
||||
w.Header().Set("X-Trailer", "trailer-value")
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
|
@ -59,12 +81,21 @@ func TestReverseProxy(t *testing.T) {
|
|||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.ContentLength = -1 // force chunked encoding (required for trailers)
|
||||
r.Header.Set("X-Header", "header-value")
|
||||
r.Trailer = map[string][]string{
|
||||
"X-Trailer": {"trailer-value"},
|
||||
}
|
||||
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
if !requestReceived {
|
||||
t.Error("Expected backend to receive request, but it didn't")
|
||||
}
|
||||
|
||||
res := w.Result()
|
||||
verifyHeaders(res.Header, res.Trailer)
|
||||
|
||||
// Make sure {upstream} placeholder is set
|
||||
rr := httpserver.NewResponseRecorder(httptest.NewRecorder())
|
||||
rr.Replacer = httpserver.NewReplacer(r, rr, "-")
|
||||
|
@ -123,7 +154,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
|
|||
defer wsNop.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsNop.URL)
|
||||
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||
|
||||
// Create client request
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
@ -148,7 +179,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
|
|||
defer wsNop.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsNop.URL)
|
||||
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||
|
||||
// Create client request
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
@ -189,7 +220,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
|||
defer wsEcho.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsEcho.URL)
|
||||
p := newWebSocketTestProxy(wsEcho.URL, false)
|
||||
|
||||
// This is a full end-end test, so the proxy handler
|
||||
// has to be part of a server listening on a port. Our
|
||||
|
@ -228,6 +259,52 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestWebSocketReverseProxyFromWSSClient(t *testing.T) {
|
||||
wsEcho := newTLSServer(websocket.Handler(func(ws *websocket.Conn) {
|
||||
io.Copy(ws, ws)
|
||||
}))
|
||||
defer wsEcho.Close()
|
||||
|
||||
p := newWebSocketTestProxy(wsEcho.URL, true)
|
||||
|
||||
echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
}))
|
||||
defer echoProxy.Close()
|
||||
|
||||
// Set up WebSocket client
|
||||
url := strings.Replace(echoProxy.URL, "https://", "wss://", 1)
|
||||
wsCfg, err := websocket.NewConfig(url, echoProxy.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wsCfg.TlsConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
ws, err := websocket.DialConfig(wsCfg)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
// Send test message
|
||||
trialMsg := "Is it working?"
|
||||
|
||||
if sendErr := websocket.Message.Send(ws, trialMsg); sendErr != nil {
|
||||
t.Fatal(sendErr)
|
||||
}
|
||||
|
||||
// It should be echoed back to us
|
||||
var actualMsg string
|
||||
|
||||
if rcvErr := websocket.Message.Receive(ws, &actualMsg); rcvErr != nil {
|
||||
t.Fatal(rcvErr)
|
||||
}
|
||||
|
||||
if actualMsg != trialMsg {
|
||||
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketProxy(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
return
|
||||
|
@ -264,7 +341,7 @@ func TestUnixSocketProxy(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
url := strings.Replace(ts.URL, "http://", "unix:", 1)
|
||||
p := newWebSocketTestProxy(url)
|
||||
p := newWebSocketTestProxy(url, false)
|
||||
|
||||
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
|
@ -982,10 +1059,14 @@ func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.
|
|||
// redirect to the specified backendAddr. The function
|
||||
// also sets up the rules/environment for testing WebSocket
|
||||
// proxy.
|
||||
func newWebSocketTestProxy(backendAddr string) *Proxy {
|
||||
func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
|
||||
return &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}},
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{&fakeWsUpstream{
|
||||
name: backendAddr,
|
||||
without: "",
|
||||
insecure: insecure,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -997,8 +1078,9 @@ func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
|
|||
}
|
||||
|
||||
type fakeWsUpstream struct {
|
||||
name string
|
||||
without string
|
||||
name string
|
||||
without string
|
||||
insecure bool
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) From() string {
|
||||
|
@ -1007,13 +1089,17 @@ func (u *fakeWsUpstream) From() string {
|
|||
|
||||
func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
|
||||
uri, _ := url.Parse(u.name)
|
||||
return &UpstreamHost{
|
||||
host := &UpstreamHost{
|
||||
Name: u.name,
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
|
||||
UpstreamHeaders: http.Header{
|
||||
"Connection": {"{>Connection}"},
|
||||
"Upgrade": {"{>Upgrade}"}},
|
||||
}
|
||||
if u.insecure {
|
||||
host.ReverseProxy.UseInsecureTransport()
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
|
||||
|
|
|
@ -27,10 +27,28 @@ import (
|
|||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
var bufferPool = sync.Pool{New: createBuffer}
|
||||
var (
|
||||
defaultDialer = &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
|
||||
bufferPool = sync.Pool{New: createBuffer}
|
||||
)
|
||||
|
||||
func createBuffer() interface{} {
|
||||
return make([]byte, 32*1024)
|
||||
return make([]byte, 0, 32*1024)
|
||||
}
|
||||
|
||||
func pooledIoCopy(dst io.Writer, src io.Reader) {
|
||||
buf := bufferPool.Get().([]byte)
|
||||
defer bufferPool.Put(buf)
|
||||
|
||||
// CopyBuffer only uses buf up to its length and panics if it's 0.
|
||||
// Due to that we extend buf's length to its capacity here and
|
||||
// ensure it's always non-zero.
|
||||
bufCap := cap(buf)
|
||||
io.CopyBuffer(dst, src, buf[0:bufCap:bufCap])
|
||||
}
|
||||
|
||||
// onExitFlushLoop is a callback set by tests to detect the state of the
|
||||
|
@ -135,11 +153,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
|||
// just use default transport, to avoid creating
|
||||
// a brand new transport
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: defaultDialer.Dial,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
@ -148,7 +163,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
|||
} else {
|
||||
transport.MaxIdleConnsPerHost = keepalive
|
||||
}
|
||||
http2.ConfigureTransport(transport)
|
||||
if httpserver.HTTP2 {
|
||||
http2.ConfigureTransport(transport)
|
||||
}
|
||||
rp.Transport = transport
|
||||
}
|
||||
return rp
|
||||
|
@ -160,18 +177,20 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
|||
func (rp *ReverseProxy) UseInsecureTransport() {
|
||||
if rp.Transport == nil {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: defaultDialer.Dial,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
http2.ConfigureTransport(transport)
|
||||
if httpserver.HTTP2 {
|
||||
http2.ConfigureTransport(transport)
|
||||
}
|
||||
rp.Transport = transport
|
||||
} else if transport, ok := rp.Transport.(*http.Transport); ok {
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
// No http2.ConfigureTransport() here.
|
||||
// For now this is only added in places where
|
||||
// an http.Transport is actually created.
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -186,20 +205,33 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
|
|||
}
|
||||
|
||||
rp.Director(outreq)
|
||||
outreq.Proto = "HTTP/1.1"
|
||||
outreq.ProtoMajor = 1
|
||||
outreq.ProtoMinor = 1
|
||||
outreq.Close = false
|
||||
|
||||
res, err := transport.RoundTrip(outreq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isWebsocket := res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket"
|
||||
|
||||
// Remove hop-by-hop headers listed in the
|
||||
// "Connection" header of the response.
|
||||
if c := res.Header.Get("Connection"); c != "" {
|
||||
for _, f := range strings.Split(c, ",") {
|
||||
if f = strings.TrimSpace(f); f != "" {
|
||||
res.Header.Del(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, h := range hopHeaders {
|
||||
res.Header.Del(h)
|
||||
}
|
||||
|
||||
if respUpdateFn != nil {
|
||||
respUpdateFn(res)
|
||||
}
|
||||
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
|
||||
|
||||
if isWebsocket {
|
||||
res.Body.Close()
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
|
@ -228,27 +260,39 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
|
|||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
go func() {
|
||||
io.Copy(backendConn, conn) // write tcp stream to backend.
|
||||
}()
|
||||
io.Copy(conn, backendConn) // read tcp stream from backend.
|
||||
go pooledIoCopy(backendConn, conn) // write tcp stream to backend
|
||||
pooledIoCopy(conn, backendConn) // read tcp stream from backend
|
||||
} else {
|
||||
defer res.Body.Close()
|
||||
for _, h := range hopHeaders {
|
||||
res.Header.Del(h)
|
||||
}
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
|
||||
// The "Trailer" header isn't included in the Transport's response,
|
||||
// at least for *http.Transport. Build it up from Trailer.
|
||||
if len(res.Trailer) > 0 {
|
||||
trailerKeys := make([]string, 0, len(res.Trailer))
|
||||
for k := range res.Trailer {
|
||||
trailerKeys = append(trailerKeys, k)
|
||||
}
|
||||
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
|
||||
}
|
||||
|
||||
rw.WriteHeader(res.StatusCode)
|
||||
if len(res.Trailer) > 0 {
|
||||
// Force chunking if we saw a response trailer.
|
||||
// This prevents net/http from calculating the length for short
|
||||
// bodies and adding a Content-Length.
|
||||
if fl, ok := rw.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
rp.copyResponse(rw, res.Body)
|
||||
res.Body.Close() // close now, instead of defer, to populate res.Trailer
|
||||
copyHeader(rw.Header(), res.Trailer)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
||||
buf := bufferPool.Get()
|
||||
defer bufferPool.Put(buf)
|
||||
|
||||
if rp.FlushInterval != 0 {
|
||||
if wf, ok := dst.(writeFlusher); ok {
|
||||
mlw := &maxLatencyWriter{
|
||||
|
@ -261,7 +305,7 @@ func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
|||
dst = mlw
|
||||
}
|
||||
}
|
||||
io.CopyBuffer(dst, src, buf.([]byte))
|
||||
pooledIoCopy(dst, src)
|
||||
}
|
||||
|
||||
// skip these headers if they already exist.
|
||||
|
@ -295,16 +339,17 @@ func copyHeader(dst, src http.Header) {
|
|||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
||||
var hopHeaders = []string{
|
||||
"Alt-Svc",
|
||||
"Alternate-Protocol",
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailers",
|
||||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
"Alternate-Protocol",
|
||||
"Alt-Svc",
|
||||
}
|
||||
|
||||
type respUpdateFn func(resp *http.Response)
|
||||
|
@ -331,51 +376,169 @@ type connHijackerTransport struct {
|
|||
}
|
||||
|
||||
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
t := &http.Transport{
|
||||
MaxIdleConnsPerHost: -1,
|
||||
}
|
||||
if base != nil {
|
||||
if baseTransport, ok := base.(*http.Transport); ok {
|
||||
transport.Proxy = baseTransport.Proxy
|
||||
transport.TLSClientConfig = baseTransport.TLSClientConfig
|
||||
transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout
|
||||
transport.Dial = baseTransport.Dial
|
||||
transport.DialTLS = baseTransport.DialTLS
|
||||
transport.MaxIdleConnsPerHost = -1
|
||||
if b, _ := base.(*http.Transport); b != nil {
|
||||
tlsClientConfig := b.TLSClientConfig
|
||||
if tlsClientConfig.NextProtos != nil {
|
||||
tlsClientConfig = cloneTLSClientConfig(tlsClientConfig)
|
||||
tlsClientConfig.NextProtos = nil
|
||||
}
|
||||
|
||||
t.Proxy = b.Proxy
|
||||
t.TLSClientConfig = tlsClientConfig
|
||||
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
|
||||
t.Dial = b.Dial
|
||||
t.DialTLS = b.DialTLS
|
||||
} else {
|
||||
t.Proxy = http.ProxyFromEnvironment
|
||||
t.TLSHandshakeTimeout = 10 * time.Second
|
||||
}
|
||||
hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]}
|
||||
oldDial := transport.Dial
|
||||
oldDialTLS := transport.DialTLS
|
||||
if oldDial == nil {
|
||||
oldDial = (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial
|
||||
hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]}
|
||||
|
||||
dial := getTransportDial(t)
|
||||
dialTLS := getTransportDialTLS(t)
|
||||
t.Dial = func(network, addr string) (net.Conn, error) {
|
||||
c, err := dial(network, addr)
|
||||
hj.Conn = c
|
||||
return &hijackedConn{c, hj}, err
|
||||
}
|
||||
hjTransport.Dial = func(network, addr string) (net.Conn, error) {
|
||||
c, err := oldDial(network, addr)
|
||||
hjTransport.Conn = c
|
||||
return &hijackedConn{c, hjTransport}, err
|
||||
t.DialTLS = func(network, addr string) (net.Conn, error) {
|
||||
c, err := dialTLS(network, addr)
|
||||
hj.Conn = c
|
||||
return &hijackedConn{c, hj}, err
|
||||
}
|
||||
if oldDialTLS != nil {
|
||||
hjTransport.DialTLS = func(network, addr string) (net.Conn, error) {
|
||||
c, err := oldDialTLS(network, addr)
|
||||
hjTransport.Conn = c
|
||||
return &hijackedConn{c, hjTransport}, err
|
||||
|
||||
return hj
|
||||
}
|
||||
|
||||
// getTransportDial always returns a plain Dialer
|
||||
// and defaults to the existing t.Dial.
|
||||
func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
||||
if t.Dial != nil {
|
||||
return t.Dial
|
||||
}
|
||||
return defaultDialer.Dial
|
||||
}
|
||||
|
||||
// getTransportDial always returns a TLS Dialer
|
||||
// and defaults to the existing t.DialTLS.
|
||||
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
||||
if t.DialTLS != nil {
|
||||
return t.DialTLS
|
||||
}
|
||||
|
||||
// newConnHijackerTransport will modify t.Dial after calling this method
|
||||
// => Create a backup reference.
|
||||
plainDial := getTransportDial(t)
|
||||
|
||||
// The following DialTLS implementation stems from the Go stdlib and
|
||||
// is identical to what happens if DialTLS is not provided.
|
||||
// Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
|
||||
return func(network, addr string) (net.Conn, error) {
|
||||
plainConn, err := plainDial(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsClientConfig := t.TLSClientConfig
|
||||
if tlsClientConfig == nil {
|
||||
tlsClientConfig = &tls.Config{}
|
||||
}
|
||||
if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" {
|
||||
tlsClientConfig.ServerName = stripPort(addr)
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(plainConn, tlsClientConfig)
|
||||
errc := make(chan error, 2)
|
||||
var timer *time.Timer
|
||||
if d := t.TLSHandshakeTimeout; d != 0 {
|
||||
timer = time.AfterFunc(d, func() {
|
||||
errc <- tlsHandshakeTimeoutError{}
|
||||
})
|
||||
}
|
||||
go func() {
|
||||
err := tlsConn.Handshake()
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
errc <- err
|
||||
}()
|
||||
if err := <-errc; err != nil {
|
||||
plainConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !tlsClientConfig.InsecureSkipVerify {
|
||||
hostname := tlsClientConfig.ServerName
|
||||
if hostname == "" {
|
||||
hostname = stripPort(addr)
|
||||
}
|
||||
if err := tlsConn.VerifyHostname(hostname); err != nil {
|
||||
plainConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
}
|
||||
|
||||
// stripPort returns address without its port if it has one and
|
||||
// works with IP addresses as well as hostnames formatted as host:port.
|
||||
//
|
||||
// IPv6 addresses (excluding the port) must be enclosed in
|
||||
// square brackets similar to the requirements of Go's stdlib.
|
||||
func stripPort(address string) string {
|
||||
// Keep in mind that the address might be a IPv6 address
|
||||
// and thus contain a colon, but not have a port.
|
||||
portIdx := strings.LastIndex(address, ":")
|
||||
ipv6Idx := strings.LastIndex(address, "]")
|
||||
if portIdx > ipv6Idx {
|
||||
address = address[:portIdx]
|
||||
}
|
||||
return address
|
||||
}
|
||||
|
||||
type tlsHandshakeTimeoutError struct{}
|
||||
|
||||
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
|
||||
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
|
||||
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
|
||||
|
||||
// cloneTLSClientConfig is like cloneTLSConfig but omits
|
||||
// the fields SessionTicketsDisabled and SessionTicketKey.
|
||||
// This makes it safe to call cloneTLSClientConfig on a config
|
||||
// in active use by a server.
|
||||
func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
|
||||
if cfg == nil {
|
||||
return &tls.Config{}
|
||||
}
|
||||
return &tls.Config{
|
||||
Rand: cfg.Rand,
|
||||
Time: cfg.Time,
|
||||
Certificates: cfg.Certificates,
|
||||
NameToCertificate: cfg.NameToCertificate,
|
||||
GetCertificate: cfg.GetCertificate,
|
||||
RootCAs: cfg.RootCAs,
|
||||
NextProtos: cfg.NextProtos,
|
||||
ServerName: cfg.ServerName,
|
||||
ClientAuth: cfg.ClientAuth,
|
||||
ClientCAs: cfg.ClientCAs,
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
CipherSuites: cfg.CipherSuites,
|
||||
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
|
||||
ClientSessionCache: cfg.ClientSessionCache,
|
||||
MinVersion: cfg.MinVersion,
|
||||
MaxVersion: cfg.MaxVersion,
|
||||
CurvePreferences: cfg.CurvePreferences,
|
||||
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
|
||||
Renegotiation: cfg.Renegotiation,
|
||||
}
|
||||
return hjTransport
|
||||
}
|
||||
|
||||
func requestIsWebsocket(req *http.Request) bool {
|
||||
return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade"))
|
||||
return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
|
||||
}
|
||||
|
||||
type writeFlusher interface {
|
||||
|
|
Loading…
Reference in a new issue