diff --git a/config/setup/rewrite_test.go b/config/setup/rewrite_test.go index 9ff294ef..5747dee3 100644 --- a/config/setup/rewrite_test.go +++ b/config/setup/rewrite_test.go @@ -4,8 +4,9 @@ import ( "testing" "fmt" - "github.com/mholt/caddy/middleware/rewrite" "regexp" + + "github.com/mholt/caddy/middleware/rewrite" ) func TestRewrite(t *testing.T) { diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index 027f2266..1db1131e 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -16,6 +16,8 @@ import ( "time" ) +const HTTPSwitchProtocols = 101 + // onExitFlushLoop is a callback set by tests to detect the state of the // flushLoop() goroutine. var onExitFlushLoop func() @@ -153,8 +155,37 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extr copyHeader(rw.Header(), res.Header) - rw.WriteHeader(res.StatusCode) - p.copyResponse(rw, res.Body) + if res.StatusCode == HTTPSwitchProtocols { + hj, ok := rw.(http.Hijacker) + if !ok { + return nil + } + + conn, _, err := hj.Hijack() + if err != nil { + return err + } + + backendConn, err := net.Dial("tcp", outreq.Host) + if err != nil { + conn.Close() + return err + } + + outreq.Write(backendConn) + + go func() { + io.Copy(backendConn, conn) // write tcp stream to backend. + backendConn.Close() + }() + + io.Copy(conn, backendConn) // read tcp stream from backend. + conn.Close() + } else { + rw.WriteHeader(res.StatusCode) + p.copyResponse(rw, res.Body) + } + return nil }