mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-23 22:27:38 -05:00
Upgrade proxy middleware. Add support for: multiple backends, load balancing, health checks, and pluggable backends
This commit is contained in:
parent
782ba32457
commit
4a4b80450a
6 changed files with 763 additions and 56 deletions
91
middleware/proxy/policy.go
Normal file
91
middleware/proxy/policy.go
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HostPool []*UpstreamHost
|
||||||
|
|
||||||
|
// Policy decides how a host will be selected from a pool.
|
||||||
|
type Policy interface {
|
||||||
|
Select(pool HostPool) *UpstreamHost
|
||||||
|
}
|
||||||
|
|
||||||
|
// The random policy randomly selected an up host from the pool.
|
||||||
|
type Random struct{}
|
||||||
|
|
||||||
|
func (r *Random) Select(pool HostPool) *UpstreamHost {
|
||||||
|
// instead of just generating a random index
|
||||||
|
// this is done to prevent selecting a down host
|
||||||
|
var randHost *UpstreamHost
|
||||||
|
count := 0
|
||||||
|
for _, host := range pool {
|
||||||
|
if host.Down() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
if count == 1 {
|
||||||
|
randHost = host
|
||||||
|
} else {
|
||||||
|
r := rand.Int() % count
|
||||||
|
if r == (count - 1) {
|
||||||
|
randHost = host
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return randHost
|
||||||
|
}
|
||||||
|
|
||||||
|
// The least_conn policy selects a host with the least connections.
|
||||||
|
// If multiple hosts have the least amount of connections, one is randomly
|
||||||
|
// chosen.
|
||||||
|
type LeastConn struct{}
|
||||||
|
|
||||||
|
func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
|
||||||
|
var bestHost *UpstreamHost
|
||||||
|
count := 0
|
||||||
|
leastConn := int64(1<<63 - 1)
|
||||||
|
for _, host := range pool {
|
||||||
|
if host.Down() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hostConns := host.Conns
|
||||||
|
if hostConns < leastConn {
|
||||||
|
bestHost = host
|
||||||
|
leastConn = hostConns
|
||||||
|
count = 1
|
||||||
|
} else if hostConns == leastConn {
|
||||||
|
// randomly select host among hosts with least connections
|
||||||
|
count++
|
||||||
|
if count == 1 {
|
||||||
|
bestHost = host
|
||||||
|
} else {
|
||||||
|
r := rand.Int() % count
|
||||||
|
if r == (count - 1) {
|
||||||
|
bestHost = host
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bestHost
|
||||||
|
}
|
||||||
|
|
||||||
|
// The round_robin policy selects a host based on round robin ordering.
|
||||||
|
type RoundRobin struct {
|
||||||
|
Robin uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
|
||||||
|
poolLen := uint32(len(pool))
|
||||||
|
selection := atomic.AddUint32(&r.Robin, 1) % poolLen
|
||||||
|
host := pool[selection]
|
||||||
|
// if the currently selected host is down, just ffwd to up host
|
||||||
|
for i := uint32(1); host.Down() && i < poolLen; i++ {
|
||||||
|
host = pool[(selection+i)%poolLen]
|
||||||
|
}
|
||||||
|
if host.Down() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
57
middleware/proxy/policy_test.go
Normal file
57
middleware/proxy/policy_test.go
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testPool() HostPool {
|
||||||
|
pool := []*UpstreamHost{
|
||||||
|
&UpstreamHost{
|
||||||
|
Name: "http://google.com", // this should resolve (healthcheck test)
|
||||||
|
},
|
||||||
|
&UpstreamHost{
|
||||||
|
Name: "http://shouldnot.resolve", // this shouldn't
|
||||||
|
},
|
||||||
|
&UpstreamHost{
|
||||||
|
Name: "http://C",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return HostPool(pool)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinPolicy(t *testing.T) {
|
||||||
|
pool := testPool()
|
||||||
|
rrPolicy := &RoundRobin{}
|
||||||
|
h := rrPolicy.Select(pool)
|
||||||
|
// First selected host is 1, because counter starts at 0
|
||||||
|
// and increments before host is selected
|
||||||
|
if h != pool[1] {
|
||||||
|
t.Error("Expected first round robin host to be second host in the pool.")
|
||||||
|
}
|
||||||
|
h = rrPolicy.Select(pool)
|
||||||
|
if h != pool[2] {
|
||||||
|
t.Error("Expected second round robin host to be third host in the pool.")
|
||||||
|
}
|
||||||
|
// mark host as down
|
||||||
|
pool[0].Unhealthy = true
|
||||||
|
h = rrPolicy.Select(pool)
|
||||||
|
if h != pool[1] {
|
||||||
|
t.Error("Expected third round robin host to be first host in the pool.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLeastConnPolicy(t *testing.T) {
|
||||||
|
pool := testPool()
|
||||||
|
lcPolicy := &LeastConn{}
|
||||||
|
pool[0].Conns = 10
|
||||||
|
pool[1].Conns = 10
|
||||||
|
h := lcPolicy.Select(pool)
|
||||||
|
if h != pool[2] {
|
||||||
|
t.Error("Expected least connection host to be third host.")
|
||||||
|
}
|
||||||
|
pool[2].Conns = 100
|
||||||
|
h = lcPolicy.Select(pool)
|
||||||
|
if h != pool[0] && h != pool[1] {
|
||||||
|
t.Error("Expected least connection host to be first or second host.")
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,52 +2,169 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"errors"
|
||||||
"net/http/httputil"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/mholt/caddy/middleware"
|
"github.com/mholt/caddy/middleware"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errUnreachable = errors.New("Unreachable backend")
|
||||||
|
|
||||||
// Proxy represents a middleware instance that can proxy requests.
|
// Proxy represents a middleware instance that can proxy requests.
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
Next middleware.Handler
|
Next middleware.Handler
|
||||||
Rules []Rule
|
Upstreams []Upstream
|
||||||
|
}
|
||||||
|
|
||||||
|
// An upstream manages a pool of proxy upstream hosts. Select should return a
|
||||||
|
// suitable upstream host, or nil if no such hosts are available.
|
||||||
|
type Upstream interface {
|
||||||
|
// The path this upstream host should be routed on
|
||||||
|
From() string
|
||||||
|
// Selects an upstream host to be routed to.
|
||||||
|
Select() *UpstreamHost
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpstreamHostDownFunc func(*UpstreamHost) bool
|
||||||
|
|
||||||
|
// An UpstreamHost represents a single proxy upstream
|
||||||
|
type UpstreamHost struct {
|
||||||
|
Name string
|
||||||
|
ReverseProxy *ReverseProxy
|
||||||
|
Conns int64
|
||||||
|
Fails int32
|
||||||
|
FailTimeout time.Duration
|
||||||
|
Unhealthy bool
|
||||||
|
ExtraHeaders http.Header
|
||||||
|
CheckDown UpstreamHostDownFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uh *UpstreamHost) Down() bool {
|
||||||
|
if uh.CheckDown == nil {
|
||||||
|
// Default settings
|
||||||
|
return uh.Unhealthy || uh.Fails > 0
|
||||||
|
}
|
||||||
|
return uh.CheckDown(uh)
|
||||||
|
}
|
||||||
|
|
||||||
|
//https://github.com/mgutz/str
|
||||||
|
var tRe = regexp.MustCompile(`([\-\[\]()*\s])`)
|
||||||
|
var tRe2 = regexp.MustCompile(`\$`)
|
||||||
|
var openDelim = tRe2.ReplaceAllString(tRe.ReplaceAllString("{{", "\\$1"), "\\$")
|
||||||
|
var closDelim = tRe2.ReplaceAllString(tRe.ReplaceAllString("}}", "\\$1"), "\\$")
|
||||||
|
var templateDelim = regexp.MustCompile(openDelim + `(.+?)` + closDelim)
|
||||||
|
|
||||||
|
type requestVars struct {
|
||||||
|
Host string
|
||||||
|
RemoteIp string
|
||||||
|
Scheme string
|
||||||
|
Upstream string
|
||||||
|
UpstreamHost string
|
||||||
|
}
|
||||||
|
|
||||||
|
func templateWithDelimiters(s string, vars requestVars) string {
|
||||||
|
matches := templateDelim.FindAllStringSubmatch(s, -1)
|
||||||
|
for _, submatches := range matches {
|
||||||
|
match := submatches[0]
|
||||||
|
key := submatches[1]
|
||||||
|
found := true
|
||||||
|
repl := ""
|
||||||
|
switch key {
|
||||||
|
case "http_host":
|
||||||
|
repl = vars.Host
|
||||||
|
case "remote_addr":
|
||||||
|
repl = vars.RemoteIp
|
||||||
|
case "scheme":
|
||||||
|
repl = vars.Scheme
|
||||||
|
case "upstream":
|
||||||
|
repl = vars.Upstream
|
||||||
|
case "upstream_host":
|
||||||
|
repl = vars.UpstreamHost
|
||||||
|
default:
|
||||||
|
found = false
|
||||||
|
}
|
||||||
|
if found {
|
||||||
|
s = strings.Replace(s, match, repl, -1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeHTTP satisfies the middleware.Handler interface.
|
// ServeHTTP satisfies the middleware.Handler interface.
|
||||||
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
|
|
||||||
for _, rule := range p.Rules {
|
for _, upstream := range p.Upstreams {
|
||||||
if middleware.Path(r.URL.Path).Matches(rule.From) {
|
if middleware.Path(r.URL.Path).Matches(upstream.From()) {
|
||||||
var base string
|
vars := requestVars{
|
||||||
|
Host: r.Host,
|
||||||
if strings.HasPrefix(rule.To, "http") { // includes https
|
Scheme: "http",
|
||||||
// destination includes a scheme! no need to guess
|
|
||||||
base = rule.To
|
|
||||||
} else {
|
|
||||||
// no scheme specified; assume same as request
|
|
||||||
var scheme string
|
|
||||||
if r.TLS == nil {
|
|
||||||
scheme = "http"
|
|
||||||
} else {
|
|
||||||
scheme = "https"
|
|
||||||
}
|
}
|
||||||
base = scheme + "://" + rule.To
|
if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||||
|
vars.RemoteIp = clientIP
|
||||||
}
|
}
|
||||||
|
if fFor := r.Header.Get("X-Forwarded-For"); fFor != "" {
|
||||||
|
vars.RemoteIp = fFor
|
||||||
|
}
|
||||||
|
if r.TLS != nil {
|
||||||
|
vars.Scheme = "https"
|
||||||
|
}
|
||||||
|
// Since Select() should give us "up" hosts, keep retrying
|
||||||
|
// hosts until timeout (or until we get a nil host).
|
||||||
|
start := time.Now()
|
||||||
|
for time.Now().Sub(start) < (60 * time.Second) {
|
||||||
|
host := upstream.Select()
|
||||||
|
if host == nil {
|
||||||
|
return http.StatusBadGateway, errUnreachable
|
||||||
|
}
|
||||||
|
proxy := host.ReverseProxy
|
||||||
|
vars.Upstream = host.Name
|
||||||
|
r.Host = host.Name
|
||||||
|
|
||||||
baseUrl, err := url.Parse(base)
|
if baseUrl, err := url.Parse(host.Name); err == nil {
|
||||||
if err != nil {
|
vars.UpstreamHost = baseUrl.Host
|
||||||
|
if proxy == nil {
|
||||||
|
proxy = NewSingleHostReverseProxy(baseUrl)
|
||||||
|
}
|
||||||
|
} else if proxy == nil {
|
||||||
return http.StatusInternalServerError, err
|
return http.StatusInternalServerError, err
|
||||||
}
|
}
|
||||||
r.Host = baseUrl.Host
|
var extraHeaders http.Header
|
||||||
|
if host.ExtraHeaders != nil {
|
||||||
|
extraHeaders = make(http.Header)
|
||||||
|
for header, values := range host.ExtraHeaders {
|
||||||
|
for _, value := range values {
|
||||||
|
extraHeaders.Add(header,
|
||||||
|
templateWithDelimiters(value, vars))
|
||||||
|
if header == "Host" {
|
||||||
|
r.Host = templateWithDelimiters(value, vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Construct this before; not during every request, if possible
|
atomic.AddInt64(&host.Conns, 1)
|
||||||
proxy := httputil.NewSingleHostReverseProxy(baseUrl)
|
backendErr := proxy.ServeHTTP(w, r, extraHeaders)
|
||||||
proxy.ServeHTTP(w, r)
|
atomic.AddInt64(&host.Conns, -1)
|
||||||
|
if backendErr == nil {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
timeout := host.FailTimeout
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = 10 * time.Second
|
||||||
|
}
|
||||||
|
atomic.AddInt32(&host.Fails, 1)
|
||||||
|
go func(host *UpstreamHost, timeout time.Duration) {
|
||||||
|
time.Sleep(timeout)
|
||||||
|
atomic.AddInt32(&host.Fails, -1)
|
||||||
|
}(host, timeout)
|
||||||
|
}
|
||||||
|
return http.StatusBadGateway, errUnreachable
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.Next.ServeHTTP(w, r)
|
return p.Next.ServeHTTP(w, r)
|
||||||
|
@ -55,30 +172,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
|
|
||||||
// New creates a new instance of proxy middleware.
|
// New creates a new instance of proxy middleware.
|
||||||
func New(c middleware.Controller) (middleware.Middleware, error) {
|
func New(c middleware.Controller) (middleware.Middleware, error) {
|
||||||
rules, err := parse(c)
|
if upstreams, err := newStaticUpstreams(c); err == nil {
|
||||||
if err != nil {
|
return func(next middleware.Handler) middleware.Handler {
|
||||||
|
return Proxy{Next: next, Upstreams: upstreams}
|
||||||
|
}, nil
|
||||||
|
} else {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(next middleware.Handler) middleware.Handler {
|
|
||||||
return Proxy{Next: next, Rules: rules}
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parse(c middleware.Controller) ([]Rule, error) {
|
|
||||||
var rules []Rule
|
|
||||||
|
|
||||||
for c.Next() {
|
|
||||||
var rule Rule
|
|
||||||
if !c.Args(&rule.From, &rule.To) {
|
|
||||||
return rules, c.ArgErr()
|
|
||||||
}
|
|
||||||
rules = append(rules, rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rules, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Rule struct {
|
|
||||||
From, To string
|
|
||||||
}
|
}
|
||||||
|
|
215
middleware/proxy/reverseproxy.go
Normal file
215
middleware/proxy/reverseproxy.go
Normal file
|
@ -0,0 +1,215 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// HTTP reverse proxy handler
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// onExitFlushLoop is a callback set by tests to detect the state of the
|
||||||
|
// flushLoop() goroutine.
|
||||||
|
var onExitFlushLoop func()
|
||||||
|
|
||||||
|
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
||||||
|
// sends it to another server, proxying the response back to the
|
||||||
|
// client.
|
||||||
|
type ReverseProxy struct {
|
||||||
|
// Director must be a function which modifies
|
||||||
|
// the request into a new request to be sent
|
||||||
|
// using Transport. Its response is then copied
|
||||||
|
// back to the original client unmodified.
|
||||||
|
Director func(*http.Request)
|
||||||
|
|
||||||
|
// The transport used to perform proxy requests.
|
||||||
|
// If nil, http.DefaultTransport is used.
|
||||||
|
Transport http.RoundTripper
|
||||||
|
|
||||||
|
// FlushInterval specifies the flush interval
|
||||||
|
// to flush to the client while copying the
|
||||||
|
// response body.
|
||||||
|
// If zero, no periodic flushing is done.
|
||||||
|
FlushInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func singleJoiningSlash(a, b string) string {
|
||||||
|
aslash := strings.HasSuffix(a, "/")
|
||||||
|
bslash := strings.HasPrefix(b, "/")
|
||||||
|
switch {
|
||||||
|
case aslash && bslash:
|
||||||
|
return a + b[1:]
|
||||||
|
case !aslash && !bslash:
|
||||||
|
return a + "/" + b
|
||||||
|
}
|
||||||
|
return a + b
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
|
||||||
|
// URLs to the scheme, host, and base path provided in target. If the
|
||||||
|
// target's path is "/base" and the incoming request was for "/dir",
|
||||||
|
// the target request will be for /base/dir.
|
||||||
|
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
|
||||||
|
targetQuery := target.RawQuery
|
||||||
|
director := func(req *http.Request) {
|
||||||
|
req.URL.Scheme = target.Scheme
|
||||||
|
req.URL.Host = target.Host
|
||||||
|
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
|
||||||
|
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||||
|
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||||
|
} else {
|
||||||
|
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &ReverseProxy{Director: director}
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyHeader(dst, src http.Header) {
|
||||||
|
for k, vv := range src {
|
||||||
|
for _, v := range vv {
|
||||||
|
dst.Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||||
|
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
||||||
|
var hopHeaders = []string{
|
||||||
|
"Connection",
|
||||||
|
"Keep-Alive",
|
||||||
|
"Proxy-Authenticate",
|
||||||
|
"Proxy-Authorization",
|
||||||
|
"Te", // canonicalized version of "TE"
|
||||||
|
"Trailers",
|
||||||
|
"Transfer-Encoding",
|
||||||
|
"Upgrade",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extraHeaders http.Header) error {
|
||||||
|
transport := p.Transport
|
||||||
|
if transport == nil {
|
||||||
|
transport = http.DefaultTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
outreq := new(http.Request)
|
||||||
|
*outreq = *req // includes shallow copies of maps, but okay
|
||||||
|
|
||||||
|
p.Director(outreq)
|
||||||
|
outreq.Proto = "HTTP/1.1"
|
||||||
|
outreq.ProtoMajor = 1
|
||||||
|
outreq.ProtoMinor = 1
|
||||||
|
outreq.Close = false
|
||||||
|
|
||||||
|
// Remove hop-by-hop headers to the backend. Especially
|
||||||
|
// important is "Connection" because we want a persistent
|
||||||
|
// connection, regardless of what the client sent to us. This
|
||||||
|
// is modifying the same underlying map from req (shallow
|
||||||
|
// copied above) so we only copy it if necessary.
|
||||||
|
copiedHeaders := false
|
||||||
|
for _, h := range hopHeaders {
|
||||||
|
if outreq.Header.Get(h) != "" {
|
||||||
|
if !copiedHeaders {
|
||||||
|
outreq.Header = make(http.Header)
|
||||||
|
copyHeader(outreq.Header, req.Header)
|
||||||
|
copiedHeaders = true
|
||||||
|
}
|
||||||
|
outreq.Header.Del(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||||
|
// If we aren't the first proxy retain prior
|
||||||
|
// X-Forwarded-For information as a comma+space
|
||||||
|
// separated list and fold multiple headers into one.
|
||||||
|
if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
|
||||||
|
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||||
|
}
|
||||||
|
outreq.Header.Set("X-Forwarded-For", clientIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
if extraHeaders != nil {
|
||||||
|
for k, v := range extraHeaders {
|
||||||
|
outreq.Header[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := transport.RoundTrip(outreq)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
for _, h := range hopHeaders {
|
||||||
|
res.Header.Del(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
copyHeader(rw.Header(), res.Header)
|
||||||
|
|
||||||
|
rw.WriteHeader(res.StatusCode)
|
||||||
|
p.copyResponse(rw, res.Body)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
||||||
|
if p.FlushInterval != 0 {
|
||||||
|
if wf, ok := dst.(writeFlusher); ok {
|
||||||
|
mlw := &maxLatencyWriter{
|
||||||
|
dst: wf,
|
||||||
|
latency: p.FlushInterval,
|
||||||
|
done: make(chan bool),
|
||||||
|
}
|
||||||
|
go mlw.flushLoop()
|
||||||
|
defer mlw.stop()
|
||||||
|
dst = mlw
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
io.Copy(dst, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
type writeFlusher interface {
|
||||||
|
io.Writer
|
||||||
|
http.Flusher
|
||||||
|
}
|
||||||
|
|
||||||
|
type maxLatencyWriter struct {
|
||||||
|
dst writeFlusher
|
||||||
|
latency time.Duration
|
||||||
|
|
||||||
|
lk sync.Mutex // protects Write + Flush
|
||||||
|
done chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
|
||||||
|
m.lk.Lock()
|
||||||
|
defer m.lk.Unlock()
|
||||||
|
return m.dst.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *maxLatencyWriter) flushLoop() {
|
||||||
|
t := time.NewTicker(m.latency)
|
||||||
|
defer t.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-m.done:
|
||||||
|
if onExitFlushLoop != nil {
|
||||||
|
onExitFlushLoop()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case <-t.C:
|
||||||
|
m.lk.Lock()
|
||||||
|
m.dst.Flush()
|
||||||
|
m.lk.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *maxLatencyWriter) stop() { m.done <- true }
|
203
middleware/proxy/upstream.go
Normal file
203
middleware/proxy/upstream.go
Normal file
|
@ -0,0 +1,203 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/mholt/caddy/middleware"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type staticUpstream struct {
|
||||||
|
from string
|
||||||
|
Hosts HostPool
|
||||||
|
Policy Policy
|
||||||
|
|
||||||
|
FailTimeout time.Duration
|
||||||
|
MaxFails int32
|
||||||
|
HealthCheck struct {
|
||||||
|
Path string
|
||||||
|
Interval time.Duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStaticUpstreams(c middleware.Controller) ([]Upstream, error) {
|
||||||
|
var upstreams []Upstream
|
||||||
|
|
||||||
|
for c.Next() {
|
||||||
|
upstream := &staticUpstream{
|
||||||
|
from: "",
|
||||||
|
Hosts: nil,
|
||||||
|
Policy: &Random{},
|
||||||
|
FailTimeout: 10 * time.Second,
|
||||||
|
MaxFails: 1,
|
||||||
|
}
|
||||||
|
var proxyHeaders http.Header
|
||||||
|
if !c.Args(&upstream.from) {
|
||||||
|
return upstreams, c.ArgErr()
|
||||||
|
}
|
||||||
|
to := c.RemainingArgs()
|
||||||
|
if len(to) == 0 {
|
||||||
|
return upstreams, c.ArgErr()
|
||||||
|
}
|
||||||
|
|
||||||
|
for c.NextBlock() {
|
||||||
|
switch c.Val() {
|
||||||
|
case "policy":
|
||||||
|
if !c.NextArg() {
|
||||||
|
return upstreams, c.ArgErr()
|
||||||
|
}
|
||||||
|
switch c.Val() {
|
||||||
|
case "random":
|
||||||
|
upstream.Policy = &Random{}
|
||||||
|
case "round_robin":
|
||||||
|
upstream.Policy = &RoundRobin{}
|
||||||
|
case "least_conn":
|
||||||
|
upstream.Policy = &LeastConn{}
|
||||||
|
default:
|
||||||
|
return upstreams, c.ArgErr()
|
||||||
|
}
|
||||||
|
case "fail_timeout":
|
||||||
|
if !c.NextArg() {
|
||||||
|
return upstreams, c.ArgErr()
|
||||||
|
}
|
||||||
|
if dur, err := time.ParseDuration(c.Val()); err == nil {
|
||||||
|
upstream.FailTimeout = dur
|
||||||
|
} else {
|
||||||
|
return upstreams, err
|
||||||
|
}
|
||||||
|
case "max_fails":
|
||||||
|
if !c.NextArg() {
|
||||||
|
return upstreams, c.ArgErr()
|
||||||
|
}
|
||||||
|
if n, err := strconv.Atoi(c.Val()); err == nil {
|
||||||
|
upstream.MaxFails = int32(n)
|
||||||
|
} else {
|
||||||
|
return upstreams, err
|
||||||
|
}
|
||||||
|
case "health_check":
|
||||||
|
if !c.NextArg() {
|
||||||
|
return upstreams, c.ArgErr()
|
||||||
|
}
|
||||||
|
upstream.HealthCheck.Path = c.Val()
|
||||||
|
upstream.HealthCheck.Interval = 30 * time.Second
|
||||||
|
if c.NextArg() {
|
||||||
|
if dur, err := time.ParseDuration(c.Val()); err == nil {
|
||||||
|
upstream.HealthCheck.Interval = dur
|
||||||
|
} else {
|
||||||
|
return upstreams, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "proxy_header":
|
||||||
|
var header, value string
|
||||||
|
if !c.Args(&header, &value) {
|
||||||
|
return upstreams, c.ArgErr()
|
||||||
|
}
|
||||||
|
if proxyHeaders == nil {
|
||||||
|
proxyHeaders = make(map[string][]string)
|
||||||
|
}
|
||||||
|
proxyHeaders.Add(header, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
upstream.Hosts = make([]*UpstreamHost, len(to))
|
||||||
|
for i, host := range to {
|
||||||
|
if !strings.HasPrefix(host, "http") {
|
||||||
|
host = "http://" + host
|
||||||
|
}
|
||||||
|
uh := &UpstreamHost{
|
||||||
|
Name: host,
|
||||||
|
Conns: 0,
|
||||||
|
Fails: 0,
|
||||||
|
FailTimeout: upstream.FailTimeout,
|
||||||
|
Unhealthy: false,
|
||||||
|
ExtraHeaders: proxyHeaders,
|
||||||
|
CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc {
|
||||||
|
return func(uh *UpstreamHost) bool {
|
||||||
|
if uh.Unhealthy {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if uh.Fails >= upstream.MaxFails &&
|
||||||
|
upstream.MaxFails != 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}(upstream),
|
||||||
|
}
|
||||||
|
if baseUrl, err := url.Parse(uh.Name); err == nil {
|
||||||
|
uh.ReverseProxy = NewSingleHostReverseProxy(baseUrl)
|
||||||
|
} else {
|
||||||
|
return upstreams, err
|
||||||
|
}
|
||||||
|
upstream.Hosts[i] = uh
|
||||||
|
}
|
||||||
|
|
||||||
|
if upstream.HealthCheck.Path != "" {
|
||||||
|
go upstream.healthCheckWorker(nil)
|
||||||
|
}
|
||||||
|
upstreams = append(upstreams, upstream)
|
||||||
|
}
|
||||||
|
return upstreams, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *staticUpstream) healthCheck() {
|
||||||
|
for _, host := range u.Hosts {
|
||||||
|
hostUrl := host.Name + u.HealthCheck.Path
|
||||||
|
if r, err := http.Get(hostUrl); err == nil {
|
||||||
|
io.Copy(ioutil.Discard, r.Body)
|
||||||
|
r.Body.Close()
|
||||||
|
host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
|
||||||
|
} else {
|
||||||
|
host.Unhealthy = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *staticUpstream) healthCheckWorker(stop chan struct{}) {
|
||||||
|
ticker := time.NewTicker(u.HealthCheck.Interval)
|
||||||
|
u.healthCheck()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
u.healthCheck()
|
||||||
|
case <-stop:
|
||||||
|
// TODO: the library should provide a stop channel and global
|
||||||
|
// waitgroup to allow goroutines started by plugins a chance
|
||||||
|
// to clean themselves up.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *staticUpstream) From() string {
|
||||||
|
return u.from
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *staticUpstream) Select() *UpstreamHost {
|
||||||
|
pool := u.Hosts
|
||||||
|
if len(pool) == 1 {
|
||||||
|
if pool[0].Down() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return pool[0]
|
||||||
|
}
|
||||||
|
allDown := true
|
||||||
|
for _, host := range pool {
|
||||||
|
if !host.Down() {
|
||||||
|
allDown = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if allDown {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Policy == nil {
|
||||||
|
return (&Random{}).Select(pool)
|
||||||
|
} else {
|
||||||
|
return u.Policy.Select(pool)
|
||||||
|
}
|
||||||
|
}
|
43
middleware/proxy/upstream_test.go
Normal file
43
middleware/proxy/upstream_test.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHealthCheck(t *testing.T) {
|
||||||
|
upstream := &staticUpstream{
|
||||||
|
from: "",
|
||||||
|
Hosts: testPool(),
|
||||||
|
Policy: &Random{},
|
||||||
|
FailTimeout: 10 * time.Second,
|
||||||
|
MaxFails: 1,
|
||||||
|
}
|
||||||
|
upstream.healthCheck()
|
||||||
|
if upstream.Hosts[0].Down() {
|
||||||
|
t.Error("Expected first host in testpool to not fail healthcheck.")
|
||||||
|
}
|
||||||
|
if !upstream.Hosts[1].Down() {
|
||||||
|
t.Error("Expected second host in testpool to fail healthcheck.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelect(t *testing.T) {
|
||||||
|
upstream := &staticUpstream{
|
||||||
|
from: "",
|
||||||
|
Hosts: testPool()[:3],
|
||||||
|
Policy: &Random{},
|
||||||
|
FailTimeout: 10 * time.Second,
|
||||||
|
MaxFails: 1,
|
||||||
|
}
|
||||||
|
upstream.Hosts[0].Unhealthy = true
|
||||||
|
upstream.Hosts[1].Unhealthy = true
|
||||||
|
upstream.Hosts[2].Unhealthy = true
|
||||||
|
if h := upstream.Select(); h != nil {
|
||||||
|
t.Error("Expected select to return nil as all host are down")
|
||||||
|
}
|
||||||
|
upstream.Hosts[2].Unhealthy = false
|
||||||
|
if h := upstream.Select(); h == nil {
|
||||||
|
t.Error("Expected select to not return nil")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue