From 026df7c5cb33331d223afc6a9599274e8c89dfd9 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Mon, 2 Sep 2019 22:01:02 -0600 Subject: [PATCH] reverse_proxy: WIP refactor and support for FastCGI --- cmd/caddy/main.go | 1 + go.mod | 5 +- go.sum | 13 +- listeners.go | 2 + modules/caddyhttp/caddyhttp.go | 14 + modules/caddyhttp/matchers.go | 5 +- .../caddyhttp/reverseproxy/fastcgi/client.go | 578 ++++++++++++ .../reverseproxy/fastcgi/client_test.go | 301 ++++++ .../caddyhttp/reverseproxy/fastcgi/fastcgi.go | 342 +++++++ .../caddyhttp/reverseproxy/healthchecker.go | 86 -- .../caddyhttp/reverseproxy/httptransport.go | 133 +++ modules/caddyhttp/reverseproxy/module.go | 53 -- .../caddyhttp/reverseproxy/reverseproxy.go | 868 ++++++++++++------ .../reverseproxy/selectionpolicies.go | 351 +++++++ .../reverseproxy/selectionpolicies_test.go | 363 ++++++++ modules/caddyhttp/reverseproxy/upstream.go | 450 --------- usagepool.go | 86 ++ 17 files changed, 2752 insertions(+), 899 deletions(-) create mode 100644 modules/caddyhttp/reverseproxy/fastcgi/client.go create mode 100644 modules/caddyhttp/reverseproxy/fastcgi/client_test.go create mode 100644 modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go delete mode 100755 modules/caddyhttp/reverseproxy/healthchecker.go create mode 100644 modules/caddyhttp/reverseproxy/httptransport.go delete mode 100755 modules/caddyhttp/reverseproxy/module.go mode change 100755 => 100644 modules/caddyhttp/reverseproxy/reverseproxy.go create mode 100644 modules/caddyhttp/reverseproxy/selectionpolicies.go create mode 100644 modules/caddyhttp/reverseproxy/selectionpolicies_test.go delete mode 100755 modules/caddyhttp/reverseproxy/upstream.go create mode 100644 usagepool.go diff --git a/cmd/caddy/main.go b/cmd/caddy/main.go index 5f6b8bb3..14b88658 100644 --- a/cmd/caddy/main.go +++ b/cmd/caddy/main.go @@ -29,6 +29,7 @@ import ( _ "github.com/caddyserver/caddy/v2/modules/caddyhttp/markdown" _ "github.com/caddyserver/caddy/v2/modules/caddyhttp/requestbody" _ "github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy" + _ "github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy/fastcgi" _ "github.com/caddyserver/caddy/v2/modules/caddyhttp/rewrite" _ "github.com/caddyserver/caddy/v2/modules/caddyhttp/templates" _ "github.com/caddyserver/caddy/v2/modules/caddytls" diff --git a/go.mod b/go.mod index d63097aa..315dcfdf 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,6 @@ require ( github.com/Masterminds/semver v1.4.2 // indirect github.com/Masterminds/sprig v2.20.0+incompatible github.com/andybalholm/brotli v0.0.0-20190704151324-71eb68cc467c - github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.0 github.com/go-acme/lego v2.6.0+incompatible github.com/google/go-cmp v0.3.1 // indirect @@ -17,7 +16,6 @@ require ( github.com/imdario/mergo v0.3.7 // indirect github.com/klauspost/compress v1.7.1-0.20190613161414-0b31f265a57b github.com/klauspost/cpuid v1.2.1 - github.com/kr/pretty v0.1.0 // indirect github.com/mholt/certmagic v0.6.2 github.com/mitchellh/go-ps v0.0.0-20170309133038-4fdf99ab2936 github.com/rs/cors v1.6.0 @@ -26,8 +24,7 @@ require ( github.com/starlight-go/starlight v0.0.0-20181207205707-b06f321544f3 go.starlark.net v0.0.0-20190604130855-6ddc71c0ba77 golang.org/x/net v0.0.0-20190603091049-60506f45cf65 - golang.org/x/sys v0.0.0-20190422165155-953cdadca894 // indirect + golang.org/x/sys v0.0.0-20190228124157-a34e9553db1e // indirect golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 - gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/yaml.v2 v2.2.2 // indirect ) diff --git a/go.sum b/go.sum index 4180648c..8771eb14 100644 --- a/go.sum +++ b/go.sum @@ -12,8 +12,6 @@ github.com/cenkalti/backoff v2.1.1+incompatible h1:tKJnvO2kl0zmb/jA5UKAt4VoEVw1q github.com/cenkalti/backoff v2.1.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/go-acme/lego v2.5.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M= @@ -34,11 +32,6 @@ github.com/klauspost/compress v1.7.1-0.20190613161414-0b31f265a57b/go.mod h1:RyI github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/cpuid v1.2.1 h1:vJi+O/nMdFt0vqm8NZBI6wzALWdA2X+egi0ogNyrC/w= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mholt/certmagic v0.6.2 h1:yy9cKm3rtxdh12SW4E51lzG3Eo6N59LEOfBQ0CTnMms= github.com/mholt/certmagic v0.6.2/go.mod h1:g4cOPxcjV0oFq3qwpjSA30LReKD8AoIfwAY9VvG35NY= github.com/miekg/dns v1.1.3 h1:1g0r1IvskvgL8rR+AcHzUA+oFmGcQlaIm4IqakufeMM= @@ -71,16 +64,14 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20190124100055-b90733256f2e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190228124157-a34e9553db1e h1:ZytStCyV048ZqDsWHiYDdoI2Vd4msMcrDECFxS+tL9c= +golang.org/x/sys v0.0.0-20190228124157-a34e9553db1e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/square/go-jose.v2 v2.2.2 h1:orlkJ3myw8CN1nVQHBFfloD+L3egixIa4FvUP6RosSA= gopkg.in/square/go-jose.v2 v2.2.2/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= diff --git a/listeners.go b/listeners.go index d97674d8..3c3a704a 100644 --- a/listeners.go +++ b/listeners.go @@ -24,6 +24,8 @@ import ( "time" ) +// TODO: Can we use the new UsagePool type? + // Listen returns a listener suitable for use in a Caddy module. // Always be sure to close listeners when you are done with them. func Listen(network, addr string) (net.Listener, error) { diff --git a/modules/caddyhttp/caddyhttp.go b/modules/caddyhttp/caddyhttp.go index b4b1ec6e..300b5fd1 100644 --- a/modules/caddyhttp/caddyhttp.go +++ b/modules/caddyhttp/caddyhttp.go @@ -518,6 +518,20 @@ func (ws WeakString) String() string { return string(ws) } +// StatusCodeMatches returns true if a real HTTP status code matches +// the configured status code, which may be either a real HTTP status +// code or an integer representing a class of codes (e.g. 4 for all +// 4xx statuses). +func StatusCodeMatches(actual, configured int) bool { + if actual == configured { + return true + } + if configured < 100 && actual >= configured*100 && actual < (configured+1)*100 { + return true + } + return false +} + const ( // DefaultHTTPPort is the default port for HTTP. DefaultHTTPPort = 80 diff --git a/modules/caddyhttp/matchers.go b/modules/caddyhttp/matchers.go index 0dac1518..2ddefd02 100644 --- a/modules/caddyhttp/matchers.go +++ b/modules/caddyhttp/matchers.go @@ -616,10 +616,7 @@ func (rm ResponseMatcher) matchStatusCode(statusCode int) bool { return true } for _, code := range rm.StatusCode { - if statusCode == code { - return true - } - if code < 100 && statusCode >= code*100 && statusCode < (code+1)*100 { + if StatusCodeMatches(statusCode, code) { return true } } diff --git a/modules/caddyhttp/reverseproxy/fastcgi/client.go b/modules/caddyhttp/reverseproxy/fastcgi/client.go new file mode 100644 index 00000000..ae0de006 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/fastcgi/client.go @@ -0,0 +1,578 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Forked Jan. 2015 from http://bitbucket.org/PinIdea/fcgi_client +// (which is forked from https://code.google.com/p/go-fastcgi-client/). +// This fork contains several fixes and improvements by Matt Holt and +// other contributors to the Caddy project. + +// Copyright 2012 Junqing Tan and The Go Authors +// Use of this source code is governed by a BSD-style +// Part of source code is from Go fcgi package + +package fastcgi + +import ( + "bufio" + "bytes" + "context" + "encoding/binary" + "errors" + "io" + "io/ioutil" + "mime/multipart" + "net" + "net/http" + "net/http/httputil" + "net/textproto" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" +) + +// FCGIListenSockFileno describes listen socket file number. +const FCGIListenSockFileno uint8 = 0 + +// FCGIHeaderLen describes header length. +const FCGIHeaderLen uint8 = 8 + +// Version1 describes the version. +const Version1 uint8 = 1 + +// FCGINullRequestID describes the null request ID. +const FCGINullRequestID uint8 = 0 + +// FCGIKeepConn describes keep connection mode. +const FCGIKeepConn uint8 = 1 + +const ( + // BeginRequest is the begin request flag. + BeginRequest uint8 = iota + 1 + // AbortRequest is the abort request flag. + AbortRequest + // EndRequest is the end request flag. + EndRequest + // Params is the parameters flag. + Params + // Stdin is the standard input flag. + Stdin + // Stdout is the standard output flag. + Stdout + // Stderr is the standard error flag. + Stderr + // Data is the data flag. + Data + // GetValues is the get values flag. + GetValues + // GetValuesResult is the get values result flag. + GetValuesResult + // UnknownType is the unknown type flag. + UnknownType + // MaxType is the maximum type flag. + MaxType = UnknownType +) + +const ( + // Responder is the responder flag. + Responder uint8 = iota + 1 + // Authorizer is the authorizer flag. + Authorizer + // Filter is the filter flag. + Filter +) + +const ( + // RequestComplete is the completed request flag. + RequestComplete uint8 = iota + // CantMultiplexConns is the multiplexed connections flag. + CantMultiplexConns + // Overloaded is the overloaded flag. + Overloaded + // UnknownRole is the unknown role flag. + UnknownRole +) + +const ( + // MaxConns is the maximum connections flag. + MaxConns string = "MAX_CONNS" + // MaxRequests is the maximum requests flag. + MaxRequests string = "MAX_REQS" + // MultiplexConns is the multiplex connections flag. + MultiplexConns string = "MPXS_CONNS" +) + +const ( + maxWrite = 65500 // 65530 may work, but for compatibility + maxPad = 255 +) + +type header struct { + Version uint8 + Type uint8 + ID uint16 + ContentLength uint16 + PaddingLength uint8 + Reserved uint8 +} + +// for padding so we don't have to allocate all the time +// not synchronized because we don't care what the contents are +var pad [maxPad]byte + +func (h *header) init(recType uint8, reqID uint16, contentLength int) { + h.Version = 1 + h.Type = recType + h.ID = reqID + h.ContentLength = uint16(contentLength) + h.PaddingLength = uint8(-contentLength & 7) +} + +type record struct { + h header + rbuf []byte +} + +func (rec *record) read(r io.Reader) (buf []byte, err error) { + if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil { + return + } + if rec.h.Version != 1 { + err = errors.New("fcgi: invalid header version") + return + } + if rec.h.Type == EndRequest { + err = io.EOF + return + } + n := int(rec.h.ContentLength) + int(rec.h.PaddingLength) + if len(rec.rbuf) < n { + rec.rbuf = make([]byte, n) + } + if _, err = io.ReadFull(r, rec.rbuf[:n]); err != nil { + return + } + buf = rec.rbuf[:int(rec.h.ContentLength)] + + return +} + +// FCGIClient implements a FastCGI client, which is a standard for +// interfacing external applications with Web servers. +type FCGIClient struct { + mutex sync.Mutex + rwc io.ReadWriteCloser + h header + buf bytes.Buffer + stderr bytes.Buffer + keepAlive bool + reqID uint16 +} + +// DialWithDialerContext connects to the fcgi responder at the specified network address, using custom net.Dialer +// and a context. +// See func net.Dial for a description of the network and address parameters. +func DialWithDialerContext(ctx context.Context, network, address string, dialer net.Dialer) (fcgi *FCGIClient, err error) { + var conn net.Conn + conn, err = dialer.DialContext(ctx, network, address) + if err != nil { + return + } + + fcgi = &FCGIClient{ + rwc: conn, + keepAlive: false, + reqID: 1, + } + + return +} + +// DialContext is like Dial but passes ctx to dialer.Dial. +func DialContext(ctx context.Context, network, address string) (fcgi *FCGIClient, err error) { + // TODO: why not set timeout here? + return DialWithDialerContext(ctx, network, address, net.Dialer{}) +} + +// Dial connects to the fcgi responder at the specified network address, using default net.Dialer. +// See func net.Dial for a description of the network and address parameters. +func Dial(network, address string) (fcgi *FCGIClient, err error) { + return DialContext(context.Background(), network, address) +} + +// Close closes fcgi connection +func (c *FCGIClient) Close() { + c.rwc.Close() +} + +func (c *FCGIClient) writeRecord(recType uint8, content []byte) (err error) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.buf.Reset() + c.h.init(recType, c.reqID, len(content)) + if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { + return err + } + if _, err := c.buf.Write(content); err != nil { + return err + } + if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { + return err + } + _, err = c.rwc.Write(c.buf.Bytes()) + return err +} + +func (c *FCGIClient) writeBeginRequest(role uint16, flags uint8) error { + b := [8]byte{byte(role >> 8), byte(role), flags} + return c.writeRecord(BeginRequest, b[:]) +} + +func (c *FCGIClient) writeEndRequest(appStatus int, protocolStatus uint8) error { + b := make([]byte, 8) + binary.BigEndian.PutUint32(b, uint32(appStatus)) + b[4] = protocolStatus + return c.writeRecord(EndRequest, b) +} + +func (c *FCGIClient) writePairs(recType uint8, pairs map[string]string) error { + w := newWriter(c, recType) + b := make([]byte, 8) + nn := 0 + for k, v := range pairs { + m := 8 + len(k) + len(v) + if m > maxWrite { + // param data size exceed 65535 bytes" + vl := maxWrite - 8 - len(k) + v = v[:vl] + } + n := encodeSize(b, uint32(len(k))) + n += encodeSize(b[n:], uint32(len(v))) + m = n + len(k) + len(v) + if (nn + m) > maxWrite { + w.Flush() + nn = 0 + } + nn += m + if _, err := w.Write(b[:n]); err != nil { + return err + } + if _, err := w.WriteString(k); err != nil { + return err + } + if _, err := w.WriteString(v); err != nil { + return err + } + } + w.Close() + return nil +} + +func encodeSize(b []byte, size uint32) int { + if size > 127 { + size |= 1 << 31 + binary.BigEndian.PutUint32(b, size) + return 4 + } + b[0] = byte(size) + return 1 +} + +// bufWriter encapsulates bufio.Writer but also closes the underlying stream when +// Closed. +type bufWriter struct { + closer io.Closer + *bufio.Writer +} + +func (w *bufWriter) Close() error { + if err := w.Writer.Flush(); err != nil { + w.closer.Close() + return err + } + return w.closer.Close() +} + +func newWriter(c *FCGIClient, recType uint8) *bufWriter { + s := &streamWriter{c: c, recType: recType} + w := bufio.NewWriterSize(s, maxWrite) + return &bufWriter{s, w} +} + +// streamWriter abstracts out the separation of a stream into discrete records. +// It only writes maxWrite bytes at a time. +type streamWriter struct { + c *FCGIClient + recType uint8 +} + +func (w *streamWriter) Write(p []byte) (int, error) { + nn := 0 + for len(p) > 0 { + n := len(p) + if n > maxWrite { + n = maxWrite + } + if err := w.c.writeRecord(w.recType, p[:n]); err != nil { + return nn, err + } + nn += n + p = p[n:] + } + return nn, nil +} + +func (w *streamWriter) Close() error { + // send empty record to close the stream + return w.c.writeRecord(w.recType, nil) +} + +type streamReader struct { + c *FCGIClient + buf []byte +} + +func (w *streamReader) Read(p []byte) (n int, err error) { + + if len(p) > 0 { + if len(w.buf) == 0 { + + // filter outputs for error log + for { + rec := &record{} + var buf []byte + buf, err = rec.read(w.c.rwc) + if err != nil { + return + } + // standard error output + if rec.h.Type == Stderr { + w.c.stderr.Write(buf) + continue + } + w.buf = buf + break + } + } + + n = len(p) + if n > len(w.buf) { + n = len(w.buf) + } + copy(p, w.buf[:n]) + w.buf = w.buf[n:] + } + + return +} + +// Do made the request and returns a io.Reader that translates the data read +// from fcgi responder out of fcgi packet before returning it. +func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err error) { + err = c.writeBeginRequest(uint16(Responder), 0) + if err != nil { + return + } + + err = c.writePairs(Params, p) + if err != nil { + return + } + + body := newWriter(c, Stdin) + if req != nil { + _, _ = io.Copy(body, req) + } + body.Close() + + r = &streamReader{c: c} + return +} + +// clientCloser is a io.ReadCloser. It wraps a io.Reader with a Closer +// that closes FCGIClient connection. +type clientCloser struct { + *FCGIClient + io.Reader +} + +func (f clientCloser) Close() error { return f.rwc.Close() } + +// Request returns a HTTP Response with Header and Body +// from fcgi responder +func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Response, err error) { + r, err := c.Do(p, req) + if err != nil { + return + } + + rb := bufio.NewReader(r) + tp := textproto.NewReader(rb) + resp = new(http.Response) + + // Parse the response headers. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil && err != io.EOF { + return + } + resp.Header = http.Header(mimeHeader) + + if resp.Header.Get("Status") != "" { + statusParts := strings.SplitN(resp.Header.Get("Status"), " ", 2) + resp.StatusCode, err = strconv.Atoi(statusParts[0]) + if err != nil { + return + } + if len(statusParts) > 1 { + resp.Status = statusParts[1] + } + + } else { + resp.StatusCode = http.StatusOK + } + + // TODO: fixTransferEncoding ? + resp.TransferEncoding = resp.Header["Transfer-Encoding"] + resp.ContentLength, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + + if chunked(resp.TransferEncoding) { + resp.Body = clientCloser{c, httputil.NewChunkedReader(rb)} + } else { + resp.Body = clientCloser{c, ioutil.NopCloser(rb)} + } + return +} + +// Get issues a GET request to the fcgi responder. +func (c *FCGIClient) Get(p map[string]string, body io.Reader, l int64) (resp *http.Response, err error) { + + p["REQUEST_METHOD"] = "GET" + p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10) + + return c.Request(p, body) +} + +// Head issues a HEAD request to the fcgi responder. +func (c *FCGIClient) Head(p map[string]string) (resp *http.Response, err error) { + + p["REQUEST_METHOD"] = "HEAD" + p["CONTENT_LENGTH"] = "0" + + return c.Request(p, nil) +} + +// Options issues an OPTIONS request to the fcgi responder. +func (c *FCGIClient) Options(p map[string]string) (resp *http.Response, err error) { + + p["REQUEST_METHOD"] = "OPTIONS" + p["CONTENT_LENGTH"] = "0" + + return c.Request(p, nil) +} + +// Post issues a POST request to the fcgi responder. with request body +// in the format that bodyType specified +func (c *FCGIClient) Post(p map[string]string, method string, bodyType string, body io.Reader, l int64) (resp *http.Response, err error) { + if p == nil { + p = make(map[string]string) + } + + p["REQUEST_METHOD"] = strings.ToUpper(method) + + if len(p["REQUEST_METHOD"]) == 0 || p["REQUEST_METHOD"] == "GET" { + p["REQUEST_METHOD"] = "POST" + } + + p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10) + if len(bodyType) > 0 { + p["CONTENT_TYPE"] = bodyType + } else { + p["CONTENT_TYPE"] = "application/x-www-form-urlencoded" + } + + return c.Request(p, body) +} + +// PostForm issues a POST to the fcgi responder, with form +// as a string key to a list values (url.Values) +func (c *FCGIClient) PostForm(p map[string]string, data url.Values) (resp *http.Response, err error) { + body := bytes.NewReader([]byte(data.Encode())) + return c.Post(p, "POST", "application/x-www-form-urlencoded", body, int64(body.Len())) +} + +// PostFile issues a POST to the fcgi responder in multipart(RFC 2046) standard, +// with form as a string key to a list values (url.Values), +// and/or with file as a string key to a list file path. +func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[string]string) (resp *http.Response, err error) { + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + bodyType := writer.FormDataContentType() + + for key, val := range data { + for _, v0 := range val { + err = writer.WriteField(key, v0) + if err != nil { + return + } + } + } + + for key, val := range file { + fd, e := os.Open(val) + if e != nil { + return nil, e + } + defer fd.Close() + + part, e := writer.CreateFormFile(key, filepath.Base(val)) + if e != nil { + return nil, e + } + _, err = io.Copy(part, fd) + if err != nil { + return + } + } + + err = writer.Close() + if err != nil { + return + } + + return c.Post(p, "POST", bodyType, buf, int64(buf.Len())) +} + +// SetReadTimeout sets the read timeout for future calls that read from the +// fcgi responder. A zero value for t means no timeout will be set. +func (c *FCGIClient) SetReadTimeout(t time.Duration) error { + if conn, ok := c.rwc.(net.Conn); ok && t != 0 { + return conn.SetReadDeadline(time.Now().Add(t)) + } + return nil +} + +// SetWriteTimeout sets the write timeout for future calls that send data to +// the fcgi responder. A zero value for t means no timeout will be set. +func (c *FCGIClient) SetWriteTimeout(t time.Duration) error { + if conn, ok := c.rwc.(net.Conn); ok && t != 0 { + return conn.SetWriteDeadline(time.Now().Add(t)) + } + return nil +} + +// Checks whether chunked is part of the encodings stack +func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } diff --git a/modules/caddyhttp/reverseproxy/fastcgi/client_test.go b/modules/caddyhttp/reverseproxy/fastcgi/client_test.go new file mode 100644 index 00000000..c090f3cc --- /dev/null +++ b/modules/caddyhttp/reverseproxy/fastcgi/client_test.go @@ -0,0 +1,301 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// NOTE: These tests were adapted from the original +// repository from which this package was forked. +// The tests are slow (~10s) and in dire need of rewriting. +// As such, the tests have been disabled to speed up +// automated builds until they can be properly written. + +package fastcgi + +import ( + "bytes" + "crypto/md5" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "log" + "math/rand" + "net" + "net/http" + "net/http/fcgi" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" +) + +// test fcgi protocol includes: +// Get, Post, Post in multipart/form-data, and Post with files +// each key should be the md5 of the value or the file uploaded +// specify remote fcgi responder ip:port to test with php +// test failed if the remote fcgi(script) failed md5 verification +// and output "FAILED" in response +const ( + scriptFile = "/tank/www/fcgic_test.php" + //ipPort = "remote-php-serv:59000" + ipPort = "127.0.0.1:59000" +) + +var globalt *testing.T + +type FastCGIServer struct{} + +func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + + if err := req.ParseMultipartForm(100000000); err != nil { + log.Printf("[ERROR] failed to parse: %v", err) + } + + stat := "PASSED" + fmt.Fprintln(resp, "-") + fileNum := 0 + { + length := 0 + for k0, v0 := range req.Form { + h := md5.New() + _, _ = io.WriteString(h, v0[0]) + _md5 := fmt.Sprintf("%x", h.Sum(nil)) + + length += len(k0) + length += len(v0[0]) + + // echo error when key != _md5(val) + if _md5 != k0 { + fmt.Fprintln(resp, "server:err ", _md5, k0) + stat = "FAILED" + } + } + if req.MultipartForm != nil { + fileNum = len(req.MultipartForm.File) + for kn, fns := range req.MultipartForm.File { + //fmt.Fprintln(resp, "server:filekey ", kn ) + length += len(kn) + for _, f := range fns { + fd, err := f.Open() + if err != nil { + log.Println("server:", err) + return + } + h := md5.New() + l0, err := io.Copy(h, fd) + if err != nil { + log.Println(err) + return + } + length += int(l0) + defer fd.Close() + md5 := fmt.Sprintf("%x", h.Sum(nil)) + //fmt.Fprintln(resp, "server:filemd5 ", md5 ) + + if kn != md5 { + fmt.Fprintln(resp, "server:err ", md5, kn) + stat = "FAILED" + } + //fmt.Fprintln(resp, "server:filename ", f.Filename ) + } + } + } + + fmt.Fprintln(resp, "server:got data length", length) + } + fmt.Fprintln(resp, "-"+stat+"-POST(", len(req.Form), ")-FILE(", fileNum, ")--") +} + +func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) { + fcgi, err := Dial("tcp", ipPort) + if err != nil { + log.Println("err:", err) + return + } + + length := 0 + + var resp *http.Response + switch reqType { + case 0: + if len(data) > 0 { + length = len(data) + rd := bytes.NewReader(data) + resp, err = fcgi.Post(fcgiParams, "", "", rd, int64(rd.Len())) + } else if len(posts) > 0 { + values := url.Values{} + for k, v := range posts { + values.Set(k, v) + length += len(k) + 2 + len(v) + } + resp, err = fcgi.PostForm(fcgiParams, values) + } else { + rd := bytes.NewReader(data) + resp, err = fcgi.Get(fcgiParams, rd, int64(rd.Len())) + } + + default: + values := url.Values{} + for k, v := range posts { + values.Set(k, v) + length += len(k) + 2 + len(v) + } + + for k, v := range files { + fi, _ := os.Lstat(v) + length += len(k) + int(fi.Size()) + } + resp, err = fcgi.PostFile(fcgiParams, values, files) + } + + if err != nil { + log.Println("err:", err) + return + } + + defer resp.Body.Close() + content, _ = ioutil.ReadAll(resp.Body) + + log.Println("c: send data length ≈", length, string(content)) + fcgi.Close() + time.Sleep(1 * time.Second) + + if bytes.Contains(content, []byte("FAILED")) { + globalt.Error("Server return failed message") + } + + return +} + +func generateRandFile(size int) (p string, m string) { + + p = filepath.Join(os.TempDir(), "fcgict"+strconv.Itoa(rand.Int())) + + // open output file + fo, err := os.Create(p) + if err != nil { + panic(err) + } + // close fo on exit and check for its returned error + defer func() { + if err := fo.Close(); err != nil { + panic(err) + } + }() + + h := md5.New() + for i := 0; i < size/16; i++ { + buf := make([]byte, 16) + binary.PutVarint(buf, rand.Int63()) + if _, err := fo.Write(buf); err != nil { + log.Printf("[ERROR] failed to write buffer: %v\n", err) + } + if _, err := h.Write(buf); err != nil { + log.Printf("[ERROR] failed to write buffer: %v\n", err) + } + } + m = fmt.Sprintf("%x", h.Sum(nil)) + return +} + +func DisabledTest(t *testing.T) { + // TODO: test chunked reader + globalt = t + + rand.Seed(time.Now().UTC().UnixNano()) + + // server + go func() { + listener, err := net.Listen("tcp", ipPort) + if err != nil { + log.Println("listener creation failed: ", err) + } + + srv := new(FastCGIServer) + if err := fcgi.Serve(listener, srv); err != nil { + log.Print("[ERROR] failed to start server: ", err) + } + }() + + time.Sleep(1 * time.Second) + + // init + fcgiParams := make(map[string]string) + fcgiParams["REQUEST_METHOD"] = "GET" + fcgiParams["SERVER_PROTOCOL"] = "HTTP/1.1" + //fcgi_params["GATEWAY_INTERFACE"] = "CGI/1.1" + fcgiParams["SCRIPT_FILENAME"] = scriptFile + + // simple GET + log.Println("test:", "get") + sendFcgi(0, fcgiParams, nil, nil, nil) + + // simple post data + log.Println("test:", "post") + sendFcgi(0, fcgiParams, []byte("c4ca4238a0b923820dcc509a6f75849b=1&7b8b965ad4bca0e41ab51de7b31363a1=n"), nil, nil) + + log.Println("test:", "post data (more than 60KB)") + data := "" + for i := 0x00; i < 0xff; i++ { + v0 := strings.Repeat(string(i), 256) + h := md5.New() + _, _ = io.WriteString(h, v0) + k0 := fmt.Sprintf("%x", h.Sum(nil)) + data += k0 + "=" + url.QueryEscape(v0) + "&" + } + sendFcgi(0, fcgiParams, []byte(data), nil, nil) + + log.Println("test:", "post form (use url.Values)") + p0 := make(map[string]string, 1) + p0["c4ca4238a0b923820dcc509a6f75849b"] = "1" + p0["7b8b965ad4bca0e41ab51de7b31363a1"] = "n" + sendFcgi(1, fcgiParams, nil, p0, nil) + + log.Println("test:", "post forms (256 keys, more than 1MB)") + p1 := make(map[string]string, 1) + for i := 0x00; i < 0xff; i++ { + v0 := strings.Repeat(string(i), 4096) + h := md5.New() + _, _ = io.WriteString(h, v0) + k0 := fmt.Sprintf("%x", h.Sum(nil)) + p1[k0] = v0 + } + sendFcgi(1, fcgiParams, nil, p1, nil) + + log.Println("test:", "post file (1 file, 500KB)) ") + f0 := make(map[string]string, 1) + path0, m0 := generateRandFile(500000) + f0[m0] = path0 + sendFcgi(1, fcgiParams, nil, p1, f0) + + log.Println("test:", "post multiple files (2 files, 5M each) and forms (256 keys, more than 1MB data") + path1, m1 := generateRandFile(5000000) + f0[m1] = path1 + sendFcgi(1, fcgiParams, nil, p1, f0) + + log.Println("test:", "post only files (2 files, 5M each)") + sendFcgi(1, fcgiParams, nil, nil, f0) + + log.Println("test:", "post only 1 file") + delete(f0, "m0") + sendFcgi(1, fcgiParams, nil, nil, f0) + + if err := os.Remove(path0); err != nil { + log.Println("[ERROR] failed to remove path: ", err) + } + if err := os.Remove(path1); err != nil { + log.Println("[ERROR] failed to remove path: ", err) + } +} diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go new file mode 100644 index 00000000..32f094b9 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go @@ -0,0 +1,342 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fastcgi + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + "path" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/caddyserver/caddy/v2/modules/caddytls" + + "github.com/caddyserver/caddy/v2" +) + +func init() { + caddy.RegisterModule(Transport{}) +} + +type Transport struct { + ////////////////////////////// + // TODO: taken from v1 Handler type + + SoftwareName string + SoftwareVersion string + ServerName string + ServerPort string + + ////////////////////////// + // TODO: taken from v1 Rule type + + // The base path to match. Required. + // Path string + + // upstream load balancer + // balancer + + // Always process files with this extension with fastcgi. + // Ext string + + // Use this directory as the fastcgi root directory. Defaults to the root + // directory of the parent virtual host. + Root string + + // The path in the URL will be split into two, with the first piece ending + // with the value of SplitPath. The first piece will be assumed as the + // actual resource (CGI script) name, and the second piece will be set to + // PATH_INFO for the CGI script to use. + SplitPath string + + // If the URL ends with '/' (which indicates a directory), these index + // files will be tried instead. + IndexFiles []string + + // Environment Variables + EnvVars [][2]string + + // Ignored paths + IgnoredSubPaths []string + + // The duration used to set a deadline when connecting to an upstream. + DialTimeout time.Duration + + // The duration used to set a deadline when reading from the FastCGI server. + ReadTimeout time.Duration + + // The duration used to set a deadline when sending to the FastCGI server. + WriteTimeout time.Duration +} + +// CaddyModule returns the Caddy module information. +func (Transport) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.transport.fastcgi", + New: func() caddy.Module { return new(Transport) }, + } +} + +func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) { + // Create environment for CGI script + env, err := t.buildEnv(r) + if err != nil { + return nil, fmt.Errorf("building environment: %v", err) + } + + // TODO: + // Connect to FastCGI gateway + // address, err := f.Address() + // if err != nil { + // return http.StatusBadGateway, err + // } + // network, address := parseAddress(address) + network, address := "tcp", r.URL.Host // TODO: + + ctx := context.Background() + if t.DialTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, t.DialTimeout) + defer cancel() + } + + fcgiBackend, err := DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("dialing backend: %v", err) + } + // fcgiBackend is closed when response body is closed (see clientCloser) + + // read/write timeouts + if err := fcgiBackend.SetReadTimeout(t.ReadTimeout); err != nil { + return nil, fmt.Errorf("setting read timeout: %v", err) + } + if err := fcgiBackend.SetWriteTimeout(t.WriteTimeout); err != nil { + return nil, fmt.Errorf("setting write timeout: %v", err) + } + + var resp *http.Response + + var contentLength int64 + // if ContentLength is already set + if r.ContentLength > 0 { + contentLength = r.ContentLength + } else { + contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64) + } + switch r.Method { + case "HEAD": + resp, err = fcgiBackend.Head(env) + case "GET": + resp, err = fcgiBackend.Get(env, r.Body, contentLength) + case "OPTIONS": + resp, err = fcgiBackend.Options(env) + default: + resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength) + } + + // TODO: + return resp, err + + // Stuff brought over from v1 that might not be necessary here: + + // if resp != nil && resp.Body != nil { + // defer resp.Body.Close() + // } + + // if err != nil { + // if err, ok := err.(net.Error); ok && err.Timeout() { + // return http.StatusGatewayTimeout, err + // } else if err != io.EOF { + // return http.StatusBadGateway, err + // } + // } + + // // Write response header + // writeHeader(w, resp) + + // // Write the response body + // _, err = io.Copy(w, resp.Body) + // if err != nil { + // return http.StatusBadGateway, err + // } + + // // Log any stderr output from upstream + // if fcgiBackend.stderr.Len() != 0 { + // // Remove trailing newline, error logger already does this. + // err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n")) + // } + + // // Normally we would return the status code if it is an error status (>= 400), + // // however, upstream FastCGI apps don't know about our contract and have + // // probably already written an error page. So we just return 0, indicating + // // that the response body is already written. However, we do return any + // // error value so it can be logged. + // // Note that the proxy middleware works the same way, returning status=0. + // return 0, err +} + +// buildEnv returns a set of CGI environment variables for the request. +func (t Transport) buildEnv(r *http.Request) (map[string]string, error) { + var env map[string]string + + // Separate remote IP and port; more lenient than net.SplitHostPort + var ip, port string + if idx := strings.LastIndex(r.RemoteAddr, ":"); idx > -1 { + ip = r.RemoteAddr[:idx] + port = r.RemoteAddr[idx+1:] + } else { + ip = r.RemoteAddr + } + + // Remove [] from IPv6 addresses + ip = strings.Replace(ip, "[", "", 1) + ip = strings.Replace(ip, "]", "", 1) + + // TODO: respect index files? or leave that to matcher/rewrite (I prefer that)? + fpath := r.URL.Path + + // Split path in preparation for env variables. + // Previous canSplit checks ensure this can never be -1. + // TODO: I haven't brought over canSplit; make sure this doesn't break + splitPos := t.splitPos(fpath) + + // Request has the extension; path was split successfully + docURI := fpath[:splitPos+len(t.SplitPath)] + pathInfo := fpath[splitPos+len(t.SplitPath):] + scriptName := fpath + + // Strip PATH_INFO from SCRIPT_NAME + scriptName = strings.TrimSuffix(scriptName, pathInfo) + + // SCRIPT_FILENAME is the absolute path of SCRIPT_NAME + scriptFilename := filepath.Join(t.Root, scriptName) + + // Add vhost path prefix to scriptName. Otherwise, some PHP software will + // have difficulty discovering its URL. + pathPrefix, _ := r.Context().Value(caddy.CtxKey("path_prefix")).(string) + scriptName = path.Join(pathPrefix, scriptName) + + // TODO: Disabled for now + // // Get the request URI from context. The context stores the original URI in case + // // it was changed by a middleware such as rewrite. By default, we pass the + // // original URI in as the value of REQUEST_URI (the user can overwrite this + // // if desired). Most PHP apps seem to want the original URI. Besides, this is + // // how nginx defaults: http://stackoverflow.com/a/12485156/1048862 + // reqURL, _ := r.Context().Value(httpserver.OriginalURLCtxKey).(url.URL) + + // // Retrieve name of remote user that was set by some downstream middleware such as basicauth. + // remoteUser, _ := r.Context().Value(httpserver.RemoteUserCtxKey).(string) + + requestScheme := "http" + if r.TLS != nil { + requestScheme = "https" + } + + // Some variables are unused but cleared explicitly to prevent + // the parent environment from interfering. + env = map[string]string{ + // Variables defined in CGI 1.1 spec + "AUTH_TYPE": "", // Not used + "CONTENT_LENGTH": r.Header.Get("Content-Length"), + "CONTENT_TYPE": r.Header.Get("Content-Type"), + "GATEWAY_INTERFACE": "CGI/1.1", + "PATH_INFO": pathInfo, + "QUERY_STRING": r.URL.RawQuery, + "REMOTE_ADDR": ip, + "REMOTE_HOST": ip, // For speed, remote host lookups disabled + "REMOTE_PORT": port, + "REMOTE_IDENT": "", // Not used + // "REMOTE_USER": remoteUser, // TODO: + "REQUEST_METHOD": r.Method, + "REQUEST_SCHEME": requestScheme, + "SERVER_NAME": t.ServerName, + "SERVER_PORT": t.ServerPort, + "SERVER_PROTOCOL": r.Proto, + "SERVER_SOFTWARE": t.SoftwareName + "/" + t.SoftwareVersion, + + // Other variables + // "DOCUMENT_ROOT": rule.Root, + "DOCUMENT_URI": docURI, + "HTTP_HOST": r.Host, // added here, since not always part of headers + // "REQUEST_URI": reqURL.RequestURI(), // TODO: + "SCRIPT_FILENAME": scriptFilename, + "SCRIPT_NAME": scriptName, + } + + // compliance with the CGI specification requires that + // PATH_TRANSLATED should only exist if PATH_INFO is defined. + // Info: https://www.ietf.org/rfc/rfc3875 Page 14 + if env["PATH_INFO"] != "" { + env["PATH_TRANSLATED"] = filepath.Join(t.Root, pathInfo) // Info: http://www.oreilly.com/openbook/cgi/ch02_04.html + } + + // Some web apps rely on knowing HTTPS or not + if r.TLS != nil { + env["HTTPS"] = "on" + // and pass the protocol details in a manner compatible with apache's mod_ssl + // (which is why these have a SSL_ prefix and not TLS_). + v, ok := tlsProtocolStrings[r.TLS.Version] + if ok { + env["SSL_PROTOCOL"] = v + } + // and pass the cipher suite in a manner compatible with apache's mod_ssl + for k, v := range caddytls.SupportedCipherSuites { + if v == r.TLS.CipherSuite { + env["SSL_CIPHER"] = k + break + } + } + } + + // Add env variables from config (with support for placeholders in values) + repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) + for _, envVar := range t.EnvVars { + env[envVar[0]] = repl.ReplaceAll(envVar[1], "") + } + + // Add all HTTP headers to env variables + for field, val := range r.Header { + header := strings.ToUpper(field) + header = headerNameReplacer.Replace(header) + env["HTTP_"+header] = strings.Join(val, ", ") + } + return env, nil +} + +// splitPos returns the index where path should +// be split based on t.SplitPath. +func (t Transport) splitPos(path string) int { + // TODO: + // if httpserver.CaseSensitivePath { + // return strings.Index(path, r.SplitPath) + // } + return strings.Index(strings.ToLower(path), strings.ToLower(t.SplitPath)) +} + +// TODO: +// Map of supported protocols to Apache ssl_mod format +// Note that these are slightly different from SupportedProtocols in caddytls/config.go +var tlsProtocolStrings = map[uint16]string{ + tls.VersionTLS10: "TLSv1", + tls.VersionTLS11: "TLSv1.1", + tls.VersionTLS12: "TLSv1.2", + tls.VersionTLS13: "TLSv1.3", +} + +var headerNameReplacer = strings.NewReplacer(" ", "_", "-", "_") diff --git a/modules/caddyhttp/reverseproxy/healthchecker.go b/modules/caddyhttp/reverseproxy/healthchecker.go deleted file mode 100755 index c557d3fe..00000000 --- a/modules/caddyhttp/reverseproxy/healthchecker.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2015 Matthew Holt and The Caddy Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package reverseproxy - -import ( - "net/http" - "time" -) - -// Upstream represents the interface that must be satisfied to use the healthchecker. -type Upstream interface { - SetHealthiness(bool) -} - -// HealthChecker represents a worker that periodically evaluates if proxy upstream host is healthy. -type HealthChecker struct { - upstream Upstream - Ticker *time.Ticker - HTTPClient *http.Client - StopChan chan bool -} - -// ScheduleChecks periodically runs health checks against an upstream host. -func (h *HealthChecker) ScheduleChecks(url string) { - // check if a host is healthy on start vs waiting for timer - h.upstream.SetHealthiness(h.IsHealthy(url)) - stop := make(chan bool) - h.StopChan = stop - - go func() { - for { - select { - case <-h.Ticker.C: - h.upstream.SetHealthiness(h.IsHealthy(url)) - case <-stop: - return - } - } - }() -} - -// Stop stops the healthchecker from makeing further requests. -func (h *HealthChecker) Stop() { - h.Ticker.Stop() - close(h.StopChan) -} - -// IsHealthy attempts to check if a upstream host is healthy. -func (h *HealthChecker) IsHealthy(url string) bool { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return false - } - - resp, err := h.HTTPClient.Do(req) - if err != nil { - return false - } - - if resp.StatusCode < 200 || resp.StatusCode >= 400 { - return false - } - - return true -} - -// NewHealthCheckWorker returns a new instance of a HealthChecker. -func NewHealthCheckWorker(u Upstream, interval time.Duration, client *http.Client) *HealthChecker { - return &HealthChecker{ - upstream: u, - Ticker: time.NewTicker(interval), - HTTPClient: client, - } -} diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go new file mode 100644 index 00000000..36dd7767 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -0,0 +1,133 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "net" + "net/http" + "time" + + "github.com/caddyserver/caddy/v2" +) + +func init() { + caddy.RegisterModule(HTTPTransport{}) +} + +// TODO: This is the default transport, basically just http.Transport, but we define JSON struct tags... +type HTTPTransport struct { + // TODO: Actually this is where the TLS config should go, technically... + // as well as keepalives and dial timeouts... + // TODO: It's possible that other transports (like fastcgi) might be + // able to borrow/use at least some of these config fields; if so, + // move them into a type called CommonTransport and embed it + + TLS *TLSConfig `json:"tls,omitempty"` + KeepAlive *KeepAlive `json:"keep_alive,omitempty"` + Compression *bool `json:"compression,omitempty"` + MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // TODO: NOTE: we use our health check stuff to enforce max REQUESTS per host, but this is connections + DialTimeout caddy.Duration `json:"dial_timeout,omitempty"` + FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"` + ResponseHeaderTimeout caddy.Duration `json:"response_header_timeout,omitempty"` + ExpectContinueTimeout caddy.Duration `json:"expect_continue_timeout,omitempty"` + MaxResponseHeaderSize int64 `json:"max_response_header_size,omitempty"` + WriteBufferSize int `json:"write_buffer_size,omitempty"` + ReadBufferSize int `json:"read_buffer_size,omitempty"` + // TODO: ProxyConnectHeader? + + RoundTripper http.RoundTripper `json:"-"` +} + +// CaddyModule returns the Caddy module information. +func (HTTPTransport) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.transport.http", + New: func() caddy.Module { return new(HTTPTransport) }, + } +} + +func (h *HTTPTransport) Provision(ctx caddy.Context) error { + dialer := &net.Dialer{ + Timeout: time.Duration(h.DialTimeout), + FallbackDelay: time.Duration(h.FallbackDelay), + // TODO: Resolver + } + rt := &http.Transport{ + DialContext: dialer.DialContext, + MaxConnsPerHost: h.MaxConnsPerHost, + ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout), + ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout), + MaxResponseHeaderBytes: h.MaxResponseHeaderSize, + WriteBufferSize: h.WriteBufferSize, + ReadBufferSize: h.ReadBufferSize, + } + + if h.TLS != nil { + rt.TLSHandshakeTimeout = time.Duration(h.TLS.HandshakeTimeout) + // TODO: rest of TLS config + } + + if h.KeepAlive != nil { + dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval) + + if enabled := h.KeepAlive.Enabled; enabled != nil { + rt.DisableKeepAlives = !*enabled + } + rt.MaxIdleConns = h.KeepAlive.MaxIdleConns + rt.MaxIdleConnsPerHost = h.KeepAlive.MaxIdleConnsPerHost + rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout) + } + + if h.Compression != nil { + rt.DisableCompression = !*h.Compression + } + + h.RoundTripper = rt + + return nil +} + +func (h HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return h.RoundTripper.RoundTrip(req) +} + +type TLSConfig struct { + CAPool []string `json:"ca_pool,omitempty"` + ClientCertificate string `json:"client_certificate,omitempty"` + InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` + HandshakeTimeout caddy.Duration `json:"handshake_timeout,omitempty"` +} + +type KeepAlive struct { + Enabled *bool `json:"enabled,omitempty"` + ProbeInterval caddy.Duration `json:"probe_interval,omitempty"` + MaxIdleConns int `json:"max_idle_conns,omitempty"` + MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"` + IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle +} + +var ( + defaultDialer = net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + } + + // TODO: does this need to be configured to enable HTTP/2? + defaultTransport = &http.Transport{ + DialContext: defaultDialer.DialContext, + TLSHandshakeTimeout: 5 * time.Second, + IdleConnTimeout: 2 * time.Minute, + } +) diff --git a/modules/caddyhttp/reverseproxy/module.go b/modules/caddyhttp/reverseproxy/module.go deleted file mode 100755 index 21aca1d4..00000000 --- a/modules/caddyhttp/reverseproxy/module.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2015 Matthew Holt and The Caddy Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package reverseproxy - -import ( - "github.com/caddyserver/caddy/v2" - "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" - "github.com/caddyserver/caddy/v2/modules/caddyhttp" -) - -func init() { - caddy.RegisterModule(new(LoadBalanced)) - httpcaddyfile.RegisterHandlerDirective("reverse_proxy", parseCaddyfile) // TODO: "proxy"? -} - -// CaddyModule returns the Caddy module information. -func (*LoadBalanced) CaddyModule() caddy.ModuleInfo { - return caddy.ModuleInfo{ - Name: "http.handlers.reverse_proxy", - New: func() caddy.Module { return new(LoadBalanced) }, - } -} - -// parseCaddyfile sets up the handler from Caddyfile tokens. Syntax: -// -// proxy [] -// -// TODO: This needs to be finished. It definitely needs to be able to open a block... -func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) { - lb := new(LoadBalanced) - for h.Next() { - allTo := h.RemainingArgs() - if len(allTo) == 0 { - return nil, h.ArgErr() - } - for _, to := range allTo { - lb.Upstreams = append(lb.Upstreams, &UpstreamConfig{Host: to}) - } - } - return lb, nil -} diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go old mode 100755 new mode 100644 index 68393de5..e312d714 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -15,7 +15,9 @@ package reverseproxy import ( + "bytes" "context" + "encoding/json" "fmt" "io" "log" @@ -24,219 +26,229 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "time" + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" "golang.org/x/net/http/httpguts" ) -// ReverseProxy is an HTTP Handler that takes an incoming request and -// sends it to another server, proxying the response back to the -// client. -type ReverseProxy struct { - // Director must be a function which modifies - // the request into a new request to be sent - // using Transport. Its response is then copied - // back to the original client unmodified. - // Director must not access the provided Request - // after returning. - Director func(*http.Request) - - // The transport used to perform proxy requests. - // If nil, http.DefaultTransport is used. - Transport http.RoundTripper - - // FlushInterval specifies the flush interval - // to flush to the client while copying the - // response body. - // If zero, no periodic flushing is done. - // A negative value means to flush immediately - // after each write to the client. - // The FlushInterval is ignored when ReverseProxy - // recognizes a response as a streaming response; - // for such responses, writes are flushed to the client - // immediately. - FlushInterval time.Duration - - // ErrorLog specifies an optional logger for errors - // that occur when attempting to proxy the request. - // If nil, logging goes to os.Stderr via the log package's - // standard logger. - ErrorLog *log.Logger - - // BufferPool optionally specifies a buffer pool to - // get byte slices for use by io.CopyBuffer when - // copying HTTP response bodies. - BufferPool BufferPool - - // ModifyResponse is an optional function that modifies the - // Response from the backend. It is called if the backend - // returns a response at all, with any HTTP status code. - // If the backend is unreachable, the optional ErrorHandler is - // called without any call to ModifyResponse. - // - // If ModifyResponse returns an error, ErrorHandler is called - // with its error value. If ErrorHandler is nil, its default - // implementation is used. - ModifyResponse func(*http.Response) error - - // ErrorHandler is an optional function that handles errors - // reaching the backend or errors from ModifyResponse. - // - // If nil, the default is to log the provided error and return - // a 502 Status Bad Gateway response. - ErrorHandler func(http.ResponseWriter, *http.Request, error) +func init() { + caddy.RegisterModule(Handler{}) } -// A BufferPool is an interface for getting and returning temporary -// byte slices for use by io.CopyBuffer. -type BufferPool interface { - Get() []byte - Put([]byte) +type Handler struct { + TransportRaw json.RawMessage `json:"transport,omitempty"` + LoadBalancing *LoadBalancing `json:"load_balancing,omitempty"` + HealthChecks *HealthChecks `json:"health_checks,omitempty"` + // UpstreamStorageRaw json.RawMessage `json:"upstream_storage,omitempty"` // TODO: + Upstreams HostPool `json:"upstreams,omitempty"` + + // UpstreamProvider UpstreamProvider `json:"-"` // TODO: + Transport http.RoundTripper `json:"-"` } -func singleJoiningSlash(a, b string) string { - aslash := strings.HasSuffix(a, "/") - bslash := strings.HasPrefix(b, "/") - switch { - case aslash && bslash: - return a + b[1:] - case !aslash && !bslash: - return a + "/" + b +// CaddyModule returns the Caddy module information. +func (Handler) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy", + New: func() caddy.Module { return new(Handler) }, } - return a + b } -// NewSingleHostReverseProxy returns a new ReverseProxy that routes -// URLs to the scheme, host, and base path provided in target. If the -// target's path is "/base" and the incoming request was for "/dir", -// the target request will be for /base/dir. -// NewSingleHostReverseProxy does not rewrite the Host header. -// To rewrite Host headers, use ReverseProxy directly with a custom -// Director policy. -func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { - targetQuery := target.RawQuery - director := func(req *http.Request) { - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) - if targetQuery == "" || req.URL.RawQuery == "" { - req.URL.RawQuery = targetQuery + req.URL.RawQuery - } else { - req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery +func (h *Handler) Provision(ctx caddy.Context) error { + if h.TransportRaw != nil { + val, err := ctx.LoadModuleInline("protocol", "http.handlers.reverse_proxy.transport", h.TransportRaw) + if err != nil { + return fmt.Errorf("loading transport module: %s", err) } - if _, ok := req.Header["User-Agent"]; !ok { - // explicitly disable User-Agent so it's not set to default value - req.Header.Set("User-Agent", "") + h.Transport = val.(http.RoundTripper) + h.TransportRaw = nil // allow GC to deallocate - TODO: Does this help? + } + if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil { + val, err := ctx.LoadModuleInline("policy", + "http.handlers.reverse_proxy.selection_policies", + h.LoadBalancing.SelectionPolicyRaw) + if err != nil { + return fmt.Errorf("loading load balancing selection module: %s", err) + } + h.LoadBalancing.SelectionPolicy = val.(Selector) + h.LoadBalancing.SelectionPolicyRaw = nil // allow GC to deallocate - TODO: Does this help? + } + + if h.Transport == nil { + h.Transport = defaultTransport + } + + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + if h.LoadBalancing.SelectionPolicy == nil { + h.LoadBalancing.SelectionPolicy = RandomSelection{} + } + if h.LoadBalancing.TryDuration > 0 && h.LoadBalancing.TryInterval == 0 { + // a non-zero try_duration with a zero try_interval + // will always spin the CPU for try_duration if the + // upstream is local or low-latency; default to some + // sane waiting period before try attempts + h.LoadBalancing.TryInterval = caddy.Duration(250 * time.Millisecond) + } + + for _, upstream := range h.Upstreams { + // url parser requires a scheme + if !strings.Contains(upstream.Address, "://") { + upstream.Address = "http://" + upstream.Address + } + u, err := url.Parse(upstream.Address) + if err != nil { + return fmt.Errorf("invalid upstream address %s: %v", upstream.Address, err) + } + upstream.hostURL = u + + // if host already exists from a current config, + // use that instead; otherwise, add it + // TODO: make hosts modular, so that their state can be distributed in enterprise for example + // TODO: If distributed, the pool should be stored in storage... + var host Host = new(upstreamHost) + activeHost, loaded := hosts.LoadOrStore(u.String(), host) + if loaded { + host = activeHost.(Host) + } + upstream.Host = host + + // if the passive health checker has a non-zero "unhealthy + // request count" but the upstream has no MaxRequests set + // (they are the same thing, but one is a default value for + // for upstreams with a zero MaxRequests), copy the default + // value into this upstream, since the value in the upstream + // is what is used during availability checks + if h.HealthChecks != nil && + h.HealthChecks.Passive != nil && + h.HealthChecks.Passive.UnhealthyRequestCount > 0 && + upstream.MaxRequests == 0 { + upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount + } + + // TODO: active health checks + + if h.HealthChecks != nil { + // upstreams need independent access to the passive + // health check policy so they can, you know, passively + // do health checks + upstream.healthCheckPolicy = h.HealthChecks.Passive } } - return &ReverseProxy{Director: director} + + return nil } -func copyHeader(dst, src http.Header) { - for k, vv := range src { - for _, v := range vv { - dst.Add(k, v) - } +func (h *Handler) Cleanup() error { + // TODO: finish this up, make sure it takes care of any active health checkers or whatever + for _, upstream := range h.Upstreams { + hosts.Delete(upstream.hostURL.String()) } + return nil } -func cloneHeader(h http.Header) http.Header { - h2 := make(http.Header, len(h)) - for k, vv := range h { - vv2 := make([]string, len(vv)) - copy(vv2, vv) - h2[k] = vv2 - } - return h2 -} - -// Hop-by-hop headers. These are removed when sent to the backend. -// As of RFC 7230, hop-by-hop headers are required to appear in the -// Connection header field. These are the headers defined by the -// obsoleted RFC 2616 (section 13.5.1) and are used for backward -// compatibility. -var hopHeaders = []string{ - "Connection", - "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Authorization", - "Te", // canonicalized version of "TE" - "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 - "Transfer-Encoding", - "Upgrade", -} - -func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { - p.logf("http: proxy error: %v", err) - rw.WriteHeader(http.StatusBadGateway) -} - -func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { - if p.ErrorHandler != nil { - return p.ErrorHandler - } - return p.defaultErrorHandler -} - -// modifyResponse conditionally runs the optional ModifyResponse hook -// and reports whether the request should proceed. -func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { - if p.ModifyResponse == nil { - return true - } - if err := p.ModifyResponse(res); err != nil { - res.Body.Close() - p.getErrorHandler()(rw, req, err) - return false - } - return true -} - -func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*http.Response, error) { - transport := p.Transport - if transport == nil { - transport = http.DefaultTransport +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { + // prepare the request for proxying; this is needed only once + err := h.prepareRequest(r) + if err != nil { + return caddyhttp.Error(http.StatusInternalServerError, + fmt.Errorf("preparing request for upstream round-trip: %v", err)) } - ctx := req.Context() - if cn, ok := rw.(http.CloseNotifier); ok { - var cancel context.CancelFunc - ctx, cancel = context.WithCancel(ctx) - defer cancel() - notifyChan := cn.CloseNotify() - go func() { - select { - case <-notifyChan: - cancel() - case <-ctx.Done(): + start := time.Now() + + var proxyErr error + for { + // choose an available upstream + upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, r) + if upstream == nil { + if proxyErr == nil { + proxyErr = fmt.Errorf("no available upstreams") } - }() + if !h.tryAgain(start, proxyErr) { + break + } + continue + } + + // proxy the request to that upstream + proxyErr = h.reverseProxy(w, r, upstream) + if proxyErr == nil { + return nil + } + + // remember this failure (if enabled) + h.countFailure(upstream) + + // if we've tried long enough, break + if !h.tryAgain(start, proxyErr) { + break + } } - outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay + return caddyhttp.Error(http.StatusBadGateway, proxyErr) +} + +// prepareRequest modifies req so that it is ready to be proxied, +// except for directing to a specific upstream. This method mutates +// headers and other necessary properties of the request and should +// be done just once (before proxying) regardless of proxy retries. +// This assumes that no mutations of the request are performed +// by h during or after proxying. +func (h Handler) prepareRequest(req *http.Request) error { + // ctx := req.Context() + // TODO: do we need to support CloseNotifier? It was deprecated years ago. + // All this does is wrap CloseNotify with context cancel, for those responsewriters + // which didn't support context, but all the ones we'd use should nowadays, right? + // if cn, ok := rw.(http.CloseNotifier); ok { + // var cancel context.CancelFunc + // ctx, cancel = context.WithCancel(ctx) + // defer cancel() + // notifyChan := cn.CloseNotify() + // go func() { + // select { + // case <-notifyChan: + // cancel() + // case <-ctx.Done(): + // } + // }() + // } + + // TODO: do we need to call WithContext, since we won't be changing req.Context() above if we remove the CloseNotifier stuff? + // TODO: (This is where references to req were originally "outreq", a shallow clone, which I think is unnecessary in our case) + // req = req.WithContext(ctx) // includes shallow copies of maps, but okay if req.ContentLength == 0 { - outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries } - outreq.Header = cloneHeader(req.Header) + // TODO: is this needed? + // req.Header = cloneHeader(req.Header) - p.Director(outreq) - outreq.Close = false + req.Close = false - reqUpType := upgradeType(outreq.Header) - removeConnectionHeaders(outreq.Header) + // if User-Agent is not set by client, then explicitly + // disable it so it's not set to default value by std lib + if _, ok := req.Header["User-Agent"]; !ok { + req.Header.Set("User-Agent", "") + } + + reqUpType := upgradeType(req.Header) + removeConnectionHeaders(req.Header) // Remove hop-by-hop headers to the backend. Especially // important is "Connection" because we want a persistent // connection, regardless of what the client sent to us. for _, h := range hopHeaders { - hv := outreq.Header.Get(h) + hv := req.Header.Get(h) if hv == "" { continue } if h == "Te" && hv == "trailers" { - // Issue 21096: tell backend applications that + // Issue golang/go#21096: tell backend applications that // care about trailer support that we support // trailers. (We do, but we don't go out of // our way to advertise that unless the @@ -244,40 +256,65 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht // worth mentioning) continue } - outreq.Header.Del(h) + req.Header.Del(h) } // After stripping all the hop-by-hop connection headers above, add back any // necessary for protocol upgrades, such as for websockets. if reqUpType != "" { - outreq.Header.Set("Connection", "Upgrade") - outreq.Header.Set("Upgrade", reqUpType) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", reqUpType) } if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { // If we aren't the first proxy retain prior // X-Forwarded-For information as a comma+space // separated list and fold multiple headers into one. - if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + if prior, ok := req.Header["X-Forwarded-For"]; ok { clientIP = strings.Join(prior, ", ") + ", " + clientIP } - outreq.Header.Set("X-Forwarded-For", clientIP) + req.Header.Set("X-Forwarded-For", clientIP) } - res, err := transport.RoundTrip(outreq) + return nil +} + +// TODO: +// this code is the entry point to what was borrowed from the net/http/httputil package in the standard library. +func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, upstream *Upstream) error { + // TODO: count this active request + + // point the request to this upstream + h.directRequest(req, upstream) + + // do the round-trip + start := time.Now() + res, err := h.Transport.RoundTrip(req) + latency := time.Since(start) if err != nil { - p.getErrorHandler()(rw, outreq, err) - return nil, err + return err + } + + // perform passive health checks (if enabled) + if h.HealthChecks != nil && h.HealthChecks.Passive != nil { + // strike if the status code matches one that is "bad" + for _, badStatus := range h.HealthChecks.Passive.UnhealthyStatus { + if caddyhttp.StatusCodeMatches(res.StatusCode, badStatus) { + h.countFailure(upstream) + } + } + + // strike if the roundtrip took too long + if h.HealthChecks.Passive.UnhealthyLatency > 0 && + latency >= time.Duration(h.HealthChecks.Passive.UnhealthyLatency) { + h.countFailure(upstream) + } } // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) if res.StatusCode == http.StatusSwitchingProtocols { - if !p.modifyResponse(rw, res, outreq) { - return res, nil - } - - p.handleUpgradeResponse(rw, outreq, res) - return res, nil + h.handleUpgradeResponse(rw, req, res) + return nil } removeConnectionHeaders(res.Header) @@ -286,10 +323,6 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht res.Header.Del(h) } - if !p.modifyResponse(rw, res, outreq) { - return res, nil - } - copyHeader(rw.Header(), res.Header) // The "Trailer" header isn't included in the Transport's response, @@ -305,15 +338,16 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht rw.WriteHeader(res.StatusCode) - err = p.copyResponse(rw, res.Body, p.flushInterval(req, res)) + err = h.copyResponse(rw, res.Body, h.flushInterval(req, res)) if err != nil { defer res.Body.Close() // Since we're streaming the response, if we run into an error all we can do - // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler + // is abort the request. Issue golang/go#23643: ReverseProxy should use ErrAbortHandler // on read error while copying body. + // TODO: Look into whether we want to panic at all in our case... if !shouldPanicOnCopyError(req) { - p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) - return nil, err + // p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) + return err } panic(http.ErrAbortHandler) @@ -331,7 +365,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht if len(res.Trailer) == announcedTrailers { copyHeader(rw.Header(), res.Trailer) - return res, nil + return nil } for k, vv := range res.Trailer { @@ -341,46 +375,90 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht } } - return res, nil + return nil } -var inOurTests bool // whether we're in our own tests - -// shouldPanicOnCopyError reports whether the reverse proxy should -// panic with http.ErrAbortHandler. This is the right thing to do by -// default, but Go 1.10 and earlier did not, so existing unit tests -// weren't expecting panics. Only panic in our own tests, or when -// running under the HTTP server. -func shouldPanicOnCopyError(req *http.Request) bool { - if inOurTests { - // Our tests know to handle this panic. - return true +// tryAgain takes the time that the handler was initially invoked +// as well as any error currently obtained and returns true if +// another attempt should be made at proxying the request. If +// true is returned, it has already blocked long enough before +// the next retry (i.e. no more sleeping is needed). If false is +// returned, the handler should stop trying to proxy the request. +func (h Handler) tryAgain(start time.Time, proxyErr error) bool { + // if downstream has canceled the request, break + if proxyErr == context.Canceled { + return false } - if req.Context().Value(http.ServerContextKey) != nil { - // We seem to be running under an HTTP server, so - // it'll recover the panic. - return true + // if we've tried long enough, break + if time.Since(start) >= time.Duration(h.LoadBalancing.TryDuration) { + return false } - // Otherwise act like Go 1.10 and earlier to not break - // existing tests. - return false + // otherwise, wait and try the next available host + time.Sleep(time.Duration(h.LoadBalancing.TryInterval)) + return true } -// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. -// See RFC 7230, section 6.1 -func removeConnectionHeaders(h http.Header) { - if c := h.Get("Connection"); c != "" { - for _, f := range strings.Split(c, ",") { - if f = strings.TrimSpace(f); f != "" { - h.Del(f) - } - } +// directRequest modifies only req.URL so that it points to the +// given upstream host. It must modify ONLY the request URL. +func (h Handler) directRequest(req *http.Request, upstream *Upstream) { + target := upstream.hostURL + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) // TODO: This might be a bug (if any part of the path was augmented from a previously-tried upstream; need to start from clean original path of request, same for query string!) + if target.RawQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = target.RawQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = target.RawQuery + "&" + req.URL.RawQuery } } +func (h Handler) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { + reqUpType := upgradeType(req.Header) + resUpType := upgradeType(res.Header) + if reqUpType != resUpType { + // p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + return + } + + copyHeader(res.Header, rw.Header()) + + hj, ok := rw.(http.Hijacker) + if !ok { + // p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + // p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) + return + } + defer backConn.Close() + conn, brw, err := hj.Hijack() + if err != nil { + // p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) + return + } + defer conn.Close() + res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above + if err := res.Write(brw); err != nil { + // p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) + return + } + if err := brw.Flush(); err != nil { + // p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) + return + } + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + <-errc + return +} + // flushInterval returns the p.FlushInterval value, conditionally // overriding its value for a specific request/response. -func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration { +func (h Handler) flushInterval(req *http.Request, res *http.Response) time.Duration { resCT := res.Header.Get("Content-Type") // For Server-Sent Events responses, flush immediately. @@ -390,10 +468,11 @@ func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time } // TODO: more specific cases? e.g. res.ContentLength == -1? - return p.FlushInterval + // return h.FlushInterval + return 0 } -func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { +func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { if flushInterval != 0 { if wf, ok := dst.(writeFlusher); ok { mlw := &maxLatencyWriter{ @@ -405,18 +484,25 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval } } + // TODO: Figure out how we want to do this... using custom buffer pool type seems unnecessary + // or maybe it is, depending on how we want to handle errors, + // see: https://github.com/golang/go/issues/21814 + // buf := bufPool.Get().(*bytes.Buffer) + // buf.Reset() + // defer bufPool.Put(buf) + // _, err := io.CopyBuffer(dst, src, ) var buf []byte - if p.BufferPool != nil { - buf = p.BufferPool.Get() - defer p.BufferPool.Put(buf) - } - _, err := p.copyBuffer(dst, src, buf) + // if h.BufferPool != nil { + // buf = h.BufferPool.Get() + // defer h.BufferPool.Put(buf) + // } + _, err := h.copyBuffer(dst, src, buf) return err } // copyBuffer returns any write errors or non-EOF read errors, and the amount // of bytes written. -func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { +func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { if len(buf) == 0 { buf = make([]byte, 32*1024) } @@ -424,7 +510,13 @@ func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int for { nr, rerr := src.Read(buf) if rerr != nil && rerr != io.EOF && rerr != context.Canceled { - p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) + // TODO: this could be useful to know (indeed, it revealed an error in our + // fastcgi PoC earlier; but it's this single error report here that necessitates + // a function separate from io.CopyBuffer, since io.CopyBuffer does not distinguish + // between read or write errors; in a reverse proxy situation, write errors are not + // something we need to report to the client, but read errors are a problem on our + // end for sure. so we need to decide what we want.) + // p.logf("copyBuffer: ReverseProxy read error during body copy: %v", rerr) } if nr > 0 { nw, werr := dst.Write(buf[:nr]) @@ -447,12 +539,36 @@ func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int } } -func (p *ReverseProxy) logf(format string, args ...interface{}) { - if p.ErrorLog != nil { - p.ErrorLog.Printf(format, args...) - } else { - log.Printf(format, args...) +// countFailure remembers 1 failure for upstream for the +// configured duration. If passive health checks are +// disabled or failure expiry is 0, this is a no-op. +func (h Handler) countFailure(upstream *Upstream) { + // only count failures if passive health checking is enabled + // and if failures are configured have a non-zero expiry + if h.HealthChecks == nil || h.HealthChecks.Passive == nil { + return } + failDuration := time.Duration(h.HealthChecks.Passive.FailDuration) + if failDuration == 0 { + return + } + + // count failure immediately + err := upstream.Host.CountFail(1) + if err != nil { + log.Printf("[ERROR] proxy: upstream %s: counting failure: %v", + upstream.hostURL, err) + } + + // forget it later + go func(host Host, failDuration time.Duration) { + time.Sleep(failDuration) + err := host.CountFail(-1) + if err != nil { + log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v", + upstream.hostURL, err) + } + }(upstream.Host, failDuration) } type writeFlusher interface { @@ -508,57 +624,6 @@ func (m *maxLatencyWriter) stop() { } } -func upgradeType(h http.Header) string { - if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { - return "" - } - return strings.ToLower(h.Get("Upgrade")) -} - -func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { - reqUpType := upgradeType(req.Header) - resUpType := upgradeType(res.Header) - if reqUpType != resUpType { - p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) - return - } - - copyHeader(res.Header, rw.Header()) - - hj, ok := rw.(http.Hijacker) - if !ok { - p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) - return - } - backConn, ok := res.Body.(io.ReadWriteCloser) - if !ok { - p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) - return - } - defer backConn.Close() - conn, brw, err := hj.Hijack() - if err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) - return - } - defer conn.Close() - res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above - if err := res.Write(brw); err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) - return - } - if err := brw.Flush(); err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) - return - } - errc := make(chan error, 1) - spc := switchProtocolCopier{user: conn, backend: backConn} - go spc.copyToBackend(errc) - go spc.copyFromBackend(errc) - <-errc - return -} - // switchProtocolCopier exists so goroutines proxying data back and // forth have nice names in stacks. type switchProtocolCopier struct { @@ -574,3 +639,224 @@ func (c switchProtocolCopier) copyToBackend(errc chan<- error) { _, err := io.Copy(c.backend, c.user) errc <- err } + +// shouldPanicOnCopyError reports whether the reverse proxy should +// panic with http.ErrAbortHandler. This is the right thing to do by +// default, but Go 1.10 and earlier did not, so existing unit tests +// weren't expecting panics. Only panic in our own tests, or when +// running under the HTTP server. +// TODO: I don't know if we want this at all... +func shouldPanicOnCopyError(req *http.Request) bool { + // if inOurTests { + // // Our tests know to handle this panic. + // return true + // } + if req.Context().Value(http.ServerContextKey) != nil { + // We seem to be running under an HTTP server, so + // it'll recover the panic. + return true + } + // Otherwise act like Go 1.10 and earlier to not break + // existing tests. + return false +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func cloneHeader(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +func upgradeType(h http.Header) string { + if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { + return "" + } + return strings.ToLower(h.Get("Upgrade")) +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. +// See RFC 7230, section 6.1 +func removeConnectionHeaders(h http.Header) { + if c := h.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + h.Del(f) + } + } + } +} + +type LoadBalancing struct { + SelectionPolicyRaw json.RawMessage `json:"selection_policy,omitempty"` + TryDuration caddy.Duration `json:"try_duration,omitempty"` + TryInterval caddy.Duration `json:"try_interval,omitempty"` + + SelectionPolicy Selector `json:"-"` +} + +type Selector interface { + Select(HostPool, *http.Request) *Upstream +} + +type HealthChecks struct { + Active *ActiveHealthChecks `json:"active,omitempty"` + Passive *PassiveHealthChecks `json:"passive,omitempty"` +} + +type ActiveHealthChecks struct { + Path string `json:"path,omitempty"` + Port int `json:"port,omitempty"` + Interval caddy.Duration `json:"interval,omitempty"` + Timeout caddy.Duration `json:"timeout,omitempty"` + MaxSize int `json:"max_size,omitempty"` + ExpectStatus int `json:"expect_status,omitempty"` + ExpectBody string `json:"expect_body,omitempty"` +} + +type PassiveHealthChecks struct { + MaxFails int `json:"max_fails,omitempty"` + FailDuration caddy.Duration `json:"fail_duration,omitempty"` + UnhealthyRequestCount int `json:"unhealthy_request_count,omitempty"` + UnhealthyStatus []int `json:"unhealthy_status,omitempty"` + UnhealthyLatency caddy.Duration `json:"unhealthy_latency,omitempty"` +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + +var bufPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + +////////////////////////////////// +// TODO: + +type Host interface { + NumRequests() int + Fails() int + Unhealthy() bool + + CountRequest(int) error + CountFail(int) error +} + +type HostPool []*Upstream + +type upstreamHost struct { + numRequests int64 // must be first field to be 64-bit aligned on 32-bit systems (see https://golang.org/pkg/sync/atomic/#pkg-note-BUG) + fails int64 + unhealthy int32 +} + +func (uh upstreamHost) NumRequests() int { + return int(atomic.LoadInt64(&uh.numRequests)) +} +func (uh upstreamHost) Fails() int { + return int(atomic.LoadInt64(&uh.fails)) +} +func (uh upstreamHost) Unhealthy() bool { + return atomic.LoadInt32(&uh.unhealthy) == 1 +} +func (uh *upstreamHost) CountRequest(delta int) error { + result := atomic.AddInt64(&uh.numRequests, int64(delta)) + if result < 0 { + return fmt.Errorf("count below 0: %d", result) + } + return nil +} +func (uh *upstreamHost) CountFail(delta int) error { + result := atomic.AddInt64(&uh.fails, int64(delta)) + if result < 0 { + return fmt.Errorf("count below 0: %d", result) + } + return nil +} + +type Upstream struct { + Host `json:"-"` + + Address string `json:"address,omitempty"` + MaxRequests int `json:"max_requests,omitempty"` + + // TODO: This could be really cool, to say that requests with + // certain headers or from certain IPs always go to this upstream + // HeaderAffinity string + // IPAffinity string + + healthCheckPolicy *PassiveHealthChecks + + hostURL *url.URL +} + +func (u Upstream) Available() bool { + return u.Healthy() && !u.Full() +} + +func (u Upstream) Healthy() bool { + healthy := !u.Host.Unhealthy() + if healthy && u.healthCheckPolicy != nil { + healthy = u.Host.Fails() < u.healthCheckPolicy.MaxFails + } + return healthy +} + +func (u Upstream) Full() bool { + return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests +} + +func (u Upstream) URL() *url.URL { + return u.hostURL +} + +var hosts = caddy.NewUsagePool() + +// TODO: ... +type UpstreamProvider interface { +} + +// Interface guards +var ( + _ caddyhttp.MiddlewareHandler = (*Handler)(nil) + _ caddy.Provisioner = (*Handler)(nil) + _ caddy.CleanerUpper = (*Handler)(nil) +) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go new file mode 100644 index 00000000..e0518c90 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go @@ -0,0 +1,351 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "fmt" + "hash/fnv" + weakrand "math/rand" + "net" + "net/http" + "sync/atomic" + "time" + + "github.com/caddyserver/caddy/v2" +) + +func init() { + caddy.RegisterModule(RandomSelection{}) + caddy.RegisterModule(RandomChoiceSelection{}) + caddy.RegisterModule(LeastConnSelection{}) + caddy.RegisterModule(RoundRobinSelection{}) + caddy.RegisterModule(FirstSelection{}) + caddy.RegisterModule(IPHashSelection{}) + caddy.RegisterModule(URIHashSelection{}) + caddy.RegisterModule(HeaderHashSelection{}) + + weakrand.Seed(time.Now().UTC().UnixNano()) +} + +// RandomSelection is a policy that selects +// an available host at random. +type RandomSelection struct{} + +// CaddyModule returns the Caddy module information. +func (RandomSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.random", + New: func() caddy.Module { return new(RandomSelection) }, + } +} + +// Select returns an available host, if any. +func (r RandomSelection) Select(pool HostPool, request *http.Request) *Upstream { + // use reservoir sampling because the number of available + // hosts isn't known: https://en.wikipedia.org/wiki/Reservoir_sampling + var randomHost *Upstream + var count int + for _, upstream := range pool { + if !upstream.Available() { + continue + } + // (n % 1 == 0) holds for all n, therefore a + // upstream will always be chosen if there is at + // least one available + count++ + if (weakrand.Int() % count) == 0 { + randomHost = upstream + } + } + return randomHost +} + +// RandomChoiceSelection is a policy that selects +// two or more available hosts at random, then +// chooses the one with the least load. +type RandomChoiceSelection struct { + Choose int `json:"choose,omitempty"` +} + +// CaddyModule returns the Caddy module information. +func (RandomChoiceSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.random_choice", + New: func() caddy.Module { return new(RandomChoiceSelection) }, + } +} + +func (r *RandomChoiceSelection) Provision(ctx caddy.Context) error { + if r.Choose == 0 { + r.Choose = 2 + } + return nil +} + +func (r RandomChoiceSelection) Validate() error { + if r.Choose < 2 { + return fmt.Errorf("choose must be at least 2") + } + return nil +} + +// Select returns an available host, if any. +func (r RandomChoiceSelection) Select(pool HostPool, _ *http.Request) *Upstream { + k := r.Choose + if k > len(pool) { + k = len(pool) + } + choices := make([]*Upstream, k) + for i, upstream := range pool { + if !upstream.Available() { + continue + } + j := weakrand.Intn(i) + if j < k { + choices[j] = upstream + } + } + return leastRequests(choices) +} + +// LeastConnSelection is a policy that selects the +// host with the least active requests. If multiple +// hosts have the same fewest number, one is chosen +// randomly. The term "conn" or "connection" is used +// in this policy name due to its similar meaning in +// other software, but our load balancer actually +// counts active requests rather than connections, +// since these days requests are multiplexed onto +// shared connections. +type LeastConnSelection struct{} + +// CaddyModule returns the Caddy module information. +func (LeastConnSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.least_conn", + New: func() caddy.Module { return new(LeastConnSelection) }, + } +} + +// Select selects the up host with the least number of connections in the +// pool. If more than one host has the same least number of connections, +// one of the hosts is chosen at random. +func (LeastConnSelection) Select(pool HostPool, _ *http.Request) *Upstream { + var bestHost *Upstream + var count int + var leastReqs int + + for _, host := range pool { + if !host.Available() { + continue + } + numReqs := host.NumRequests() + if numReqs < leastReqs { + leastReqs = numReqs + count = 0 + } + + // among hosts with same least connections, perform a reservoir + // sample: https://en.wikipedia.org/wiki/Reservoir_sampling + if numReqs == leastReqs { + count++ + if (weakrand.Int() % count) == 0 { + bestHost = host + } + } + } + + return bestHost +} + +// RoundRobinSelection is a policy that selects +// a host based on round-robin ordering. +type RoundRobinSelection struct { + robin uint32 +} + +// CaddyModule returns the Caddy module information. +func (RoundRobinSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.round_robin", + New: func() caddy.Module { return new(RoundRobinSelection) }, + } +} + +// Select returns an available host, if any. +func (r *RoundRobinSelection) Select(pool HostPool, _ *http.Request) *Upstream { + n := uint32(len(pool)) + if n == 0 { + return nil + } + for i := uint32(0); i < n; i++ { + atomic.AddUint32(&r.robin, 1) + host := pool[r.robin%n] + if host.Available() { + return host + } + } + return nil +} + +// FirstSelection is a policy that selects +// the first available host. +type FirstSelection struct{} + +// CaddyModule returns the Caddy module information. +func (FirstSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.first", + New: func() caddy.Module { return new(FirstSelection) }, + } +} + +// Select returns an available host, if any. +func (FirstSelection) Select(pool HostPool, _ *http.Request) *Upstream { + for _, host := range pool { + if host.Available() { + return host + } + } + return nil +} + +// IPHashSelection is a policy that selects a host +// based on hashing the remote IP of the request. +type IPHashSelection struct{} + +// CaddyModule returns the Caddy module information. +func (IPHashSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.ip_hash", + New: func() caddy.Module { return new(IPHashSelection) }, + } +} + +// Select returns an available host, if any. +func (IPHashSelection) Select(pool HostPool, req *http.Request) *Upstream { + clientIP, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + clientIP = req.RemoteAddr + } + return hostByHashing(pool, clientIP) +} + +// URIHashSelection is a policy that selects a +// host by hashing the request URI. +type URIHashSelection struct{} + +// CaddyModule returns the Caddy module information. +func (URIHashSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.uri_hash", + New: func() caddy.Module { return new(URIHashSelection) }, + } +} + +// Select returns an available host, if any. +func (URIHashSelection) Select(pool HostPool, req *http.Request) *Upstream { + return hostByHashing(pool, req.RequestURI) +} + +// HeaderHashSelection is a policy that selects +// a host based on a given request header. +type HeaderHashSelection struct { + Field string `json:"field,omitempty"` +} + +// CaddyModule returns the Caddy module information. +func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.header", + New: func() caddy.Module { return new(HeaderHashSelection) }, + } +} + +// Select returns an available host, if any. +func (s HeaderHashSelection) Select(pool HostPool, req *http.Request) *Upstream { + if s.Field == "" { + return nil + } + val := req.Header.Get(s.Field) + if val == "" { + return RandomSelection{}.Select(pool, req) + } + return hostByHashing(pool, val) +} + +// leastRequests returns the host with the +// least number of active requests to it. +// If more than one host has the same +// least number of active requests, then +// one of those is chosen at random. +func leastRequests(upstreams []*Upstream) *Upstream { + if len(upstreams) == 0 { + return nil + } + var best []*Upstream + var bestReqs int + for _, upstream := range upstreams { + reqs := upstream.NumRequests() + if reqs == 0 { + return upstream + } + if reqs <= bestReqs { + bestReqs = reqs + best = append(best, upstream) + } + } + return best[weakrand.Intn(len(best))] +} + +// hostByHashing returns an available host +// from pool based on a hashable string s. +func hostByHashing(pool []*Upstream, s string) *Upstream { + poolLen := uint32(len(pool)) + if poolLen == 0 { + return nil + } + index := hash(s) % poolLen + for i := uint32(0); i < poolLen; i++ { + index += i + upstream := pool[index%poolLen] + if upstream.Available() { + return upstream + } + } + return nil +} + +// hash calculates a fast hash based on s. +func hash(s string) uint32 { + h := fnv.New32a() + h.Write([]byte(s)) + return h.Sum32() +} + +// Interface guards +var ( + _ Selector = (*RandomSelection)(nil) + _ Selector = (*RandomChoiceSelection)(nil) + _ Selector = (*LeastConnSelection)(nil) + _ Selector = (*RoundRobinSelection)(nil) + _ Selector = (*FirstSelection)(nil) + _ Selector = (*IPHashSelection)(nil) + _ Selector = (*URIHashSelection)(nil) + _ Selector = (*HeaderHashSelection)(nil) + + _ caddy.Validator = (*RandomChoiceSelection)(nil) + _ caddy.Provisioner = (*RandomChoiceSelection)(nil) +) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go new file mode 100644 index 00000000..8006fb1b --- /dev/null +++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go @@ -0,0 +1,363 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +// TODO: finish migrating these + +// import ( +// "net/http" +// "net/http/httptest" +// "os" +// "testing" +// ) + +// var workableServer *httptest.Server + +// func TestMain(m *testing.M) { +// workableServer = httptest.NewServer(http.HandlerFunc( +// func(w http.ResponseWriter, r *http.Request) { +// // do nothing +// })) +// r := m.Run() +// workableServer.Close() +// os.Exit(r) +// } + +// type customPolicy struct{} + +// func (customPolicy) Select(pool HostPool, _ *http.Request) Host { +// return pool[0] +// } + +// func testPool() HostPool { +// pool := []*UpstreamHost{ +// { +// Name: workableServer.URL, // this should resolve (healthcheck test) +// }, +// { +// Name: "http://localhost:99998", // this shouldn't +// }, +// { +// Name: "http://C", +// }, +// } +// return HostPool(pool) +// } + +// func TestRoundRobinPolicy(t *testing.T) { +// pool := testPool() +// rrPolicy := &RoundRobin{} +// request, _ := http.NewRequest("GET", "/", nil) + +// h := rrPolicy.Select(pool, request) +// // First selected host is 1, because counter starts at 0 +// // and increments before host is selected +// if h != pool[1] { +// t.Error("Expected first round robin host to be second host in the pool.") +// } +// h = rrPolicy.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected second round robin host to be third host in the pool.") +// } +// h = rrPolicy.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected third round robin host to be first host in the pool.") +// } +// // mark host as down +// pool[1].Unhealthy = 1 +// h = rrPolicy.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected to skip down host.") +// } +// // mark host as up +// pool[1].Unhealthy = 0 + +// h = rrPolicy.Select(pool, request) +// if h == pool[2] { +// t.Error("Expected to balance evenly among healthy hosts") +// } +// // mark host as full +// pool[1].Conns = 1 +// pool[1].MaxConns = 1 +// h = rrPolicy.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected to skip full host.") +// } +// } + +// func TestLeastConnPolicy(t *testing.T) { +// pool := testPool() +// lcPolicy := &LeastConn{} +// request, _ := http.NewRequest("GET", "/", nil) + +// pool[0].Conns = 10 +// pool[1].Conns = 10 +// h := lcPolicy.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected least connection host to be third host.") +// } +// pool[2].Conns = 100 +// h = lcPolicy.Select(pool, request) +// if h != pool[0] && h != pool[1] { +// t.Error("Expected least connection host to be first or second host.") +// } +// } + +// func TestCustomPolicy(t *testing.T) { +// pool := testPool() +// customPolicy := &customPolicy{} +// request, _ := http.NewRequest("GET", "/", nil) + +// h := customPolicy.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected custom policy host to be the first host.") +// } +// } + +// func TestIPHashPolicy(t *testing.T) { +// pool := testPool() +// ipHash := &IPHash{} +// request, _ := http.NewRequest("GET", "/", nil) +// // We should be able to predict where every request is routed. +// request.RemoteAddr = "172.0.0.1:80" +// h := ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } +// request.RemoteAddr = "172.0.0.2:80" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } +// request.RemoteAddr = "172.0.0.3:80" +// h = ipHash.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected ip hash policy host to be the third host.") +// } +// request.RemoteAddr = "172.0.0.4:80" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } + +// // we should get the same results without a port +// request.RemoteAddr = "172.0.0.1" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } +// request.RemoteAddr = "172.0.0.2" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } +// request.RemoteAddr = "172.0.0.3" +// h = ipHash.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected ip hash policy host to be the third host.") +// } +// request.RemoteAddr = "172.0.0.4" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } + +// // we should get a healthy host if the original host is unhealthy and a +// // healthy host is available +// request.RemoteAddr = "172.0.0.1" +// pool[1].Unhealthy = 1 +// h = ipHash.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected ip hash policy host to be the third host.") +// } + +// request.RemoteAddr = "172.0.0.2" +// h = ipHash.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected ip hash policy host to be the third host.") +// } +// pool[1].Unhealthy = 0 + +// request.RemoteAddr = "172.0.0.3" +// pool[2].Unhealthy = 1 +// h = ipHash.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected ip hash policy host to be the first host.") +// } +// request.RemoteAddr = "172.0.0.4" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } + +// // We should be able to resize the host pool and still be able to predict +// // where a request will be routed with the same IP's used above +// pool = []*UpstreamHost{ +// { +// Name: workableServer.URL, // this should resolve (healthcheck test) +// }, +// { +// Name: "http://localhost:99998", // this shouldn't +// }, +// } +// pool = HostPool(pool) +// request.RemoteAddr = "172.0.0.1:80" +// h = ipHash.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected ip hash policy host to be the first host.") +// } +// request.RemoteAddr = "172.0.0.2:80" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } +// request.RemoteAddr = "172.0.0.3:80" +// h = ipHash.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected ip hash policy host to be the first host.") +// } +// request.RemoteAddr = "172.0.0.4:80" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } + +// // We should get nil when there are no healthy hosts +// pool[0].Unhealthy = 1 +// pool[1].Unhealthy = 1 +// h = ipHash.Select(pool, request) +// if h != nil { +// t.Error("Expected ip hash policy host to be nil.") +// } +// } + +// func TestFirstPolicy(t *testing.T) { +// pool := testPool() +// firstPolicy := &First{} +// req := httptest.NewRequest(http.MethodGet, "/", nil) + +// h := firstPolicy.Select(pool, req) +// if h != pool[0] { +// t.Error("Expected first policy host to be the first host.") +// } + +// pool[0].Unhealthy = 1 +// h = firstPolicy.Select(pool, req) +// if h != pool[1] { +// t.Error("Expected first policy host to be the second host.") +// } +// } + +// func TestUriPolicy(t *testing.T) { +// pool := testPool() +// uriPolicy := &URIHash{} + +// request := httptest.NewRequest(http.MethodGet, "/test", nil) +// h := uriPolicy.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected uri policy host to be the first host.") +// } + +// pool[0].Unhealthy = 1 +// h = uriPolicy.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected uri policy host to be the first host.") +// } + +// request = httptest.NewRequest(http.MethodGet, "/test_2", nil) +// h = uriPolicy.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected uri policy host to be the second host.") +// } + +// // We should be able to resize the host pool and still be able to predict +// // where a request will be routed with the same URI's used above +// pool = []*UpstreamHost{ +// { +// Name: workableServer.URL, // this should resolve (healthcheck test) +// }, +// { +// Name: "http://localhost:99998", // this shouldn't +// }, +// } + +// request = httptest.NewRequest(http.MethodGet, "/test", nil) +// h = uriPolicy.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected uri policy host to be the first host.") +// } + +// pool[0].Unhealthy = 1 +// h = uriPolicy.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected uri policy host to be the first host.") +// } + +// request = httptest.NewRequest(http.MethodGet, "/test_2", nil) +// h = uriPolicy.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected uri policy host to be the second host.") +// } + +// pool[0].Unhealthy = 1 +// pool[1].Unhealthy = 1 +// h = uriPolicy.Select(pool, request) +// if h != nil { +// t.Error("Expected uri policy policy host to be nil.") +// } +// } + +// func TestHeaderPolicy(t *testing.T) { +// pool := testPool() +// tests := []struct { +// Name string +// Policy *Header +// RequestHeaderName string +// RequestHeaderValue string +// NilHost bool +// HostIndex int +// }{ +// {"empty config", &Header{""}, "", "", true, 0}, +// {"empty config+header+value", &Header{""}, "Affinity", "somevalue", true, 0}, +// {"empty config+header", &Header{""}, "Affinity", "", true, 0}, + +// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 1}, +// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 2}, +// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 0}, + +// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue", false, 1}, +// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue2", false, 0}, +// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue3", false, 2}, +// {"hash route with empty value", &Header{"Affinity"}, "Affinity", "", false, 1}, +// } + +// for idx, test := range tests { +// request, _ := http.NewRequest("GET", "/", nil) +// if test.RequestHeaderName != "" { +// request.Header.Add(test.RequestHeaderName, test.RequestHeaderValue) +// } + +// host := test.Policy.Select(pool, request) +// if test.NilHost && host != nil { +// t.Errorf("%d: Expected host to be nil", idx) +// } +// if !test.NilHost && host == nil { +// t.Errorf("%d: Did not expect host to be nil", idx) +// } +// if !test.NilHost && host != pool[test.HostIndex] { +// t.Errorf("%d: Expected Header policy to be host %d", idx, test.HostIndex) +// } +// } +// } diff --git a/modules/caddyhttp/reverseproxy/upstream.go b/modules/caddyhttp/reverseproxy/upstream.go deleted file mode 100755 index 1f0693e1..00000000 --- a/modules/caddyhttp/reverseproxy/upstream.go +++ /dev/null @@ -1,450 +0,0 @@ -// Copyright 2015 Matthew Holt and The Caddy Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package reverseproxy implements a load-balanced reverse proxy. -package reverseproxy - -import ( - "context" - "encoding/json" - "fmt" - "math/rand" - "net" - "net/http" - "net/url" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/caddyserver/caddy/v2" - "github.com/caddyserver/caddy/v2/modules/caddyhttp" -) - -// CircuitBreaker defines the functionality of a circuit breaker module. -type CircuitBreaker interface { - Ok() bool - RecordMetric(statusCode int, latency time.Duration) -} - -type noopCircuitBreaker struct{} - -func (ncb noopCircuitBreaker) RecordMetric(statusCode int, latency time.Duration) {} -func (ncb noopCircuitBreaker) Ok() bool { - return true -} - -const ( - // TypeBalanceRoundRobin represents the value to use for configuring a load balanced reverse proxy to use round robin load balancing. - TypeBalanceRoundRobin = iota - - // TypeBalanceRandom represents the value to use for configuring a load balanced reverse proxy to use random load balancing. - TypeBalanceRandom - - // TODO: add random with two choices - - // msgNoHealthyUpstreams is returned if there are no upstreams that are healthy to proxy a request to - msgNoHealthyUpstreams = "No healthy upstreams." - - // by default perform health checks every 30 seconds - defaultHealthCheckDur = time.Second * 30 - - // used when an upstream is unhealthy, health checks can be configured to perform at a faster rate - defaultFastHealthCheckDur = time.Second * 1 -) - -var ( - // defaultTransport is the default transport to use for the reverse proxy. - defaultTransport = &http.Transport{ - Dial: (&net.Dialer{ - Timeout: 5 * time.Second, - }).Dial, - TLSHandshakeTimeout: 5 * time.Second, - } - - // defaultHTTPClient is the default http client to use for the healthchecker. - defaultHTTPClient = &http.Client{ - Timeout: time.Second * 10, - Transport: defaultTransport, - } - - // typeMap maps caddy load balance configuration to the internal representation of the loadbalance algorithm type. - typeMap = map[string]int{ - "round_robin": TypeBalanceRoundRobin, - "random": TypeBalanceRandom, - } -) - -// NewLoadBalancedReverseProxy returns a collection of Upstreams that are to be loadbalanced. -func NewLoadBalancedReverseProxy(lb *LoadBalanced, ctx caddy.Context) error { - // set defaults - if lb.NoHealthyUpstreamsMessage == "" { - lb.NoHealthyUpstreamsMessage = msgNoHealthyUpstreams - } - - if lb.TryInterval == "" { - lb.TryInterval = "20s" - } - - // set request retry interval - ti, err := time.ParseDuration(lb.TryInterval) - if err != nil { - return fmt.Errorf("NewLoadBalancedReverseProxy: %v", err.Error()) - } - lb.tryInterval = ti - - // set load balance algorithm - t, ok := typeMap[lb.LoadBalanceType] - if !ok { - t = TypeBalanceRandom - } - lb.loadBalanceType = t - - // setup each upstream - var us []*upstream - for _, uc := range lb.Upstreams { - // pass the upstream decr and incr methods to keep track of unhealthy nodes - nu, err := newUpstream(uc, lb.decrUnhealthy, lb.incrUnhealthy) - if err != nil { - return err - } - - // setup any configured circuit breakers - var cbModule = "http.handlers.reverse_proxy.circuit_breaker" - var cb CircuitBreaker - - if uc.CircuitBreaker != nil { - if _, err := caddy.GetModule(cbModule); err == nil { - val, err := ctx.LoadModule(cbModule, uc.CircuitBreaker) - if err == nil { - cbv, ok := val.(CircuitBreaker) - if ok { - cb = cbv - } else { - fmt.Printf("\nerr: %v; cannot load circuit_breaker, using noop", err.Error()) - cb = noopCircuitBreaker{} - } - } else { - fmt.Printf("\nerr: %v; cannot load circuit_breaker, using noop", err.Error()) - cb = noopCircuitBreaker{} - } - } else { - fmt.Println("circuit_breaker module not loaded, using noop") - cb = noopCircuitBreaker{} - } - } else { - cb = noopCircuitBreaker{} - } - nu.CB = cb - - // start a healthcheck worker which will periodically check to see if an upstream is healthy - // to proxy requests to. - nu.healthChecker = NewHealthCheckWorker(nu, defaultHealthCheckDur, defaultHTTPClient) - - // TODO :- if path is empty why does this empty the entire Target? - // nu.Target.Path = uc.HealthCheckPath - - nu.healthChecker.ScheduleChecks(nu.Target.String()) - lb.HealthCheckers = append(lb.HealthCheckers, nu.healthChecker) - - us = append(us, nu) - } - - lb.upstreams = us - - return nil -} - -// LoadBalanced represents a collection of upstream hosts that are loadbalanced. It -// contains multiple features like health checking and circuit breaking functionality -// for upstreams. -type LoadBalanced struct { - mu sync.Mutex - numUnhealthy int32 - selectedServer int // used during round robin load balancing - loadBalanceType int - tryInterval time.Duration - upstreams []*upstream - - // The following struct fields are set by caddy configuration. - // TryInterval is the max duration for which request retrys will be performed for a request. - TryInterval string `json:"try_interval,omitempty"` - - // Upstreams are the configs for upstream hosts - Upstreams []*UpstreamConfig `json:"upstreams,omitempty"` - - // LoadBalanceType is the string representation of what loadbalancing algorithm to use. i.e. "random" or "round_robin". - LoadBalanceType string `json:"load_balance_type,omitempty"` - - // NoHealthyUpstreamsMessage is returned as a response when there are no healthy upstreams to loadbalance to. - NoHealthyUpstreamsMessage string `json:"no_healthy_upstreams_message,omitempty"` - - // TODO :- store healthcheckers as package level state where each upstream gets a single healthchecker - // currently a healthchecker is created for each upstream defined, even if a healthchecker was previously created - // for that upstream - HealthCheckers []*HealthChecker `json:"health_checkers,omitempty"` -} - -// Cleanup stops all health checkers on a loadbalanced reverse proxy. -func (lb *LoadBalanced) Cleanup() error { - for _, hc := range lb.HealthCheckers { - hc.Stop() - } - - return nil -} - -// Provision sets up a new loadbalanced reverse proxy. -func (lb *LoadBalanced) Provision(ctx caddy.Context) error { - return NewLoadBalancedReverseProxy(lb, ctx) -} - -// ServeHTTP implements the caddyhttp.MiddlewareHandler interface to -// dispatch an HTTP request to the proper server. -func (lb *LoadBalanced) ServeHTTP(w http.ResponseWriter, r *http.Request, _ caddyhttp.Handler) error { - // ensure requests don't hang if an upstream does not respond or is not eventually healthy - var u *upstream - var done bool - - retryTimer := time.NewTicker(lb.tryInterval) - defer retryTimer.Stop() - - go func() { - select { - case <-retryTimer.C: - done = true - } - }() - - // keep trying to get an available upstream to process the request - for { - switch lb.loadBalanceType { - case TypeBalanceRandom: - u = lb.random() - case TypeBalanceRoundRobin: - u = lb.roundRobin() - } - - // if we can't get an upstream and our retry interval has ended return an error response - if u == nil && done { - w.WriteHeader(http.StatusBadGateway) - fmt.Fprint(w, lb.NoHealthyUpstreamsMessage) - - return fmt.Errorf(msgNoHealthyUpstreams) - } - - // attempt to get an available upstream - if u == nil { - continue - } - - start := time.Now() - - // if we get an error retry until we get a healthy upstream - res, err := u.ReverseProxy.ServeHTTP(w, r) - if err != nil { - if err == context.Canceled { - return nil - } - - continue - } - - // record circuit breaker metrics - go u.CB.RecordMetric(res.StatusCode, time.Now().Sub(start)) - - return nil - } -} - -// incrUnhealthy increments the amount of unhealthy nodes in a loadbalancer. -func (lb *LoadBalanced) incrUnhealthy() { - atomic.AddInt32(&lb.numUnhealthy, 1) -} - -// decrUnhealthy decrements the amount of unhealthy nodes in a loadbalancer. -func (lb *LoadBalanced) decrUnhealthy() { - atomic.AddInt32(&lb.numUnhealthy, -1) -} - -// roundRobin implements a round robin load balancing algorithm to select -// which server to forward requests to. -func (lb *LoadBalanced) roundRobin() *upstream { - if atomic.LoadInt32(&lb.numUnhealthy) == int32(len(lb.upstreams)) { - return nil - } - - selected := lb.upstreams[lb.selectedServer] - - lb.mu.Lock() - lb.selectedServer++ - if lb.selectedServer >= len(lb.upstreams) { - lb.selectedServer = 0 - } - lb.mu.Unlock() - - if selected.IsHealthy() && selected.CB.Ok() { - return selected - } - - return nil -} - -// random implements a random server selector for load balancing. -func (lb *LoadBalanced) random() *upstream { - if atomic.LoadInt32(&lb.numUnhealthy) == int32(len(lb.upstreams)) { - return nil - } - - n := rand.Int() % len(lb.upstreams) - selected := lb.upstreams[n] - - if selected.IsHealthy() && selected.CB.Ok() { - return selected - } - - return nil -} - -// UpstreamConfig represents the config of an upstream. -type UpstreamConfig struct { - // Host is the host name of the upstream server. - Host string `json:"host,omitempty"` - - // FastHealthCheckDuration is the duration for which a health check is performed when a node is considered unhealthy. - FastHealthCheckDuration string `json:"fast_health_check_duration,omitempty"` - - CircuitBreaker json.RawMessage `json:"circuit_breaker,omitempty"` - - // // CircuitBreakerConfig is the config passed to setup a circuit breaker. - // CircuitBreakerConfig *circuitbreaker.Config `json:"circuit_breaker,omitempty"` - circuitbreaker CircuitBreaker - - // HealthCheckDuration is the default duration for which a health check is performed. - HealthCheckDuration string `json:"health_check_duration,omitempty"` - - // HealthCheckPath is the path at the upstream host to use for healthchecks. - HealthCheckPath string `json:"health_check_path,omitempty"` -} - -// upstream represents an upstream host. -type upstream struct { - Healthy int32 // 0 = false, 1 = true - Target *url.URL - ReverseProxy *ReverseProxy - Incr func() - Decr func() - CB CircuitBreaker - healthChecker *HealthChecker - healthCheckDur time.Duration - fastHealthCheckDur time.Duration -} - -// newUpstream returns a new upstream. -func newUpstream(uc *UpstreamConfig, d func(), i func()) (*upstream, error) { - host := strings.TrimSpace(uc.Host) - protoIdx := strings.Index(host, "://") - if protoIdx == -1 || len(host[:protoIdx]) == 0 { - return nil, fmt.Errorf("protocol is required for host") - } - - hostURL, err := url.Parse(host) - if err != nil { - return nil, err - } - - // parse healthcheck durations - hcd, err := time.ParseDuration(uc.HealthCheckDuration) - if err != nil { - hcd = defaultHealthCheckDur - } - - fhcd, err := time.ParseDuration(uc.FastHealthCheckDuration) - if err != nil { - fhcd = defaultFastHealthCheckDur - } - - u := upstream{ - healthCheckDur: hcd, - fastHealthCheckDur: fhcd, - Target: hostURL, - Decr: d, - Incr: i, - Healthy: int32(0), // assume is unhealthy on start - } - - u.ReverseProxy = newReverseProxy(hostURL, u.SetHealthiness) - return &u, nil -} - -// SetHealthiness sets whether an upstream is healthy or not. The health check worker is updated to -// perform checks faster if a node is unhealthy. -func (u *upstream) SetHealthiness(ok bool) { - h := atomic.LoadInt32(&u.Healthy) - var wasHealthy bool - if h == 1 { - wasHealthy = true - } else { - wasHealthy = false - } - - if ok { - u.healthChecker.Ticker = time.NewTicker(u.healthCheckDur) - - if !wasHealthy { - atomic.AddInt32(&u.Healthy, 1) - u.Decr() - } - } else { - u.healthChecker.Ticker = time.NewTicker(u.fastHealthCheckDur) - - if wasHealthy { - atomic.AddInt32(&u.Healthy, -1) - u.Incr() - } - } -} - -// IsHealthy returns whether an Upstream is healthy or not. -func (u *upstream) IsHealthy() bool { - i := atomic.LoadInt32(&u.Healthy) - if i == 1 { - return true - } - - return false -} - -// newReverseProxy returns a new reverse proxy handler. -func newReverseProxy(target *url.URL, setHealthiness func(bool)) *ReverseProxy { - errorHandler := func(w http.ResponseWriter, r *http.Request, err error) { - // we don't need to worry about cancelled contexts since this doesn't necessarilly mean that - // the upstream is unhealthy. - if err != context.Canceled { - setHealthiness(false) - } - } - - rp := NewSingleHostReverseProxy(target) - rp.ErrorHandler = errorHandler - rp.Transport = defaultTransport // use default transport that times out in 5 seconds - return rp -} - -// Interface guards -var ( - _ caddyhttp.MiddlewareHandler = (*LoadBalanced)(nil) - _ caddy.Provisioner = (*LoadBalanced)(nil) - _ caddy.CleanerUpper = (*LoadBalanced)(nil) -) diff --git a/usagepool.go b/usagepool.go new file mode 100644 index 00000000..3b8e975f --- /dev/null +++ b/usagepool.go @@ -0,0 +1,86 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package caddy + +import ( + "fmt" + "sync" + "sync/atomic" +) + +// UsagePool is a thread-safe map that pools values +// based on usage; a LoadOrStore operation increments +// the usage, and a Delete decrements from the usage. +// If the usage count reaches 0, the value will be +// removed from the map. There is no way to overwrite +// existing keys in the pool without first deleting +// it as many times as it was stored. Deleting too +// many times will panic. +// +// An empty UsagePool is NOT safe to use; always call +// NewUsagePool() to make a new value. +type UsagePool struct { + pool *sync.Map +} + +// NewUsagePool returns a new usage pool. +func NewUsagePool() *UsagePool { + return &UsagePool{pool: new(sync.Map)} +} + +// Delete decrements the usage count for key and removes the +// value from the underlying map if the usage is 0. It returns +// true if the usage count reached 0 and the value was deleted. +// It panics if the usage count drops below 0; always call +// Delete precisely as many times as LoadOrStore. +func (up *UsagePool) Delete(key interface{}) (deleted bool) { + usageVal, ok := up.pool.Load(key) + if !ok { + return false + } + upv := usageVal.(*usagePoolVal) + newUsage := atomic.AddInt32(&upv.usage, -1) + if newUsage == 0 { + up.pool.Delete(key) + return true + } else if newUsage < 0 { + panic(fmt.Sprintf("deleted more than stored: %#v (usage: %d)", + upv.value, upv.usage)) + } + return false +} + +// LoadOrStore puts val in the pool and returns false if key does +// not already exist; otherwise if the key exists, it loads the +// existing value, increments the usage for that value, and returns +// the value along with true. +func (up *UsagePool) LoadOrStore(key, val interface{}) (actual interface{}, loaded bool) { + usageVal := &usagePoolVal{ + usage: 1, + value: val, + } + actual, loaded = up.pool.LoadOrStore(key, usageVal) + if loaded { + upv := actual.(*usagePoolVal) + actual = upv.value + atomic.AddInt32(&upv.usage, 1) + } + return +} + +type usagePoolVal struct { + usage int32 // accessed atomically; must be 64-bit aligned for 32-bit systems + value interface{} +}