// Copyright 2015 Matthew Holt and The Caddy Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package reverseproxy import ( "net/http" "strconv" "strings" "time" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/caddyserver/caddy/v2/modules/caddyhttp/headers" "github.com/dustin/go-humanize" ) func init() { httpcaddyfile.RegisterHandlerDirective("reverse_proxy", parseCaddyfile) } func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) { rp := new(Handler) err := rp.UnmarshalCaddyfile(h.Dispenser) if err != nil { return nil, err } return rp, nil } // UnmarshalCaddyfile sets up the handler from Caddyfile tokens. Syntax: // // reverse_proxy [<matcher>] [<upstreams...>] { // # upstreams // to <upstreams...> // // # load balancing // lb_policy <name> [<options...>] // lb_try_duration <duration> // lb_try_interval <interval> // // # active health checking // health_path <path> // health_port <port> // health_interval <interval> // health_timeout <duration> // health_status <status> // health_body <regexp> // // # passive health checking // max_fails <num> // fail_duration <duration> // max_conns <num> // unhealthy_status <status> // unhealthy_latency <duration> // // # streaming // flush_interval <duration> // // # header manipulation // header_up [+|-]<field> [<value|regexp> [<replacement>]] // header_down [+|-]<field> [<value|regexp> [<replacement>]] // // # round trip // transport <name> { // ... // } // } // func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { for d.Next() { for _, up := range d.RemainingArgs() { h.Upstreams = append(h.Upstreams, &Upstream{Dial: up}) } for d.NextBlock(0) { switch d.Val() { case "to": args := d.RemainingArgs() if len(args) == 0 { return d.ArgErr() } for _, up := range args { h.Upstreams = append(h.Upstreams, &Upstream{Dial: up}) } case "lb_policy": if !d.NextArg() { return d.ArgErr() } if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil { return d.Err("load balancing selection policy already specified") } name := d.Val() mod, err := caddy.GetModule("http.reverse_proxy.selection_policies." + name) if err != nil { return d.Errf("getting load balancing policy module '%s': %v", mod, err) } unm, ok := mod.New().(caddyfile.Unmarshaler) if !ok { return d.Errf("load balancing policy module '%s' is not a Caddyfile unmarshaler", mod) } err = unm.UnmarshalCaddyfile(d.NewFromNextTokens()) if err != nil { return err } sel, ok := unm.(Selector) if !ok { return d.Errf("module %s is not a Selector", mod) } if h.LoadBalancing == nil { h.LoadBalancing = new(LoadBalancing) } h.LoadBalancing.SelectionPolicyRaw = caddyconfig.JSONModuleObject(sel, "policy", name, nil) case "lb_try_duration": if !d.NextArg() { return d.ArgErr() } if h.LoadBalancing == nil { h.LoadBalancing = new(LoadBalancing) } dur, err := time.ParseDuration(d.Val()) if err != nil { return d.Errf("bad duration value %s: %v", d.Val(), err) } h.LoadBalancing.TryDuration = caddy.Duration(dur) case "lb_try_interval": if !d.NextArg() { return d.ArgErr() } if h.LoadBalancing == nil { h.LoadBalancing = new(LoadBalancing) } dur, err := time.ParseDuration(d.Val()) if err != nil { return d.Errf("bad interval value '%s': %v", d.Val(), err) } h.LoadBalancing.TryInterval = caddy.Duration(dur) case "health_path": if !d.NextArg() { return d.ArgErr() } if h.HealthChecks == nil { h.HealthChecks = new(HealthChecks) } if h.HealthChecks.Active == nil { h.HealthChecks.Active = new(ActiveHealthChecks) } h.HealthChecks.Active.Path = d.Val() case "health_port": if !d.NextArg() { return d.ArgErr() } if h.HealthChecks == nil { h.HealthChecks = new(HealthChecks) } if h.HealthChecks.Active == nil { h.HealthChecks.Active = new(ActiveHealthChecks) } portNum, err := strconv.Atoi(d.Val()) if err != nil { return d.Errf("bad port number '%s': %v", d.Val(), err) } h.HealthChecks.Active.Port = portNum case "health_interval": if !d.NextArg() { return d.ArgErr() } if h.HealthChecks == nil { h.HealthChecks = new(HealthChecks) } if h.HealthChecks.Active == nil { h.HealthChecks.Active = new(ActiveHealthChecks) } dur, err := time.ParseDuration(d.Val()) if err != nil { return d.Errf("bad interval value %s: %v", d.Val(), err) } h.HealthChecks.Active.Interval = caddy.Duration(dur) case "health_timeout": if !d.NextArg() { return d.ArgErr() } if h.HealthChecks == nil { h.HealthChecks = new(HealthChecks) } if h.HealthChecks.Active == nil { h.HealthChecks.Active = new(ActiveHealthChecks) } dur, err := time.ParseDuration(d.Val()) if err != nil { return d.Errf("bad timeout value %s: %v", d.Val(), err) } h.HealthChecks.Active.Timeout = caddy.Duration(dur) case "health_status": if !d.NextArg() { return d.ArgErr() } if h.HealthChecks == nil { h.HealthChecks = new(HealthChecks) } if h.HealthChecks.Active == nil { h.HealthChecks.Active = new(ActiveHealthChecks) } val := d.Val() if len(val) == 3 && strings.HasSuffix(val, "xx") { val = val[:1] } statusNum, err := strconv.Atoi(val[:1]) if err != nil { return d.Errf("bad status value '%s': %v", d.Val(), err) } h.HealthChecks.Active.ExpectStatus = statusNum case "health_body": if !d.NextArg() { return d.ArgErr() } if h.HealthChecks == nil { h.HealthChecks = new(HealthChecks) } if h.HealthChecks.Active == nil { h.HealthChecks.Active = new(ActiveHealthChecks) } h.HealthChecks.Active.ExpectBody = d.Val() case "max_fails": if !d.NextArg() { return d.ArgErr() } if h.HealthChecks == nil { h.HealthChecks = new(HealthChecks) } if h.HealthChecks.Passive == nil { h.HealthChecks.Passive = new(PassiveHealthChecks) } maxFails, err := strconv.Atoi(d.Val()) if err != nil { return d.Errf("invalid maximum fail count '%s': %v", d.Val(), err) } h.HealthChecks.Passive.MaxFails = maxFails case "fail_duration": if !d.NextArg() { return d.ArgErr() } if h.HealthChecks == nil { h.HealthChecks = new(HealthChecks) } if h.HealthChecks.Passive == nil { h.HealthChecks.Passive = new(PassiveHealthChecks) } dur, err := time.ParseDuration(d.Val()) if err != nil { return d.Errf("bad duration value '%s': %v", d.Val(), err) } h.HealthChecks.Passive.FailDuration = caddy.Duration(dur) case "unhealthy_request_count": if !d.NextArg() { return d.ArgErr() } if h.HealthChecks == nil { h.HealthChecks = new(HealthChecks) } if h.HealthChecks.Passive == nil { h.HealthChecks.Passive = new(PassiveHealthChecks) } maxConns, err := strconv.Atoi(d.Val()) if err != nil { return d.Errf("invalid maximum connection count '%s': %v", d.Val(), err) } h.HealthChecks.Passive.UnhealthyRequestCount = maxConns case "unhealthy_status": args := d.RemainingArgs() if len(args) == 0 { return d.ArgErr() } if h.HealthChecks == nil { h.HealthChecks = new(HealthChecks) } if h.HealthChecks.Passive == nil { h.HealthChecks.Passive = new(PassiveHealthChecks) } for _, arg := range args { if len(arg) == 3 && strings.HasSuffix(arg, "xx") { arg = arg[:1] } statusNum, err := strconv.Atoi(arg[:1]) if err != nil { return d.Errf("bad status value '%s': %v", d.Val(), err) } h.HealthChecks.Passive.UnhealthyStatus = append(h.HealthChecks.Passive.UnhealthyStatus, statusNum) } case "unhealthy_latency": if !d.NextArg() { return d.ArgErr() } if h.HealthChecks == nil { h.HealthChecks = new(HealthChecks) } if h.HealthChecks.Passive == nil { h.HealthChecks.Passive = new(PassiveHealthChecks) } dur, err := time.ParseDuration(d.Val()) if err != nil { return d.Errf("bad duration value '%s': %v", d.Val(), err) } h.HealthChecks.Passive.UnhealthyLatency = caddy.Duration(dur) case "flush_interval": if !d.NextArg() { return d.ArgErr() } dur, err := time.ParseDuration(d.Val()) if err != nil { return d.Errf("bad duration value '%s': %v", d.Val(), err) } h.FlushInterval = caddy.Duration(dur) case "header_up": if h.Headers == nil { h.Headers = new(headers.Handler) } if h.Headers.Request == nil { h.Headers.Request = new(headers.HeaderOps) } args := d.RemainingArgs() switch len(args) { case 1: headers.CaddyfileHeaderOp(h.Headers.Request, args[0], "", "") case 2: headers.CaddyfileHeaderOp(h.Headers.Request, args[0], args[1], "") case 3: headers.CaddyfileHeaderOp(h.Headers.Request, args[0], args[1], args[2]) default: return d.ArgErr() } case "header_down": if h.Headers == nil { h.Headers = new(headers.Handler) } if h.Headers.Response == nil { h.Headers.Response = &headers.RespHeaderOps{ HeaderOps: new(headers.HeaderOps), } } args := d.RemainingArgs() switch len(args) { case 1: headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], "", "") case 2: headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], args[1], "") case 3: headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], args[1], args[2]) default: return d.ArgErr() } case "transport": if !d.NextArg() { return d.ArgErr() } if h.TransportRaw != nil { return d.Err("transport already specified") } name := d.Val() mod, err := caddy.GetModule("http.reverse_proxy.transport." + name) if err != nil { return d.Errf("getting transport module '%s': %v", mod, err) } unm, ok := mod.New().(caddyfile.Unmarshaler) if !ok { return d.Errf("transport module '%s' is not a Caddyfile unmarshaler", mod) } err = unm.UnmarshalCaddyfile(d.NewFromNextTokens()) if err != nil { return err } rt, ok := unm.(http.RoundTripper) if !ok { return d.Errf("module %s is not a RoundTripper", mod) } h.TransportRaw = caddyconfig.JSONModuleObject(rt, "protocol", name, nil) default: return d.Errf("unrecognized subdirective %s", d.Val()) } } } return nil } // UnmarshalCaddyfile deserializes Caddyfile tokens into h. // // transport http { // read_buffer <size> // write_buffer <size> // dial_timeout <duration> // tls_client_auth <cert_file> <key_file> // tls_insecure_skip_verify // tls_timeout <duration> // tls_trusted_ca_certs <cert_files...> // keepalive [off|<duration>] // keepalive_idle_conns <max_count> // } // func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { for d.Next() { for d.NextBlock(0) { switch d.Val() { case "read_buffer": if !d.NextArg() { return d.ArgErr() } size, err := humanize.ParseBytes(d.Val()) if err != nil { return d.Errf("invalid read buffer size '%s': %v", d.Val(), err) } h.ReadBufferSize = int(size) case "write_buffer": if !d.NextArg() { return d.ArgErr() } size, err := humanize.ParseBytes(d.Val()) if err != nil { return d.Errf("invalid write buffer size '%s': %v", d.Val(), err) } h.WriteBufferSize = int(size) case "dial_timeout": if !d.NextArg() { return d.ArgErr() } dur, err := time.ParseDuration(d.Val()) if err != nil { return d.Errf("bad timeout value '%s': %v", d.Val(), err) } h.DialTimeout = caddy.Duration(dur) case "tls_client_auth": args := d.RemainingArgs() if len(args) != 2 { return d.ArgErr() } if h.TLS == nil { h.TLS = new(TLSConfig) } h.TLS.ClientCertificateFile = args[0] h.TLS.ClientCertificateKeyFile = args[1] case "tls": if h.TLS == nil { h.TLS = new(TLSConfig) } case "tls_insecure_skip_verify": if d.NextArg() { return d.ArgErr() } if h.TLS == nil { h.TLS = new(TLSConfig) } h.TLS.InsecureSkipVerify = true case "tls_timeout": if !d.NextArg() { return d.ArgErr() } dur, err := time.ParseDuration(d.Val()) if err != nil { return d.Errf("bad timeout value '%s': %v", d.Val(), err) } if h.TLS == nil { h.TLS = new(TLSConfig) } h.TLS.HandshakeTimeout = caddy.Duration(dur) case "tls_trusted_ca_certs": args := d.RemainingArgs() if len(args) == 0 { return d.ArgErr() } if h.TLS == nil { h.TLS = new(TLSConfig) } h.TLS.RootCAPemFiles = args case "keepalive": if !d.NextArg() { return d.ArgErr() } if h.KeepAlive == nil { h.KeepAlive = new(KeepAlive) } if d.Val() == "off" { var disable bool h.KeepAlive.Enabled = &disable break } dur, err := time.ParseDuration(d.Val()) if err != nil { return d.Errf("bad duration value '%s': %v", d.Val(), err) } h.KeepAlive.IdleConnTimeout = caddy.Duration(dur) case "keepalive_idle_conns": if !d.NextArg() { return d.ArgErr() } num, err := strconv.Atoi(d.Val()) if err != nil { return d.Errf("bad integer value '%s': %v", d.Val(), err) } if h.KeepAlive == nil { h.KeepAlive = new(KeepAlive) } h.KeepAlive.MaxIdleConns = num h.KeepAlive.MaxIdleConnsPerHost = num default: return d.Errf("unrecognized subdirective %s", d.Val()) } } } return nil } // Interface guards var ( _ caddyfile.Unmarshaler = (*Handler)(nil) _ caddyfile.Unmarshaler = (*HTTPTransport)(nil) )