0
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2025-01-06 22:40:31 -05:00

proxy: Add, remove, or replace upstream and downstream headers (closes #666) (PR #788)

* Overwrite proxy headers based on directive

Headers of the request sent by the proxy upstream can now be modified in
the following way:

Prefix header with `+`: Header will be added if it doesn't exist
otherwise, the values will be merge
Prefix header with `-': Header will be removed
No prefix: Header will be replaced with given value

* Add missing formating directive reported by go vet

* Overwrite up/down stream proxy headers

Add Up/DownStreamHeaders to UpstreamHost

Split `proxy_header` option in `proxy` directive into `header_upstream`
and `header_downstream`. By splitting into two, it makes it clear in
what direction the given headers must be applied.

`proxy_header` can still be used (to maintain backward compatability)
but its assumed to be `header_upstream`

Response headers received by the reverse proxy from the upstream host
are updated according the `header_downstream` rules.
The update occurs through a func given to the reverse proxy, which is
applied once a response is received.

Headers (for upstream and downstream) can now be modified in
the following way:

Prefix header with `+`: Header will be added if it doesn't exist
otherwise, the values will be merge
Prefix header with `-': Header will be removed
No prefix: Header will be replaced with given value

Updated branch with changes from master

* minor refactor to make intent clearer

* Make Up/Down stream headers naming consistent

* Fix error descriptions to be more clear

* Fix lint issue
This commit is contained in:
William Bezuidenhout 2016-04-30 21:41:30 +02:00 committed by Matt Holt
parent 96425f0f40
commit e2234497b7
4 changed files with 234 additions and 31 deletions

View file

@ -43,7 +43,8 @@ type UpstreamHost struct {
Fails int32
FailTimeout time.Duration
Unhealthy bool
ExtraHeaders http.Header
UpstreamHeaders http.Header
DownstreamHeaders http.Header
CheckDown UpstreamHostDownFunc
WithoutPathPrefix string
MaxConns int64
@ -99,26 +100,33 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
}
outreq.Host = host.Name
if host.ExtraHeaders != nil {
extraHeaders := make(http.Header)
if host.UpstreamHeaders != nil {
if replacer == nil {
rHost := r.Host
replacer = middleware.NewReplacer(r, nil, "")
outreq.Host = rHost
}
for header, values := range host.ExtraHeaders {
for _, value := range values {
extraHeaders.Add(header, replacer.Replace(value))
if header == "Host" {
outreq.Host = replacer.Replace(value)
}
}
if v, ok := host.UpstreamHeaders["Host"]; ok {
r.Host = replacer.Replace(v[len(v)-1])
}
for k, v := range extraHeaders {
// Modify headers for request that will be sent to the upstream host
upHeaders := createHeadersByRules(host.UpstreamHeaders, r.Header, replacer)
for k, v := range upHeaders {
outreq.Header[k] = v
}
}
var downHeaderUpdateFn respUpdateFn
if host.DownstreamHeaders != nil {
if replacer == nil {
rHost := r.Host
replacer = middleware.NewReplacer(r, nil, "")
outreq.Host = rHost
}
//Creates a function that is used to update headers the response received by the reverse proxy
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
}
proxy := host.ReverseProxy
if baseURL, err := url.Parse(host.Name); err == nil {
r.Host = baseURL.Host
@ -130,7 +138,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
}
atomic.AddInt64(&host.Conns, 1)
backendErr := proxy.ServeHTTP(w, outreq)
backendErr := proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
atomic.AddInt64(&host.Conns, -1)
if backendErr == nil {
return 0, nil
@ -182,3 +190,48 @@ func createUpstreamRequest(r *http.Request) *http.Request {
return outreq
}
func createRespHeaderUpdateFn(rules http.Header, replacer middleware.Replacer) respUpdateFn {
return func(resp *http.Response) {
newHeaders := createHeadersByRules(rules, resp.Header, replacer)
for h, v := range newHeaders {
resp.Header[h] = v
}
}
}
func createHeadersByRules(rules http.Header, base http.Header, repl middleware.Replacer) http.Header {
newHeaders := make(http.Header)
for header, values := range rules {
if strings.HasPrefix(header, "+") {
header = strings.TrimLeft(header, "+")
add(newHeaders, header, base[header])
applyEach(values, repl.Replace)
add(newHeaders, header, values)
} else if strings.HasPrefix(header, "-") {
base.Del(strings.TrimLeft(header, "-"))
} else if _, ok := base[header]; ok {
applyEach(values, repl.Replace)
for _, v := range values {
newHeaders.Set(header, v)
}
} else {
applyEach(values, repl.Replace)
add(newHeaders, header, values)
add(newHeaders, header, base[header])
}
}
return newHeaders
}
func applyEach(values []string, mapFn func(string) string) {
for i, v := range values {
values[i] = mapFn(v)
}
}
func add(base http.Header, header string, values []string) {
for _, v := range values {
base.Add(header, v)
}
}

View file

@ -348,6 +348,141 @@ func TestUnixSocketProxyPaths(t *testing.T) {
}
}
func TestUpstreamHeadersUpdate(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
var actualHeaders http.Header
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, client"))
actualHeaders = r.Header
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream.host.UpstreamHeaders = http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"},
"+Merge-Me": {"Merge-Value"},
"+Add-Me": {"Add-Value"},
"-Remove-Me": {""},
"Replace-Me": {"{hostname}"},
}
// set up proxy
p := &Proxy{
Upstreams: []Upstream{upstream},
}
// create request and response recorder
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
w := httptest.NewRecorder()
//add initial headers
r.Header.Add("Merge-Me", "Initial")
r.Header.Add("Remove-Me", "Remove-Value")
r.Header.Add("Replace-Me", "Replace-Value")
p.ServeHTTP(w, r)
replacer := middleware.NewReplacer(r, nil, "")
headerKey := "Merge-Me"
values, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Request sent to upstream backend does not contain expected %v header. Expected header to be added", headerKey)
} else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) {
t.Errorf("Values for proxy header `+Merge-Me` should be merged. Got %v", values)
}
headerKey = "Add-Me"
if _, ok := actualHeaders[headerKey]; !ok {
t.Errorf("Request sent to upstream backend does not contain expected %v header", headerKey)
}
headerKey = "Remove-Me"
if _, ok := actualHeaders[headerKey]; ok {
t.Errorf("Request sent to upstream backend should not contain %v header", headerKey)
}
headerKey = "Replace-Me"
headerValue := replacer.Replace("{hostname}")
value, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Request sent to upstream backend should not remove %v header", headerKey)
} else if len(value) > 0 && headerValue != value[0] {
t.Errorf("Request sent to upstream backend should replace value of %v header with %v. Instead value was %v", headerKey, headerValue, value)
}
}
func TestDownstreamHeadersUpdate(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Merge-Me", "Initial")
w.Header().Add("Remove-Me", "Remove-Value")
w.Header().Add("Replace-Me", "Replace-Value")
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream.host.DownstreamHeaders = http.Header{
"+Merge-Me": {"Merge-Value"},
"+Add-Me": {"Add-Value"},
"-Remove-Me": {""},
"Replace-Me": {"{hostname}"},
}
// set up proxy
p := &Proxy{
Upstreams: []Upstream{upstream},
}
// create request and response recorder
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
replacer := middleware.NewReplacer(r, nil, "")
actualHeaders := w.Header()
headerKey := "Merge-Me"
values, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Downstream response does not contain expected %v header. Expected header should be added", headerKey)
} else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) {
t.Errorf("Values for header `+Merge-Me` should be merged. Got %v", values)
}
headerKey = "Add-Me"
if _, ok := actualHeaders[headerKey]; !ok {
t.Errorf("Downstream response does not contain expected %v header", headerKey)
}
headerKey = "Remove-Me"
if _, ok := actualHeaders[headerKey]; ok {
t.Errorf("Downstream response should not contain %v header received from upstream", headerKey)
}
headerKey = "Replace-Me"
headerValue := replacer.Replace("{hostname}")
value, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Downstream response should contain %v header and not remove it", headerKey)
} else if len(value) > 0 && headerValue != value[0] {
t.Errorf("Downstream response should have header %v with value %v. Instead value was %v", headerKey, headerValue, value)
}
}
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
uri, _ := url.Parse(name)
u := &fakeUpstream{
@ -410,7 +545,7 @@ func (u *fakeWsUpstream) Select() *UpstreamHost {
return &UpstreamHost{
Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without),
ExtraHeaders: http.Header{
UpstreamHeaders: http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"}},
}

View file

@ -154,7 +154,9 @@ var InsecureTransport http.RoundTripper = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request) error {
type respUpdateFn func(resp *http.Response)
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
transport := p.Transport
if transport == nil {
transport = http.DefaultTransport
@ -169,6 +171,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request) e
res, err := transport.RoundTrip(outreq)
if err != nil {
return err
} else if respUpdateFn != nil {
respUpdateFn(res)
}
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {

View file

@ -20,7 +20,8 @@ var (
type staticUpstream struct {
from string
proxyHeaders http.Header
upstreamHeaders http.Header
downstreamHeaders http.Header
Hosts HostPool
Policy Policy
insecureSkipVerify bool
@ -42,13 +43,14 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
var upstreams []Upstream
for c.Next() {
upstream := &staticUpstream{
from: "",
proxyHeaders: make(http.Header),
Hosts: nil,
Policy: &Random{},
FailTimeout: 10 * time.Second,
MaxFails: 1,
MaxConns: 0,
from: "",
upstreamHeaders: make(http.Header),
downstreamHeaders: make(http.Header),
Hosts: nil,
Policy: &Random{},
FailTimeout: 10 * time.Second,
MaxFails: 1,
MaxConns: 0,
}
if !c.Args(&upstream.from) {
@ -97,12 +99,13 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
host = "http://" + host
}
uh := &UpstreamHost{
Name: host,
Conns: 0,
Fails: 0,
FailTimeout: u.FailTimeout,
Unhealthy: false,
ExtraHeaders: u.proxyHeaders,
Name: host,
Conns: 0,
Fails: 0,
FailTimeout: u.FailTimeout,
Unhealthy: false,
UpstreamHeaders: u.upstreamHeaders,
DownstreamHeaders: u.downstreamHeaders,
CheckDown: func(u *staticUpstream) UpstreamHostDownFunc {
return func(uh *UpstreamHost) bool {
if uh.Unhealthy {
@ -182,15 +185,23 @@ func parseBlock(c *parse.Dispenser, u *staticUpstream) error {
}
u.HealthCheck.Interval = dur
}
case "header_upstream":
fallthrough
case "proxy_header":
var header, value string
if !c.Args(&header, &value) {
return c.ArgErr()
}
u.proxyHeaders.Add(header, value)
u.upstreamHeaders.Add(header, value)
case "header_downstream":
var header, value string
if !c.Args(&header, &value) {
return c.ArgErr()
}
u.downstreamHeaders.Add(header, value)
case "websocket":
u.proxyHeaders.Add("Connection", "{>Connection}")
u.proxyHeaders.Add("Upgrade", "{>Upgrade}")
u.upstreamHeaders.Add("Connection", "{>Connection}")
u.upstreamHeaders.Add("Upgrade", "{>Upgrade}")
case "without":
if !c.NextArg() {
return c.ArgErr()