1
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2024-12-16 21:56:40 -05:00

caddyhttp: Make use of http.ResponseController (#5654)

* caddyhttp: Make use of http.ResponseController

Also syncs the reverseproxy implementation with stdlib's which now uses ResponseController as well 2449bbb5e6

* Enable full-duplex for HTTP/1.1

* Appease linter

* Add warning for builds with Go 1.20, so it's less surprising to users

* Improved godoc for EnableFullDuplex, copied text from stdlib

* Only wrap in encode if not already wrapped
This commit is contained in:
Francis Lavoie 2023-08-02 16:03:26 -04:00 committed by GitHub
parent e198c605bd
commit cd486c25d1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 167 additions and 95 deletions

View file

@ -41,6 +41,7 @@ type serverOptions struct {
IdleTimeout caddy.Duration
KeepAliveInterval caddy.Duration
MaxHeaderBytes int
EnableFullDuplex bool
Protocols []string
StrictSNIHost *bool
TrustedProxiesRaw json.RawMessage
@ -157,6 +158,12 @@ func unmarshalCaddyfileServerOptions(d *caddyfile.Dispenser) (any, error) {
}
serverOpts.MaxHeaderBytes = int(size)
case "enable_full_duplex":
if d.NextArg() {
return nil, d.ArgErr()
}
serverOpts.EnableFullDuplex = true
case "log_credentials":
if d.NextArg() {
return nil, d.ArgErr()
@ -327,6 +334,7 @@ func applyServerOptions(
server.IdleTimeout = opts.IdleTimeout
server.KeepAliveInterval = opts.KeepAliveInterval
server.MaxHeaderBytes = opts.MaxHeaderBytes
server.EnableFullDuplex = opts.EnableFullDuplex
server.Protocols = opts.Protocols
server.StrictSNIHost = opts.StrictSNIHost
server.TrustedProxiesRaw = opts.TrustedProxiesRaw

View file

@ -11,6 +11,7 @@
idle 30s
}
max_header_size 100MB
enable_full_duplex
log_credentials
protocols h1 h2 h2c h3
strict_sni_host
@ -45,6 +46,7 @@ foo.com {
"write_timeout": 30000000000,
"idle_timeout": 30000000000,
"max_header_bytes": 100000000,
"enable_full_duplex": true,
"routes": [
{
"match": [

View file

@ -176,9 +176,7 @@ func testH2ToH2CStreamServeH2C(t *testing.T) *http.Server {
w.Header().Set("Cache-Control", "no-store")
w.WriteHeader(200)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
http.NewResponseController(w).Flush()
buf := make([]byte, 4*1024)

View file

@ -20,7 +20,9 @@ import (
"fmt"
"net"
"net/http"
"runtime"
"strconv"
"strings"
"sync"
"time"
@ -325,9 +327,15 @@ func (app *App) Provision(ctx caddy.Context) error {
// Validate ensures the app's configuration is valid.
func (app *App) Validate() error {
isGo120 := strings.Contains(runtime.Version(), "go1.20")
// each server must use distinct listener addresses
lnAddrs := make(map[string]string)
for srvName, srv := range app.Servers {
if isGo120 && srv.EnableFullDuplex {
app.logger.Warn("enable_full_duplex is not supported in Go 1.20, use a build made with Go 1.21 or later", zap.String("server", srvName))
}
for _, addr := range srv.Listen {
listenAddr, err := caddy.ParseNetworkAddress(addr)
if err != nil {

View file

@ -0,0 +1,25 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !go1.21
package caddyhttp
import (
"net/http"
)
func enableFullDuplex(w http.ResponseWriter) {
// Do nothing, Go 1.20 and earlier do not support full duplex
}

View file

@ -0,0 +1,25 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build go1.21
package caddyhttp
import (
"net/http"
)
func enableFullDuplex(w http.ResponseWriter) {
http.NewResponseController(w).EnableFullDuplex()
}

View file

@ -167,10 +167,10 @@ func (enc *Encode) openResponseWriter(encodingName string, w http.ResponseWriter
// initResponseWriter initializes the responseWriter instance
// allocated in openResponseWriter, enabling mid-stack inlining.
func (enc *Encode) initResponseWriter(rw *responseWriter, encodingName string, wrappedRW http.ResponseWriter) *responseWriter {
if httpInterfaces, ok := wrappedRW.(caddyhttp.HTTPInterfaces); ok {
rw.HTTPInterfaces = httpInterfaces
if rww, ok := wrappedRW.(*caddyhttp.ResponseWriterWrapper); ok {
rw.ResponseWriter = rww
} else {
rw.HTTPInterfaces = &caddyhttp.ResponseWriterWrapper{ResponseWriter: wrappedRW}
rw.ResponseWriter = &caddyhttp.ResponseWriterWrapper{ResponseWriter: wrappedRW}
}
rw.encodingName = encodingName
rw.config = enc
@ -182,7 +182,7 @@ func (enc *Encode) initResponseWriter(rw *responseWriter, encodingName string, w
// using the encoding represented by encodingName and
// configured by config.
type responseWriter struct {
caddyhttp.HTTPInterfaces
http.ResponseWriter
encodingName string
w Encoder
config *Encode
@ -211,7 +211,8 @@ func (rw *responseWriter) Flush() {
// to rw.Write (see bug in #4314)
return
}
rw.HTTPInterfaces.Flush()
//nolint:bodyclose
http.NewResponseController(rw).Flush()
}
// Hijack implements http.Hijacker. It will flush status code if set. We don't track actual hijacked
@ -219,11 +220,12 @@ func (rw *responseWriter) Flush() {
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if !rw.wroteHeader {
if rw.statusCode != 0 {
rw.HTTPInterfaces.WriteHeader(rw.statusCode)
rw.ResponseWriter.WriteHeader(rw.statusCode)
}
rw.wroteHeader = true
}
return rw.HTTPInterfaces.Hijack()
//nolint:bodyclose
return http.NewResponseController(rw).Hijack()
}
// Write writes to the response. If the response qualifies,
@ -260,7 +262,7 @@ func (rw *responseWriter) Write(p []byte) (int, error) {
// by the standard library
if !rw.wroteHeader {
if rw.statusCode != 0 {
rw.HTTPInterfaces.WriteHeader(rw.statusCode)
rw.ResponseWriter.WriteHeader(rw.statusCode)
}
rw.wroteHeader = true
}
@ -268,7 +270,7 @@ func (rw *responseWriter) Write(p []byte) (int, error) {
if rw.w != nil {
return rw.w.Write(p)
} else {
return rw.HTTPInterfaces.Write(p)
return rw.ResponseWriter.Write(p)
}
}
@ -284,7 +286,7 @@ func (rw *responseWriter) Close() error {
// issue #5059, don't write status code if not set explicitly.
if rw.statusCode != 0 {
rw.HTTPInterfaces.WriteHeader(rw.statusCode)
rw.ResponseWriter.WriteHeader(rw.statusCode)
}
rw.wroteHeader = true
}
@ -301,7 +303,7 @@ func (rw *responseWriter) Close() error {
// Unwrap returns the underlying ResponseWriter.
func (rw *responseWriter) Unwrap() http.ResponseWriter {
return rw.HTTPInterfaces
return rw.ResponseWriter
}
// init should be called before we write a response, if rw.buf has contents.
@ -310,7 +312,7 @@ func (rw *responseWriter) init() {
rw.config.Match(rw) {
rw.w = rw.config.writerPools[rw.encodingName].Get().(Encoder)
rw.w.Reset(rw.HTTPInterfaces)
rw.w.Reset(rw.ResponseWriter)
rw.Header().Del("Content-Length") // https://github.com/golang/go/issues/14975
rw.Header().Set("Content-Encoding", rw.encodingName)
rw.Header().Add("Vary", "Accept-Encoding")
@ -429,5 +431,4 @@ var (
_ caddy.Provisioner = (*Encode)(nil)
_ caddy.Validator = (*Encode)(nil)
_ caddyhttp.MiddlewareHandler = (*Encode)(nil)
_ caddyhttp.HTTPInterfaces = (*responseWriter)(nil)
)

View file

@ -371,5 +371,5 @@ func (rww *responseWriterWrapper) Write(d []byte) (int, error) {
var (
_ caddy.Provisioner = (*Handler)(nil)
_ caddyhttp.MiddlewareHandler = (*Handler)(nil)
_ caddyhttp.HTTPInterfaces = (*responseWriterWrapper)(nil)
_ http.ResponseWriter = (*responseWriterWrapper)(nil)
)

View file

@ -251,5 +251,6 @@ const pushedLink = "http.handlers.push.pushed_link"
var (
_ caddy.Provisioner = (*Handler)(nil)
_ caddyhttp.MiddlewareHandler = (*Handler)(nil)
_ caddyhttp.HTTPInterfaces = (*linkPusher)(nil)
_ http.ResponseWriter = (*linkPusher)(nil)
_ http.Pusher = (*linkPusher)(nil)
)

View file

@ -24,34 +24,14 @@ import (
)
// ResponseWriterWrapper wraps an underlying ResponseWriter and
// promotes its Pusher/Flusher/Hijacker methods as well. To use
// this type, embed a pointer to it within your own struct type
// that implements the http.ResponseWriter interface, then call
// methods on the embedded value. You can make sure your type
// wraps correctly by asserting that it implements the
// HTTPInterfaces interface.
// promotes its Pusher method as well. To use this type, embed
// a pointer to it within your own struct type that implements
// the http.ResponseWriter interface, then call methods on the
// embedded value.
type ResponseWriterWrapper struct {
http.ResponseWriter
}
// Hijack implements http.Hijacker. It simply calls the underlying
// ResponseWriter's Hijack method if there is one, or returns
// ErrNotImplemented otherwise.
func (rww *ResponseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := rww.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, ErrNotImplemented
}
// Flush implements http.Flusher. It simply calls the underlying
// ResponseWriter's Flush method if there is one.
func (rww *ResponseWriterWrapper) Flush() {
if f, ok := rww.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// Push implements http.Pusher. It simply calls the underlying
// ResponseWriter's Push method if there is one, or returns
// ErrNotImplemented otherwise.
@ -62,29 +42,18 @@ func (rww *ResponseWriterWrapper) Push(target string, opts *http.PushOptions) er
return ErrNotImplemented
}
// ReadFrom implements io.ReaderFrom. It simply calls the underlying
// ResponseWriter's ReadFrom method if there is one, otherwise it defaults
// to io.Copy.
// ReadFrom implements io.ReaderFrom. It simply calls io.Copy,
// which uses io.ReaderFrom if available.
func (rww *ResponseWriterWrapper) ReadFrom(r io.Reader) (n int64, err error) {
if rf, ok := rww.ResponseWriter.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
return io.Copy(rww.ResponseWriter, r)
}
// Unwrap returns the underlying ResponseWriter.
// Unwrap returns the underlying ResponseWriter, necessary for
// http.ResponseController to work correctly.
func (rww *ResponseWriterWrapper) Unwrap() http.ResponseWriter {
return rww.ResponseWriter
}
// HTTPInterfaces mix all the interfaces that middleware ResponseWriters need to support.
type HTTPInterfaces interface {
http.ResponseWriter
http.Pusher
http.Flusher
http.Hijacker
}
// ErrNotImplemented is returned when an underlying
// ResponseWriter does not implement the required method.
var ErrNotImplemented = fmt.Errorf("method not implemented")
@ -262,7 +231,8 @@ func (rr *responseRecorder) WriteResponse() error {
}
func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
conn, brw, err := rr.ResponseWriterWrapper.Hijack()
//nolint:bodyclose
conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack()
if err != nil {
return nil, nil, err
}
@ -294,7 +264,7 @@ func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) {
// responses instead of writing them to the client. See
// docs for NewResponseRecorder for proper usage.
type ResponseRecorder interface {
HTTPInterfaces
http.ResponseWriter
Status() int
Buffer() *bytes.Buffer
Buffered() bool
@ -309,12 +279,13 @@ type ShouldBufferFunc func(status int, header http.Header) bool
// Interface guards
var (
_ HTTPInterfaces = (*ResponseWriterWrapper)(nil)
_ ResponseRecorder = (*responseRecorder)(nil)
_ http.ResponseWriter = (*ResponseWriterWrapper)(nil)
_ ResponseRecorder = (*responseRecorder)(nil)
// Implementing ReaderFrom can be such a significant
// optimization that it should probably be required!
// see PR #5022 (25%-50% speedup)
_ io.ReaderFrom = (*ResponseWriterWrapper)(nil)
_ io.ReaderFrom = (*responseRecorder)(nil)
_ io.ReaderFrom = (*hijackedConn)(nil)
)

View file

@ -962,9 +962,8 @@ func (h *Handler) finalizeResponse(
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
if fl, ok := rw.(http.Flusher); ok {
fl.Flush()
}
//nolint:bodyclose
http.NewResponseController(rw).Flush()
}
// total duration spent proxying, including writing response body

View file

@ -20,6 +20,7 @@ package reverseproxy
import (
"context"
"errors"
"fmt"
"io"
weakrand "math/rand"
@ -51,17 +52,19 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
return
}
hj, ok := rw.(http.Hijacker)
if !ok {
logger.Error("can't switch protocols using non-Hijacker ResponseWriter", zap.String("type", fmt.Sprintf("%T", rw)))
return
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
logger.Error("internal error: 101 switching protocols response with non-writable body")
return
}
//nolint:bodyclose
conn, brw, hijackErr := http.NewResponseController(rw).Hijack()
if errors.Is(hijackErr, http.ErrNotSupported) {
h.logger.Sugar().Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)
return
}
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
backConnCloseCh := make(chan struct{})
go func() {
@ -81,9 +84,8 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
rw.WriteHeader(res.StatusCode)
logger.Debug("upgrading connection")
conn, brw, err := hj.Hijack()
if err != nil {
logger.Error("hijack failed on protocol switch", zap.Error(err))
if hijackErr != nil {
h.logger.Error("hijack failed on protocol switch", zap.Error(hijackErr))
return
}
@ -181,26 +183,28 @@ func (h Handler) isBidirectionalStream(req *http.Request, res *http.Response) bo
(ae == "identity" || ae == "")
}
func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
func (h Handler) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error {
var w io.Writer = dst
if flushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
latency: flushInterval,
}
defer mlw.stop()
// set up initial timer so headers get flushed even if body writes are delayed
mlw.flushPending = true
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
dst = mlw
mlw := &maxLatencyWriter{
dst: dst,
//nolint:bodyclose
flush: http.NewResponseController(dst).Flush,
latency: flushInterval,
}
defer mlw.stop()
// set up initial timer so headers get flushed even if body writes are delayed
mlw.flushPending = true
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
w = mlw
}
buf := streamingBufPool.Get().(*[]byte)
defer streamingBufPool.Put(buf)
_, err := h.copyBuffer(dst, src, *buf)
_, err := h.copyBuffer(w, src, *buf)
return err
}
@ -439,13 +443,9 @@ type openConnection struct {
gracefulClose func() error
}
type writeFlusher interface {
io.Writer
http.Flusher
}
type maxLatencyWriter struct {
dst writeFlusher
dst io.Writer
flush func() error
latency time.Duration // non-zero; negative means to flush immediately
mu sync.Mutex // protects t, flushPending, and dst.Flush
@ -458,7 +458,8 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
defer m.mu.Unlock()
n, err = m.dst.Write(p)
if m.latency < 0 {
m.dst.Flush()
//nolint:errcheck
m.flush()
return
}
if m.flushPending {
@ -479,7 +480,8 @@ func (m *maxLatencyWriter) delayedFlush() {
if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
return
}
m.dst.Flush()
//nolint:errcheck
m.flush()
m.flushPending = false
}

View file

@ -2,6 +2,7 @@ package reverseproxy
import (
"bytes"
"net/http/httptest"
"strings"
"testing"
)
@ -13,12 +14,15 @@ func TestHandlerCopyResponse(t *testing.T) {
strings.Repeat("a", defaultBufferSize),
strings.Repeat("123456789 123456789 123456789 12", 3000),
}
dst := bytes.NewBuffer(nil)
recorder := httptest.NewRecorder()
recorder.Body = dst
for _, d := range testdata {
src := bytes.NewBuffer([]byte(d))
dst.Reset()
err := h.copyResponse(dst, src, 0)
err := h.copyResponse(recorder, src, 0)
if err != nil {
t.Errorf("failed with error: %v", err)
}

View file

@ -82,6 +82,26 @@ type Server struct {
// HTTP request headers.
MaxHeaderBytes int `json:"max_header_bytes,omitempty"`
// Enable full-duplex communication for HTTP/1 requests.
// Only has an effect if Caddy was built with Go 1.21 or later.
//
// For HTTP/1 requests, the Go HTTP server by default consumes any
// unread portion of the request body before beginning to write the
// response, preventing handlers from concurrently reading from the
// request and writing the response. Enabling this option disables
// this behavior and permits handlers to continue to read from the
// request while concurrently writing the response.
//
// For HTTP/2 requests, the Go HTTP server always permits concurrent
// reads and responses, so this option has no effect.
//
// Test thoroughly with your HTTP clients, as some older clients may
// not support full-duplex HTTP/1 which can cause them to deadlock.
// See https://github.com/golang/go/issues/57786 for more info.
//
// TODO: This is an EXPERIMENTAL feature. Subject to change or removal.
EnableFullDuplex bool `json:"enable_full_duplex,omitempty"`
// Routes describes how this server will handle requests.
// Routes are executed sequentially. First a route's matchers
// are evaluated, then its grouping. If it matches and has
@ -264,6 +284,14 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
repl := caddy.NewReplacer()
r = PrepareRequest(r, repl, w, s)
// enable full-duplex for HTTP/1, ensuring the entire
// request body gets consumed before writing the response
if s.EnableFullDuplex {
// TODO: Remove duplex_go12*.go abstraction once our
// minimum Go version is 1.21 or later
enableFullDuplex(w)
}
// encode the request for logging purposes before
// it enters any handler chain; this is necessary
// to capture the original request in case it gets