0
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2025-01-13 22:51:08 -05:00

update logic and refactoring

This commit is contained in:
Viacheslav Biriukov 2015-06-05 22:29:02 +03:00
parent b7af081949
commit aa37c1e05c

View file

@ -2,8 +2,10 @@ package sed
import (
"bytes"
"compress/flate"
"compress/gzip"
"errors"
"io"
"io/ioutil"
"net/http"
"strconv"
@ -12,7 +14,7 @@ import (
"github.com/mholt/caddy/middleware"
)
// bufferedWriter buffers responce for replacing content
// bufferedWriter buffers responce for replacing content.
type bufferedWriter struct {
http.ResponseWriter
Body *bytes.Buffer
@ -26,6 +28,9 @@ type bufferedWriter struct {
firstChunk bool
}
// NewBufferedWriter returns bufferedWriter.
// ct is a list of Content-Type's for buffering.
// size is a size of a buffer.
func NewBufferedWriter(w http.ResponseWriter, ct []string, size int) *bufferedWriter {
cts := make(map[string]bool)
for _, c := range ct {
@ -42,28 +47,41 @@ func NewBufferedWriter(w http.ResponseWriter, ct []string, size int) *bufferedWr
}
}
// checkContentType checks Content-Type for buffering.
func (bw *bufferedWriter) checkContentType() bool {
// buffering only ct's content types
ct := bw.ResponseWriter.Header().Get("Content-Type")
if ct != "" {
ct = strings.ToLower(ct)
ct = strings.SplitN(ct, ";", 2)[0]
if _, ok := bw.cts[ct]; ok {
bw.Buffered = true
bw.ContentType = ct
bw.ContentEncoding = bw.ResponseWriter.Header().Get("Content-Encoding")
return true
}
}
return false
}
// WriteHeader implements the WriteHeader method of http.ResponseWriter.
func (bw *bufferedWriter) WriteHeader(code int) {
bw.Code = code
bw.Buffered = true
if bw.firstChunk {
if !bw.checkContentType() {
bw.ResponseWriter.WriteHeader(code)
}
bw.firstChunk = false
}
}
// Write implements the write method of http.ResponseWriter.
func (bw *bufferedWriter) Write(buf []byte) (int, error) {
// do once
if bw.firstChunk {
// buffering only cts content types
ct := bw.ResponseWriter.Header().Get("Content-Type")
if ct != "" {
ct = strings.ToLower(ct)
ct = strings.SplitN(ct, ";", 2)[0]
if _, ok := bw.cts[ct]; ok {
bw.Buffered = true
bw.ContentType = ct
bw.ContentEncoding = bw.ResponseWriter.Header().Get("Content-Encoding")
}
}
bw.checkContentType()
bw.firstChunk = false
}
bw.firstChunk = false
// unbuffered write
if !bw.Buffered {
@ -77,11 +95,7 @@ func (bw *bufferedWriter) Write(buf []byte) (int, error) {
return n, err
}
// drop state
bw.Buffered = false
bw.n = 0
bw.ContentType = ""
bw.ContentEncoding = ""
bw.Body.Reset()
bw.reset()
// write new data
return bw.ResponseWriter.Write(buf)
}
@ -89,10 +103,21 @@ func (bw *bufferedWriter) Write(buf []byte) (int, error) {
n, err := bw.Body.Write(buf)
bw.n += n
return n, err
}
var ErrNotBuffered = errors.New("not buffered write")
var (
ErrNotBuffered = errors.New("not buffered write")
ErrUnknownEncoding = errors.New("unknown encoding")
)
func (bw *bufferedWriter) reset() {
// drop state
bw.Buffered = false
bw.n = 0
bw.ContentType = ""
bw.ContentEncoding = ""
bw.Body.Reset()
}
// Apply returns buffered data (might be nil: e.g. 304 responce) and error.
func (bw *bufferedWriter) Apply(r *http.Request) ([]byte, error) {
@ -103,31 +128,44 @@ func (bw *bufferedWriter) Apply(r *http.Request) ([]byte, error) {
// buffered
var body []byte
switch bw.ContentEncoding {
case "gzip":
if bw.ContentEncoding != "" {
if r.Header.Get("Accept-Encoding") == "" {
// gzip middleware has been already ungziped data
body = bw.Body.Bytes()
break
return body, nil
}
bw.Header().Del("Content-Encoding")
gzr, err := gzip.NewReader(bw.Body)
defer gzr.Close()
var ce io.ReadCloser
var err error
switch bw.ContentEncoding {
case "gzip":
ce, err = gzip.NewReader(bw.Body)
case "deflate":
ce = flate.NewReader(bw.Body)
default:
// Unknown Content-Encoding so write data and return.
bw.ResponseWriter.WriteHeader(bw.Code)
bw.ResponseWriter.Write(body)
bw.reset()
return nil, ErrUnknownEncoding
}
defer ce.Close()
if err != nil {
return nil, err
}
body, err = ioutil.ReadAll(gzr)
body, err = ioutil.ReadAll(ce)
if err != nil {
return nil, err
}
case "":
} else {
// no Content-Encoding
body = bw.Body.Bytes()
return body, nil
}
return body, nil
}
type Sed struct {
@ -139,8 +177,8 @@ func (s Sed) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, rule := range s.Rules {
if middleware.Path(r.URL.Path).Matches(rule.Url) {
// buffering write
cts := []string{"text/html"} // TODO (brk0v): right now only html pages
// Buffering write.
cts := []string{"text/html"} // TODO (brk0v): only html pages
size := 1 << 18 // default 256 KB
if rule.Size != 0 {
size = rule.Size
@ -150,23 +188,23 @@ func (s Sed) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
st, err := s.Next.ServeHTTP(bw, r)
body, bufErr := bw.Apply(r)
// not buffered
if bufErr == ErrNotBuffered {
// Not buffered or unknow encoding.
if bufErr == ErrNotBuffered || bufErr == ErrUnknownEncoding {
return st, err
}
// rest are errors
// Rest are errors.
if bufErr != nil {
return http.StatusInternalServerError, err
}
// send headers and return (304 and others)
// Send headers and return immediately (304 and others).
if body == nil {
w.WriteHeader(bw.Code)
return st, err
}
// replace data
// Replace data.
var oldnew []string
for _, pattern := range rule.Patterns {
oldnew = append(oldnew, pattern.Find)
@ -175,16 +213,15 @@ func (s Sed) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
replacer := strings.NewReplacer(oldnew...)
data := replacer.Replace(string(body))
// update Content-Length if we have Content-Length
// Update Content-Length if we have Content-Length.
if bw.Header().Get("Content-Length") != "" {
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
}
// send data
// Send data.
w.WriteHeader(bw.Code)
w.Write([]byte(data))
return st, err
}
}