diff --git a/middleware/proxy/policy.go b/middleware/proxy/policy.go index 27e02a17..a2522bcb 100644 --- a/middleware/proxy/policy.go +++ b/middleware/proxy/policy.go @@ -13,6 +13,12 @@ type Policy interface { Select(pool HostPool) *UpstreamHost } +func init() { + RegisterPolicy("random", func() Policy { return &Random{} }) + RegisterPolicy("least_conn", func() Policy { return &LeastConn{} }) + RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} }) +} + // Random is a policy that selects up hosts from a pool at random. type Random struct{} diff --git a/middleware/proxy/policy_test.go b/middleware/proxy/policy_test.go index 11269a4f..d4c75225 100644 --- a/middleware/proxy/policy_test.go +++ b/middleware/proxy/policy_test.go @@ -4,6 +4,12 @@ import ( "testing" ) +type customPolicy struct{} + +func (r *customPolicy) Select(pool HostPool) *UpstreamHost { + return pool[0] +} + func testPool() HostPool { pool := []*UpstreamHost{ &UpstreamHost{ @@ -55,3 +61,12 @@ func TestLeastConnPolicy(t *testing.T) { t.Error("Expected least connection host to be first or second host.") } } + +func TestCustomPolicy(t *testing.T) { + pool := testPool() + customPolicy := &customPolicy{} + h := customPolicy.Select(pool) + if h != pool[0] { + t.Error("Expected custom policy host to be the first host.") + } +} diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index c724da78..a657a088 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -12,6 +12,8 @@ import ( "github.com/mholt/caddy/config/parse" ) +var supportedPolicies map[string]func() Policy = make(map[string]func() Policy) + type staticUpstream struct { from string Hosts HostPool @@ -53,14 +55,10 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { if !c.NextArg() { return upstreams, c.ArgErr() } - switch c.Val() { - case "random": - upstream.Policy = &Random{} - case "round_robin": - upstream.Policy = &RoundRobin{} - case "least_conn": - upstream.Policy = &LeastConn{} - default: + + if policyCreateFunc, ok := supportedPolicies[c.Val()]; ok { + upstream.Policy = policyCreateFunc() + } else { return upstreams, c.ArgErr() } case "fail_timeout": @@ -147,6 +145,11 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { return upstreams, nil } +// RegisterPolicy adds a custom policy to the proxy. +func RegisterPolicy(name string, policy func() Policy) { + supportedPolicies[name] = policy +} + func (u *staticUpstream) From() string { return u.from } diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go index 6be3f6ce..f3df1613 100644 --- a/middleware/proxy/upstream_test.go +++ b/middleware/proxy/upstream_test.go @@ -41,3 +41,13 @@ func TestSelect(t *testing.T) { t.Error("Expected select to not return nil") } } + +func TestRegisterPolicy(t *testing.T) { + name := "custom" + customPolicy := &customPolicy{} + RegisterPolicy(name, func() Policy { return customPolicy }) + if _, ok := supportedPolicies[name]; !ok { + t.Error("Expected supportedPolicies to have a custom policy.") + } + +}