diff --git a/middleware/proxy/policy.go b/middleware/proxy/policy.go index f90a1e23..27e02a17 100644 --- a/middleware/proxy/policy.go +++ b/middleware/proxy/policy.go @@ -11,7 +11,6 @@ type HostPool []*UpstreamHost // Policy decides how a host will be selected from a pool. type Policy interface { Select(pool HostPool) *UpstreamHost - Name() string } // Random is a policy that selects up hosts from a pool at random. @@ -40,11 +39,6 @@ func (r *Random) Select(pool HostPool) *UpstreamHost { return randHost } -// Name returns the name of the policy. -func (r *Random) Name() string { - return "random" -} - // LeastConn is a policy that selects the host with the least connections. type LeastConn struct{} @@ -80,11 +74,6 @@ func (r *LeastConn) Select(pool HostPool) *UpstreamHost { return bestHost } -// Name returns the name of the policy. -func (r *LeastConn) Name() string { - return "least_conn" -} - // RoundRobin is a policy that selects hosts based on round robin ordering. type RoundRobin struct { Robin uint32 @@ -104,8 +93,3 @@ func (r *RoundRobin) Select(pool HostPool) *UpstreamHost { } return host } - -// Name returns the name of the policy. -func (r *RoundRobin) Name() string { - return "round_robin" -} diff --git a/middleware/proxy/policy_test.go b/middleware/proxy/policy_test.go index 4272e332..d4c75225 100644 --- a/middleware/proxy/policy_test.go +++ b/middleware/proxy/policy_test.go @@ -10,10 +10,6 @@ func (r *customPolicy) Select(pool HostPool) *UpstreamHost { return pool[0] } -func (r *customPolicy) Name() string { - return "custom" -} - func testPool() HostPool { pool := []*UpstreamHost{ &UpstreamHost{ diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index 58325f1d..03bba0e6 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -12,7 +12,7 @@ import ( "github.com/mholt/caddy/config/parse" ) -var supportedPolicies map[string]Policy = make(map[string]Policy) +var supportedPolicies map[string]func() Policy = make(map[string]func() Policy) type staticUpstream struct { from string @@ -27,15 +27,17 @@ type staticUpstream struct { } } +func init() { + RegisterPolicy("random", func() Policy { return &Random{} }) + RegisterPolicy("least_conn", func() Policy { return &LeastConn{} }) + RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} }) +} + // NewStaticUpstreams parses the configuration input and sets up // static upstreams for the proxy middleware. func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { var upstreams []Upstream - RegisterPolicy(&Random{}) - RegisterPolicy(&LeastConn{}) - RegisterPolicy(&RoundRobin{}) - for c.Next() { upstream := &staticUpstream{ from: "", @@ -60,11 +62,11 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { return upstreams, c.ArgErr() } - policy, ok := supportedPolicies[c.Val()] - if !ok { + if policyCreateFunc, ok := supportedPolicies[c.Val()]; ok { + upstream.Policy = policyCreateFunc() + } else { return upstreams, c.ArgErr() } - upstream.Policy = policy case "fail_timeout": if !c.NextArg() { return upstreams, c.ArgErr() @@ -150,8 +152,8 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { } // RegisterPolicy adds a custom policy to the proxy. -func RegisterPolicy(policy Policy) { - supportedPolicies[policy.Name()] = policy +func RegisterPolicy(name string, policy func() Policy) { + supportedPolicies[name] = policy } func (u *staticUpstream) From() string { diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go index 1d1cc317..f3df1613 100644 --- a/middleware/proxy/upstream_test.go +++ b/middleware/proxy/upstream_test.go @@ -43,9 +43,10 @@ func TestSelect(t *testing.T) { } func TestRegisterPolicy(t *testing.T) { + name := "custom" customPolicy := &customPolicy{} - RegisterPolicy(customPolicy) - if _, ok := supportedPolicies[customPolicy.Name()]; !ok { + RegisterPolicy(name, func() Policy { return customPolicy }) + if _, ok := supportedPolicies[name]; !ok { t.Error("Expected supportedPolicies to have a custom policy.") }