0
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2024-12-30 22:34:15 -05:00

caddyhttp: record num. bytes read when response writer is hijacked (#6173)

* record the number of bytes read when response writer is hijacked

* record body size when not nil
This commit is contained in:
WeidiDeng 2024-04-17 23:00:37 +08:00 committed by GitHub
parent 70953e873a
commit e0daa39cd3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 42 additions and 0 deletions

View file

@ -66,6 +66,8 @@ type responseRecorder struct {
size int
wroteHeader bool
stream bool
readSize *int
}
// NewResponseRecorder returns a new ResponseRecorder that can be
@ -240,6 +242,12 @@ func (rr *responseRecorder) FlushError() error {
return nil
}
// Private interface so it can only be used in this package
// #TODO: maybe export it later
func (rr *responseRecorder) setReadSize(size *int) {
rr.readSize = size
}
func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
//nolint:bodyclose
conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack()
@ -249,6 +257,15 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not
conn = &hijackedConn{conn, rr}
brw.Writer.Reset(conn)
buffered := brw.Reader.Buffered()
if buffered != 0 {
conn.(*hijackedConn).updateReadSize(buffered)
data, _ := brw.Peek(buffered)
brw.Reader.Reset(io.MultiReader(bytes.NewReader(data), conn))
} else {
brw.Reader.Reset(conn)
}
return conn, brw, nil
}
@ -258,6 +275,24 @@ type hijackedConn struct {
rr *responseRecorder
}
func (hc *hijackedConn) updateReadSize(n int) {
if hc.rr.readSize != nil {
*hc.rr.readSize += n
}
}
func (hc *hijackedConn) Read(p []byte) (int, error) {
n, err := hc.Conn.Read(p)
hc.updateReadSize(n)
return n, err
}
func (hc *hijackedConn) WriteTo(w io.Writer) (int64, error) {
n, err := io.Copy(w, hc.Conn)
hc.updateReadSize(int(n))
return n, err
}
func (hc *hijackedConn) Write(p []byte) (int, error) {
n, err := hc.Conn.Write(p)
hc.rr.size += n
@ -298,4 +333,6 @@ var (
_ io.ReaderFrom = (*ResponseWriterWrapper)(nil)
_ io.ReaderFrom = (*responseRecorder)(nil)
_ io.ReaderFrom = (*hijackedConn)(nil)
_ io.WriterTo = (*hijackedConn)(nil)
)

View file

@ -326,6 +326,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Body != nil {
bodyReader = &lengthReader{Source: r.Body}
r.Body = bodyReader
// should always be true, private interface can only be referenced in the same package
if setReadSizer, ok := wrec.(interface{ setReadSize(*int) }); ok {
setReadSizer.setReadSize(&bodyReader.Length)
}
}
// capture the original version of the request