diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index ada456da..e68cf3a2 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -3,8 +3,10 @@ package proxy import ( "errors" + "net" "net/http" "net/url" + "strings" "sync/atomic" "time" @@ -75,71 +77,108 @@ var tryDuration = 60 * time.Second // ServeHTTP satisfies the middleware.Handler interface. func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { for _, upstream := range p.Upstreams { - if middleware.Path(r.URL.Path).Matches(upstream.From()) && upstream.AllowedPath(r.URL.Path) { - var replacer middleware.Replacer - start := time.Now() - requestHost := r.Host + if !middleware.Path(r.URL.Path).Matches(upstream.From()) || + !upstream.AllowedPath(r.URL.Path) { + continue + } - // Since Select() should give us "up" hosts, keep retrying - // hosts until timeout (or until we get a nil host). - for time.Now().Sub(start) < tryDuration { - host := upstream.Select() - if host == nil { - return http.StatusBadGateway, errUnreachable - } - proxy := host.ReverseProxy - r.Host = host.Name - if rr, ok := w.(*middleware.ResponseRecorder); ok && rr.Replacer != nil { - rr.Replacer.Set("upstream", host.Name) - } + var replacer middleware.Replacer + start := time.Now() - if baseURL, err := url.Parse(host.Name); err == nil { - r.Host = baseURL.Host - if proxy == nil { - proxy = NewSingleHostReverseProxy(baseURL, host.WithoutPathPrefix) - } - } else if proxy == nil { - return http.StatusInternalServerError, err + outreq := createUpstreamRequest(r) + + // Since Select() should give us "up" hosts, keep retrying + // hosts until timeout (or until we get a nil host). + for time.Now().Sub(start) < tryDuration { + host := upstream.Select() + if host == nil { + return http.StatusBadGateway, errUnreachable + } + if rr, ok := w.(*middleware.ResponseRecorder); ok && rr.Replacer != nil { + rr.Replacer.Set("upstream", host.Name) + } + + outreq.Host = host.Name + if host.ExtraHeaders != nil { + extraHeaders := make(http.Header) + if replacer == nil { + rHost := r.Host + replacer = middleware.NewReplacer(r, nil, "") + outreq.Host = rHost } - var extraHeaders http.Header - if host.ExtraHeaders != nil { - extraHeaders = make(http.Header) - if replacer == nil { - rHost := r.Host - r.Host = requestHost - replacer = middleware.NewReplacer(r, nil, "") - r.Host = rHost - } - for header, values := range host.ExtraHeaders { - for _, value := range values { - extraHeaders.Add(header, - replacer.Replace(value)) - if header == "Host" { - r.Host = replacer.Replace(value) - } + for header, values := range host.ExtraHeaders { + for _, value := range values { + extraHeaders.Add(header, replacer.Replace(value)) + if header == "Host" { + outreq.Host = replacer.Replace(value) } } } - - atomic.AddInt64(&host.Conns, 1) - backendErr := proxy.ServeHTTP(w, r, extraHeaders) - atomic.AddInt64(&host.Conns, -1) - if backendErr == nil { - return 0, nil + for k, v := range extraHeaders { + outreq.Header[k] = v } - timeout := host.FailTimeout - if timeout == 0 { - timeout = 10 * time.Second - } - atomic.AddInt32(&host.Fails, 1) - go func(host *UpstreamHost, timeout time.Duration) { - time.Sleep(timeout) - atomic.AddInt32(&host.Fails, -1) - }(host, timeout) } - return http.StatusBadGateway, errUnreachable + + proxy := host.ReverseProxy + if baseURL, err := url.Parse(host.Name); err == nil { + r.Host = baseURL.Host + if proxy == nil { + proxy = NewSingleHostReverseProxy(baseURL, host.WithoutPathPrefix) + } + } else if proxy == nil { + return http.StatusInternalServerError, err + } + + atomic.AddInt64(&host.Conns, 1) + backendErr := proxy.ServeHTTP(w, outreq) + atomic.AddInt64(&host.Conns, -1) + if backendErr == nil { + return 0, nil + } + timeout := host.FailTimeout + if timeout == 0 { + timeout = 10 * time.Second + } + atomic.AddInt32(&host.Fails, 1) + go func(host *UpstreamHost, timeout time.Duration) { + time.Sleep(timeout) + atomic.AddInt32(&host.Fails, -1) + }(host, timeout) } + return http.StatusBadGateway, errUnreachable } return p.Next.ServeHTTP(w, r) } + +// createUpstremRequest shallow-copies r into a new request +// that can be sent upstream. +func createUpstreamRequest(r *http.Request) *http.Request { + outreq := new(http.Request) + *outreq = *r // includes shallow copies of maps, but okay + + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. This + // is modifying the same underlying map from r (shallow + // copied above) so we only copy it if necessary. + for _, h := range hopHeaders { + if outreq.Header.Get(h) != "" { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, r.Header) + outreq.Header.Del(h) + } + } + + if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + // If we aren't the first proxy, retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + outreq.Header.Set("X-Forwarded-For", clientIP) + } + + return outreq +} diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index 05538dfc..fd630876 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -154,54 +154,18 @@ var InsecureTransport http.RoundTripper = &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } -func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extraHeaders http.Header) error { +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request) error { transport := p.Transport if transport == nil { transport = http.DefaultTransport } - outreq := new(http.Request) - *outreq = *req // includes shallow copies of maps, but okay - p.Director(outreq) outreq.Proto = "HTTP/1.1" outreq.ProtoMajor = 1 outreq.ProtoMinor = 1 outreq.Close = false - // Remove hop-by-hop headers to the backend. Especially - // important is "Connection" because we want a persistent - // connection, regardless of what the client sent to us. This - // is modifying the same underlying map from req (shallow - // copied above) so we only copy it if necessary. - copiedHeaders := false - for _, h := range hopHeaders { - if outreq.Header.Get(h) != "" { - if !copiedHeaders { - outreq.Header = make(http.Header) - copyHeader(outreq.Header, req.Header) - copiedHeaders = true - } - outreq.Header.Del(h) - } - } - - if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { - // If we aren't the first proxy retain prior - // X-Forwarded-For information as a comma+space - // separated list and fold multiple headers into one. - if prior, ok := outreq.Header["X-Forwarded-For"]; ok { - clientIP = strings.Join(prior, ", ") + ", " + clientIP - } - outreq.Header.Set("X-Forwarded-For", clientIP) - } - - if extraHeaders != nil { - for k, v := range extraHeaders { - outreq.Header[k] = v - } - } - res, err := transport.RoundTrip(outreq) if err != nil { return err @@ -237,9 +201,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extr for _, h := range hopHeaders { res.Header.Del(h) } - copyHeader(rw.Header(), res.Header) - rw.WriteHeader(res.StatusCode) p.copyResponse(rw, res.Body) } @@ -260,7 +222,6 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { dst = mlw } } - io.Copy(dst, src) }