diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 305a0b38..7fc61ae6 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -368,7 +368,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht var proxyErr error for { // choose an available upstream - upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, r) + upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, r, w) if upstream == nil { if proxyErr == nil { proxyErr = fmt.Errorf("no upstreams available") @@ -816,7 +816,7 @@ type LoadBalancing struct { // Selector selects an available upstream from the pool. type Selector interface { - Select(UpstreamPool, *http.Request) *Upstream + Select(UpstreamPool, *http.Request, http.ResponseWriter) *Upstream } // Hop-by-hop headers. These are removed when sent to the backend. diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go index 2aef63dc..a1010f4b 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go @@ -15,6 +15,9 @@ package reverseproxy import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" "fmt" "hash/fnv" weakrand "math/rand" @@ -37,6 +40,7 @@ func init() { caddy.RegisterModule(IPHashSelection{}) caddy.RegisterModule(URIHashSelection{}) caddy.RegisterModule(HeaderHashSelection{}) + caddy.RegisterModule(CookieHashSelection{}) weakrand.Seed(time.Now().UTC().UnixNano()) } @@ -54,24 +58,8 @@ func (RandomSelection) CaddyModule() caddy.ModuleInfo { } // Select returns an available host, if any. -func (r RandomSelection) Select(pool UpstreamPool, request *http.Request) *Upstream { - // use reservoir sampling because the number of available - // hosts isn't known: https://en.wikipedia.org/wiki/Reservoir_sampling - var randomHost *Upstream - var count int - for _, upstream := range pool { - if !upstream.Available() { - continue - } - // (n % 1 == 0) holds for all n, therefore a - // upstream will always be chosen if there is at - // least one available - count++ - if (weakrand.Int() % count) == 0 { - randomHost = upstream - } - } - return randomHost +func (r RandomSelection) Select(pool UpstreamPool, request *http.Request, _ http.ResponseWriter) *Upstream { + return selectRandomHost(pool) } // UnmarshalCaddyfile sets up the module from Caddyfile tokens. @@ -134,7 +122,7 @@ func (r RandomChoiceSelection) Validate() error { } // Select returns an available host, if any. -func (r RandomChoiceSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream { +func (r RandomChoiceSelection) Select(pool UpstreamPool, _ *http.Request, _ http.ResponseWriter) *Upstream { k := r.Choose if k > len(pool) { k = len(pool) @@ -174,7 +162,7 @@ func (LeastConnSelection) CaddyModule() caddy.ModuleInfo { // Select selects the up host with the least number of connections in the // pool. If more than one host has the same least number of connections, // one of the hosts is chosen at random. -func (LeastConnSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream { +func (LeastConnSelection) Select(pool UpstreamPool, _ *http.Request, _ http.ResponseWriter) *Upstream { var bestHost *Upstream var count int leastReqs := -1 @@ -227,7 +215,7 @@ func (RoundRobinSelection) CaddyModule() caddy.ModuleInfo { } // Select returns an available host, if any. -func (r *RoundRobinSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream { +func (r *RoundRobinSelection) Select(pool UpstreamPool, _ *http.Request, _ http.ResponseWriter) *Upstream { n := uint32(len(pool)) if n == 0 { return nil @@ -265,7 +253,7 @@ func (FirstSelection) CaddyModule() caddy.ModuleInfo { } // Select returns an available host, if any. -func (FirstSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream { +func (FirstSelection) Select(pool UpstreamPool, _ *http.Request, _ http.ResponseWriter) *Upstream { for _, host := range pool { if host.Available() { return host @@ -297,7 +285,7 @@ func (IPHashSelection) CaddyModule() caddy.ModuleInfo { } // Select returns an available host, if any. -func (IPHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream { +func (IPHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream { clientIP, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { clientIP = req.RemoteAddr @@ -328,7 +316,7 @@ func (URIHashSelection) CaddyModule() caddy.ModuleInfo { } // Select returns an available host, if any. -func (URIHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream { +func (URIHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream { return hostByHashing(pool, req.RequestURI) } @@ -358,7 +346,7 @@ func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo { } // Select returns an available host, if any. -func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream { +func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream { if s.Field == "" { return nil } @@ -371,7 +359,7 @@ func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstr val := req.Header.Get(s.Field) if val == "" { - return RandomSelection{}.Select(pool, req) + return RandomSelection{}.Select(pool, req, nil) } return hostByHashing(pool, val) } @@ -387,6 +375,114 @@ func (s *HeaderHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { return nil } +// CookieHashSelection is a policy that selects +// a host based on a given cookie name. +type CookieHashSelection struct { + // The HTTP cookie name whose value is to be hashed and used for upstream selection. + Name string `json:"name,omitempty"` + // Secret to hash (Hmac256) chosen upstream in cookie + Secret string `json:"secret,omitempty"` +} + +// CaddyModule returns the Caddy module information. +func (CookieHashSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "http.reverse_proxy.selection_policies.cookie", + New: func() caddy.Module { return new(CookieHashSelection) }, + } +} + +// Select returns an available host, if any. +func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http.ResponseWriter) *Upstream { + if s.Name == "" { + s.Name = "lb" + } + cookie, err := req.Cookie(s.Name) + // If there's no cookie, select new random host + if err != nil || cookie == nil { + return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name) + } else { + // If the cookie is present, loop over the available upstreams until we find a match + cookieValue := cookie.Value + for _, upstream := range pool { + if !upstream.Available() { + continue + } + sha, err := hashCookie(s.Secret, upstream.Dial) + if err == nil && sha == cookieValue { + return upstream + } + } + } + // If there is no matching host, select new random host + return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name) +} + +// UnmarshalCaddyfile sets up the module from Caddyfile tokens. Syntax: +// lb_policy cookie [ []] +// +// By default name is `lb` +func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + args := d.RemainingArgs() + switch len(args) { + case 1: + case 2: + s.Name = args[1] + case 3: + s.Name = args[1] + s.Secret = args[2] + default: + return d.ArgErr() + } + return nil +} + +// Select a new Host randomly and add a sticky session cookie +func selectNewHostWithCookieHashSelection(pool []*Upstream, w http.ResponseWriter, cookieSecret string, cookieName string) *Upstream { + randomHost := selectRandomHost(pool) + + if randomHost != nil { + // Hash (HMAC with some key for privacy) the upstream.Dial string as the cookie value + sha, err := hashCookie(cookieSecret, randomHost.Dial) + if err == nil { + // write the cookie. + http.SetCookie(w, &http.Cookie{Name: cookieName, Value: sha, Secure: false}) + } + } + return randomHost +} + +// hashCookie hashes (HMAC 256) some data with the secret +func hashCookie(secret string, data string) (string, error) { + h := hmac.New(sha256.New, []byte(secret)) + _, err := h.Write([]byte(data)) + if err != nil { + return "", err + } + return hex.EncodeToString(h.Sum(nil)), nil +} + +// selectRandomHost returns a random available host +func selectRandomHost(pool []*Upstream) *Upstream { + // use reservoir sampling because the number of available + // hosts isn't known: https://en.wikipedia.org/wiki/Reservoir_sampling + var randomHost *Upstream + var count int + for _, upstream := range pool { + if !upstream.Available() { + continue + } + // (n % 1 == 0) holds for all n, therefore a + // upstream will always be chosen if there is at + // least one available + count++ + if (weakrand.Int() % count) == 0 { + randomHost = upstream + } + } + return randomHost +} + // leastRequests returns the host with the // least number of active requests to it. // If more than one host has the same @@ -454,6 +550,7 @@ var ( _ Selector = (*IPHashSelection)(nil) _ Selector = (*URIHashSelection)(nil) _ Selector = (*HeaderHashSelection)(nil) + _ Selector = (*CookieHashSelection)(nil) _ caddy.Validator = (*RandomChoiceSelection)(nil) _ caddy.Provisioner = (*RandomChoiceSelection)(nil) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go index 49585da4..5368a1ac 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go @@ -33,37 +33,37 @@ func TestRoundRobinPolicy(t *testing.T) { rrPolicy := new(RoundRobinSelection) req, _ := http.NewRequest("GET", "/", nil) - h := rrPolicy.Select(pool, req) + h := rrPolicy.Select(pool, req, nil) // First selected host is 1, because counter starts at 0 // and increments before host is selected if h != pool[1] { t.Error("Expected first round robin host to be second host in the pool.") } - h = rrPolicy.Select(pool, req) + h = rrPolicy.Select(pool, req, nil) if h != pool[2] { t.Error("Expected second round robin host to be third host in the pool.") } - h = rrPolicy.Select(pool, req) + h = rrPolicy.Select(pool, req, nil) if h != pool[0] { t.Error("Expected third round robin host to be first host in the pool.") } // mark host as down pool[1].SetHealthy(false) - h = rrPolicy.Select(pool, req) + h = rrPolicy.Select(pool, req, nil) if h != pool[2] { t.Error("Expected to skip down host.") } // mark host as up pool[1].SetHealthy(true) - h = rrPolicy.Select(pool, req) + h = rrPolicy.Select(pool, req, nil) if h == pool[2] { t.Error("Expected to balance evenly among healthy hosts") } // mark host as full pool[1].CountRequest(1) pool[1].MaxRequests = 1 - h = rrPolicy.Select(pool, req) + h = rrPolicy.Select(pool, req, nil) if h != pool[2] { t.Error("Expected to skip full host.") } @@ -76,12 +76,12 @@ func TestLeastConnPolicy(t *testing.T) { pool[0].CountRequest(10) pool[1].CountRequest(10) - h := lcPolicy.Select(pool, req) + h := lcPolicy.Select(pool, req, nil) if h != pool[2] { t.Error("Expected least connection host to be third host.") } pool[2].CountRequest(100) - h = lcPolicy.Select(pool, req) + h = lcPolicy.Select(pool, req, nil) if h != pool[0] && h != pool[1] { t.Error("Expected least connection host to be first or second host.") } @@ -94,44 +94,44 @@ func TestIPHashPolicy(t *testing.T) { // We should be able to predict where every request is routed. req.RemoteAddr = "172.0.0.1:80" - h := ipHash.Select(pool, req) + h := ipHash.Select(pool, req, nil) if h != pool[1] { t.Error("Expected ip hash policy host to be the second host.") } req.RemoteAddr = "172.0.0.2:80" - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[1] { t.Error("Expected ip hash policy host to be the second host.") } req.RemoteAddr = "172.0.0.3:80" - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[2] { t.Error("Expected ip hash policy host to be the third host.") } req.RemoteAddr = "172.0.0.4:80" - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[1] { t.Error("Expected ip hash policy host to be the second host.") } // we should get the same results without a port req.RemoteAddr = "172.0.0.1" - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[1] { t.Error("Expected ip hash policy host to be the second host.") } req.RemoteAddr = "172.0.0.2" - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[1] { t.Error("Expected ip hash policy host to be the second host.") } req.RemoteAddr = "172.0.0.3" - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[2] { t.Error("Expected ip hash policy host to be the third host.") } req.RemoteAddr = "172.0.0.4" - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[1] { t.Error("Expected ip hash policy host to be the second host.") } @@ -140,13 +140,13 @@ func TestIPHashPolicy(t *testing.T) { // healthy host is available req.RemoteAddr = "172.0.0.1" pool[1].SetHealthy(false) - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[2] { t.Error("Expected ip hash policy host to be the third host.") } req.RemoteAddr = "172.0.0.2" - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[2] { t.Error("Expected ip hash policy host to be the third host.") } @@ -154,12 +154,12 @@ func TestIPHashPolicy(t *testing.T) { req.RemoteAddr = "172.0.0.3" pool[2].SetHealthy(false) - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[0] { t.Error("Expected ip hash policy host to be the first host.") } req.RemoteAddr = "172.0.0.4" - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[1] { t.Error("Expected ip hash policy host to be the second host.") } @@ -171,22 +171,22 @@ func TestIPHashPolicy(t *testing.T) { {Host: new(upstreamHost)}, } req.RemoteAddr = "172.0.0.1:80" - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[0] { t.Error("Expected ip hash policy host to be the first host.") } req.RemoteAddr = "172.0.0.2:80" - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[1] { t.Error("Expected ip hash policy host to be the second host.") } req.RemoteAddr = "172.0.0.3:80" - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[0] { t.Error("Expected ip hash policy host to be the first host.") } req.RemoteAddr = "172.0.0.4:80" - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != pool[1] { t.Error("Expected ip hash policy host to be the second host.") } @@ -194,7 +194,7 @@ func TestIPHashPolicy(t *testing.T) { // We should get nil when there are no healthy hosts pool[0].SetHealthy(false) pool[1].SetHealthy(false) - h = ipHash.Select(pool, req) + h = ipHash.Select(pool, req, nil) if h != nil { t.Error("Expected ip hash policy host to be nil.") } @@ -205,13 +205,13 @@ func TestFirstPolicy(t *testing.T) { firstPolicy := new(FirstSelection) req := httptest.NewRequest(http.MethodGet, "/", nil) - h := firstPolicy.Select(pool, req) + h := firstPolicy.Select(pool, req, nil) if h != pool[0] { t.Error("Expected first policy host to be the first host.") } pool[0].SetHealthy(false) - h = firstPolicy.Select(pool, req) + h = firstPolicy.Select(pool, req, nil) if h != pool[1] { t.Error("Expected first policy host to be the second host.") } @@ -222,19 +222,19 @@ func TestURIHashPolicy(t *testing.T) { uriPolicy := new(URIHashSelection) request := httptest.NewRequest(http.MethodGet, "/test", nil) - h := uriPolicy.Select(pool, request) + h := uriPolicy.Select(pool, request, nil) if h != pool[0] { t.Error("Expected uri policy host to be the first host.") } pool[0].SetHealthy(false) - h = uriPolicy.Select(pool, request) + h = uriPolicy.Select(pool, request, nil) if h != pool[1] { t.Error("Expected uri policy host to be the first host.") } request = httptest.NewRequest(http.MethodGet, "/test_2", nil) - h = uriPolicy.Select(pool, request) + h = uriPolicy.Select(pool, request, nil) if h != pool[1] { t.Error("Expected uri policy host to be the second host.") } @@ -247,26 +247,26 @@ func TestURIHashPolicy(t *testing.T) { } request = httptest.NewRequest(http.MethodGet, "/test", nil) - h = uriPolicy.Select(pool, request) + h = uriPolicy.Select(pool, request, nil) if h != pool[0] { t.Error("Expected uri policy host to be the first host.") } pool[0].SetHealthy(false) - h = uriPolicy.Select(pool, request) + h = uriPolicy.Select(pool, request, nil) if h != pool[1] { t.Error("Expected uri policy host to be the first host.") } request = httptest.NewRequest(http.MethodGet, "/test_2", nil) - h = uriPolicy.Select(pool, request) + h = uriPolicy.Select(pool, request, nil) if h != pool[1] { t.Error("Expected uri policy host to be the second host.") } pool[0].SetHealthy(false) pool[1].SetHealthy(false) - h = uriPolicy.Select(pool, request) + h = uriPolicy.Select(pool, request, nil) if h != nil { t.Error("Expected uri policy policy host to be nil.") } @@ -311,7 +311,7 @@ func TestRandomChoicePolicy(t *testing.T) { randomChoicePolicy := new(RandomChoiceSelection) randomChoicePolicy.Choose = 2 - h := randomChoicePolicy.Select(pool, request) + h := randomChoicePolicy.Select(pool, request, nil) if h == nil { t.Error("RandomChoicePolicy should not return nil") @@ -322,3 +322,51 @@ func TestRandomChoicePolicy(t *testing.T) { } } + +func TestCookieHashPolicy(t *testing.T) { + pool := testPool() + pool[0].Dial = "localhost:8080" + pool[1].Dial = "localhost:8081" + pool[2].Dial = "localhost:8082" + pool[0].SetHealthy(true) + pool[1].SetHealthy(false) + pool[2].SetHealthy(false) + request := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + cookieHashPolicy := new(CookieHashSelection) + h := cookieHashPolicy.Select(pool, request, w) + cookie_server1 := w.Result().Cookies()[0] + if cookie_server1 == nil { + t.Error("cookieHashPolicy should set a cookie") + } + if cookie_server1.Name != "lb" { + t.Error("cookieHashPolicy should set a cookie with name lb") + } + if h != pool[0] { + t.Error("Expected cookieHashPolicy host to be the first only available host.") + } + pool[1].SetHealthy(true) + pool[2].SetHealthy(true) + request = httptest.NewRequest(http.MethodGet, "/test", nil) + w = httptest.NewRecorder() + request.AddCookie(cookie_server1) + h = cookieHashPolicy.Select(pool, request, w) + if h != pool[0] { + t.Error("Expected cookieHashPolicy host to stick to the first host (matching cookie).") + } + s := w.Result().Cookies() + if len(s) != 0 { + t.Error("Expected cookieHashPolicy to not set a new cookie.") + } + pool[0].SetHealthy(false) + request = httptest.NewRequest(http.MethodGet, "/test", nil) + w = httptest.NewRecorder() + request.AddCookie(cookie_server1) + h = cookieHashPolicy.Select(pool, request, w) + if h == pool[0] { + t.Error("Expected cookieHashPolicy to select a new host.") + } + if w.Result().Cookies() == nil { + t.Error("Expected cookieHashPolicy to set a new cookie.") + } +}