mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-30 22:34:15 -05:00
59bf71c293
* Add a shutdown function and context to staticUpstream so that running goroutines can be cancelled. Add a GetShutdownFunc to Upstream interface to expose the shutdown function to the caddy Controller for performing it on restarts. * Make fakeUpstream implement new Upstream methods. Implement new Upstream method for fakeWSUpstream as well. * Rename GetShutdownFunc to Stop(). Add a waitgroup to the staticUpstream for controlling individual object's goroutines. Add the Stop function to OnRestart and OnShutdown. Add tests for checking to see if healthchecks continue hitting a backend server after stop has been called. * Go back to using a stop channel since the context adds no additional benefit. Only register stop function for onShutdown since it's called as part of restart. * Remove assignment to atomic value * Incrementing WaitGroup outside of goroutine to avoid race condition. Loading atomic values in test. * Linting: change counter to just use the default zero value instead of setting it * Clarify Stop method comments, add comments to stop channel and waitgroup and remove out of date comment about handling stopping the proxy. Stop the ticker when the stop signal is sent
349 lines
11 KiB
Go
349 lines
11 KiB
Go
// Package proxy is middleware that proxies HTTP requests.
|
|
package proxy
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/mholt/caddy/caddyhttp/httpserver"
|
|
)
|
|
|
|
// Proxy represents a middleware instance that can proxy requests.
|
|
type Proxy struct {
|
|
Next httpserver.Handler
|
|
Upstreams []Upstream
|
|
}
|
|
|
|
// Upstream manages a pool of proxy upstream hosts.
|
|
type Upstream interface {
|
|
// The path this upstream host should be routed on
|
|
From() string
|
|
|
|
// Selects an upstream host to be routed to. It
|
|
// should return a suitable upstream host, or nil
|
|
// if no such hosts are available.
|
|
Select(*http.Request) *UpstreamHost
|
|
|
|
// Checks if subpath is not an ignored path
|
|
AllowedPath(string) bool
|
|
|
|
// Gets how long to try selecting upstream hosts
|
|
// in the case of cascading failures.
|
|
GetTryDuration() time.Duration
|
|
|
|
// Gets how long to wait between selecting upstream
|
|
// hosts in the case of cascading failures.
|
|
GetTryInterval() time.Duration
|
|
|
|
// Gets the number of upstream hosts.
|
|
GetHostCount() int
|
|
|
|
// Stops the upstream from proxying requests to shutdown goroutines cleanly.
|
|
Stop() error
|
|
}
|
|
|
|
// UpstreamHostDownFunc can be used to customize how Down behaves.
|
|
type UpstreamHostDownFunc func(*UpstreamHost) bool
|
|
|
|
// UpstreamHost represents a single proxy upstream
|
|
type UpstreamHost struct {
|
|
// This field is read & written to concurrently, so all access must use
|
|
// atomic operations.
|
|
Conns int64 // must be first field to be 64-bit aligned on 32-bit systems
|
|
MaxConns int64
|
|
Name string // hostname of this upstream host
|
|
UpstreamHeaders http.Header
|
|
DownstreamHeaders http.Header
|
|
FailTimeout time.Duration
|
|
CheckDown UpstreamHostDownFunc
|
|
WithoutPathPrefix string
|
|
ReverseProxy *ReverseProxy
|
|
Fails int32
|
|
// This is an int32 so that we can use atomic operations to do concurrent
|
|
// reads & writes to this value. The default value of 0 indicates that it
|
|
// is healthy and any non-zero value indicates unhealthy.
|
|
Unhealthy int32
|
|
}
|
|
|
|
// Down checks whether the upstream host is down or not.
|
|
// Down will try to use uh.CheckDown first, and will fall
|
|
// back to some default criteria if necessary.
|
|
func (uh *UpstreamHost) Down() bool {
|
|
if uh.CheckDown == nil {
|
|
// Default settings
|
|
return atomic.LoadInt32(&uh.Unhealthy) != 0 || atomic.LoadInt32(&uh.Fails) > 0
|
|
}
|
|
return uh.CheckDown(uh)
|
|
}
|
|
|
|
// Full checks whether the upstream host has reached its maximum connections
|
|
func (uh *UpstreamHost) Full() bool {
|
|
return uh.MaxConns > 0 && atomic.LoadInt64(&uh.Conns) >= uh.MaxConns
|
|
}
|
|
|
|
// Available checks whether the upstream host is available for proxying to
|
|
func (uh *UpstreamHost) Available() bool {
|
|
return !uh.Down() && !uh.Full()
|
|
}
|
|
|
|
// ServeHTTP satisfies the httpserver.Handler interface.
|
|
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|
// start by selecting most specific matching upstream config
|
|
upstream := p.match(r)
|
|
if upstream == nil {
|
|
return p.Next.ServeHTTP(w, r)
|
|
}
|
|
|
|
// this replacer is used to fill in header field values
|
|
replacer := httpserver.NewReplacer(r, nil, "")
|
|
|
|
// outreq is the request that makes a roundtrip to the backend
|
|
outreq := createUpstreamRequest(r)
|
|
|
|
// If we have more than one upstream host defined and if retrying is enabled
|
|
// by setting try_duration to a non-zero value, caddy will try to
|
|
// retry the request at a different host if the first one failed.
|
|
//
|
|
// This requires us to possibly rewind and replay the request body though,
|
|
// which in turn requires us to buffer the request body first.
|
|
//
|
|
// An unbuffered request is usually preferrable, because it reduces latency
|
|
// as well as memory usage. Furthermore it enables different kinds of
|
|
// HTTP streaming applications like gRPC for instance.
|
|
requiresBuffering := upstream.GetHostCount() > 1 && upstream.GetTryDuration() != 0
|
|
|
|
if requiresBuffering {
|
|
body, err := newBufferedBody(outreq.Body)
|
|
if err != nil {
|
|
return http.StatusBadRequest, errors.New("failed to read downstream request body")
|
|
}
|
|
if body != nil {
|
|
outreq.Body = body
|
|
}
|
|
}
|
|
|
|
// The keepRetrying function will return true if we should
|
|
// loop and try to select another host, or false if we
|
|
// should break and stop retrying.
|
|
start := time.Now()
|
|
keepRetrying := func() bool {
|
|
// if we've tried long enough, break
|
|
if time.Since(start) >= upstream.GetTryDuration() {
|
|
return false
|
|
}
|
|
// otherwise, wait and try the next available host
|
|
time.Sleep(upstream.GetTryInterval())
|
|
return true
|
|
}
|
|
|
|
var backendErr error
|
|
for {
|
|
// since Select() should give us "up" hosts, keep retrying
|
|
// hosts until timeout (or until we get a nil host).
|
|
host := upstream.Select(r)
|
|
if host == nil {
|
|
if backendErr == nil {
|
|
backendErr = errors.New("no hosts available upstream")
|
|
}
|
|
if !keepRetrying() {
|
|
break
|
|
}
|
|
continue
|
|
}
|
|
if rr, ok := w.(*httpserver.ResponseRecorder); ok && rr.Replacer != nil {
|
|
rr.Replacer.Set("upstream", host.Name)
|
|
}
|
|
|
|
proxy := host.ReverseProxy
|
|
|
|
// a backend's name may contain more than just the host,
|
|
// so we parse it as a URL to try to isolate the host.
|
|
if nameURL, err := url.Parse(host.Name); err == nil {
|
|
outreq.Host = nameURL.Host
|
|
if proxy == nil {
|
|
proxy = NewSingleHostReverseProxy(nameURL, host.WithoutPathPrefix, http.DefaultMaxIdleConnsPerHost)
|
|
}
|
|
|
|
// use upstream credentials by default
|
|
if outreq.Header.Get("Authorization") == "" && nameURL.User != nil {
|
|
pwd, _ := nameURL.User.Password()
|
|
outreq.SetBasicAuth(nameURL.User.Username(), pwd)
|
|
}
|
|
} else {
|
|
outreq.Host = host.Name
|
|
}
|
|
if proxy == nil {
|
|
return http.StatusInternalServerError, errors.New("proxy for host '" + host.Name + "' is nil")
|
|
}
|
|
|
|
// set headers for request going upstream
|
|
if host.UpstreamHeaders != nil {
|
|
// modify headers for request that will be sent to the upstream host
|
|
mutateHeadersByRules(outreq.Header, host.UpstreamHeaders, replacer)
|
|
if hostHeaders, ok := outreq.Header["Host"]; ok && len(hostHeaders) > 0 {
|
|
outreq.Host = hostHeaders[len(hostHeaders)-1]
|
|
}
|
|
}
|
|
|
|
// prepare a function that will update response
|
|
// headers coming back downstream
|
|
var downHeaderUpdateFn respUpdateFn
|
|
if host.DownstreamHeaders != nil {
|
|
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
|
|
}
|
|
|
|
// Before we retry the request we have to make sure
|
|
// that the body is rewound to it's beginning.
|
|
if bb, ok := outreq.Body.(*bufferedBody); ok {
|
|
if err := bb.rewind(); err != nil {
|
|
return http.StatusInternalServerError, errors.New("unable to rewind downstream request body")
|
|
}
|
|
}
|
|
|
|
// tell the proxy to serve the request
|
|
//
|
|
// NOTE:
|
|
// The call to proxy.ServeHTTP can theoretically panic.
|
|
// To prevent host.Conns from getting out-of-sync we thus have to
|
|
// make sure that it's _always_ correctly decremented afterwards.
|
|
func() {
|
|
atomic.AddInt64(&host.Conns, 1)
|
|
defer atomic.AddInt64(&host.Conns, -1)
|
|
backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
|
|
}()
|
|
|
|
// if no errors, we're done here
|
|
if backendErr == nil {
|
|
return 0, nil
|
|
}
|
|
|
|
if _, ok := backendErr.(httpserver.MaxBytesExceeded); ok {
|
|
return http.StatusRequestEntityTooLarge, backendErr
|
|
}
|
|
|
|
// failover; remember this failure for some time if
|
|
// request failure counting is enabled
|
|
timeout := host.FailTimeout
|
|
if timeout > 0 {
|
|
atomic.AddInt32(&host.Fails, 1)
|
|
go func(host *UpstreamHost, timeout time.Duration) {
|
|
time.Sleep(timeout)
|
|
atomic.AddInt32(&host.Fails, -1)
|
|
}(host, timeout)
|
|
}
|
|
|
|
// if we've tried long enough, break
|
|
if !keepRetrying() {
|
|
break
|
|
}
|
|
}
|
|
|
|
return http.StatusBadGateway, backendErr
|
|
}
|
|
|
|
// match finds the best match for a proxy config based on r.
|
|
func (p Proxy) match(r *http.Request) Upstream {
|
|
var u Upstream
|
|
var longestMatch int
|
|
for _, upstream := range p.Upstreams {
|
|
basePath := upstream.From()
|
|
if !httpserver.Path(r.URL.Path).Matches(basePath) || !upstream.AllowedPath(r.URL.Path) {
|
|
continue
|
|
}
|
|
if len(basePath) > longestMatch {
|
|
longestMatch = len(basePath)
|
|
u = upstream
|
|
}
|
|
}
|
|
return u
|
|
}
|
|
|
|
// createUpstremRequest shallow-copies r into a new request
|
|
// that can be sent upstream.
|
|
//
|
|
// Derived from reverseproxy.go in the standard Go httputil package.
|
|
func createUpstreamRequest(r *http.Request) *http.Request {
|
|
outreq := new(http.Request)
|
|
*outreq = *r // includes shallow copies of maps, but okay
|
|
// We should set body to nil explicitly if request body is empty.
|
|
// For server requests the Request Body is always non-nil.
|
|
if r.ContentLength == 0 {
|
|
outreq.Body = nil
|
|
}
|
|
|
|
// We are modifying the same underlying map from req (shallow
|
|
// copied above) so we only copy it if necessary.
|
|
copiedHeaders := false
|
|
|
|
// Remove hop-by-hop headers listed in the "Connection" header.
|
|
// See RFC 2616, section 14.10.
|
|
if c := outreq.Header.Get("Connection"); c != "" {
|
|
for _, f := range strings.Split(c, ",") {
|
|
if f = strings.TrimSpace(f); f != "" {
|
|
if !copiedHeaders {
|
|
outreq.Header = make(http.Header)
|
|
copyHeader(outreq.Header, r.Header)
|
|
copiedHeaders = true
|
|
}
|
|
outreq.Header.Del(f)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remove hop-by-hop headers to the backend. Especially
|
|
// important is "Connection" because we want a persistent
|
|
// connection, regardless of what the client sent to us.
|
|
for _, h := range hopHeaders {
|
|
if outreq.Header.Get(h) != "" {
|
|
if !copiedHeaders {
|
|
outreq.Header = make(http.Header)
|
|
copyHeader(outreq.Header, r.Header)
|
|
copiedHeaders = true
|
|
}
|
|
outreq.Header.Del(h)
|
|
}
|
|
}
|
|
|
|
if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
|
// If we aren't the first proxy, retain prior
|
|
// X-Forwarded-For information as a comma+space
|
|
// separated list and fold multiple headers into one.
|
|
if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
|
|
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
|
}
|
|
outreq.Header.Set("X-Forwarded-For", clientIP)
|
|
}
|
|
|
|
return outreq
|
|
}
|
|
|
|
func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn {
|
|
return func(resp *http.Response) {
|
|
mutateHeadersByRules(resp.Header, rules, replacer)
|
|
}
|
|
}
|
|
|
|
func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer) {
|
|
for ruleField, ruleValues := range rules {
|
|
if strings.HasPrefix(ruleField, "+") {
|
|
for _, ruleValue := range ruleValues {
|
|
replacement := repl.Replace(ruleValue)
|
|
if len(replacement) > 0 {
|
|
headers.Add(strings.TrimPrefix(ruleField, "+"), replacement)
|
|
}
|
|
}
|
|
} else if strings.HasPrefix(ruleField, "-") {
|
|
headers.Del(strings.TrimPrefix(ruleField, "-"))
|
|
} else if len(ruleValues) > 0 {
|
|
replacement := repl.Replace(ruleValues[len(ruleValues)-1])
|
|
if len(replacement) > 0 {
|
|
headers.Set(ruleField, replacement)
|
|
}
|
|
}
|
|
}
|
|
}
|