mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-30 22:34:15 -05:00
Reconcile upstream dial addresses and request host/URL information
My goodness that was complicated Blessed be request.Context Sort of
This commit is contained in:
parent
a60d54dbfd
commit
0830fbad03
9 changed files with 237 additions and 183 deletions
18
listeners.go
18
listeners.go
|
@ -165,19 +165,19 @@ var (
|
|||
listenersMu sync.Mutex
|
||||
)
|
||||
|
||||
// ParseListenAddr parses addr, a string of the form "network/host:port"
|
||||
// ParseNetworkAddress parses addr, a string of the form "network/host:port"
|
||||
// (with any part optional) into its component parts. Because a port can
|
||||
// also be a port range, there may be multiple addresses returned.
|
||||
func ParseListenAddr(addr string) (network string, addrs []string, err error) {
|
||||
func ParseNetworkAddress(addr string) (network string, addrs []string, err error) {
|
||||
var host, port string
|
||||
network, host, port, err = SplitListenAddr(addr)
|
||||
network, host, port, err = SplitNetworkAddress(addr)
|
||||
if network == "" {
|
||||
network = "tcp"
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if network == "unix" {
|
||||
if network == "unix" || network == "unixgram" || network == "unixpacket" {
|
||||
addrs = []string{host}
|
||||
return
|
||||
}
|
||||
|
@ -204,14 +204,14 @@ func ParseListenAddr(addr string) (network string, addrs []string, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// SplitListenAddr splits a into its network, host, and port components.
|
||||
// SplitNetworkAddress splits a into its network, host, and port components.
|
||||
// Note that port may be a port range, or omitted for unix sockets.
|
||||
func SplitListenAddr(a string) (network, host, port string, err error) {
|
||||
func SplitNetworkAddress(a string) (network, host, port string, err error) {
|
||||
if idx := strings.Index(a, "/"); idx >= 0 {
|
||||
network = strings.ToLower(strings.TrimSpace(a[:idx]))
|
||||
a = a[idx+1:]
|
||||
}
|
||||
if network == "unix" {
|
||||
if network == "unix" || network == "unixgram" || network == "unixpacket" {
|
||||
host = a
|
||||
return
|
||||
}
|
||||
|
@ -219,11 +219,11 @@ func SplitListenAddr(a string) (network, host, port string, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// JoinListenAddr combines network, host, and port into a single
|
||||
// JoinNetworkAddress combines network, host, and port into a single
|
||||
// address string of the form "network/host:port". Port may be a
|
||||
// port range. For unix sockets, the network should be "unix" and
|
||||
// the path to the socket should be given in the host argument.
|
||||
func JoinListenAddr(network, host, port string) string {
|
||||
func JoinNetworkAddress(network, host, port string) string {
|
||||
var a string
|
||||
if network != "" {
|
||||
a = network + "/"
|
||||
|
|
|
@ -19,7 +19,7 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func TestSplitListenerAddr(t *testing.T) {
|
||||
func TestSplitNetworkAddress(t *testing.T) {
|
||||
for i, tc := range []struct {
|
||||
input string
|
||||
expectNetwork string
|
||||
|
@ -67,8 +67,18 @@ func TestSplitListenerAddr(t *testing.T) {
|
|||
expectNetwork: "unix",
|
||||
expectHost: "/foo/bar",
|
||||
},
|
||||
{
|
||||
input: "unixgram//foo/bar",
|
||||
expectNetwork: "unixgram",
|
||||
expectHost: "/foo/bar",
|
||||
},
|
||||
{
|
||||
input: "unixpacket//foo/bar",
|
||||
expectNetwork: "unixpacket",
|
||||
expectHost: "/foo/bar",
|
||||
},
|
||||
} {
|
||||
actualNetwork, actualHost, actualPort, err := SplitListenAddr(tc.input)
|
||||
actualNetwork, actualHost, actualPort, err := SplitNetworkAddress(tc.input)
|
||||
if tc.expectErr && err == nil {
|
||||
t.Errorf("Test %d: Expected error but got: %v", i, err)
|
||||
}
|
||||
|
@ -87,7 +97,7 @@ func TestSplitListenerAddr(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestJoinListenerAddr(t *testing.T) {
|
||||
func TestJoinNetworkAddress(t *testing.T) {
|
||||
for i, tc := range []struct {
|
||||
network, host, port string
|
||||
expect string
|
||||
|
@ -129,14 +139,14 @@ func TestJoinListenerAddr(t *testing.T) {
|
|||
expect: "unix//foo/bar",
|
||||
},
|
||||
} {
|
||||
actual := JoinListenAddr(tc.network, tc.host, tc.port)
|
||||
actual := JoinNetworkAddress(tc.network, tc.host, tc.port)
|
||||
if actual != tc.expect {
|
||||
t.Errorf("Test %d: Expected '%s' but got '%s'", i, tc.expect, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseListenerAddr(t *testing.T) {
|
||||
func TestParseNetworkAddress(t *testing.T) {
|
||||
for i, tc := range []struct {
|
||||
input string
|
||||
expectNetwork string
|
||||
|
@ -194,7 +204,7 @@ func TestParseListenerAddr(t *testing.T) {
|
|||
expectAddrs: []string{"localhost:0"},
|
||||
},
|
||||
} {
|
||||
actualNetwork, actualAddrs, err := ParseListenAddr(tc.input)
|
||||
actualNetwork, actualAddrs, err := ParseNetworkAddress(tc.input)
|
||||
if tc.expectErr && err == nil {
|
||||
t.Errorf("Test %d: Expected error but got: %v", i, err)
|
||||
}
|
||||
|
|
|
@ -108,7 +108,7 @@ func (app *App) Validate() error {
|
|||
lnAddrs := make(map[string]string)
|
||||
for srvName, srv := range app.Servers {
|
||||
for _, addr := range srv.Listen {
|
||||
netw, expanded, err := caddy.ParseListenAddr(addr)
|
||||
netw, expanded, err := caddy.ParseNetworkAddress(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid listener address '%s': %v", addr, err)
|
||||
}
|
||||
|
@ -149,7 +149,7 @@ func (app *App) Start() error {
|
|||
}
|
||||
|
||||
for _, lnAddr := range srv.Listen {
|
||||
network, addrs, err := caddy.ParseListenAddr(lnAddr)
|
||||
network, addrs, err := caddy.ParseNetworkAddress(lnAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err)
|
||||
}
|
||||
|
@ -309,7 +309,7 @@ func (app *App) automaticHTTPS() error {
|
|||
|
||||
// create HTTP->HTTPS redirects
|
||||
for _, addr := range srv.Listen {
|
||||
netw, host, port, err := caddy.SplitListenAddr(addr)
|
||||
netw, host, port, err := caddy.SplitNetworkAddress(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: invalid listener address: %v", srvName, addr)
|
||||
}
|
||||
|
@ -318,7 +318,7 @@ func (app *App) automaticHTTPS() error {
|
|||
if httpPort == 0 {
|
||||
httpPort = DefaultHTTPPort
|
||||
}
|
||||
httpRedirLnAddr := caddy.JoinListenAddr(netw, host, strconv.Itoa(httpPort))
|
||||
httpRedirLnAddr := caddy.JoinNetworkAddress(netw, host, strconv.Itoa(httpPort))
|
||||
lnAddrMap[httpRedirLnAddr] = struct{}{}
|
||||
|
||||
if parts := strings.SplitN(port, "-", 2); len(parts) == 2 {
|
||||
|
@ -361,7 +361,7 @@ func (app *App) automaticHTTPS() error {
|
|||
var lnAddrs []string
|
||||
mapLoop:
|
||||
for addr := range lnAddrMap {
|
||||
netw, addrs, err := caddy.ParseListenAddr(addr)
|
||||
netw, addrs, err := caddy.ParseNetworkAddress(addr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
@ -386,7 +386,7 @@ func (app *App) automaticHTTPS() error {
|
|||
func (app *App) listenerTaken(network, address string) bool {
|
||||
for _, srv := range app.Servers {
|
||||
for _, addr := range srv.Listen {
|
||||
netw, addrs, err := caddy.ParseListenAddr(addr)
|
||||
netw, addrs, err := caddy.ParseNetworkAddress(addr)
|
||||
if err != nil || netw != network {
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddytls"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
|
@ -34,6 +35,7 @@ func init() {
|
|||
caddy.RegisterModule(Transport{})
|
||||
}
|
||||
|
||||
// Transport facilitates FastCGI communication.
|
||||
type Transport struct {
|
||||
//////////////////////////////
|
||||
// TODO: taken from v1 Handler type
|
||||
|
@ -57,32 +59,32 @@ type Transport struct {
|
|||
|
||||
// Use this directory as the fastcgi root directory. Defaults to the root
|
||||
// directory of the parent virtual host.
|
||||
Root string
|
||||
Root string `json:"root,omitempty"`
|
||||
|
||||
// The path in the URL will be split into two, with the first piece ending
|
||||
// with the value of SplitPath. The first piece will be assumed as the
|
||||
// actual resource (CGI script) name, and the second piece will be set to
|
||||
// PATH_INFO for the CGI script to use.
|
||||
SplitPath string
|
||||
SplitPath string `json:"split_path,omitempty"`
|
||||
|
||||
// If the URL ends with '/' (which indicates a directory), these index
|
||||
// files will be tried instead.
|
||||
IndexFiles []string
|
||||
// IndexFiles []string
|
||||
|
||||
// Environment Variables
|
||||
EnvVars [][2]string
|
||||
EnvVars [][2]string `json:"env,omitempty"`
|
||||
|
||||
// Ignored paths
|
||||
IgnoredSubPaths []string
|
||||
// IgnoredSubPaths []string
|
||||
|
||||
// The duration used to set a deadline when connecting to an upstream.
|
||||
DialTimeout time.Duration
|
||||
DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
|
||||
|
||||
// The duration used to set a deadline when reading from the FastCGI server.
|
||||
ReadTimeout time.Duration
|
||||
ReadTimeout caddy.Duration `json:"read_timeout,omitempty"`
|
||||
|
||||
// The duration used to set a deadline when sending to the FastCGI server.
|
||||
WriteTimeout time.Duration
|
||||
WriteTimeout caddy.Duration `json:"write_timeout,omitempty"`
|
||||
}
|
||||
|
||||
// CaddyModule returns the Caddy module information.
|
||||
|
@ -93,102 +95,62 @@ func (Transport) CaddyModule() caddy.ModuleInfo {
|
|||
}
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper.
|
||||
func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
// Create environment for CGI script
|
||||
env, err := t.buildEnv(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("building environment: %v", err)
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// Connect to FastCGI gateway
|
||||
// address, err := f.Address()
|
||||
// if err != nil {
|
||||
// return http.StatusBadGateway, err
|
||||
// }
|
||||
// network, address := parseAddress(address)
|
||||
network, address := "tcp", r.URL.Host // TODO:
|
||||
|
||||
// TODO: doesn't dialer have a Timeout field?
|
||||
ctx := context.Background()
|
||||
if t.DialTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, t.DialTimeout)
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(t.DialTimeout))
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// extract dial information from request (this
|
||||
// should embedded by the reverse proxy)
|
||||
network, address := "tcp", r.URL.Host
|
||||
if dialInfoVal := ctx.Value(reverseproxy.DialInfoCtxKey); dialInfoVal != nil {
|
||||
dialInfo := dialInfoVal.(reverseproxy.DialInfo)
|
||||
network = dialInfo.Network
|
||||
address = dialInfo.Address
|
||||
}
|
||||
|
||||
fcgiBackend, err := DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing backend: %v", err)
|
||||
}
|
||||
// fcgiBackend is closed when response body is closed (see clientCloser)
|
||||
// fcgiBackend gets closed when response body is closed (see clientCloser)
|
||||
|
||||
// read/write timeouts
|
||||
if err := fcgiBackend.SetReadTimeout(t.ReadTimeout); err != nil {
|
||||
if err := fcgiBackend.SetReadTimeout(time.Duration(t.ReadTimeout)); err != nil {
|
||||
return nil, fmt.Errorf("setting read timeout: %v", err)
|
||||
}
|
||||
if err := fcgiBackend.SetWriteTimeout(t.WriteTimeout); err != nil {
|
||||
if err := fcgiBackend.SetWriteTimeout(time.Duration(t.WriteTimeout)); err != nil {
|
||||
return nil, fmt.Errorf("setting write timeout: %v", err)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
|
||||
var contentLength int64
|
||||
// if ContentLength is already set
|
||||
if r.ContentLength > 0 {
|
||||
contentLength = r.ContentLength
|
||||
} else {
|
||||
contentLength := r.ContentLength
|
||||
if contentLength == 0 {
|
||||
contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
switch r.Method {
|
||||
case "HEAD":
|
||||
case http.MethodHead:
|
||||
resp, err = fcgiBackend.Head(env)
|
||||
case "GET":
|
||||
case http.MethodGet:
|
||||
resp, err = fcgiBackend.Get(env, r.Body, contentLength)
|
||||
case "OPTIONS":
|
||||
case http.MethodOptions:
|
||||
resp, err = fcgiBackend.Options(env)
|
||||
default:
|
||||
resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength)
|
||||
}
|
||||
|
||||
// TODO:
|
||||
return resp, err
|
||||
|
||||
// Stuff brought over from v1 that might not be necessary here:
|
||||
|
||||
// if resp != nil && resp.Body != nil {
|
||||
// defer resp.Body.Close()
|
||||
// }
|
||||
|
||||
// if err != nil {
|
||||
// if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||
// return http.StatusGatewayTimeout, err
|
||||
// } else if err != io.EOF {
|
||||
// return http.StatusBadGateway, err
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Write response header
|
||||
// writeHeader(w, resp)
|
||||
|
||||
// // Write the response body
|
||||
// _, err = io.Copy(w, resp.Body)
|
||||
// if err != nil {
|
||||
// return http.StatusBadGateway, err
|
||||
// }
|
||||
|
||||
// // Log any stderr output from upstream
|
||||
// if fcgiBackend.stderr.Len() != 0 {
|
||||
// // Remove trailing newline, error logger already does this.
|
||||
// err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
|
||||
// }
|
||||
|
||||
// // Normally we would return the status code if it is an error status (>= 400),
|
||||
// // however, upstream FastCGI apps don't know about our contract and have
|
||||
// // probably already written an error page. So we just return 0, indicating
|
||||
// // that the response body is already written. However, we do return any
|
||||
// // error value so it can be logged.
|
||||
// // Note that the proxy middleware works the same way, returning status=0.
|
||||
// return 0, err
|
||||
}
|
||||
|
||||
// buildEnv returns a set of CGI environment variables for the request.
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
@ -93,15 +94,31 @@ func (h *Handler) activeHealthChecker() {
|
|||
// health checks for all hosts in the global repository.
|
||||
func (h *Handler) doActiveHealthChecksForAllHosts() {
|
||||
hosts.Range(func(key, value interface{}) bool {
|
||||
addr := key.(string)
|
||||
networkAddr := key.(string)
|
||||
host := value.(Host)
|
||||
|
||||
go func(addr string, host Host) {
|
||||
err := h.doActiveHealthCheck(addr, host)
|
||||
go func(networkAddr string, host Host) {
|
||||
network, addrs, err := caddy.ParseNetworkAddress(networkAddr)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", addr, err)
|
||||
log.Printf("[ERROR] reverse_proxy: active health check for host %s: bad network address: %v", networkAddr, err)
|
||||
return
|
||||
}
|
||||
}(addr, host)
|
||||
if len(addrs) != 1 {
|
||||
log.Printf("[ERROR] reverse_proxy: active health check for host %s: multiple addresses (upstream must map to only one address)", networkAddr)
|
||||
return
|
||||
}
|
||||
hostAddr := addrs[0]
|
||||
if network == "unix" || network == "unixgram" || network == "unixpacket" {
|
||||
// this will be used as the Host portion of a http.Request URL, and
|
||||
// paths to socket files would produce an error when creating URL,
|
||||
// so use a fake Host value instead
|
||||
hostAddr = network
|
||||
}
|
||||
err = h.doActiveHealthCheck(DialInfo{network, addrs[0]}, hostAddr, host)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", networkAddr, err)
|
||||
}
|
||||
}(networkAddr, host)
|
||||
|
||||
// continue to iterate all hosts
|
||||
return true
|
||||
|
@ -115,26 +132,39 @@ func (h *Handler) doActiveHealthChecksForAllHosts() {
|
|||
// according to whether it passes the health check. An error is
|
||||
// returned only if the health check fails to occur or if marking
|
||||
// the host's health status fails.
|
||||
func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error {
|
||||
// create the URL for the health check
|
||||
u, err := url.Parse(hostAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host Host) error {
|
||||
// create the URL for the request that acts as a health check
|
||||
scheme := "http"
|
||||
if ht, ok := h.Transport.(*http.Transport); ok && ht.TLSClientConfig != nil {
|
||||
// this is kind of a hacky way to know if we should use HTTPS, but whatever
|
||||
scheme = "https"
|
||||
}
|
||||
if h.HealthChecks.Active.Path != "" {
|
||||
u.Path = h.HealthChecks.Active.Path
|
||||
u := &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: hostAddr,
|
||||
Path: h.HealthChecks.Active.Path,
|
||||
}
|
||||
|
||||
// adjust the port, if configured to be different
|
||||
if h.HealthChecks.Active.Port != 0 {
|
||||
portStr := strconv.Itoa(h.HealthChecks.Active.Port)
|
||||
u.Host = net.JoinHostPort(u.Hostname(), portStr)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
host, _, err := net.SplitHostPort(hostAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
host = hostAddr
|
||||
}
|
||||
u.Host = net.JoinHostPort(host, portStr)
|
||||
}
|
||||
|
||||
// do the request, careful to tame the response body
|
||||
// attach dialing information to this request
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, caddy.ReplacerCtxKey, caddy.NewReplacer())
|
||||
ctx = context.WithValue(ctx, DialInfoCtxKey, dialInfo)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("making request: %v", err)
|
||||
}
|
||||
|
||||
// do the request, being careful to tame the response body
|
||||
resp, err := h.HealthChecks.Active.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[INFO] reverse_proxy: active health check: %s is down (HTTP request failed: %v)", hostAddr, err)
|
||||
|
@ -149,7 +179,7 @@ func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error {
|
|||
body = io.LimitReader(body, h.HealthChecks.Active.MaxSize)
|
||||
}
|
||||
defer func() {
|
||||
// drain any remaining body so connection can be re-used
|
||||
// drain any remaining body so connection could be re-used
|
||||
io.Copy(ioutil.Discard, body)
|
||||
resp.Body.Close()
|
||||
}()
|
||||
|
@ -225,7 +255,7 @@ func (h *Handler) countFailure(upstream *Upstream) {
|
|||
err := upstream.Host.CountFail(1)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] proxy: upstream %s: counting failure: %v",
|
||||
upstream.hostURL, err)
|
||||
upstream.dialInfo, err)
|
||||
}
|
||||
|
||||
// forget it later
|
||||
|
@ -234,7 +264,7 @@ func (h *Handler) countFailure(upstream *Upstream) {
|
|||
err := host.CountFail(-1)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v",
|
||||
upstream.hostURL, err)
|
||||
upstream.dialInfo, err)
|
||||
}
|
||||
}(upstream.Host, failDuration)
|
||||
}
|
||||
|
|
|
@ -16,7 +16,6 @@ package reverseproxy
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
|
@ -59,7 +58,7 @@ type UpstreamPool []*Upstream
|
|||
type Upstream struct {
|
||||
Host `json:"-"`
|
||||
|
||||
Address string `json:"address,omitempty"`
|
||||
Dial string `json:"dial,omitempty"`
|
||||
MaxRequests int `json:"max_requests,omitempty"`
|
||||
|
||||
// TODO: This could be really useful, to bind requests
|
||||
|
@ -68,8 +67,8 @@ type Upstream struct {
|
|||
// IPAffinity string
|
||||
|
||||
healthCheckPolicy *PassiveHealthChecks
|
||||
hostURL *url.URL
|
||||
cb CircuitBreaker
|
||||
dialInfo DialInfo
|
||||
}
|
||||
|
||||
// Available returns true if the remote host
|
||||
|
@ -101,11 +100,6 @@ func (u *Upstream) Full() bool {
|
|||
return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests
|
||||
}
|
||||
|
||||
// URL returns the upstream host's endpoint URL.
|
||||
func (u *Upstream) URL() *url.URL {
|
||||
return u.hostURL
|
||||
}
|
||||
|
||||
// upstreamHost is the basic, in-memory representation
|
||||
// of the state of a remote host. It implements the
|
||||
// Host interface.
|
||||
|
@ -162,6 +156,34 @@ func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) {
|
|||
return swapped, nil
|
||||
}
|
||||
|
||||
// DialInfo contains information needed to dial a
|
||||
// connection to an upstream host. This information
|
||||
// may be different than that which is represented
|
||||
// in a URL (for example, unix sockets don't have
|
||||
// a host that can be represented in a URL, but
|
||||
// they certainly have a network name and address).
|
||||
type DialInfo struct {
|
||||
// The network to use. This should be one of the
|
||||
// values that is accepted by net.Dial:
|
||||
// https://golang.org/pkg/net/#Dial
|
||||
Network string
|
||||
|
||||
// The address to dial. Follows the same
|
||||
// semantics and rules as net.Dial.
|
||||
Address string
|
||||
}
|
||||
|
||||
// String returns the Caddy network address form
|
||||
// by joining the network and address with a
|
||||
// forward slash.
|
||||
func (di DialInfo) String() string {
|
||||
return di.Network + "/" + di.Address
|
||||
}
|
||||
|
||||
// DialInfoCtxKey is used to store a DialInfo
|
||||
// in a context.Context.
|
||||
const DialInfoCtxKey = caddy.CtxKey("dial_info")
|
||||
|
||||
// hosts is the global repository for hosts that are
|
||||
// currently in use by active configuration(s). This
|
||||
// allows the state of remote hosts to be preserved
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
|
@ -63,14 +64,23 @@ func (HTTPTransport) CaddyModule() caddy.ModuleInfo {
|
|||
|
||||
// Provision sets up h.RoundTripper with a http.Transport
|
||||
// that is ready to use.
|
||||
func (h *HTTPTransport) Provision(ctx caddy.Context) error {
|
||||
func (h *HTTPTransport) Provision(_ caddy.Context) error {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: time.Duration(h.DialTimeout),
|
||||
FallbackDelay: time.Duration(h.FallbackDelay),
|
||||
// TODO: Resolver
|
||||
}
|
||||
|
||||
rt := &http.Transport{
|
||||
DialContext: dialer.DialContext,
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// the proper dialing information should be embedded into the request's context
|
||||
if dialInfoVal := ctx.Value(DialInfoCtxKey); dialInfoVal != nil {
|
||||
dialInfo := dialInfoVal.(DialInfo)
|
||||
network = dialInfo.Network
|
||||
address = dialInfo.Address
|
||||
}
|
||||
return dialer.DialContext(ctx, network, address)
|
||||
},
|
||||
MaxConnsPerHost: h.MaxConnsPerHost,
|
||||
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
|
||||
ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout),
|
||||
|
@ -91,7 +101,6 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error {
|
|||
|
||||
if h.KeepAlive != nil {
|
||||
dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval)
|
||||
|
||||
if enabled := h.KeepAlive.Enabled; enabled != nil {
|
||||
rt.DisableKeepAlives = !*enabled
|
||||
}
|
||||
|
@ -191,16 +200,3 @@ type KeepAlive struct {
|
|||
MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"`
|
||||
IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle
|
||||
}
|
||||
|
||||
var (
|
||||
defaultDialer = net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
|
||||
defaultTransport = &http.Transport{
|
||||
DialContext: defaultDialer.DialContext,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
IdleConnTimeout: 2 * time.Minute,
|
||||
}
|
||||
)
|
||||
|
|
|
@ -20,7 +20,6 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -86,7 +85,18 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
|||
}
|
||||
|
||||
if h.Transport == nil {
|
||||
h.Transport = defaultTransport
|
||||
t := &HTTPTransport{
|
||||
KeepAlive: &KeepAlive{
|
||||
ProbeInterval: caddy.Duration(30 * time.Second),
|
||||
IdleConnTimeout: caddy.Duration(2 * time.Minute),
|
||||
},
|
||||
DialTimeout: caddy.Duration(10 * time.Second),
|
||||
}
|
||||
err := t.Provision(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("provisioning default transport: %v", err)
|
||||
}
|
||||
h.Transport = t
|
||||
}
|
||||
|
||||
if h.LoadBalancing == nil {
|
||||
|
@ -133,51 +143,65 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
|||
go h.activeHealthChecker()
|
||||
}
|
||||
|
||||
var allUpstreams []*Upstream
|
||||
for _, upstream := range h.Upstreams {
|
||||
upstream.cb = h.CB
|
||||
|
||||
// url parser requires a scheme
|
||||
if !strings.Contains(upstream.Address, "://") {
|
||||
upstream.Address = "http://" + upstream.Address
|
||||
}
|
||||
u, err := url.Parse(upstream.Address)
|
||||
// upstreams are allowed to map to only a single host,
|
||||
// but an upstream's address may semantically represent
|
||||
// multiple addresses, so make sure to handle each
|
||||
// one in turn based on this one upstream config
|
||||
network, addresses, err := caddy.ParseNetworkAddress(upstream.Dial)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid upstream address %s: %v", upstream.Address, err)
|
||||
return fmt.Errorf("parsing dial address: %v", err)
|
||||
}
|
||||
upstream.hostURL = u
|
||||
|
||||
for _, addr := range addresses {
|
||||
// make a new upstream based on the original
|
||||
// that has a singular dial address
|
||||
upstreamCopy := *upstream
|
||||
upstreamCopy.dialInfo = DialInfo{network, addr}
|
||||
upstreamCopy.Dial = upstreamCopy.dialInfo.String()
|
||||
upstreamCopy.cb = h.CB
|
||||
|
||||
// if host already exists from a current config,
|
||||
// use that instead; otherwise, add it
|
||||
// TODO: make hosts modular, so that their state can be distributed in enterprise for example
|
||||
// TODO: If distributed, the pool should be stored in storage...
|
||||
var host Host = new(upstreamHost)
|
||||
activeHost, loaded := hosts.LoadOrStore(u.String(), host)
|
||||
activeHost, loaded := hosts.LoadOrStore(upstreamCopy.Dial, host)
|
||||
if loaded {
|
||||
host = activeHost.(Host)
|
||||
}
|
||||
upstream.Host = host
|
||||
upstreamCopy.Host = host
|
||||
|
||||
// if the passive health checker has a non-zero "unhealthy
|
||||
// request count" but the upstream has no MaxRequests set
|
||||
// (they are the same thing, but one is a default value for
|
||||
// for upstreams with a zero MaxRequests), copy the default
|
||||
// value into this upstream, since the value in the upstream
|
||||
// is what is used during availability checks
|
||||
// (MaxRequests) is what is used during availability checks
|
||||
if h.HealthChecks != nil &&
|
||||
h.HealthChecks.Passive != nil &&
|
||||
h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
|
||||
upstream.MaxRequests == 0 {
|
||||
upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
|
||||
upstreamCopy.MaxRequests == 0 {
|
||||
upstreamCopy.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
|
||||
}
|
||||
|
||||
if h.HealthChecks != nil {
|
||||
// upstreams need independent access to the passive
|
||||
// health check policy so they can, you know, passively
|
||||
// do health checks
|
||||
upstream.healthCheckPolicy = h.HealthChecks.Passive
|
||||
// health check policy because they run outside of the
|
||||
// scope of a request handler
|
||||
if h.HealthChecks != nil {
|
||||
upstreamCopy.healthCheckPolicy = h.HealthChecks.Passive
|
||||
}
|
||||
|
||||
allUpstreams = append(allUpstreams, &upstreamCopy)
|
||||
}
|
||||
}
|
||||
|
||||
// replace the unmarshaled upstreams (possible 1:many
|
||||
// address mapping) with our list, which is mapped 1:1,
|
||||
// thus may have expanded the original list
|
||||
h.Upstreams = allUpstreams
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -192,7 +216,7 @@ func (h *Handler) Cleanup() error {
|
|||
|
||||
// remove hosts from our config from the pool
|
||||
for _, upstream := range h.Upstreams {
|
||||
hosts.Delete(upstream.hostURL.String())
|
||||
hosts.Delete(upstream.dialInfo.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -222,6 +246,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
|
|||
continue
|
||||
}
|
||||
|
||||
// attach to the request information about how to dial the upstream;
|
||||
// this is necessary because the information cannot be sufficiently
|
||||
// or satisfactorily represented in a URL
|
||||
ctx := context.WithValue(r.Context(), DialInfoCtxKey, upstream.dialInfo)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// proxy the request to that upstream
|
||||
proxyErr = h.reverseProxy(w, r, upstream)
|
||||
if proxyErr == nil || proxyErr == context.Canceled {
|
||||
|
@ -249,6 +279,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
|
|||
// This assumes that no mutations of the request are performed
|
||||
// by h during or after proxying.
|
||||
func (h Handler) prepareRequest(req *http.Request) error {
|
||||
// as a special (but very common) case, if the transport
|
||||
// is HTTP, then ensure the request has the proper scheme
|
||||
// because incoming requests by default are lacking it
|
||||
if req.URL.Scheme == "" {
|
||||
req.URL.Scheme = "http"
|
||||
if ht, ok := h.Transport.(*HTTPTransport); ok && ht.TLS != nil {
|
||||
req.URL.Scheme = "https"
|
||||
}
|
||||
}
|
||||
|
||||
if req.ContentLength == 0 {
|
||||
req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries
|
||||
}
|
||||
|
@ -433,14 +473,8 @@ func (h Handler) tryAgain(start time.Time, proxyErr error) bool {
|
|||
// directRequest modifies only req.URL so that it points to the
|
||||
// given upstream host. It must modify ONLY the request URL.
|
||||
func (h Handler) directRequest(req *http.Request, upstream *Upstream) {
|
||||
target := upstream.hostURL
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) // TODO: This might be a bug (if any part of the path was augmented from a previously-tried upstream; need to start from clean original path of request, same for query string!)
|
||||
if target.RawQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = target.RawQuery + req.URL.RawQuery
|
||||
} else {
|
||||
req.URL.RawQuery = target.RawQuery + "&" + req.URL.RawQuery
|
||||
if req.URL.Host == "" {
|
||||
req.URL.Host = upstream.dialInfo.Address
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -168,7 +168,7 @@ func (s *Server) enforcementHandler(w http.ResponseWriter, r *http.Request, next
|
|||
// listeners in s that use a port which is not otherPort.
|
||||
func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool {
|
||||
for _, lnAddr := range s.Listen {
|
||||
_, addrs, err := caddy.ParseListenAddr(lnAddr)
|
||||
_, addrs, err := caddy.ParseNetworkAddress(lnAddr)
|
||||
if err == nil {
|
||||
for _, a := range addrs {
|
||||
_, port, err := net.SplitHostPort(a)
|
||||
|
|
Loading…
Reference in a new issue