0
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2025-01-27 23:03:37 -05:00

update logic and refactoring

This commit is contained in:
Viacheslav Biriukov 2015-06-05 22:29:02 +03:00
parent b7af081949
commit aa37c1e05c

View file

@ -2,8 +2,10 @@ package sed
import ( import (
"bytes" "bytes"
"compress/flate"
"compress/gzip" "compress/gzip"
"errors" "errors"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strconv" "strconv"
@ -12,7 +14,7 @@ import (
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
) )
// bufferedWriter buffers responce for replacing content // bufferedWriter buffers responce for replacing content.
type bufferedWriter struct { type bufferedWriter struct {
http.ResponseWriter http.ResponseWriter
Body *bytes.Buffer Body *bytes.Buffer
@ -26,6 +28,9 @@ type bufferedWriter struct {
firstChunk bool firstChunk bool
} }
// NewBufferedWriter returns bufferedWriter.
// ct is a list of Content-Type's for buffering.
// size is a size of a buffer.
func NewBufferedWriter(w http.ResponseWriter, ct []string, size int) *bufferedWriter { func NewBufferedWriter(w http.ResponseWriter, ct []string, size int) *bufferedWriter {
cts := make(map[string]bool) cts := make(map[string]bool)
for _, c := range ct { for _, c := range ct {
@ -42,28 +47,41 @@ func NewBufferedWriter(w http.ResponseWriter, ct []string, size int) *bufferedWr
} }
} }
// checkContentType checks Content-Type for buffering.
func (bw *bufferedWriter) checkContentType() bool {
// buffering only ct's content types
ct := bw.ResponseWriter.Header().Get("Content-Type")
if ct != "" {
ct = strings.ToLower(ct)
ct = strings.SplitN(ct, ";", 2)[0]
if _, ok := bw.cts[ct]; ok {
bw.Buffered = true
bw.ContentType = ct
bw.ContentEncoding = bw.ResponseWriter.Header().Get("Content-Encoding")
return true
}
}
return false
}
// WriteHeader implements the WriteHeader method of http.ResponseWriter. // WriteHeader implements the WriteHeader method of http.ResponseWriter.
func (bw *bufferedWriter) WriteHeader(code int) { func (bw *bufferedWriter) WriteHeader(code int) {
bw.Code = code bw.Code = code
bw.Buffered = true if bw.firstChunk {
if !bw.checkContentType() {
bw.ResponseWriter.WriteHeader(code)
}
bw.firstChunk = false
}
} }
// Write implements the write method of http.ResponseWriter. // Write implements the write method of http.ResponseWriter.
func (bw *bufferedWriter) Write(buf []byte) (int, error) { func (bw *bufferedWriter) Write(buf []byte) (int, error) {
// do once
if bw.firstChunk { if bw.firstChunk {
// buffering only cts content types bw.checkContentType()
ct := bw.ResponseWriter.Header().Get("Content-Type") bw.firstChunk = false
if ct != "" {
ct = strings.ToLower(ct)
ct = strings.SplitN(ct, ";", 2)[0]
if _, ok := bw.cts[ct]; ok {
bw.Buffered = true
bw.ContentType = ct
bw.ContentEncoding = bw.ResponseWriter.Header().Get("Content-Encoding")
}
}
} }
bw.firstChunk = false
// unbuffered write // unbuffered write
if !bw.Buffered { if !bw.Buffered {
@ -77,11 +95,7 @@ func (bw *bufferedWriter) Write(buf []byte) (int, error) {
return n, err return n, err
} }
// drop state // drop state
bw.Buffered = false bw.reset()
bw.n = 0
bw.ContentType = ""
bw.ContentEncoding = ""
bw.Body.Reset()
// write new data // write new data
return bw.ResponseWriter.Write(buf) return bw.ResponseWriter.Write(buf)
} }
@ -89,10 +103,21 @@ func (bw *bufferedWriter) Write(buf []byte) (int, error) {
n, err := bw.Body.Write(buf) n, err := bw.Body.Write(buf)
bw.n += n bw.n += n
return n, err return n, err
} }
var ErrNotBuffered = errors.New("not buffered write") var (
ErrNotBuffered = errors.New("not buffered write")
ErrUnknownEncoding = errors.New("unknown encoding")
)
func (bw *bufferedWriter) reset() {
// drop state
bw.Buffered = false
bw.n = 0
bw.ContentType = ""
bw.ContentEncoding = ""
bw.Body.Reset()
}
// Apply returns buffered data (might be nil: e.g. 304 responce) and error. // Apply returns buffered data (might be nil: e.g. 304 responce) and error.
func (bw *bufferedWriter) Apply(r *http.Request) ([]byte, error) { func (bw *bufferedWriter) Apply(r *http.Request) ([]byte, error) {
@ -103,31 +128,44 @@ func (bw *bufferedWriter) Apply(r *http.Request) ([]byte, error) {
// buffered // buffered
var body []byte var body []byte
switch bw.ContentEncoding { if bw.ContentEncoding != "" {
case "gzip":
if r.Header.Get("Accept-Encoding") == "" { if r.Header.Get("Accept-Encoding") == "" {
// gzip middleware has been already ungziped data // gzip middleware has been already ungziped data
body = bw.Body.Bytes() body = bw.Body.Bytes()
break return body, nil
} }
bw.Header().Del("Content-Encoding") bw.Header().Del("Content-Encoding")
gzr, err := gzip.NewReader(bw.Body) var ce io.ReadCloser
defer gzr.Close() var err error
switch bw.ContentEncoding {
case "gzip":
ce, err = gzip.NewReader(bw.Body)
case "deflate":
ce = flate.NewReader(bw.Body)
default:
// Unknown Content-Encoding so write data and return.
bw.ResponseWriter.WriteHeader(bw.Code)
bw.ResponseWriter.Write(body)
bw.reset()
return nil, ErrUnknownEncoding
}
defer ce.Close()
if err != nil { if err != nil {
return nil, err return nil, err
} }
body, err = ioutil.ReadAll(gzr) body, err = ioutil.ReadAll(ce)
if err != nil { if err != nil {
return nil, err return nil, err
} }
case "": } else {
// no Content-Encoding // no Content-Encoding
body = bw.Body.Bytes() body = bw.Body.Bytes()
return body, nil return body, nil
} }
return body, nil return body, nil
} }
type Sed struct { type Sed struct {
@ -139,8 +177,8 @@ func (s Sed) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, rule := range s.Rules { for _, rule := range s.Rules {
if middleware.Path(r.URL.Path).Matches(rule.Url) { if middleware.Path(r.URL.Path).Matches(rule.Url) {
// buffering write // Buffering write.
cts := []string{"text/html"} // TODO (brk0v): right now only html pages cts := []string{"text/html"} // TODO (brk0v): only html pages
size := 1 << 18 // default 256 KB size := 1 << 18 // default 256 KB
if rule.Size != 0 { if rule.Size != 0 {
size = rule.Size size = rule.Size
@ -150,23 +188,23 @@ func (s Sed) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
st, err := s.Next.ServeHTTP(bw, r) st, err := s.Next.ServeHTTP(bw, r)
body, bufErr := bw.Apply(r) body, bufErr := bw.Apply(r)
// not buffered // Not buffered or unknow encoding.
if bufErr == ErrNotBuffered { if bufErr == ErrNotBuffered || bufErr == ErrUnknownEncoding {
return st, err return st, err
} }
// rest are errors // Rest are errors.
if bufErr != nil { if bufErr != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
} }
// send headers and return (304 and others) // Send headers and return immediately (304 and others).
if body == nil { if body == nil {
w.WriteHeader(bw.Code) w.WriteHeader(bw.Code)
return st, err return st, err
} }
// replace data // Replace data.
var oldnew []string var oldnew []string
for _, pattern := range rule.Patterns { for _, pattern := range rule.Patterns {
oldnew = append(oldnew, pattern.Find) oldnew = append(oldnew, pattern.Find)
@ -175,16 +213,15 @@ func (s Sed) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
replacer := strings.NewReplacer(oldnew...) replacer := strings.NewReplacer(oldnew...)
data := replacer.Replace(string(body)) data := replacer.Replace(string(body))
// update Content-Length if we have Content-Length // Update Content-Length if we have Content-Length.
if bw.Header().Get("Content-Length") != "" { if bw.Header().Get("Content-Length") != "" {
w.Header().Set("Content-Length", strconv.Itoa(len(data))) w.Header().Set("Content-Length", strconv.Itoa(len(data)))
} }
// send data // Send data.
w.WriteHeader(bw.Code) w.WriteHeader(bw.Code)
w.Write([]byte(data)) w.Write([]byte(data))
return st, err return st, err
} }
} }