mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-23 22:27:38 -05:00
Merge branch 'master' of github.com:mholt/caddy
This commit is contained in:
commit
a1a8d0f655
6 changed files with 545 additions and 85 deletions
76
caddyhttp/httpserver/pathcleaner.go
Normal file
76
caddyhttp/httpserver/pathcleaner.go
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
package httpserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CleanMaskedPath prevents one or more of the path cleanup operations:
|
||||||
|
// - collapse multiple slashes into one
|
||||||
|
// - eliminate "/." (current directory)
|
||||||
|
// - eliminate "<parent_directory>/.."
|
||||||
|
// by masking certain patterns in the path with a temporary random string.
|
||||||
|
// This could be helpful when certain patterns in the path are desired to be preserved
|
||||||
|
// that would otherwise be changed by path.Clean().
|
||||||
|
// One such use case is the presence of the double slashes as protocol separator
|
||||||
|
// (e.g., /api/endpoint/http://example.com).
|
||||||
|
// This is a common pattern in many applications to allow passing URIs as path argument.
|
||||||
|
func CleanMaskedPath(reqPath string, masks ...string) string {
|
||||||
|
var replacerVal string
|
||||||
|
maskMap := make(map[string]string)
|
||||||
|
|
||||||
|
// Iterate over supplied masks and create temporary replacement strings
|
||||||
|
// only for the masks that are present in the path, then replace all occurrences
|
||||||
|
for _, mask := range masks {
|
||||||
|
if strings.Index(reqPath, mask) >= 0 {
|
||||||
|
replacerVal = "/_caddy" + generateRandomString() + "__"
|
||||||
|
maskMap[mask] = replacerVal
|
||||||
|
reqPath = strings.Replace(reqPath, mask, replacerVal, -1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reqPath = path.Clean(reqPath)
|
||||||
|
|
||||||
|
// Revert the replaced masks after path cleanup
|
||||||
|
for mask, replacerVal := range maskMap {
|
||||||
|
reqPath = strings.Replace(reqPath, replacerVal, mask, -1)
|
||||||
|
}
|
||||||
|
return reqPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanPath calls CleanMaskedPath() with the default mask of "://"
|
||||||
|
// to preserve double slashes of protocols
|
||||||
|
// such as "http://", "https://", and "ftp://" etc.
|
||||||
|
func CleanPath(reqPath string) string {
|
||||||
|
return CleanMaskedPath(reqPath, "://")
|
||||||
|
}
|
||||||
|
|
||||||
|
// An efficient and fast method for random string generation.
|
||||||
|
// Inspired by http://stackoverflow.com/a/31832326.
|
||||||
|
const randomStringLength = 4
|
||||||
|
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
|
const (
|
||||||
|
letterIdxBits = 6
|
||||||
|
letterIdxMask = 1<<letterIdxBits - 1
|
||||||
|
letterIdxMax = 63 / letterIdxBits
|
||||||
|
)
|
||||||
|
|
||||||
|
var src = rand.NewSource(time.Now().UnixNano())
|
||||||
|
|
||||||
|
func generateRandomString() string {
|
||||||
|
b := make([]byte, randomStringLength)
|
||||||
|
for i, cache, remain := randomStringLength-1, src.Int63(), letterIdxMax; i >= 0; {
|
||||||
|
if remain == 0 {
|
||||||
|
cache, remain = src.Int63(), letterIdxMax
|
||||||
|
}
|
||||||
|
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
|
||||||
|
b[i] = letterBytes[idx]
|
||||||
|
i--
|
||||||
|
}
|
||||||
|
cache >>= letterIdxBits
|
||||||
|
remain--
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
120
caddyhttp/httpserver/pathcleaner_test.go
Normal file
120
caddyhttp/httpserver/pathcleaner_test.go
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
package httpserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var paths = map[string]map[string]string{
|
||||||
|
"/../a/b/../././/c": {
|
||||||
|
"preserve_all": "/../a/b/../././/c",
|
||||||
|
"preserve_protocol": "/a/c",
|
||||||
|
"preserve_slashes": "/a//c",
|
||||||
|
"preserve_dots": "/../a/b/../././c",
|
||||||
|
"clean_all": "/a/c",
|
||||||
|
},
|
||||||
|
"/path/https://www.google.com": {
|
||||||
|
"preserve_all": "/path/https://www.google.com",
|
||||||
|
"preserve_protocol": "/path/https://www.google.com",
|
||||||
|
"preserve_slashes": "/path/https://www.google.com",
|
||||||
|
"preserve_dots": "/path/https:/www.google.com",
|
||||||
|
"clean_all": "/path/https:/www.google.com",
|
||||||
|
},
|
||||||
|
"/a/b/../././/c/http://example.com/foo//bar/../blah": {
|
||||||
|
"preserve_all": "/a/b/../././/c/http://example.com/foo//bar/../blah",
|
||||||
|
"preserve_protocol": "/a/c/http://example.com/foo/blah",
|
||||||
|
"preserve_slashes": "/a//c/http://example.com/foo/blah",
|
||||||
|
"preserve_dots": "/a/b/../././c/http:/example.com/foo/bar/../blah",
|
||||||
|
"clean_all": "/a/c/http:/example.com/foo/blah",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertEqual(t *testing.T, expected, received string) {
|
||||||
|
if expected != received {
|
||||||
|
t.Errorf("\tExpected: %s\n\t\t\tRecieved: %s", expected, received)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func maskedTestRunner(t *testing.T, variation string, masks ...string) {
|
||||||
|
for reqPath, transformation := range paths {
|
||||||
|
assertEqual(t, transformation[variation], CleanMaskedPath(reqPath, masks...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No need to test the built-in path.Clean() function.
|
||||||
|
// However, it could be useful to cross-examine the test dataset.
|
||||||
|
func TestPathClean(t *testing.T) {
|
||||||
|
for reqPath, transformation := range paths {
|
||||||
|
assertEqual(t, transformation["clean_all"], path.Clean(reqPath))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanAll(t *testing.T) {
|
||||||
|
maskedTestRunner(t, "clean_all")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreserveAll(t *testing.T) {
|
||||||
|
maskedTestRunner(t, "preserve_all", "//", "/..", "/.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreserveProtocol(t *testing.T) {
|
||||||
|
maskedTestRunner(t, "preserve_protocol", "://")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreserveSlashes(t *testing.T) {
|
||||||
|
maskedTestRunner(t, "preserve_slashes", "//")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreserveDots(t *testing.T) {
|
||||||
|
maskedTestRunner(t, "preserve_dots", "/..", "/.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultMask(t *testing.T) {
|
||||||
|
for reqPath, transformation := range paths {
|
||||||
|
assertEqual(t, transformation["preserve_protocol"], CleanPath(reqPath))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func maskedBenchmarkRunner(b *testing.B, masks ...string) {
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
for reqPath := range paths {
|
||||||
|
CleanMaskedPath(reqPath, masks...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPathClean(b *testing.B) {
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
for reqPath := range paths {
|
||||||
|
path.Clean(reqPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCleanAll(b *testing.B) {
|
||||||
|
maskedBenchmarkRunner(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPreserveAll(b *testing.B) {
|
||||||
|
maskedBenchmarkRunner(b, "//", "/..", "/.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPreserveProtocol(b *testing.B) {
|
||||||
|
maskedBenchmarkRunner(b, "://")
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPreserveSlashes(b *testing.B) {
|
||||||
|
maskedBenchmarkRunner(b, "//")
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPreserveDots(b *testing.B) {
|
||||||
|
maskedBenchmarkRunner(b, "/..", "/.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDefaultMask(b *testing.B) {
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
for reqPath := range paths {
|
||||||
|
CleanPath(reqPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -351,7 +350,7 @@ func sanitizePath(r *http.Request) {
|
||||||
if r.URL.Path == "/" {
|
if r.URL.Path == "/" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cleanedPath := path.Clean(r.URL.Path)
|
cleanedPath := CleanPath(r.URL.Path)
|
||||||
if cleanedPath == "." {
|
if cleanedPath == "." {
|
||||||
r.URL.Path = "/"
|
r.URL.Path = "/"
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -247,12 +247,28 @@ func createUpstreamRequest(r *http.Request) *http.Request {
|
||||||
outreq.URL.Opaque = outreq.URL.RawPath
|
outreq.URL.Opaque = outreq.URL.RawPath
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We are modifying the same underlying map from req (shallow
|
||||||
|
// copied above) so we only copy it if necessary.
|
||||||
|
copiedHeaders := false
|
||||||
|
|
||||||
|
// Remove hop-by-hop headers listed in the "Connection" header.
|
||||||
|
// See RFC 2616, section 14.10.
|
||||||
|
if c := outreq.Header.Get("Connection"); c != "" {
|
||||||
|
for _, f := range strings.Split(c, ",") {
|
||||||
|
if f = strings.TrimSpace(f); f != "" {
|
||||||
|
if !copiedHeaders {
|
||||||
|
outreq.Header = make(http.Header)
|
||||||
|
copyHeader(outreq.Header, r.Header)
|
||||||
|
copiedHeaders = true
|
||||||
|
}
|
||||||
|
outreq.Header.Del(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Remove hop-by-hop headers to the backend. Especially
|
// Remove hop-by-hop headers to the backend. Especially
|
||||||
// important is "Connection" because we want a persistent
|
// important is "Connection" because we want a persistent
|
||||||
// connection, regardless of what the client sent to us. This
|
// connection, regardless of what the client sent to us.
|
||||||
// is modifying the same underlying map from r (shallow
|
|
||||||
// copied above) so we only copy it if necessary.
|
|
||||||
var copiedHeaders bool
|
|
||||||
for _, h := range hopHeaders {
|
for _, h := range hopHeaders {
|
||||||
if outreq.Header.Get(h) != "" {
|
if outreq.Header.Get(h) != "" {
|
||||||
if !copiedHeaders {
|
if !copiedHeaders {
|
||||||
|
|
|
@ -42,10 +42,32 @@ func TestReverseProxy(t *testing.T) {
|
||||||
log.SetOutput(ioutil.Discard)
|
log.SetOutput(ioutil.Discard)
|
||||||
defer log.SetOutput(os.Stderr)
|
defer log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
verifyHeaders := func(headers http.Header, trailers http.Header) {
|
||||||
|
if headers.Get("X-Header") != "header-value" {
|
||||||
|
t.Error("Expected header 'X-Header' to be proxied properly")
|
||||||
|
}
|
||||||
|
|
||||||
|
if trailers == nil {
|
||||||
|
t.Error("Expected to receive trailers")
|
||||||
|
}
|
||||||
|
if trailers.Get("X-Trailer") != "trailer-value" {
|
||||||
|
t.Error("Expected header 'X-Trailer' to be proxied properly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var requestReceived bool
|
var requestReceived bool
|
||||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// read the body (even if it's empty) to make Go parse trailers
|
||||||
|
io.Copy(ioutil.Discard, r.Body)
|
||||||
|
verifyHeaders(r.Header, r.Trailer)
|
||||||
|
|
||||||
requestReceived = true
|
requestReceived = true
|
||||||
|
|
||||||
|
w.Header().Set("Trailer", "X-Trailer")
|
||||||
|
w.Header().Set("X-Header", "header-value")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("Hello, client"))
|
w.Write([]byte("Hello, client"))
|
||||||
|
w.Header().Set("X-Trailer", "trailer-value")
|
||||||
}))
|
}))
|
||||||
defer backend.Close()
|
defer backend.Close()
|
||||||
|
|
||||||
|
@ -59,12 +81,21 @@ func TestReverseProxy(t *testing.T) {
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
r.ContentLength = -1 // force chunked encoding (required for trailers)
|
||||||
|
r.Header.Set("X-Header", "header-value")
|
||||||
|
r.Trailer = map[string][]string{
|
||||||
|
"X-Trailer": {"trailer-value"},
|
||||||
|
}
|
||||||
|
|
||||||
p.ServeHTTP(w, r)
|
p.ServeHTTP(w, r)
|
||||||
|
|
||||||
if !requestReceived {
|
if !requestReceived {
|
||||||
t.Error("Expected backend to receive request, but it didn't")
|
t.Error("Expected backend to receive request, but it didn't")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
res := w.Result()
|
||||||
|
verifyHeaders(res.Header, res.Trailer)
|
||||||
|
|
||||||
// Make sure {upstream} placeholder is set
|
// Make sure {upstream} placeholder is set
|
||||||
rr := httpserver.NewResponseRecorder(httptest.NewRecorder())
|
rr := httpserver.NewResponseRecorder(httptest.NewRecorder())
|
||||||
rr.Replacer = httpserver.NewReplacer(r, rr, "-")
|
rr.Replacer = httpserver.NewReplacer(r, rr, "-")
|
||||||
|
@ -123,7 +154,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
|
||||||
defer wsNop.Close()
|
defer wsNop.Close()
|
||||||
|
|
||||||
// Get proxy to use for the test
|
// Get proxy to use for the test
|
||||||
p := newWebSocketTestProxy(wsNop.URL)
|
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||||
|
|
||||||
// Create client request
|
// Create client request
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
@ -148,7 +179,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
|
||||||
defer wsNop.Close()
|
defer wsNop.Close()
|
||||||
|
|
||||||
// Get proxy to use for the test
|
// Get proxy to use for the test
|
||||||
p := newWebSocketTestProxy(wsNop.URL)
|
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||||
|
|
||||||
// Create client request
|
// Create client request
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
@ -189,7 +220,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
||||||
defer wsEcho.Close()
|
defer wsEcho.Close()
|
||||||
|
|
||||||
// Get proxy to use for the test
|
// Get proxy to use for the test
|
||||||
p := newWebSocketTestProxy(wsEcho.URL)
|
p := newWebSocketTestProxy(wsEcho.URL, false)
|
||||||
|
|
||||||
// This is a full end-end test, so the proxy handler
|
// This is a full end-end test, so the proxy handler
|
||||||
// has to be part of a server listening on a port. Our
|
// has to be part of a server listening on a port. Our
|
||||||
|
@ -228,6 +259,52 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWebSocketReverseProxyFromWSSClient(t *testing.T) {
|
||||||
|
wsEcho := newTLSServer(websocket.Handler(func(ws *websocket.Conn) {
|
||||||
|
io.Copy(ws, ws)
|
||||||
|
}))
|
||||||
|
defer wsEcho.Close()
|
||||||
|
|
||||||
|
p := newWebSocketTestProxy(wsEcho.URL, true)
|
||||||
|
|
||||||
|
echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
p.ServeHTTP(w, r)
|
||||||
|
}))
|
||||||
|
defer echoProxy.Close()
|
||||||
|
|
||||||
|
// Set up WebSocket client
|
||||||
|
url := strings.Replace(echoProxy.URL, "https://", "wss://", 1)
|
||||||
|
wsCfg, err := websocket.NewConfig(url, echoProxy.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
wsCfg.TlsConfig = &tls.Config{InsecureSkipVerify: true}
|
||||||
|
ws, err := websocket.DialConfig(wsCfg)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
|
||||||
|
// Send test message
|
||||||
|
trialMsg := "Is it working?"
|
||||||
|
|
||||||
|
if sendErr := websocket.Message.Send(ws, trialMsg); sendErr != nil {
|
||||||
|
t.Fatal(sendErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// It should be echoed back to us
|
||||||
|
var actualMsg string
|
||||||
|
|
||||||
|
if rcvErr := websocket.Message.Receive(ws, &actualMsg); rcvErr != nil {
|
||||||
|
t.Fatal(rcvErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actualMsg != trialMsg {
|
||||||
|
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUnixSocketProxy(t *testing.T) {
|
func TestUnixSocketProxy(t *testing.T) {
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
return
|
return
|
||||||
|
@ -264,7 +341,7 @@ func TestUnixSocketProxy(t *testing.T) {
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
url := strings.Replace(ts.URL, "http://", "unix:", 1)
|
url := strings.Replace(ts.URL, "http://", "unix:", 1)
|
||||||
p := newWebSocketTestProxy(url)
|
p := newWebSocketTestProxy(url, false)
|
||||||
|
|
||||||
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
p.ServeHTTP(w, r)
|
p.ServeHTTP(w, r)
|
||||||
|
@ -982,10 +1059,14 @@ func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.
|
||||||
// redirect to the specified backendAddr. The function
|
// redirect to the specified backendAddr. The function
|
||||||
// also sets up the rules/environment for testing WebSocket
|
// also sets up the rules/environment for testing WebSocket
|
||||||
// proxy.
|
// proxy.
|
||||||
func newWebSocketTestProxy(backendAddr string) *Proxy {
|
func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
|
||||||
return &Proxy{
|
return &Proxy{
|
||||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||||
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}},
|
Upstreams: []Upstream{&fakeWsUpstream{
|
||||||
|
name: backendAddr,
|
||||||
|
without: "",
|
||||||
|
insecure: insecure,
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -999,6 +1080,7 @@ func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
|
||||||
type fakeWsUpstream struct {
|
type fakeWsUpstream struct {
|
||||||
name string
|
name string
|
||||||
without string
|
without string
|
||||||
|
insecure bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *fakeWsUpstream) From() string {
|
func (u *fakeWsUpstream) From() string {
|
||||||
|
@ -1007,13 +1089,17 @@ func (u *fakeWsUpstream) From() string {
|
||||||
|
|
||||||
func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
|
func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
|
||||||
uri, _ := url.Parse(u.name)
|
uri, _ := url.Parse(u.name)
|
||||||
return &UpstreamHost{
|
host := &UpstreamHost{
|
||||||
Name: u.name,
|
Name: u.name,
|
||||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
|
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
|
||||||
UpstreamHeaders: http.Header{
|
UpstreamHeaders: http.Header{
|
||||||
"Connection": {"{>Connection}"},
|
"Connection": {"{>Connection}"},
|
||||||
"Upgrade": {"{>Upgrade}"}},
|
"Upgrade": {"{>Upgrade}"}},
|
||||||
}
|
}
|
||||||
|
if u.insecure {
|
||||||
|
host.ReverseProxy.UseInsecureTransport()
|
||||||
|
}
|
||||||
|
return host
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
|
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
|
||||||
|
|
|
@ -27,10 +27,28 @@ import (
|
||||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||||
)
|
)
|
||||||
|
|
||||||
var bufferPool = sync.Pool{New: createBuffer}
|
var (
|
||||||
|
defaultDialer = &net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
bufferPool = sync.Pool{New: createBuffer}
|
||||||
|
)
|
||||||
|
|
||||||
func createBuffer() interface{} {
|
func createBuffer() interface{} {
|
||||||
return make([]byte, 32*1024)
|
return make([]byte, 0, 32*1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pooledIoCopy(dst io.Writer, src io.Reader) {
|
||||||
|
buf := bufferPool.Get().([]byte)
|
||||||
|
defer bufferPool.Put(buf)
|
||||||
|
|
||||||
|
// CopyBuffer only uses buf up to its length and panics if it's 0.
|
||||||
|
// Due to that we extend buf's length to its capacity here and
|
||||||
|
// ensure it's always non-zero.
|
||||||
|
bufCap := cap(buf)
|
||||||
|
io.CopyBuffer(dst, src, buf[0:bufCap:bufCap])
|
||||||
}
|
}
|
||||||
|
|
||||||
// onExitFlushLoop is a callback set by tests to detect the state of the
|
// onExitFlushLoop is a callback set by tests to detect the state of the
|
||||||
|
@ -136,10 +154,7 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
||||||
// a brand new transport
|
// a brand new transport
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
Dial: (&net.Dialer{
|
Dial: defaultDialer.Dial,
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).Dial,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
}
|
}
|
||||||
|
@ -148,7 +163,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
||||||
} else {
|
} else {
|
||||||
transport.MaxIdleConnsPerHost = keepalive
|
transport.MaxIdleConnsPerHost = keepalive
|
||||||
}
|
}
|
||||||
|
if httpserver.HTTP2 {
|
||||||
http2.ConfigureTransport(transport)
|
http2.ConfigureTransport(transport)
|
||||||
|
}
|
||||||
rp.Transport = transport
|
rp.Transport = transport
|
||||||
}
|
}
|
||||||
return rp
|
return rp
|
||||||
|
@ -161,17 +178,19 @@ func (rp *ReverseProxy) UseInsecureTransport() {
|
||||||
if rp.Transport == nil {
|
if rp.Transport == nil {
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
Dial: (&net.Dialer{
|
Dial: defaultDialer.Dial,
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).Dial,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
}
|
}
|
||||||
|
if httpserver.HTTP2 {
|
||||||
http2.ConfigureTransport(transport)
|
http2.ConfigureTransport(transport)
|
||||||
|
}
|
||||||
rp.Transport = transport
|
rp.Transport = transport
|
||||||
} else if transport, ok := rp.Transport.(*http.Transport); ok {
|
} else if transport, ok := rp.Transport.(*http.Transport); ok {
|
||||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||||
|
// No http2.ConfigureTransport() here.
|
||||||
|
// For now this is only added in places where
|
||||||
|
// an http.Transport is actually created.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,20 +205,33 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
|
||||||
}
|
}
|
||||||
|
|
||||||
rp.Director(outreq)
|
rp.Director(outreq)
|
||||||
outreq.Proto = "HTTP/1.1"
|
|
||||||
outreq.ProtoMajor = 1
|
|
||||||
outreq.ProtoMinor = 1
|
|
||||||
outreq.Close = false
|
|
||||||
|
|
||||||
res, err := transport.RoundTrip(outreq)
|
res, err := transport.RoundTrip(outreq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
isWebsocket := res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket"
|
||||||
|
|
||||||
|
// Remove hop-by-hop headers listed in the
|
||||||
|
// "Connection" header of the response.
|
||||||
|
if c := res.Header.Get("Connection"); c != "" {
|
||||||
|
for _, f := range strings.Split(c, ",") {
|
||||||
|
if f = strings.TrimSpace(f); f != "" {
|
||||||
|
res.Header.Del(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, h := range hopHeaders {
|
||||||
|
res.Header.Del(h)
|
||||||
|
}
|
||||||
|
|
||||||
if respUpdateFn != nil {
|
if respUpdateFn != nil {
|
||||||
respUpdateFn(res)
|
respUpdateFn(res)
|
||||||
}
|
}
|
||||||
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
|
|
||||||
|
if isWebsocket {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
hj, ok := rw.(http.Hijacker)
|
hj, ok := rw.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -228,27 +260,39 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
|
||||||
}
|
}
|
||||||
defer backendConn.Close()
|
defer backendConn.Close()
|
||||||
|
|
||||||
go func() {
|
go pooledIoCopy(backendConn, conn) // write tcp stream to backend
|
||||||
io.Copy(backendConn, conn) // write tcp stream to backend.
|
pooledIoCopy(conn, backendConn) // read tcp stream from backend
|
||||||
}()
|
|
||||||
io.Copy(conn, backendConn) // read tcp stream from backend.
|
|
||||||
} else {
|
} else {
|
||||||
defer res.Body.Close()
|
|
||||||
for _, h := range hopHeaders {
|
|
||||||
res.Header.Del(h)
|
|
||||||
}
|
|
||||||
copyHeader(rw.Header(), res.Header)
|
copyHeader(rw.Header(), res.Header)
|
||||||
|
|
||||||
|
// The "Trailer" header isn't included in the Transport's response,
|
||||||
|
// at least for *http.Transport. Build it up from Trailer.
|
||||||
|
if len(res.Trailer) > 0 {
|
||||||
|
trailerKeys := make([]string, 0, len(res.Trailer))
|
||||||
|
for k := range res.Trailer {
|
||||||
|
trailerKeys = append(trailerKeys, k)
|
||||||
|
}
|
||||||
|
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
rw.WriteHeader(res.StatusCode)
|
rw.WriteHeader(res.StatusCode)
|
||||||
|
if len(res.Trailer) > 0 {
|
||||||
|
// Force chunking if we saw a response trailer.
|
||||||
|
// This prevents net/http from calculating the length for short
|
||||||
|
// bodies and adding a Content-Length.
|
||||||
|
if fl, ok := rw.(http.Flusher); ok {
|
||||||
|
fl.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
rp.copyResponse(rw, res.Body)
|
rp.copyResponse(rw, res.Body)
|
||||||
|
res.Body.Close() // close now, instead of defer, to populate res.Trailer
|
||||||
|
copyHeader(rw.Header(), res.Trailer)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
||||||
buf := bufferPool.Get()
|
|
||||||
defer bufferPool.Put(buf)
|
|
||||||
|
|
||||||
if rp.FlushInterval != 0 {
|
if rp.FlushInterval != 0 {
|
||||||
if wf, ok := dst.(writeFlusher); ok {
|
if wf, ok := dst.(writeFlusher); ok {
|
||||||
mlw := &maxLatencyWriter{
|
mlw := &maxLatencyWriter{
|
||||||
|
@ -261,7 +305,7 @@ func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
||||||
dst = mlw
|
dst = mlw
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
io.CopyBuffer(dst, src, buf.([]byte))
|
pooledIoCopy(dst, src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// skip these headers if they already exist.
|
// skip these headers if they already exist.
|
||||||
|
@ -295,16 +339,17 @@ func copyHeader(dst, src http.Header) {
|
||||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
||||||
var hopHeaders = []string{
|
var hopHeaders = []string{
|
||||||
|
"Alt-Svc",
|
||||||
|
"Alternate-Protocol",
|
||||||
"Connection",
|
"Connection",
|
||||||
"Keep-Alive",
|
"Keep-Alive",
|
||||||
"Proxy-Authenticate",
|
"Proxy-Authenticate",
|
||||||
"Proxy-Authorization",
|
"Proxy-Authorization",
|
||||||
|
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||||
"Te", // canonicalized version of "TE"
|
"Te", // canonicalized version of "TE"
|
||||||
"Trailers",
|
"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
|
||||||
"Transfer-Encoding",
|
"Transfer-Encoding",
|
||||||
"Upgrade",
|
"Upgrade",
|
||||||
"Alternate-Protocol",
|
|
||||||
"Alt-Svc",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type respUpdateFn func(resp *http.Response)
|
type respUpdateFn func(resp *http.Response)
|
||||||
|
@ -331,51 +376,169 @@ type connHijackerTransport struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
||||||
transport := &http.Transport{
|
t := &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
Dial: (&net.Dialer{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).Dial,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
MaxIdleConnsPerHost: -1,
|
MaxIdleConnsPerHost: -1,
|
||||||
}
|
}
|
||||||
if base != nil {
|
if b, _ := base.(*http.Transport); b != nil {
|
||||||
if baseTransport, ok := base.(*http.Transport); ok {
|
tlsClientConfig := b.TLSClientConfig
|
||||||
transport.Proxy = baseTransport.Proxy
|
if tlsClientConfig.NextProtos != nil {
|
||||||
transport.TLSClientConfig = baseTransport.TLSClientConfig
|
tlsClientConfig = cloneTLSClientConfig(tlsClientConfig)
|
||||||
transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout
|
tlsClientConfig.NextProtos = nil
|
||||||
transport.Dial = baseTransport.Dial
|
}
|
||||||
transport.DialTLS = baseTransport.DialTLS
|
|
||||||
transport.MaxIdleConnsPerHost = -1
|
t.Proxy = b.Proxy
|
||||||
|
t.TLSClientConfig = tlsClientConfig
|
||||||
|
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
|
||||||
|
t.Dial = b.Dial
|
||||||
|
t.DialTLS = b.DialTLS
|
||||||
|
} else {
|
||||||
|
t.Proxy = http.ProxyFromEnvironment
|
||||||
|
t.TLSHandshakeTimeout = 10 * time.Second
|
||||||
|
}
|
||||||
|
hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]}
|
||||||
|
|
||||||
|
dial := getTransportDial(t)
|
||||||
|
dialTLS := getTransportDialTLS(t)
|
||||||
|
t.Dial = func(network, addr string) (net.Conn, error) {
|
||||||
|
c, err := dial(network, addr)
|
||||||
|
hj.Conn = c
|
||||||
|
return &hijackedConn{c, hj}, err
|
||||||
|
}
|
||||||
|
t.DialTLS = func(network, addr string) (net.Conn, error) {
|
||||||
|
c, err := dialTLS(network, addr)
|
||||||
|
hj.Conn = c
|
||||||
|
return &hijackedConn{c, hj}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return hj
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTransportDial always returns a plain Dialer
|
||||||
|
// and defaults to the existing t.Dial.
|
||||||
|
func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
||||||
|
if t.Dial != nil {
|
||||||
|
return t.Dial
|
||||||
|
}
|
||||||
|
return defaultDialer.Dial
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTransportDial always returns a TLS Dialer
|
||||||
|
// and defaults to the existing t.DialTLS.
|
||||||
|
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
||||||
|
if t.DialTLS != nil {
|
||||||
|
return t.DialTLS
|
||||||
|
}
|
||||||
|
|
||||||
|
// newConnHijackerTransport will modify t.Dial after calling this method
|
||||||
|
// => Create a backup reference.
|
||||||
|
plainDial := getTransportDial(t)
|
||||||
|
|
||||||
|
// The following DialTLS implementation stems from the Go stdlib and
|
||||||
|
// is identical to what happens if DialTLS is not provided.
|
||||||
|
// Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
|
||||||
|
return func(network, addr string) (net.Conn, error) {
|
||||||
|
plainConn, err := plainDial(network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsClientConfig := t.TLSClientConfig
|
||||||
|
if tlsClientConfig == nil {
|
||||||
|
tlsClientConfig = &tls.Config{}
|
||||||
|
}
|
||||||
|
if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" {
|
||||||
|
tlsClientConfig.ServerName = stripPort(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn := tls.Client(plainConn, tlsClientConfig)
|
||||||
|
errc := make(chan error, 2)
|
||||||
|
var timer *time.Timer
|
||||||
|
if d := t.TLSHandshakeTimeout; d != 0 {
|
||||||
|
timer = time.AfterFunc(d, func() {
|
||||||
|
errc <- tlsHandshakeTimeoutError{}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
err := tlsConn.Handshake()
|
||||||
|
if timer != nil {
|
||||||
|
timer.Stop()
|
||||||
|
}
|
||||||
|
errc <- err
|
||||||
|
}()
|
||||||
|
if err := <-errc; err != nil {
|
||||||
|
plainConn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !tlsClientConfig.InsecureSkipVerify {
|
||||||
|
hostname := tlsClientConfig.ServerName
|
||||||
|
if hostname == "" {
|
||||||
|
hostname = stripPort(addr)
|
||||||
|
}
|
||||||
|
if err := tlsConn.VerifyHostname(hostname); err != nil {
|
||||||
|
plainConn.Close()
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]}
|
|
||||||
oldDial := transport.Dial
|
return tlsConn, nil
|
||||||
oldDialTLS := transport.DialTLS
|
|
||||||
if oldDial == nil {
|
|
||||||
oldDial = (&net.Dialer{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).Dial
|
|
||||||
}
|
}
|
||||||
hjTransport.Dial = func(network, addr string) (net.Conn, error) {
|
}
|
||||||
c, err := oldDial(network, addr)
|
|
||||||
hjTransport.Conn = c
|
// stripPort returns address without its port if it has one and
|
||||||
return &hijackedConn{c, hjTransport}, err
|
// works with IP addresses as well as hostnames formatted as host:port.
|
||||||
|
//
|
||||||
|
// IPv6 addresses (excluding the port) must be enclosed in
|
||||||
|
// square brackets similar to the requirements of Go's stdlib.
|
||||||
|
func stripPort(address string) string {
|
||||||
|
// Keep in mind that the address might be a IPv6 address
|
||||||
|
// and thus contain a colon, but not have a port.
|
||||||
|
portIdx := strings.LastIndex(address, ":")
|
||||||
|
ipv6Idx := strings.LastIndex(address, "]")
|
||||||
|
if portIdx > ipv6Idx {
|
||||||
|
address = address[:portIdx]
|
||||||
}
|
}
|
||||||
if oldDialTLS != nil {
|
return address
|
||||||
hjTransport.DialTLS = func(network, addr string) (net.Conn, error) {
|
}
|
||||||
c, err := oldDialTLS(network, addr)
|
|
||||||
hjTransport.Conn = c
|
type tlsHandshakeTimeoutError struct{}
|
||||||
return &hijackedConn{c, hjTransport}, err
|
|
||||||
|
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
|
||||||
|
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
|
||||||
|
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
|
||||||
|
|
||||||
|
// cloneTLSClientConfig is like cloneTLSConfig but omits
|
||||||
|
// the fields SessionTicketsDisabled and SessionTicketKey.
|
||||||
|
// This makes it safe to call cloneTLSClientConfig on a config
|
||||||
|
// in active use by a server.
|
||||||
|
func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
|
||||||
|
if cfg == nil {
|
||||||
|
return &tls.Config{}
|
||||||
}
|
}
|
||||||
|
return &tls.Config{
|
||||||
|
Rand: cfg.Rand,
|
||||||
|
Time: cfg.Time,
|
||||||
|
Certificates: cfg.Certificates,
|
||||||
|
NameToCertificate: cfg.NameToCertificate,
|
||||||
|
GetCertificate: cfg.GetCertificate,
|
||||||
|
RootCAs: cfg.RootCAs,
|
||||||
|
NextProtos: cfg.NextProtos,
|
||||||
|
ServerName: cfg.ServerName,
|
||||||
|
ClientAuth: cfg.ClientAuth,
|
||||||
|
ClientCAs: cfg.ClientCAs,
|
||||||
|
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||||
|
CipherSuites: cfg.CipherSuites,
|
||||||
|
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
|
||||||
|
ClientSessionCache: cfg.ClientSessionCache,
|
||||||
|
MinVersion: cfg.MinVersion,
|
||||||
|
MaxVersion: cfg.MaxVersion,
|
||||||
|
CurvePreferences: cfg.CurvePreferences,
|
||||||
|
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
|
||||||
|
Renegotiation: cfg.Renegotiation,
|
||||||
}
|
}
|
||||||
return hjTransport
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestIsWebsocket(req *http.Request) bool {
|
func requestIsWebsocket(req *http.Request) bool {
|
||||||
return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade"))
|
return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
|
||||||
}
|
}
|
||||||
|
|
||||||
type writeFlusher interface {
|
type writeFlusher interface {
|
||||||
|
|
Loading…
Reference in a new issue