diff --git a/caddyhttp/fastcgi/dialer.go b/caddyhttp/fastcgi/dialer.go index 58a8f156..0afd8c0a 100644 --- a/caddyhttp/fastcgi/dialer.go +++ b/caddyhttp/fastcgi/dialer.go @@ -1,10 +1,14 @@ package fastcgi -import "sync" +import ( + "errors" + "sync" + "sync/atomic" +) type dialer interface { - Dial() (*FCGIClient, error) - Close(*FCGIClient) error + Dial() (Client, error) + Close(Client) error } // basicDialer is a basic dialer that wraps default fcgi functions. @@ -12,8 +16,8 @@ type basicDialer struct { network, address string } -func (b basicDialer) Dial() (*FCGIClient, error) { return Dial(b.network, b.address) } -func (b basicDialer) Close(c *FCGIClient) error { return c.Close() } +func (b basicDialer) Dial() (Client, error) { return Dial(b.network, b.address) } +func (b basicDialer) Close(c Client) error { return c.Close() } // persistentDialer keeps a pool of fcgi connections. // connections are not closed after use, rather added back to the pool for reuse. @@ -21,11 +25,11 @@ type persistentDialer struct { size int network string address string - pool []*FCGIClient + pool []Client sync.Mutex } -func (p *persistentDialer) Dial() (*FCGIClient, error) { +func (p *persistentDialer) Dial() (Client, error) { p.Lock() // connection is available, return first one. if len(p.pool) > 0 { @@ -42,7 +46,7 @@ func (p *persistentDialer) Dial() (*FCGIClient, error) { return Dial(p.network, p.address) } -func (p *persistentDialer) Close(client *FCGIClient) error { +func (p *persistentDialer) Close(client Client) error { p.Lock() if len(p.pool) < p.size { // pool is not full yet, add connection for reuse @@ -57,3 +61,35 @@ func (p *persistentDialer) Close(client *FCGIClient) error { // otherwise, close the connection. return client.Close() } + +type loadBalancingDialer struct { + dialers []dialer + current int64 +} + +func (m *loadBalancingDialer) Dial() (Client, error) { + nextDialerIndex := atomic.AddInt64(&m.current, 1) % int64(len(m.dialers)) + currentDialer := m.dialers[nextDialerIndex] + + client, err := currentDialer.Dial() + + if err != nil { + return nil, err + } + + return &dialerAwareClient{Client: client, dialer: currentDialer}, nil +} + +func (m *loadBalancingDialer) Close(c Client) error { + // Close the client according to dialer behaviour + if da, ok := c.(*dialerAwareClient); ok { + return da.dialer.Close(c) + } + + return errors.New("Cannot close client") +} + +type dialerAwareClient struct { + Client + dialer dialer +} diff --git a/caddyhttp/fastcgi/dialer_test.go b/caddyhttp/fastcgi/dialer_test.go new file mode 100644 index 00000000..231d97aa --- /dev/null +++ b/caddyhttp/fastcgi/dialer_test.go @@ -0,0 +1,126 @@ +package fastcgi + +import ( + "errors" + "testing" +) + +func TestLoadbalancingDialer(t *testing.T) { + // given + runs := 100 + mockDialer1 := new(mockDialer) + mockDialer2 := new(mockDialer) + + dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1, mockDialer2}} + + // when + for i := 0; i < runs; i++ { + client, err := dialer.Dial() + dialer.Close(client) + + if err != nil { + t.Errorf("Expected error to be nil") + } + } + + // then + if mockDialer1.dialCalled != mockDialer2.dialCalled && mockDialer1.dialCalled != 50 { + t.Errorf("Expected dialer to call Dial() on multiple backend dialers %d times [actual: %d, %d]", 50, mockDialer1.dialCalled, mockDialer2.dialCalled) + } + + if mockDialer1.closeCalled != mockDialer2.closeCalled && mockDialer1.closeCalled != 50 { + t.Errorf("Expected dialer to call Close() on multiple backend dialers %d times [actual: %d, %d]", 50, mockDialer1.closeCalled, mockDialer2.closeCalled) + } +} + +func TestLoadBalancingDialerShouldReturnDialerAwareClient(t *testing.T) { + // given + mockDialer1 := new(mockDialer) + dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1}} + + // when + client, err := dialer.Dial() + + // then + if err != nil { + t.Errorf("Expected error to be nil") + } + + if awareClient, ok := client.(*dialerAwareClient); !ok { + t.Error("Expected dialer to wrap client") + } else { + if awareClient.dialer != mockDialer1 { + t.Error("Expected wrapped client to have reference to dialer") + } + } +} + +func TestLoadBalancingDialerShouldUnderlyingReturnDialerError(t *testing.T) { + // given + mockDialer1 := new(errorReturningDialer) + dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1}} + + // when + _, err := dialer.Dial() + + // then + if err.Error() != "Error during dial" { + t.Errorf("Expected 'Error during dial', got: '%s'", err.Error()) + } +} + +func TestLoadBalancingDialerShouldCloseClient(t *testing.T) { + // given + mockDialer1 := new(mockDialer) + mockDialer2 := new(mockDialer) + + dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1, mockDialer2}} + client, _ := dialer.Dial() + + // when + err := dialer.Close(client) + + // then + if err != nil { + t.Error("Expected error not to occur") + } + + // load balancing starts from index 1 + if mockDialer2.client != client { + t.Errorf("Expected Close() to be called on referenced dialer") + } +} + +type mockDialer struct { + dialCalled int + closeCalled int + client Client +} + +type mockClient struct { + Client +} + +func (m *mockDialer) Dial() (Client, error) { + m.dialCalled++ + return mockClient{Client: &FCGIClient{}}, nil +} + +func (m *mockDialer) Close(c Client) error { + m.client = c + m.closeCalled++ + return nil +} + +type errorReturningDialer struct { + client Client +} + +func (m *errorReturningDialer) Dial() (Client, error) { + return mockClient{Client: &FCGIClient{}}, errors.New("Error during dial") +} + +func (m *errorReturningDialer) Close(c Client) error { + m.client = c + return errors.New("Error during close") +} diff --git a/caddyhttp/fastcgi/fastcgi.go b/caddyhttp/fastcgi/fastcgi.go index 8d0f282a..90417943 100644 --- a/caddyhttp/fastcgi/fastcgi.go +++ b/caddyhttp/fastcgi/fastcgi.go @@ -111,9 +111,9 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) defer rule.dialer.Close(fcgiBackend) // Log any stderr output from upstream - if fcgiBackend.stderr.Len() != 0 { + if stderr := fcgiBackend.StdErr(); stderr.Len() != 0 { // Remove trailing newline, error logger already does this. - err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n")) + err = LogError(strings.TrimSuffix(stderr.String(), "\n")) } // Normally we would return the status code if it is an error status (>= 400), diff --git a/caddyhttp/fastcgi/fcgiclient.go b/caddyhttp/fastcgi/fcgiclient.go index 4f0d28d1..925a0689 100644 --- a/caddyhttp/fastcgi/fcgiclient.go +++ b/caddyhttp/fastcgi/fcgiclient.go @@ -106,6 +106,16 @@ const ( maxPad = 255 ) +// Client interface +type Client interface { + Get(pair map[string]string) (response *http.Response, err error) + Head(pair map[string]string) (response *http.Response, err error) + Options(pairs map[string]string) (response *http.Response, err error) + Post(pairs map[string]string, method string, bodyType string, body io.Reader, contentLength int) (response *http.Response, err error) + Close() error + StdErr() bytes.Buffer +} + type header struct { Version uint8 Type uint8 @@ -197,22 +207,29 @@ func (c *FCGIClient) Close() error { return c.rwc.Close() } -func (c *FCGIClient) writeRecord(recType uint8, content []byte) (err error) { +func (c *FCGIClient) writeRecord(recType uint8, content []byte) error { c.mutex.Lock() defer c.mutex.Unlock() c.buf.Reset() c.h.init(recType, c.reqID, len(content)) + if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { return err } + if _, err := c.buf.Write(content); err != nil { return err } + if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { return err } - _, err = c.rwc.Write(c.buf.Bytes()) - return err + + if _, err := c.rwc.Write(c.buf.Bytes()); err != nil { + return err + } + + return nil } func (c *FCGIClient) writeBeginRequest(role uint16, flags uint8) error { @@ -360,6 +377,11 @@ func (w *streamReader) Read(p []byte) (n int, err error) { return } +// StdErr returns stderr stream +func (c *FCGIClient) StdErr() bytes.Buffer { + return c.stderr +} + // Do made the request and returns a io.Reader that translates the data read // from fcgi responder out of fcgi packet before returning it. func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err error) { diff --git a/caddyhttp/fastcgi/setup.go b/caddyhttp/fastcgi/setup.go index 9b8bad6c..b4444f2b 100644 --- a/caddyhttp/fastcgi/setup.go +++ b/caddyhttp/fastcgi/setup.go @@ -5,6 +5,7 @@ import ( "net/http" "path/filepath" "strconv" + "strings" "github.com/mholt/caddy" "github.com/mholt/caddy/caddyhttp/httpserver" @@ -55,26 +56,21 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { args := c.RemainingArgs() - switch len(args) { - case 0: + if len(args) < 2 || len(args) > 3 { return rules, c.ArgErr() - case 1: - rule.Path = "/" - rule.Address = args[0] - case 2: - rule.Path = args[0] - rule.Address = args[1] - case 3: - rule.Path = args[0] - rule.Address = args[1] - err := fastcgiPreset(args[2], &rule) - if err != nil { - return rules, c.Err("Invalid fastcgi rule preset '" + args[2] + "'") + } + + rule.Path = args[0] + upstreams := []string{args[1]} + + if len(args) == 3 { + if err := fastcgiPreset(args[2], &rule); err != nil { + return rules, err } } - network, address := parseAddress(rule.Address) - rule.dialer = basicDialer{network: network, address: address} + var dialers []dialer + var poolSize = -1 for c.NextBlock() { switch c.Val() { @@ -94,6 +90,15 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { return rules, c.ArgErr() } rule.IndexFiles = args + + case "upstream": + args := c.RemainingArgs() + + if len(args) != 1 { + return rules, c.ArgErr() + } + + upstreams = append(upstreams, args[0]) case "env": envArgs := c.RemainingArgs() if len(envArgs) < 2 { @@ -106,6 +111,7 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { return rules, c.ArgErr() } rule.IgnoredSubPaths = ignoredPaths + case "pool": if !c.NextArg() { return rules, c.ArgErr() @@ -115,13 +121,24 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { return rules, err } if pool >= 0 { - rule.dialer = &persistentDialer{size: pool, network: network, address: address} + poolSize = pool } else { return rules, c.Errf("positive integer expected, found %d", pool) } } } + for _, rawAddress := range upstreams { + network, address := parseAddress(rawAddress) + if poolSize >= 0 { + dialers = append(dialers, &persistentDialer{size: poolSize, network: network, address: address}) + } else { + dialers = append(dialers, basicDialer{network: network, address: address}) + } + } + + rule.dialer = &loadBalancingDialer{dialers: dialers} + rule.Address = strings.Join(upstreams, ",") rules = append(rules, rule) } diff --git a/caddyhttp/fastcgi/setup_test.go b/caddyhttp/fastcgi/setup_test.go index e53bc605..95d2d765 100644 --- a/caddyhttp/fastcgi/setup_test.go +++ b/caddyhttp/fastcgi/setup_test.go @@ -76,9 +76,31 @@ func TestFastcgiParse(t *testing.T) { Address: "127.0.0.1:9000", Ext: ".php", SplitPath: ".php", - dialer: basicDialer{network: "tcp", address: "127.0.0.1:9000"}, + dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000"}}}, IndexFiles: []string{"index.php"}, }}}, + {`fastcgi /blog 127.0.0.1:9000 php { + upstream 127.0.0.1:9001 + }`, + false, []Rule{{ + Path: "/blog", + Address: "127.0.0.1:9000,127.0.0.1:9001", + Ext: ".php", + SplitPath: ".php", + dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000"}, basicDialer{network: "tcp", address: "127.0.0.1:9001"}}}, + IndexFiles: []string{"index.php"}, + }}}, + {`fastcgi /blog 127.0.0.1:9000 { + upstream 127.0.0.1:9001 + }`, + false, []Rule{{ + Path: "/blog", + Address: "127.0.0.1:9000,127.0.0.1:9001", + Ext: "", + SplitPath: "", + dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000"}, basicDialer{network: "tcp", address: "127.0.0.1:9001"}}}, + IndexFiles: []string{}, + }}}, {`fastcgi / ` + defaultAddress + ` { split .html }`, @@ -87,7 +109,7 @@ func TestFastcgiParse(t *testing.T) { Address: defaultAddress, Ext: "", SplitPath: ".html", - dialer: basicDialer{network: network, address: address}, + dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}}, IndexFiles: []string{}, }}}, {`fastcgi / ` + defaultAddress + ` { @@ -99,7 +121,7 @@ func TestFastcgiParse(t *testing.T) { Address: "127.0.0.1:9001", Ext: "", SplitPath: ".html", - dialer: basicDialer{network: network, address: address}, + dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}}, IndexFiles: []string{}, IgnoredSubPaths: []string{"/admin", "/user"}, }}}, @@ -111,18 +133,19 @@ func TestFastcgiParse(t *testing.T) { Address: defaultAddress, Ext: "", SplitPath: "", - dialer: &persistentDialer{size: 0, network: network, address: address}, + dialer: &loadBalancingDialer{dialers: []dialer{&persistentDialer{size: 0, network: network, address: address}}}, IndexFiles: []string{}, }}}, - {`fastcgi / ` + defaultAddress + ` { + {`fastcgi / 127.0.0.1:8080 { + upstream 127.0.0.1:9000 pool 5 }`, false, []Rule{{ Path: "/", - Address: defaultAddress, + Address: "127.0.0.1:8080,127.0.0.1:9000", Ext: "", SplitPath: "", - dialer: &persistentDialer{size: 5, network: network, address: address}, + dialer: &loadBalancingDialer{dialers: []dialer{&persistentDialer{size: 5, network: "tcp", address: "127.0.0.1:8080"}, &persistentDialer{size: 5, network: "tcp", address: "127.0.0.1:9000"}}}, IndexFiles: []string{}, }}}, {`fastcgi / ` + defaultAddress + ` { @@ -133,9 +156,14 @@ func TestFastcgiParse(t *testing.T) { Address: defaultAddress, Ext: "", SplitPath: ".php", - dialer: basicDialer{network: network, address: address}, + dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}}, IndexFiles: []string{}, }}}, + {`fastcgi / { + + }`, + true, []Rule{}, + }, } for i, test := range tests { actualFastcgiConfigs, err := fastcgiParse(caddy.NewTestController("http", test.inputFastcgiConfig)) @@ -175,20 +203,7 @@ func TestFastcgiParse(t *testing.T) { t.Errorf("Test %d expected %dth FastCGI dialer to be of type %T, but got %T", i, j, test.expectedFastcgiConfig[j].dialer, actualFastcgiConfig.dialer) } else { - equal := true - switch actual := actualFastcgiConfig.dialer.(type) { - case basicDialer: - equal = actualFastcgiConfig.dialer == test.expectedFastcgiConfig[j].dialer - case *persistentDialer: - if expected, ok := test.expectedFastcgiConfig[j].dialer.(*persistentDialer); ok { - equal = actual.Equals(expected) - } else { - equal = false - } - default: - t.Errorf("Unkonw dialer type %T", actualFastcgiConfig.dialer) - } - if !equal { + if !areDialersEqual(actualFastcgiConfig.dialer, test.expectedFastcgiConfig[j].dialer, t) { t.Errorf("Test %d expected %dth FastCGI dialer to be %v, but got %v", i, j, test.expectedFastcgiConfig[j].dialer, actualFastcgiConfig.dialer) } @@ -205,5 +220,31 @@ func TestFastcgiParse(t *testing.T) { } } } - +} + +func areDialersEqual(current, expected dialer, t *testing.T) bool { + + switch actual := current.(type) { + case *loadBalancingDialer: + if expected, ok := expected.(*loadBalancingDialer); ok { + for i := 0; i < len(actual.dialers); i++ { + if !areDialersEqual(actual.dialers[i], expected.dialers[i], t) { + return false + } + } + + return true + } + case basicDialer: + return current == expected + case *persistentDialer: + if expected, ok := expected.(*persistentDialer); ok { + return actual.Equals(expected) + } + + default: + t.Errorf("Unknown dialer type %T", current) + } + + return false }