mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-13 22:51:08 -05:00
update logic and refactoring
This commit is contained in:
parent
b7af081949
commit
aa37c1e05c
1 changed files with 76 additions and 39 deletions
|
@ -2,8 +2,10 @@ package sed
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
@ -12,7 +14,7 @@ import (
|
|||
"github.com/mholt/caddy/middleware"
|
||||
)
|
||||
|
||||
// bufferedWriter buffers responce for replacing content
|
||||
// bufferedWriter buffers responce for replacing content.
|
||||
type bufferedWriter struct {
|
||||
http.ResponseWriter
|
||||
Body *bytes.Buffer
|
||||
|
@ -26,6 +28,9 @@ type bufferedWriter struct {
|
|||
firstChunk bool
|
||||
}
|
||||
|
||||
// NewBufferedWriter returns bufferedWriter.
|
||||
// ct is a list of Content-Type's for buffering.
|
||||
// size is a size of a buffer.
|
||||
func NewBufferedWriter(w http.ResponseWriter, ct []string, size int) *bufferedWriter {
|
||||
cts := make(map[string]bool)
|
||||
for _, c := range ct {
|
||||
|
@ -42,28 +47,41 @@ func NewBufferedWriter(w http.ResponseWriter, ct []string, size int) *bufferedWr
|
|||
}
|
||||
}
|
||||
|
||||
// checkContentType checks Content-Type for buffering.
|
||||
func (bw *bufferedWriter) checkContentType() bool {
|
||||
// buffering only ct's content types
|
||||
ct := bw.ResponseWriter.Header().Get("Content-Type")
|
||||
if ct != "" {
|
||||
ct = strings.ToLower(ct)
|
||||
ct = strings.SplitN(ct, ";", 2)[0]
|
||||
if _, ok := bw.cts[ct]; ok {
|
||||
bw.Buffered = true
|
||||
bw.ContentType = ct
|
||||
bw.ContentEncoding = bw.ResponseWriter.Header().Get("Content-Encoding")
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WriteHeader implements the WriteHeader method of http.ResponseWriter.
|
||||
func (bw *bufferedWriter) WriteHeader(code int) {
|
||||
bw.Code = code
|
||||
bw.Buffered = true
|
||||
if bw.firstChunk {
|
||||
if !bw.checkContentType() {
|
||||
bw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
bw.firstChunk = false
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements the write method of http.ResponseWriter.
|
||||
func (bw *bufferedWriter) Write(buf []byte) (int, error) {
|
||||
// do once
|
||||
if bw.firstChunk {
|
||||
// buffering only cts content types
|
||||
ct := bw.ResponseWriter.Header().Get("Content-Type")
|
||||
if ct != "" {
|
||||
ct = strings.ToLower(ct)
|
||||
ct = strings.SplitN(ct, ";", 2)[0]
|
||||
if _, ok := bw.cts[ct]; ok {
|
||||
bw.Buffered = true
|
||||
bw.ContentType = ct
|
||||
bw.ContentEncoding = bw.ResponseWriter.Header().Get("Content-Encoding")
|
||||
}
|
||||
}
|
||||
bw.checkContentType()
|
||||
bw.firstChunk = false
|
||||
}
|
||||
bw.firstChunk = false
|
||||
|
||||
// unbuffered write
|
||||
if !bw.Buffered {
|
||||
|
@ -77,11 +95,7 @@ func (bw *bufferedWriter) Write(buf []byte) (int, error) {
|
|||
return n, err
|
||||
}
|
||||
// drop state
|
||||
bw.Buffered = false
|
||||
bw.n = 0
|
||||
bw.ContentType = ""
|
||||
bw.ContentEncoding = ""
|
||||
bw.Body.Reset()
|
||||
bw.reset()
|
||||
// write new data
|
||||
return bw.ResponseWriter.Write(buf)
|
||||
}
|
||||
|
@ -89,10 +103,21 @@ func (bw *bufferedWriter) Write(buf []byte) (int, error) {
|
|||
n, err := bw.Body.Write(buf)
|
||||
bw.n += n
|
||||
return n, err
|
||||
|
||||
}
|
||||
|
||||
var ErrNotBuffered = errors.New("not buffered write")
|
||||
var (
|
||||
ErrNotBuffered = errors.New("not buffered write")
|
||||
ErrUnknownEncoding = errors.New("unknown encoding")
|
||||
)
|
||||
|
||||
func (bw *bufferedWriter) reset() {
|
||||
// drop state
|
||||
bw.Buffered = false
|
||||
bw.n = 0
|
||||
bw.ContentType = ""
|
||||
bw.ContentEncoding = ""
|
||||
bw.Body.Reset()
|
||||
}
|
||||
|
||||
// Apply returns buffered data (might be nil: e.g. 304 responce) and error.
|
||||
func (bw *bufferedWriter) Apply(r *http.Request) ([]byte, error) {
|
||||
|
@ -103,31 +128,44 @@ func (bw *bufferedWriter) Apply(r *http.Request) ([]byte, error) {
|
|||
|
||||
// buffered
|
||||
var body []byte
|
||||
switch bw.ContentEncoding {
|
||||
case "gzip":
|
||||
if bw.ContentEncoding != "" {
|
||||
if r.Header.Get("Accept-Encoding") == "" {
|
||||
// gzip middleware has been already ungziped data
|
||||
body = bw.Body.Bytes()
|
||||
break
|
||||
return body, nil
|
||||
}
|
||||
|
||||
bw.Header().Del("Content-Encoding")
|
||||
gzr, err := gzip.NewReader(bw.Body)
|
||||
defer gzr.Close()
|
||||
var ce io.ReadCloser
|
||||
var err error
|
||||
switch bw.ContentEncoding {
|
||||
case "gzip":
|
||||
ce, err = gzip.NewReader(bw.Body)
|
||||
case "deflate":
|
||||
ce = flate.NewReader(bw.Body)
|
||||
default:
|
||||
// Unknown Content-Encoding so write data and return.
|
||||
bw.ResponseWriter.WriteHeader(bw.Code)
|
||||
bw.ResponseWriter.Write(body)
|
||||
bw.reset()
|
||||
return nil, ErrUnknownEncoding
|
||||
}
|
||||
|
||||
defer ce.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body, err = ioutil.ReadAll(gzr)
|
||||
body, err = ioutil.ReadAll(ce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "":
|
||||
} else {
|
||||
// no Content-Encoding
|
||||
body = bw.Body.Bytes()
|
||||
return body, nil
|
||||
}
|
||||
|
||||
return body, nil
|
||||
|
||||
}
|
||||
|
||||
type Sed struct {
|
||||
|
@ -139,8 +177,8 @@ func (s Sed) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|||
for _, rule := range s.Rules {
|
||||
if middleware.Path(r.URL.Path).Matches(rule.Url) {
|
||||
|
||||
// buffering write
|
||||
cts := []string{"text/html"} // TODO (brk0v): right now only html pages
|
||||
// Buffering write.
|
||||
cts := []string{"text/html"} // TODO (brk0v): only html pages
|
||||
size := 1 << 18 // default 256 KB
|
||||
if rule.Size != 0 {
|
||||
size = rule.Size
|
||||
|
@ -150,23 +188,23 @@ func (s Sed) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|||
st, err := s.Next.ServeHTTP(bw, r)
|
||||
body, bufErr := bw.Apply(r)
|
||||
|
||||
// not buffered
|
||||
if bufErr == ErrNotBuffered {
|
||||
// Not buffered or unknow encoding.
|
||||
if bufErr == ErrNotBuffered || bufErr == ErrUnknownEncoding {
|
||||
return st, err
|
||||
}
|
||||
|
||||
// rest are errors
|
||||
// Rest are errors.
|
||||
if bufErr != nil {
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
// send headers and return (304 and others)
|
||||
// Send headers and return immediately (304 and others).
|
||||
if body == nil {
|
||||
w.WriteHeader(bw.Code)
|
||||
return st, err
|
||||
}
|
||||
|
||||
// replace data
|
||||
// Replace data.
|
||||
var oldnew []string
|
||||
for _, pattern := range rule.Patterns {
|
||||
oldnew = append(oldnew, pattern.Find)
|
||||
|
@ -175,16 +213,15 @@ func (s Sed) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|||
replacer := strings.NewReplacer(oldnew...)
|
||||
data := replacer.Replace(string(body))
|
||||
|
||||
// update Content-Length if we have Content-Length
|
||||
// Update Content-Length if we have Content-Length.
|
||||
if bw.Header().Get("Content-Length") != "" {
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
||||
}
|
||||
|
||||
// send data
|
||||
// Send data.
|
||||
w.WriteHeader(bw.Code)
|
||||
w.Write([]byte(data))
|
||||
return st, err
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue