0
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2025-01-13 22:51:08 -05:00
caddy/caddyhttp/proxy/upstream.go
Leonard Hecker 8048e9c3bc proxy: Added unbuffered request optimization
If only one upstream is defined we don't need to buffer the body.
Instead we directly stream the body to the upstream host,
which reduces memory usage as well as latency.
Furthermore this enables different kinds of HTTP streaming
applications like gRPC for instance.
2017-01-11 19:38:14 +01:00

433 lines
9.5 KiB
Go

package proxy
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/mholt/caddy/caddyfile"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
var (
supportedPolicies = make(map[string]func() Policy)
)
type staticUpstream struct {
from string
upstreamHeaders http.Header
downstreamHeaders http.Header
Hosts HostPool
Policy Policy
KeepAlive int
FailTimeout time.Duration
TryDuration time.Duration
TryInterval time.Duration
MaxConns int64
HealthCheck struct {
Client http.Client
Path string
Interval time.Duration
Timeout time.Duration
}
WithoutPathPrefix string
IgnoredSubPaths []string
insecureSkipVerify bool
MaxFails int32
}
// NewStaticUpstreams parses the configuration input and sets up
// static upstreams for the proxy middleware.
func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) {
var upstreams []Upstream
for c.Next() {
upstream := &staticUpstream{
from: "",
upstreamHeaders: make(http.Header),
downstreamHeaders: make(http.Header),
Hosts: nil,
Policy: &Random{},
MaxFails: 1,
TryInterval: 250 * time.Millisecond,
MaxConns: 0,
KeepAlive: http.DefaultMaxIdleConnsPerHost,
}
if !c.Args(&upstream.from) {
return upstreams, c.ArgErr()
}
var to []string
for _, t := range c.RemainingArgs() {
parsed, err := parseUpstream(t)
if err != nil {
return upstreams, err
}
to = append(to, parsed...)
}
for c.NextBlock() {
switch c.Val() {
case "upstream":
if !c.NextArg() {
return upstreams, c.ArgErr()
}
parsed, err := parseUpstream(c.Val())
if err != nil {
return upstreams, err
}
to = append(to, parsed...)
default:
if err := parseBlock(&c, upstream); err != nil {
return upstreams, err
}
}
}
if len(to) == 0 {
return upstreams, c.ArgErr()
}
upstream.Hosts = make([]*UpstreamHost, len(to))
for i, host := range to {
uh, err := upstream.NewHost(host)
if err != nil {
return upstreams, err
}
upstream.Hosts[i] = uh
}
if upstream.HealthCheck.Path != "" {
upstream.HealthCheck.Client = http.Client{
Timeout: upstream.HealthCheck.Timeout,
}
go upstream.HealthCheckWorker(nil)
}
upstreams = append(upstreams, upstream)
}
return upstreams, nil
}
func (u *staticUpstream) From() string {
return u.from
}
func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
if !strings.HasPrefix(host, "http") &&
!strings.HasPrefix(host, "unix:") {
host = "http://" + host
}
uh := &UpstreamHost{
Name: host,
Conns: 0,
Fails: 0,
FailTimeout: u.FailTimeout,
Unhealthy: false,
UpstreamHeaders: u.upstreamHeaders,
DownstreamHeaders: u.downstreamHeaders,
CheckDown: func(u *staticUpstream) UpstreamHostDownFunc {
return func(uh *UpstreamHost) bool {
if uh.Unhealthy {
return true
}
if uh.Fails >= u.MaxFails {
return true
}
return false
}
}(u),
WithoutPathPrefix: u.WithoutPathPrefix,
MaxConns: u.MaxConns,
}
baseURL, err := url.Parse(uh.Name)
if err != nil {
return nil, err
}
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive)
if u.insecureSkipVerify {
uh.ReverseProxy.UseInsecureTransport()
}
return uh, nil
}
func parseUpstream(u string) ([]string, error) {
if !strings.HasPrefix(u, "unix:") {
colonIdx := strings.LastIndex(u, ":")
protoIdx := strings.Index(u, "://")
if colonIdx != -1 && colonIdx != protoIdx {
us := u[:colonIdx]
ue := ""
portsEnd := len(u)
if nextSlash := strings.Index(u[colonIdx:], "/"); nextSlash != -1 {
portsEnd = colonIdx + nextSlash
ue = u[portsEnd:]
}
ports := u[len(us)+1 : portsEnd]
if separators := strings.Count(ports, "-"); separators == 1 {
portsStr := strings.Split(ports, "-")
pIni, err := strconv.Atoi(portsStr[0])
if err != nil {
return nil, err
}
pEnd, err := strconv.Atoi(portsStr[1])
if err != nil {
return nil, err
}
if pEnd <= pIni {
return nil, fmt.Errorf("port range [%s] is invalid", ports)
}
hosts := []string{}
for p := pIni; p <= pEnd; p++ {
hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue))
}
return hosts, nil
}
}
}
return []string{u}, nil
}
func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
switch c.Val() {
case "policy":
if !c.NextArg() {
return c.ArgErr()
}
policyCreateFunc, ok := supportedPolicies[c.Val()]
if !ok {
return c.ArgErr()
}
u.Policy = policyCreateFunc()
case "fail_timeout":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
u.FailTimeout = dur
case "max_fails":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.Atoi(c.Val())
if err != nil {
return err
}
if n < 1 {
return c.Err("max_fails must be at least 1")
}
u.MaxFails = int32(n)
case "try_duration":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
u.TryDuration = dur
case "try_interval":
if !c.NextArg() {
return c.ArgErr()
}
interval, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
u.TryInterval = interval
case "max_conns":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.ParseInt(c.Val(), 10, 64)
if err != nil {
return err
}
u.MaxConns = n
case "health_check":
if !c.NextArg() {
return c.ArgErr()
}
u.HealthCheck.Path = c.Val()
// Set defaults
if u.HealthCheck.Interval == 0 {
u.HealthCheck.Interval = 30 * time.Second
}
if u.HealthCheck.Timeout == 0 {
u.HealthCheck.Timeout = 60 * time.Second
}
case "health_check_interval":
var interval string
if !c.Args(&interval) {
return c.ArgErr()
}
dur, err := time.ParseDuration(interval)
if err != nil {
return err
}
u.HealthCheck.Interval = dur
case "health_check_timeout":
var interval string
if !c.Args(&interval) {
return c.ArgErr()
}
dur, err := time.ParseDuration(interval)
if err != nil {
return err
}
u.HealthCheck.Timeout = dur
case "header_upstream":
var header, value string
if !c.Args(&header, &value) {
// When removing a header, the value can be optional.
if !strings.HasPrefix(header, "-") {
return c.ArgErr()
}
}
u.upstreamHeaders.Add(header, value)
case "header_downstream":
var header, value string
if !c.Args(&header, &value) {
// When removing a header, the value can be optional.
if !strings.HasPrefix(header, "-") {
return c.ArgErr()
}
}
u.downstreamHeaders.Add(header, value)
case "transparent":
u.upstreamHeaders.Add("Host", "{host}")
u.upstreamHeaders.Add("X-Real-IP", "{remote}")
u.upstreamHeaders.Add("X-Forwarded-For", "{remote}")
u.upstreamHeaders.Add("X-Forwarded-Proto", "{scheme}")
case "websocket":
u.upstreamHeaders.Add("Connection", "{>Connection}")
u.upstreamHeaders.Add("Upgrade", "{>Upgrade}")
case "without":
if !c.NextArg() {
return c.ArgErr()
}
u.WithoutPathPrefix = c.Val()
case "except":
ignoredPaths := c.RemainingArgs()
if len(ignoredPaths) == 0 {
return c.ArgErr()
}
u.IgnoredSubPaths = ignoredPaths
case "insecure_skip_verify":
u.insecureSkipVerify = true
case "keepalive":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.Atoi(c.Val())
if err != nil {
return err
}
if n < 0 {
return c.ArgErr()
}
u.KeepAlive = n
default:
return c.Errf("unknown property '%s'", c.Val())
}
return nil
}
func (u *staticUpstream) healthCheck() {
for _, host := range u.Hosts {
hostURL := host.Name + u.HealthCheck.Path
if r, err := u.HealthCheck.Client.Get(hostURL); err == nil {
io.Copy(ioutil.Discard, r.Body)
r.Body.Close()
host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
} else {
host.Unhealthy = true
}
}
}
func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
ticker := time.NewTicker(u.HealthCheck.Interval)
u.healthCheck()
for {
select {
case <-ticker.C:
u.healthCheck()
case <-stop:
// TODO: the library should provide a stop channel and global
// waitgroup to allow goroutines started by plugins a chance
// to clean themselves up.
}
}
}
func (u *staticUpstream) Select(r *http.Request) *UpstreamHost {
pool := u.Hosts
if len(pool) == 1 {
if !pool[0].Available() {
return nil
}
return pool[0]
}
allUnavailable := true
for _, host := range pool {
if host.Available() {
allUnavailable = false
break
}
}
if allUnavailable {
return nil
}
if u.Policy == nil {
return (&Random{}).Select(pool, r)
}
return u.Policy.Select(pool, r)
}
func (u *staticUpstream) AllowedPath(requestPath string) bool {
for _, ignoredSubPath := range u.IgnoredSubPaths {
if httpserver.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) {
return false
}
}
return true
}
// GetTryDuration returns u.TryDuration.
func (u *staticUpstream) GetTryDuration() time.Duration {
return u.TryDuration
}
// GetTryInterval returns u.TryInterval.
func (u *staticUpstream) GetTryInterval() time.Duration {
return u.TryInterval
}
func (u *staticUpstream) GetHostCount() int {
return len(u.Hosts)
}
// RegisterPolicy adds a custom policy to the proxy.
func RegisterPolicy(name string, policy func() Policy) {
supportedPolicies[name] = policy
}