mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-23 22:27:38 -05:00
Proxy: When connecting to websocket backend, reuse the connection isntead of starting a new one.
This commit is contained in:
parent
c4e65df262
commit
d534a2139f
1 changed files with 87 additions and 7 deletions
|
@ -183,9 +183,80 @@ var hopHeaders = []string{
|
||||||
|
|
||||||
type respUpdateFn func(resp *http.Response)
|
type respUpdateFn func(resp *http.Response)
|
||||||
|
|
||||||
|
type hijackedConn struct {
|
||||||
|
net.Conn
|
||||||
|
hj *connHijackerTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *hijackedConn) Read(b []byte) (n int, err error) {
|
||||||
|
n, err = c.Conn.Read(b)
|
||||||
|
c.hj.Replay = append(c.hj.Replay, b[:n]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *hijackedConn) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type connHijackerTransport struct {
|
||||||
|
*http.Transport
|
||||||
|
Conn net.Conn
|
||||||
|
Replay []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
||||||
|
transport := &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
Dial: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).Dial,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
}
|
||||||
|
if base != nil {
|
||||||
|
if baseTransport, ok := base.(*http.Transport); ok {
|
||||||
|
transport.Proxy = baseTransport.Proxy
|
||||||
|
transport.TLSClientConfig = baseTransport.TLSClientConfig
|
||||||
|
transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout
|
||||||
|
transport.Dial = baseTransport.Dial
|
||||||
|
transport.DialTLS = baseTransport.DialTLS
|
||||||
|
transport.DisableKeepAlives = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]}
|
||||||
|
oldDial := transport.Dial
|
||||||
|
oldDialTLS := transport.DialTLS
|
||||||
|
if oldDial == nil {
|
||||||
|
oldDial = (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).Dial
|
||||||
|
}
|
||||||
|
hjTransport.Dial = func(network, addr string) (net.Conn, error) {
|
||||||
|
c, err := oldDial(network, addr)
|
||||||
|
hjTransport.Conn = c
|
||||||
|
return &hijackedConn{c, hjTransport}, err
|
||||||
|
}
|
||||||
|
if oldDialTLS != nil {
|
||||||
|
hjTransport.DialTLS = func(network, addr string) (net.Conn, error) {
|
||||||
|
c, err := oldDialTLS(network, addr)
|
||||||
|
hjTransport.Conn = c
|
||||||
|
return &hijackedConn{c, hjTransport}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hjTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestIsWebsocket(req *http.Request) bool {
|
||||||
|
return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade"))
|
||||||
|
}
|
||||||
|
|
||||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
|
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
|
||||||
transport := p.Transport
|
transport := p.Transport
|
||||||
if transport == nil {
|
if requestIsWebsocket(outreq) {
|
||||||
|
transport = newConnHijackerTransport(transport)
|
||||||
|
} else if transport == nil {
|
||||||
transport = http.DefaultTransport
|
transport = http.DefaultTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -216,13 +287,22 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, r
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
backendConn, err := net.Dial("tcp", outreq.URL.Host)
|
var backendConn net.Conn
|
||||||
if err != nil {
|
if hj, ok := transport.(*connHijackerTransport); ok {
|
||||||
return err
|
backendConn = hj.Conn
|
||||||
}
|
if _, err := conn.Write(hj.Replay); err != nil {
|
||||||
defer backendConn.Close()
|
return err
|
||||||
|
}
|
||||||
|
bufferPool.Put(hj.Replay)
|
||||||
|
} else {
|
||||||
|
backendConn, err = net.Dial("tcp", outreq.URL.Host)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer backendConn.Close()
|
||||||
|
|
||||||
outreq.Write(backendConn)
|
outreq.Write(backendConn)
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
io.Copy(backendConn, conn) // write tcp stream to backend.
|
io.Copy(backendConn, conn) // write tcp stream to backend.
|
||||||
|
|
Loading…
Reference in a new issue