mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-13 22:51:08 -05:00
Fixed #1292: Failure to proxy WebSockets over HTTPS
This issue was caused by connHijackerTransport trying to record HTTP response headers by "hijacking" the Read() method of the plain net.Conn. This does not simply work over TLS though since this will record the TLS handshake and encrypted data instead of the actual content. This commit fixes the problem by providing an alternative transport.DialTLS which correctly hijacks the overlying tls.Conn instead.
This commit is contained in:
parent
53635ba538
commit
9f9ad21aaa
1 changed files with 140 additions and 44 deletions
|
@ -27,6 +27,11 @@ import (
|
||||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var defaultDialer = &net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
var bufferPool = sync.Pool{New: createBuffer}
|
var bufferPool = sync.Pool{New: createBuffer}
|
||||||
|
|
||||||
func createBuffer() interface{} {
|
func createBuffer() interface{} {
|
||||||
|
@ -136,10 +141,7 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
||||||
// a brand new transport
|
// a brand new transport
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
Dial: (&net.Dialer{
|
Dial: defaultDialer.Dial,
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).Dial,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
}
|
}
|
||||||
|
@ -163,10 +165,7 @@ func (rp *ReverseProxy) UseInsecureTransport() {
|
||||||
if rp.Transport == nil {
|
if rp.Transport == nil {
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
Dial: (&net.Dialer{
|
Dial: defaultDialer.Dial,
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).Dial,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
}
|
}
|
||||||
|
@ -341,51 +340,148 @@ type connHijackerTransport struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
||||||
transport := &http.Transport{
|
t := &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
Dial: (&net.Dialer{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).Dial,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
MaxIdleConnsPerHost: -1,
|
MaxIdleConnsPerHost: -1,
|
||||||
}
|
}
|
||||||
if base != nil {
|
if b, _ := base.(*http.Transport); b != nil {
|
||||||
if baseTransport, ok := base.(*http.Transport); ok {
|
t.Proxy = b.Proxy
|
||||||
transport.Proxy = baseTransport.Proxy
|
t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig)
|
||||||
transport.TLSClientConfig = baseTransport.TLSClientConfig
|
t.TLSClientConfig.NextProtos = nil
|
||||||
transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout
|
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
|
||||||
transport.Dial = baseTransport.Dial
|
t.Dial = b.Dial
|
||||||
transport.DialTLS = baseTransport.DialTLS
|
t.DialTLS = b.DialTLS
|
||||||
transport.MaxIdleConnsPerHost = -1
|
} else {
|
||||||
|
t.Proxy = http.ProxyFromEnvironment
|
||||||
|
t.TLSHandshakeTimeout = 10 * time.Second
|
||||||
|
}
|
||||||
|
hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]}
|
||||||
|
|
||||||
|
dial := getTransportDial(t)
|
||||||
|
dialTLS := getTransportDialTLS(t)
|
||||||
|
|
||||||
|
t.Dial = func(network, addr string) (net.Conn, error) {
|
||||||
|
c, err := dial(network, addr)
|
||||||
|
hj.Conn = c
|
||||||
|
return &hijackedConn{c, hj}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if dialTLS != nil {
|
||||||
|
t.DialTLS = func(network, addr string) (net.Conn, error) {
|
||||||
|
c, err := dialTLS(network, addr)
|
||||||
|
hj.Conn = c
|
||||||
|
return &hijackedConn{c, hj}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]}
|
|
||||||
oldDial := transport.Dial
|
return hj
|
||||||
oldDialTLS := transport.DialTLS
|
|
||||||
if oldDial == nil {
|
|
||||||
oldDial = (&net.Dialer{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).Dial
|
|
||||||
}
|
}
|
||||||
hjTransport.Dial = func(network, addr string) (net.Conn, error) {
|
|
||||||
c, err := oldDial(network, addr)
|
// getTransportDial always returns a plain Dialer
|
||||||
hjTransport.Conn = c
|
// and defaults to the existing t.Dial.
|
||||||
return &hijackedConn{c, hjTransport}, err
|
func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
||||||
|
if t.Dial != nil {
|
||||||
|
return t.Dial
|
||||||
}
|
}
|
||||||
if oldDialTLS != nil {
|
return defaultDialer.Dial
|
||||||
hjTransport.DialTLS = func(network, addr string) (net.Conn, error) {
|
}
|
||||||
c, err := oldDialTLS(network, addr)
|
|
||||||
hjTransport.Conn = c
|
// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil
|
||||||
return &hijackedConn{c, hjTransport}, err
|
// and defaults to the existing t.DialTLS.
|
||||||
|
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
||||||
|
if t.DialTLS != nil {
|
||||||
|
return t.DialTLS
|
||||||
|
}
|
||||||
|
if t.TLSClientConfig == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newConnHijackerTransport will modify t.Dial after calling this method
|
||||||
|
// => Create a backup reference.
|
||||||
|
plainDial := getTransportDial(t)
|
||||||
|
|
||||||
|
return func(network, addr string) (net.Conn, error) {
|
||||||
|
plainConn, err := plainDial(network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn := tls.Client(plainConn, t.TLSClientConfig)
|
||||||
|
errc := make(chan error, 2)
|
||||||
|
var timer *time.Timer
|
||||||
|
if d := t.TLSHandshakeTimeout; d != 0 {
|
||||||
|
timer = time.AfterFunc(d, func() {
|
||||||
|
errc <- tlsHandshakeTimeoutError{}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
err := tlsConn.Handshake()
|
||||||
|
if timer != nil {
|
||||||
|
timer.Stop()
|
||||||
|
}
|
||||||
|
errc <- err
|
||||||
|
}()
|
||||||
|
if err := <-errc; err != nil {
|
||||||
|
plainConn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !t.TLSClientConfig.InsecureSkipVerify {
|
||||||
|
serverName := t.TLSClientConfig.ServerName
|
||||||
|
if serverName == "" {
|
||||||
|
serverName = addr
|
||||||
|
idx := strings.LastIndex(serverName, ":")
|
||||||
|
if idx != -1 {
|
||||||
|
serverName = serverName[:idx]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return hjTransport
|
if err := tlsConn.VerifyHostname(serverName); err != nil {
|
||||||
|
plainConn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConn, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type tlsHandshakeTimeoutError struct{}
|
||||||
|
|
||||||
|
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
|
||||||
|
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
|
||||||
|
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
|
||||||
|
|
||||||
|
// cloneTLSClientConfig is like cloneTLSConfig but omits
|
||||||
|
// the fields SessionTicketsDisabled and SessionTicketKey.
|
||||||
|
// This makes it safe to call cloneTLSClientConfig on a config
|
||||||
|
// in active use by a server.
|
||||||
|
func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
|
||||||
|
if cfg == nil {
|
||||||
|
return &tls.Config{}
|
||||||
|
}
|
||||||
|
return &tls.Config{
|
||||||
|
Rand: cfg.Rand,
|
||||||
|
Time: cfg.Time,
|
||||||
|
Certificates: cfg.Certificates,
|
||||||
|
NameToCertificate: cfg.NameToCertificate,
|
||||||
|
GetCertificate: cfg.GetCertificate,
|
||||||
|
RootCAs: cfg.RootCAs,
|
||||||
|
NextProtos: cfg.NextProtos,
|
||||||
|
ServerName: cfg.ServerName,
|
||||||
|
ClientAuth: cfg.ClientAuth,
|
||||||
|
ClientCAs: cfg.ClientCAs,
|
||||||
|
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||||
|
CipherSuites: cfg.CipherSuites,
|
||||||
|
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
|
||||||
|
ClientSessionCache: cfg.ClientSessionCache,
|
||||||
|
MinVersion: cfg.MinVersion,
|
||||||
|
MaxVersion: cfg.MaxVersion,
|
||||||
|
CurvePreferences: cfg.CurvePreferences,
|
||||||
|
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
|
||||||
|
Renegotiation: cfg.Renegotiation,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestIsWebsocket(req *http.Request) bool {
|
func requestIsWebsocket(req *http.Request) bool {
|
||||||
return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade"))
|
return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
|
||||||
}
|
}
|
||||||
|
|
||||||
type writeFlusher interface {
|
type writeFlusher interface {
|
||||||
|
|
Loading…
Reference in a new issue