0
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2025-01-13 22:51:08 -05:00

Reconcile upstream dial addresses and request host/URL information

My goodness that was complicated

Blessed be request.Context

Sort of
This commit is contained in:
Matthew Holt 2019-09-05 13:14:39 -06:00
parent a60d54dbfd
commit 0830fbad03
No known key found for this signature in database
GPG key ID: 2A349DD577D586A5
9 changed files with 237 additions and 183 deletions

View file

@ -165,19 +165,19 @@ var (
listenersMu sync.Mutex
)
// ParseListenAddr parses addr, a string of the form "network/host:port"
// ParseNetworkAddress parses addr, a string of the form "network/host:port"
// (with any part optional) into its component parts. Because a port can
// also be a port range, there may be multiple addresses returned.
func ParseListenAddr(addr string) (network string, addrs []string, err error) {
func ParseNetworkAddress(addr string) (network string, addrs []string, err error) {
var host, port string
network, host, port, err = SplitListenAddr(addr)
network, host, port, err = SplitNetworkAddress(addr)
if network == "" {
network = "tcp"
}
if err != nil {
return
}
if network == "unix" {
if network == "unix" || network == "unixgram" || network == "unixpacket" {
addrs = []string{host}
return
}
@ -204,14 +204,14 @@ func ParseListenAddr(addr string) (network string, addrs []string, err error) {
return
}
// SplitListenAddr splits a into its network, host, and port components.
// SplitNetworkAddress splits a into its network, host, and port components.
// Note that port may be a port range, or omitted for unix sockets.
func SplitListenAddr(a string) (network, host, port string, err error) {
func SplitNetworkAddress(a string) (network, host, port string, err error) {
if idx := strings.Index(a, "/"); idx >= 0 {
network = strings.ToLower(strings.TrimSpace(a[:idx]))
a = a[idx+1:]
}
if network == "unix" {
if network == "unix" || network == "unixgram" || network == "unixpacket" {
host = a
return
}
@ -219,11 +219,11 @@ func SplitListenAddr(a string) (network, host, port string, err error) {
return
}
// JoinListenAddr combines network, host, and port into a single
// JoinNetworkAddress combines network, host, and port into a single
// address string of the form "network/host:port". Port may be a
// port range. For unix sockets, the network should be "unix" and
// the path to the socket should be given in the host argument.
func JoinListenAddr(network, host, port string) string {
func JoinNetworkAddress(network, host, port string) string {
var a string
if network != "" {
a = network + "/"

View file

@ -19,7 +19,7 @@ import (
"testing"
)
func TestSplitListenerAddr(t *testing.T) {
func TestSplitNetworkAddress(t *testing.T) {
for i, tc := range []struct {
input string
expectNetwork string
@ -67,8 +67,18 @@ func TestSplitListenerAddr(t *testing.T) {
expectNetwork: "unix",
expectHost: "/foo/bar",
},
{
input: "unixgram//foo/bar",
expectNetwork: "unixgram",
expectHost: "/foo/bar",
},
{
input: "unixpacket//foo/bar",
expectNetwork: "unixpacket",
expectHost: "/foo/bar",
},
} {
actualNetwork, actualHost, actualPort, err := SplitListenAddr(tc.input)
actualNetwork, actualHost, actualPort, err := SplitNetworkAddress(tc.input)
if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got: %v", i, err)
}
@ -87,7 +97,7 @@ func TestSplitListenerAddr(t *testing.T) {
}
}
func TestJoinListenerAddr(t *testing.T) {
func TestJoinNetworkAddress(t *testing.T) {
for i, tc := range []struct {
network, host, port string
expect string
@ -129,14 +139,14 @@ func TestJoinListenerAddr(t *testing.T) {
expect: "unix//foo/bar",
},
} {
actual := JoinListenAddr(tc.network, tc.host, tc.port)
actual := JoinNetworkAddress(tc.network, tc.host, tc.port)
if actual != tc.expect {
t.Errorf("Test %d: Expected '%s' but got '%s'", i, tc.expect, actual)
}
}
}
func TestParseListenerAddr(t *testing.T) {
func TestParseNetworkAddress(t *testing.T) {
for i, tc := range []struct {
input string
expectNetwork string
@ -194,7 +204,7 @@ func TestParseListenerAddr(t *testing.T) {
expectAddrs: []string{"localhost:0"},
},
} {
actualNetwork, actualAddrs, err := ParseListenAddr(tc.input)
actualNetwork, actualAddrs, err := ParseNetworkAddress(tc.input)
if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got: %v", i, err)
}

View file

@ -108,7 +108,7 @@ func (app *App) Validate() error {
lnAddrs := make(map[string]string)
for srvName, srv := range app.Servers {
for _, addr := range srv.Listen {
netw, expanded, err := caddy.ParseListenAddr(addr)
netw, expanded, err := caddy.ParseNetworkAddress(addr)
if err != nil {
return fmt.Errorf("invalid listener address '%s': %v", addr, err)
}
@ -149,7 +149,7 @@ func (app *App) Start() error {
}
for _, lnAddr := range srv.Listen {
network, addrs, err := caddy.ParseListenAddr(lnAddr)
network, addrs, err := caddy.ParseNetworkAddress(lnAddr)
if err != nil {
return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err)
}
@ -309,7 +309,7 @@ func (app *App) automaticHTTPS() error {
// create HTTP->HTTPS redirects
for _, addr := range srv.Listen {
netw, host, port, err := caddy.SplitListenAddr(addr)
netw, host, port, err := caddy.SplitNetworkAddress(addr)
if err != nil {
return fmt.Errorf("%s: invalid listener address: %v", srvName, addr)
}
@ -318,7 +318,7 @@ func (app *App) automaticHTTPS() error {
if httpPort == 0 {
httpPort = DefaultHTTPPort
}
httpRedirLnAddr := caddy.JoinListenAddr(netw, host, strconv.Itoa(httpPort))
httpRedirLnAddr := caddy.JoinNetworkAddress(netw, host, strconv.Itoa(httpPort))
lnAddrMap[httpRedirLnAddr] = struct{}{}
if parts := strings.SplitN(port, "-", 2); len(parts) == 2 {
@ -361,7 +361,7 @@ func (app *App) automaticHTTPS() error {
var lnAddrs []string
mapLoop:
for addr := range lnAddrMap {
netw, addrs, err := caddy.ParseListenAddr(addr)
netw, addrs, err := caddy.ParseNetworkAddress(addr)
if err != nil {
continue
}
@ -386,7 +386,7 @@ func (app *App) automaticHTTPS() error {
func (app *App) listenerTaken(network, address string) bool {
for _, srv := range app.Servers {
for _, addr := range srv.Listen {
netw, addrs, err := caddy.ParseListenAddr(addr)
netw, addrs, err := caddy.ParseNetworkAddress(addr)
if err != nil || netw != network {
continue
}

View file

@ -25,6 +25,7 @@ import (
"strings"
"time"
"github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy"
"github.com/caddyserver/caddy/v2/modules/caddytls"
"github.com/caddyserver/caddy/v2"
@ -34,6 +35,7 @@ func init() {
caddy.RegisterModule(Transport{})
}
// Transport facilitates FastCGI communication.
type Transport struct {
//////////////////////////////
// TODO: taken from v1 Handler type
@ -57,32 +59,32 @@ type Transport struct {
// Use this directory as the fastcgi root directory. Defaults to the root
// directory of the parent virtual host.
Root string
Root string `json:"root,omitempty"`
// The path in the URL will be split into two, with the first piece ending
// with the value of SplitPath. The first piece will be assumed as the
// actual resource (CGI script) name, and the second piece will be set to
// PATH_INFO for the CGI script to use.
SplitPath string
SplitPath string `json:"split_path,omitempty"`
// If the URL ends with '/' (which indicates a directory), these index
// files will be tried instead.
IndexFiles []string
// IndexFiles []string
// Environment Variables
EnvVars [][2]string
EnvVars [][2]string `json:"env,omitempty"`
// Ignored paths
IgnoredSubPaths []string
// IgnoredSubPaths []string
// The duration used to set a deadline when connecting to an upstream.
DialTimeout time.Duration
DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
// The duration used to set a deadline when reading from the FastCGI server.
ReadTimeout time.Duration
ReadTimeout caddy.Duration `json:"read_timeout,omitempty"`
// The duration used to set a deadline when sending to the FastCGI server.
WriteTimeout time.Duration
WriteTimeout caddy.Duration `json:"write_timeout,omitempty"`
}
// CaddyModule returns the Caddy module information.
@ -93,102 +95,62 @@ func (Transport) CaddyModule() caddy.ModuleInfo {
}
}
// RoundTrip implements http.RoundTripper.
func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
// Create environment for CGI script
env, err := t.buildEnv(r)
if err != nil {
return nil, fmt.Errorf("building environment: %v", err)
}
// TODO:
// Connect to FastCGI gateway
// address, err := f.Address()
// if err != nil {
// return http.StatusBadGateway, err
// }
// network, address := parseAddress(address)
network, address := "tcp", r.URL.Host // TODO:
// TODO: doesn't dialer have a Timeout field?
ctx := context.Background()
if t.DialTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, t.DialTimeout)
ctx, cancel = context.WithTimeout(ctx, time.Duration(t.DialTimeout))
defer cancel()
}
// extract dial information from request (this
// should embedded by the reverse proxy)
network, address := "tcp", r.URL.Host
if dialInfoVal := ctx.Value(reverseproxy.DialInfoCtxKey); dialInfoVal != nil {
dialInfo := dialInfoVal.(reverseproxy.DialInfo)
network = dialInfo.Network
address = dialInfo.Address
}
fcgiBackend, err := DialContext(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("dialing backend: %v", err)
}
// fcgiBackend is closed when response body is closed (see clientCloser)
// fcgiBackend gets closed when response body is closed (see clientCloser)
// read/write timeouts
if err := fcgiBackend.SetReadTimeout(t.ReadTimeout); err != nil {
if err := fcgiBackend.SetReadTimeout(time.Duration(t.ReadTimeout)); err != nil {
return nil, fmt.Errorf("setting read timeout: %v", err)
}
if err := fcgiBackend.SetWriteTimeout(t.WriteTimeout); err != nil {
if err := fcgiBackend.SetWriteTimeout(time.Duration(t.WriteTimeout)); err != nil {
return nil, fmt.Errorf("setting write timeout: %v", err)
}
var resp *http.Response
var contentLength int64
// if ContentLength is already set
if r.ContentLength > 0 {
contentLength = r.ContentLength
} else {
contentLength := r.ContentLength
if contentLength == 0 {
contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
}
var resp *http.Response
switch r.Method {
case "HEAD":
case http.MethodHead:
resp, err = fcgiBackend.Head(env)
case "GET":
case http.MethodGet:
resp, err = fcgiBackend.Get(env, r.Body, contentLength)
case "OPTIONS":
case http.MethodOptions:
resp, err = fcgiBackend.Options(env)
default:
resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength)
}
// TODO:
return resp, err
// Stuff brought over from v1 that might not be necessary here:
// if resp != nil && resp.Body != nil {
// defer resp.Body.Close()
// }
// if err != nil {
// if err, ok := err.(net.Error); ok && err.Timeout() {
// return http.StatusGatewayTimeout, err
// } else if err != io.EOF {
// return http.StatusBadGateway, err
// }
// }
// // Write response header
// writeHeader(w, resp)
// // Write the response body
// _, err = io.Copy(w, resp.Body)
// if err != nil {
// return http.StatusBadGateway, err
// }
// // Log any stderr output from upstream
// if fcgiBackend.stderr.Len() != 0 {
// // Remove trailing newline, error logger already does this.
// err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
// }
// // Normally we would return the status code if it is an error status (>= 400),
// // however, upstream FastCGI apps don't know about our contract and have
// // probably already written an error page. So we just return 0, indicating
// // that the response body is already written. However, we do return any
// // error value so it can be logged.
// // Note that the proxy middleware works the same way, returning status=0.
// return 0, err
}
// buildEnv returns a set of CGI environment variables for the request.

View file

@ -15,6 +15,7 @@
package reverseproxy
import (
"context"
"fmt"
"io"
"io/ioutil"
@ -93,15 +94,31 @@ func (h *Handler) activeHealthChecker() {
// health checks for all hosts in the global repository.
func (h *Handler) doActiveHealthChecksForAllHosts() {
hosts.Range(func(key, value interface{}) bool {
addr := key.(string)
networkAddr := key.(string)
host := value.(Host)
go func(addr string, host Host) {
err := h.doActiveHealthCheck(addr, host)
go func(networkAddr string, host Host) {
network, addrs, err := caddy.ParseNetworkAddress(networkAddr)
if err != nil {
log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", addr, err)
log.Printf("[ERROR] reverse_proxy: active health check for host %s: bad network address: %v", networkAddr, err)
return
}
}(addr, host)
if len(addrs) != 1 {
log.Printf("[ERROR] reverse_proxy: active health check for host %s: multiple addresses (upstream must map to only one address)", networkAddr)
return
}
hostAddr := addrs[0]
if network == "unix" || network == "unixgram" || network == "unixpacket" {
// this will be used as the Host portion of a http.Request URL, and
// paths to socket files would produce an error when creating URL,
// so use a fake Host value instead
hostAddr = network
}
err = h.doActiveHealthCheck(DialInfo{network, addrs[0]}, hostAddr, host)
if err != nil {
log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", networkAddr, err)
}
}(networkAddr, host)
// continue to iterate all hosts
return true
@ -115,26 +132,39 @@ func (h *Handler) doActiveHealthChecksForAllHosts() {
// according to whether it passes the health check. An error is
// returned only if the health check fails to occur or if marking
// the host's health status fails.
func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error {
// create the URL for the health check
u, err := url.Parse(hostAddr)
if err != nil {
return err
func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host Host) error {
// create the URL for the request that acts as a health check
scheme := "http"
if ht, ok := h.Transport.(*http.Transport); ok && ht.TLSClientConfig != nil {
// this is kind of a hacky way to know if we should use HTTPS, but whatever
scheme = "https"
}
if h.HealthChecks.Active.Path != "" {
u.Path = h.HealthChecks.Active.Path
u := &url.URL{
Scheme: scheme,
Host: hostAddr,
Path: h.HealthChecks.Active.Path,
}
// adjust the port, if configured to be different
if h.HealthChecks.Active.Port != 0 {
portStr := strconv.Itoa(h.HealthChecks.Active.Port)
u.Host = net.JoinHostPort(u.Hostname(), portStr)
host, _, err := net.SplitHostPort(hostAddr)
if err != nil {
host = hostAddr
}
u.Host = net.JoinHostPort(host, portStr)
}
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
// attach dialing information to this request
ctx := context.Background()
ctx = context.WithValue(ctx, caddy.ReplacerCtxKey, caddy.NewReplacer())
ctx = context.WithValue(ctx, DialInfoCtxKey, dialInfo)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return err
return fmt.Errorf("making request: %v", err)
}
// do the request, careful to tame the response body
// do the request, being careful to tame the response body
resp, err := h.HealthChecks.Active.httpClient.Do(req)
if err != nil {
log.Printf("[INFO] reverse_proxy: active health check: %s is down (HTTP request failed: %v)", hostAddr, err)
@ -149,7 +179,7 @@ func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error {
body = io.LimitReader(body, h.HealthChecks.Active.MaxSize)
}
defer func() {
// drain any remaining body so connection can be re-used
// drain any remaining body so connection could be re-used
io.Copy(ioutil.Discard, body)
resp.Body.Close()
}()
@ -225,7 +255,7 @@ func (h *Handler) countFailure(upstream *Upstream) {
err := upstream.Host.CountFail(1)
if err != nil {
log.Printf("[ERROR] proxy: upstream %s: counting failure: %v",
upstream.hostURL, err)
upstream.dialInfo, err)
}
// forget it later
@ -234,7 +264,7 @@ func (h *Handler) countFailure(upstream *Upstream) {
err := host.CountFail(-1)
if err != nil {
log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v",
upstream.hostURL, err)
upstream.dialInfo, err)
}
}(upstream.Host, failDuration)
}

View file

@ -16,7 +16,6 @@ package reverseproxy
import (
"fmt"
"net/url"
"sync/atomic"
"github.com/caddyserver/caddy/v2"
@ -59,7 +58,7 @@ type UpstreamPool []*Upstream
type Upstream struct {
Host `json:"-"`
Address string `json:"address,omitempty"`
Dial string `json:"dial,omitempty"`
MaxRequests int `json:"max_requests,omitempty"`
// TODO: This could be really useful, to bind requests
@ -68,8 +67,8 @@ type Upstream struct {
// IPAffinity string
healthCheckPolicy *PassiveHealthChecks
hostURL *url.URL
cb CircuitBreaker
dialInfo DialInfo
}
// Available returns true if the remote host
@ -101,11 +100,6 @@ func (u *Upstream) Full() bool {
return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests
}
// URL returns the upstream host's endpoint URL.
func (u *Upstream) URL() *url.URL {
return u.hostURL
}
// upstreamHost is the basic, in-memory representation
// of the state of a remote host. It implements the
// Host interface.
@ -162,6 +156,34 @@ func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) {
return swapped, nil
}
// DialInfo contains information needed to dial a
// connection to an upstream host. This information
// may be different than that which is represented
// in a URL (for example, unix sockets don't have
// a host that can be represented in a URL, but
// they certainly have a network name and address).
type DialInfo struct {
// The network to use. This should be one of the
// values that is accepted by net.Dial:
// https://golang.org/pkg/net/#Dial
Network string
// The address to dial. Follows the same
// semantics and rules as net.Dial.
Address string
}
// String returns the Caddy network address form
// by joining the network and address with a
// forward slash.
func (di DialInfo) String() string {
return di.Network + "/" + di.Address
}
// DialInfoCtxKey is used to store a DialInfo
// in a context.Context.
const DialInfoCtxKey = caddy.CtxKey("dial_info")
// hosts is the global repository for hosts that are
// currently in use by active configuration(s). This
// allows the state of remote hosts to be preserved

View file

@ -15,6 +15,7 @@
package reverseproxy
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
@ -63,14 +64,23 @@ func (HTTPTransport) CaddyModule() caddy.ModuleInfo {
// Provision sets up h.RoundTripper with a http.Transport
// that is ready to use.
func (h *HTTPTransport) Provision(ctx caddy.Context) error {
func (h *HTTPTransport) Provision(_ caddy.Context) error {
dialer := &net.Dialer{
Timeout: time.Duration(h.DialTimeout),
FallbackDelay: time.Duration(h.FallbackDelay),
// TODO: Resolver
}
rt := &http.Transport{
DialContext: dialer.DialContext,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
// the proper dialing information should be embedded into the request's context
if dialInfoVal := ctx.Value(DialInfoCtxKey); dialInfoVal != nil {
dialInfo := dialInfoVal.(DialInfo)
network = dialInfo.Network
address = dialInfo.Address
}
return dialer.DialContext(ctx, network, address)
},
MaxConnsPerHost: h.MaxConnsPerHost,
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout),
@ -91,7 +101,6 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error {
if h.KeepAlive != nil {
dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval)
if enabled := h.KeepAlive.Enabled; enabled != nil {
rt.DisableKeepAlives = !*enabled
}
@ -191,16 +200,3 @@ type KeepAlive struct {
MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"`
IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle
}
var (
defaultDialer = net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
defaultTransport = &http.Transport{
DialContext: defaultDialer.DialContext,
TLSHandshakeTimeout: 5 * time.Second,
IdleConnTimeout: 2 * time.Minute,
}
)

View file

@ -20,7 +20,6 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"regexp"
"strings"
"time"
@ -86,7 +85,18 @@ func (h *Handler) Provision(ctx caddy.Context) error {
}
if h.Transport == nil {
h.Transport = defaultTransport
t := &HTTPTransport{
KeepAlive: &KeepAlive{
ProbeInterval: caddy.Duration(30 * time.Second),
IdleConnTimeout: caddy.Duration(2 * time.Minute),
},
DialTimeout: caddy.Duration(10 * time.Second),
}
err := t.Provision(ctx)
if err != nil {
return fmt.Errorf("provisioning default transport: %v", err)
}
h.Transport = t
}
if h.LoadBalancing == nil {
@ -133,51 +143,65 @@ func (h *Handler) Provision(ctx caddy.Context) error {
go h.activeHealthChecker()
}
var allUpstreams []*Upstream
for _, upstream := range h.Upstreams {
upstream.cb = h.CB
// url parser requires a scheme
if !strings.Contains(upstream.Address, "://") {
upstream.Address = "http://" + upstream.Address
}
u, err := url.Parse(upstream.Address)
// upstreams are allowed to map to only a single host,
// but an upstream's address may semantically represent
// multiple addresses, so make sure to handle each
// one in turn based on this one upstream config
network, addresses, err := caddy.ParseNetworkAddress(upstream.Dial)
if err != nil {
return fmt.Errorf("invalid upstream address %s: %v", upstream.Address, err)
}
upstream.hostURL = u
// if host already exists from a current config,
// use that instead; otherwise, add it
// TODO: make hosts modular, so that their state can be distributed in enterprise for example
// TODO: If distributed, the pool should be stored in storage...
var host Host = new(upstreamHost)
activeHost, loaded := hosts.LoadOrStore(u.String(), host)
if loaded {
host = activeHost.(Host)
}
upstream.Host = host
// if the passive health checker has a non-zero "unhealthy
// request count" but the upstream has no MaxRequests set
// (they are the same thing, but one is a default value for
// for upstreams with a zero MaxRequests), copy the default
// value into this upstream, since the value in the upstream
// is what is used during availability checks
if h.HealthChecks != nil &&
h.HealthChecks.Passive != nil &&
h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
upstream.MaxRequests == 0 {
upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
return fmt.Errorf("parsing dial address: %v", err)
}
if h.HealthChecks != nil {
for _, addr := range addresses {
// make a new upstream based on the original
// that has a singular dial address
upstreamCopy := *upstream
upstreamCopy.dialInfo = DialInfo{network, addr}
upstreamCopy.Dial = upstreamCopy.dialInfo.String()
upstreamCopy.cb = h.CB
// if host already exists from a current config,
// use that instead; otherwise, add it
// TODO: make hosts modular, so that their state can be distributed in enterprise for example
// TODO: If distributed, the pool should be stored in storage...
var host Host = new(upstreamHost)
activeHost, loaded := hosts.LoadOrStore(upstreamCopy.Dial, host)
if loaded {
host = activeHost.(Host)
}
upstreamCopy.Host = host
// if the passive health checker has a non-zero "unhealthy
// request count" but the upstream has no MaxRequests set
// (they are the same thing, but one is a default value for
// for upstreams with a zero MaxRequests), copy the default
// value into this upstream, since the value in the upstream
// (MaxRequests) is what is used during availability checks
if h.HealthChecks != nil &&
h.HealthChecks.Passive != nil &&
h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
upstreamCopy.MaxRequests == 0 {
upstreamCopy.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
}
// upstreams need independent access to the passive
// health check policy so they can, you know, passively
// do health checks
upstream.healthCheckPolicy = h.HealthChecks.Passive
// health check policy because they run outside of the
// scope of a request handler
if h.HealthChecks != nil {
upstreamCopy.healthCheckPolicy = h.HealthChecks.Passive
}
allUpstreams = append(allUpstreams, &upstreamCopy)
}
}
// replace the unmarshaled upstreams (possible 1:many
// address mapping) with our list, which is mapped 1:1,
// thus may have expanded the original list
h.Upstreams = allUpstreams
return nil
}
@ -192,7 +216,7 @@ func (h *Handler) Cleanup() error {
// remove hosts from our config from the pool
for _, upstream := range h.Upstreams {
hosts.Delete(upstream.hostURL.String())
hosts.Delete(upstream.dialInfo.String())
}
return nil
@ -222,6 +246,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
continue
}
// attach to the request information about how to dial the upstream;
// this is necessary because the information cannot be sufficiently
// or satisfactorily represented in a URL
ctx := context.WithValue(r.Context(), DialInfoCtxKey, upstream.dialInfo)
r = r.WithContext(ctx)
// proxy the request to that upstream
proxyErr = h.reverseProxy(w, r, upstream)
if proxyErr == nil || proxyErr == context.Canceled {
@ -249,6 +279,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
// This assumes that no mutations of the request are performed
// by h during or after proxying.
func (h Handler) prepareRequest(req *http.Request) error {
// as a special (but very common) case, if the transport
// is HTTP, then ensure the request has the proper scheme
// because incoming requests by default are lacking it
if req.URL.Scheme == "" {
req.URL.Scheme = "http"
if ht, ok := h.Transport.(*HTTPTransport); ok && ht.TLS != nil {
req.URL.Scheme = "https"
}
}
if req.ContentLength == 0 {
req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries
}
@ -433,14 +473,8 @@ func (h Handler) tryAgain(start time.Time, proxyErr error) bool {
// directRequest modifies only req.URL so that it points to the
// given upstream host. It must modify ONLY the request URL.
func (h Handler) directRequest(req *http.Request, upstream *Upstream) {
target := upstream.hostURL
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) // TODO: This might be a bug (if any part of the path was augmented from a previously-tried upstream; need to start from clean original path of request, same for query string!)
if target.RawQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = target.RawQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = target.RawQuery + "&" + req.URL.RawQuery
if req.URL.Host == "" {
req.URL.Host = upstream.dialInfo.Address
}
}

View file

@ -168,7 +168,7 @@ func (s *Server) enforcementHandler(w http.ResponseWriter, r *http.Request, next
// listeners in s that use a port which is not otherPort.
func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool {
for _, lnAddr := range s.Listen {
_, addrs, err := caddy.ParseListenAddr(lnAddr)
_, addrs, err := caddy.ParseNetworkAddress(lnAddr)
if err == nil {
for _, a := range addrs {
_, port, err := net.SplitHostPort(a)