mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-13 22:51:08 -05:00
Merge pull request #987 from nemothekid/proxy/single-webconn
Proxy: Single WebSocket connection
This commit is contained in:
commit
4c6082df64
2 changed files with 91 additions and 7 deletions
|
@ -15,6 +15,7 @@ import (
|
|||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -102,7 +103,8 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) {
|
|||
func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
|
||||
// No-op websocket backend simply allows the WS connection to be
|
||||
// accepted then it will be immediately closed. Perfect for testing.
|
||||
wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {}))
|
||||
var connCount int32
|
||||
wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { atomic.AddInt32(&connCount, 1) }))
|
||||
defer wsNop.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
|
@ -135,6 +137,9 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
|
|||
if !bytes.Equal(actual, expected) {
|
||||
t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual)
|
||||
}
|
||||
if atomic.LoadInt32(&connCount) != 1 {
|
||||
t.Errorf("Expected 1 websocket connection, got %d", connCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
||||
|
|
|
@ -186,9 +186,80 @@ var hopHeaders = []string{
|
|||
|
||||
type respUpdateFn func(resp *http.Response)
|
||||
|
||||
type hijackedConn struct {
|
||||
net.Conn
|
||||
hj *connHijackerTransport
|
||||
}
|
||||
|
||||
func (c *hijackedConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Read(b)
|
||||
c.hj.Replay = append(c.hj.Replay, b[:n]...)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *hijackedConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type connHijackerTransport struct {
|
||||
*http.Transport
|
||||
Conn net.Conn
|
||||
Replay []byte
|
||||
}
|
||||
|
||||
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
DisableKeepAlives: true,
|
||||
}
|
||||
if base != nil {
|
||||
if baseTransport, ok := base.(*http.Transport); ok {
|
||||
transport.Proxy = baseTransport.Proxy
|
||||
transport.TLSClientConfig = baseTransport.TLSClientConfig
|
||||
transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout
|
||||
transport.Dial = baseTransport.Dial
|
||||
transport.DialTLS = baseTransport.DialTLS
|
||||
transport.DisableKeepAlives = true
|
||||
}
|
||||
}
|
||||
hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]}
|
||||
oldDial := transport.Dial
|
||||
oldDialTLS := transport.DialTLS
|
||||
if oldDial == nil {
|
||||
oldDial = (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial
|
||||
}
|
||||
hjTransport.Dial = func(network, addr string) (net.Conn, error) {
|
||||
c, err := oldDial(network, addr)
|
||||
hjTransport.Conn = c
|
||||
return &hijackedConn{c, hjTransport}, err
|
||||
}
|
||||
if oldDialTLS != nil {
|
||||
hjTransport.DialTLS = func(network, addr string) (net.Conn, error) {
|
||||
c, err := oldDialTLS(network, addr)
|
||||
hjTransport.Conn = c
|
||||
return &hijackedConn{c, hjTransport}, err
|
||||
}
|
||||
}
|
||||
return hjTransport
|
||||
}
|
||||
|
||||
func requestIsWebsocket(req *http.Request) bool {
|
||||
return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade"))
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
|
||||
transport := p.Transport
|
||||
if transport == nil {
|
||||
if requestIsWebsocket(outreq) {
|
||||
transport = newConnHijackerTransport(transport)
|
||||
} else if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
|
||||
|
@ -219,14 +290,22 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, r
|
|||
}
|
||||
defer conn.Close()
|
||||
|
||||
backendConn, err := net.Dial("tcp", outreq.URL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
var backendConn net.Conn
|
||||
if hj, ok := transport.(*connHijackerTransport); ok {
|
||||
backendConn = hj.Conn
|
||||
if _, err := conn.Write(hj.Replay); err != nil {
|
||||
return err
|
||||
}
|
||||
bufferPool.Put(hj.Replay)
|
||||
} else {
|
||||
backendConn, err = net.Dial("tcp", outreq.URL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
outreq.Write(backendConn)
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
outreq.Write(backendConn)
|
||||
|
||||
go func() {
|
||||
io.Copy(backendConn, conn) // write tcp stream to backend.
|
||||
}()
|
||||
|
|
Loading…
Reference in a new issue