0
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2025-01-20 22:52:58 -05:00

proxy: Fixed support for TLS verification of WebSocket connections

This commit is contained in:
Leonard Hecker 2016-12-28 17:20:31 +01:00
parent 153d4a5ac6
commit b857265f9c

View file

@ -349,9 +349,14 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
MaxIdleConnsPerHost: -1,
}
if b, _ := base.(*http.Transport); b != nil {
tlsClientConfig := b.TLSClientConfig
if tlsClientConfig.NextProtos != nil {
tlsClientConfig = cloneTLSClientConfig(tlsClientConfig)
tlsClientConfig.NextProtos = nil
}
t.Proxy = b.Proxy
t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig)
t.TLSClientConfig.NextProtos = nil
t.TLSClientConfig = tlsClientConfig
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
t.Dial = b.Dial
t.DialTLS = b.DialTLS
@ -363,19 +368,15 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
dial := getTransportDial(t)
dialTLS := getTransportDialTLS(t)
t.Dial = func(network, addr string) (net.Conn, error) {
c, err := dial(network, addr)
hj.Conn = c
return &hijackedConn{c, hj}, err
}
if dialTLS != nil {
t.DialTLS = func(network, addr string) (net.Conn, error) {
c, err := dialTLS(network, addr)
hj.Conn = c
return &hijackedConn{c, hj}, err
}
t.DialTLS = func(network, addr string) (net.Conn, error) {
c, err := dialTLS(network, addr)
hj.Conn = c
return &hijackedConn{c, hj}, err
}
return hj
@ -390,27 +391,35 @@ func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, e
return defaultDialer.Dial
}
// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil
// getTransportDial always returns a TLS Dialer
// and defaults to the existing t.DialTLS.
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
if t.DialTLS != nil {
return t.DialTLS
}
if t.TLSClientConfig == nil {
return nil
}
// newConnHijackerTransport will modify t.Dial after calling this method
// => Create a backup reference.
plainDial := getTransportDial(t)
// The following DialTLS implementation stems from the Go stdlib and
// is identical to what happens if DialTLS is not provided.
// Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
return func(network, addr string) (net.Conn, error) {
plainConn, err := plainDial(network, addr)
if err != nil {
return nil, err
}
tlsConn := tls.Client(plainConn, t.TLSClientConfig)
tlsClientConfig := t.TLSClientConfig
if tlsClientConfig == nil {
tlsClientConfig = &tls.Config{}
}
if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" {
tlsClientConfig.ServerName = stripPort(addr)
}
tlsConn := tls.Client(plainConn, tlsClientConfig)
errc := make(chan error, 2)
var timer *time.Timer
if d := t.TLSHandshakeTimeout; d != 0 {
@ -429,16 +438,12 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn
plainConn.Close()
return nil, err
}
if !t.TLSClientConfig.InsecureSkipVerify {
serverName := t.TLSClientConfig.ServerName
if serverName == "" {
serverName = addr
idx := strings.LastIndex(serverName, ":")
if idx != -1 {
serverName = serverName[:idx]
}
if !tlsClientConfig.InsecureSkipVerify {
hostname := tlsClientConfig.ServerName
if hostname == "" {
hostname = stripPort(addr)
}
if err := tlsConn.VerifyHostname(serverName); err != nil {
if err := tlsConn.VerifyHostname(hostname); err != nil {
plainConn.Close()
return nil, err
}
@ -448,6 +453,22 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn
}
}
// stripPort returns address without its port if it has one and
// works with IP addresses as well as hostnames formatted as host:port.
//
// IPv6 addresses (excluding the port) must be enclosed in
// square brackets similar to the requirements of Go's stdlib.
func stripPort(address string) string {
// Keep in mind that the address might be a IPv6 address
// and thus contain a colon, but not have a port.
portIdx := strings.LastIndex(address, ":")
ipv6Idx := strings.LastIndex(address, "]")
if portIdx > ipv6Idx {
address = address[:portIdx]
}
return address
}
type tlsHandshakeTimeoutError struct{}
func (tlsHandshakeTimeoutError) Timeout() bool { return true }