mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-27 23:03:37 -05:00
update logic and refactoring
This commit is contained in:
parent
b7af081949
commit
aa37c1e05c
1 changed files with 76 additions and 39 deletions
|
@ -2,8 +2,10 @@ package sed
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"compress/flate"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -12,7 +14,7 @@ import (
|
||||||
"github.com/mholt/caddy/middleware"
|
"github.com/mholt/caddy/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
// bufferedWriter buffers responce for replacing content
|
// bufferedWriter buffers responce for replacing content.
|
||||||
type bufferedWriter struct {
|
type bufferedWriter struct {
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
Body *bytes.Buffer
|
Body *bytes.Buffer
|
||||||
|
@ -26,6 +28,9 @@ type bufferedWriter struct {
|
||||||
firstChunk bool
|
firstChunk bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewBufferedWriter returns bufferedWriter.
|
||||||
|
// ct is a list of Content-Type's for buffering.
|
||||||
|
// size is a size of a buffer.
|
||||||
func NewBufferedWriter(w http.ResponseWriter, ct []string, size int) *bufferedWriter {
|
func NewBufferedWriter(w http.ResponseWriter, ct []string, size int) *bufferedWriter {
|
||||||
cts := make(map[string]bool)
|
cts := make(map[string]bool)
|
||||||
for _, c := range ct {
|
for _, c := range ct {
|
||||||
|
@ -42,28 +47,41 @@ func NewBufferedWriter(w http.ResponseWriter, ct []string, size int) *bufferedWr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkContentType checks Content-Type for buffering.
|
||||||
|
func (bw *bufferedWriter) checkContentType() bool {
|
||||||
|
// buffering only ct's content types
|
||||||
|
ct := bw.ResponseWriter.Header().Get("Content-Type")
|
||||||
|
if ct != "" {
|
||||||
|
ct = strings.ToLower(ct)
|
||||||
|
ct = strings.SplitN(ct, ";", 2)[0]
|
||||||
|
if _, ok := bw.cts[ct]; ok {
|
||||||
|
bw.Buffered = true
|
||||||
|
bw.ContentType = ct
|
||||||
|
bw.ContentEncoding = bw.ResponseWriter.Header().Get("Content-Encoding")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// WriteHeader implements the WriteHeader method of http.ResponseWriter.
|
// WriteHeader implements the WriteHeader method of http.ResponseWriter.
|
||||||
func (bw *bufferedWriter) WriteHeader(code int) {
|
func (bw *bufferedWriter) WriteHeader(code int) {
|
||||||
bw.Code = code
|
bw.Code = code
|
||||||
bw.Buffered = true
|
if bw.firstChunk {
|
||||||
|
if !bw.checkContentType() {
|
||||||
|
bw.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
bw.firstChunk = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write implements the write method of http.ResponseWriter.
|
// Write implements the write method of http.ResponseWriter.
|
||||||
func (bw *bufferedWriter) Write(buf []byte) (int, error) {
|
func (bw *bufferedWriter) Write(buf []byte) (int, error) {
|
||||||
|
// do once
|
||||||
if bw.firstChunk {
|
if bw.firstChunk {
|
||||||
// buffering only cts content types
|
bw.checkContentType()
|
||||||
ct := bw.ResponseWriter.Header().Get("Content-Type")
|
bw.firstChunk = false
|
||||||
if ct != "" {
|
|
||||||
ct = strings.ToLower(ct)
|
|
||||||
ct = strings.SplitN(ct, ";", 2)[0]
|
|
||||||
if _, ok := bw.cts[ct]; ok {
|
|
||||||
bw.Buffered = true
|
|
||||||
bw.ContentType = ct
|
|
||||||
bw.ContentEncoding = bw.ResponseWriter.Header().Get("Content-Encoding")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
bw.firstChunk = false
|
|
||||||
|
|
||||||
// unbuffered write
|
// unbuffered write
|
||||||
if !bw.Buffered {
|
if !bw.Buffered {
|
||||||
|
@ -77,11 +95,7 @@ func (bw *bufferedWriter) Write(buf []byte) (int, error) {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
// drop state
|
// drop state
|
||||||
bw.Buffered = false
|
bw.reset()
|
||||||
bw.n = 0
|
|
||||||
bw.ContentType = ""
|
|
||||||
bw.ContentEncoding = ""
|
|
||||||
bw.Body.Reset()
|
|
||||||
// write new data
|
// write new data
|
||||||
return bw.ResponseWriter.Write(buf)
|
return bw.ResponseWriter.Write(buf)
|
||||||
}
|
}
|
||||||
|
@ -89,10 +103,21 @@ func (bw *bufferedWriter) Write(buf []byte) (int, error) {
|
||||||
n, err := bw.Body.Write(buf)
|
n, err := bw.Body.Write(buf)
|
||||||
bw.n += n
|
bw.n += n
|
||||||
return n, err
|
return n, err
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrNotBuffered = errors.New("not buffered write")
|
var (
|
||||||
|
ErrNotBuffered = errors.New("not buffered write")
|
||||||
|
ErrUnknownEncoding = errors.New("unknown encoding")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (bw *bufferedWriter) reset() {
|
||||||
|
// drop state
|
||||||
|
bw.Buffered = false
|
||||||
|
bw.n = 0
|
||||||
|
bw.ContentType = ""
|
||||||
|
bw.ContentEncoding = ""
|
||||||
|
bw.Body.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
// Apply returns buffered data (might be nil: e.g. 304 responce) and error.
|
// Apply returns buffered data (might be nil: e.g. 304 responce) and error.
|
||||||
func (bw *bufferedWriter) Apply(r *http.Request) ([]byte, error) {
|
func (bw *bufferedWriter) Apply(r *http.Request) ([]byte, error) {
|
||||||
|
@ -103,31 +128,44 @@ func (bw *bufferedWriter) Apply(r *http.Request) ([]byte, error) {
|
||||||
|
|
||||||
// buffered
|
// buffered
|
||||||
var body []byte
|
var body []byte
|
||||||
switch bw.ContentEncoding {
|
if bw.ContentEncoding != "" {
|
||||||
case "gzip":
|
|
||||||
if r.Header.Get("Accept-Encoding") == "" {
|
if r.Header.Get("Accept-Encoding") == "" {
|
||||||
// gzip middleware has been already ungziped data
|
// gzip middleware has been already ungziped data
|
||||||
body = bw.Body.Bytes()
|
body = bw.Body.Bytes()
|
||||||
break
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
bw.Header().Del("Content-Encoding")
|
bw.Header().Del("Content-Encoding")
|
||||||
gzr, err := gzip.NewReader(bw.Body)
|
var ce io.ReadCloser
|
||||||
defer gzr.Close()
|
var err error
|
||||||
|
switch bw.ContentEncoding {
|
||||||
|
case "gzip":
|
||||||
|
ce, err = gzip.NewReader(bw.Body)
|
||||||
|
case "deflate":
|
||||||
|
ce = flate.NewReader(bw.Body)
|
||||||
|
default:
|
||||||
|
// Unknown Content-Encoding so write data and return.
|
||||||
|
bw.ResponseWriter.WriteHeader(bw.Code)
|
||||||
|
bw.ResponseWriter.Write(body)
|
||||||
|
bw.reset()
|
||||||
|
return nil, ErrUnknownEncoding
|
||||||
|
}
|
||||||
|
|
||||||
|
defer ce.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
body, err = ioutil.ReadAll(gzr)
|
body, err = ioutil.ReadAll(ce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
case "":
|
} else {
|
||||||
// no Content-Encoding
|
// no Content-Encoding
|
||||||
body = bw.Body.Bytes()
|
body = bw.Body.Bytes()
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return body, nil
|
return body, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Sed struct {
|
type Sed struct {
|
||||||
|
@ -139,8 +177,8 @@ func (s Sed) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
for _, rule := range s.Rules {
|
for _, rule := range s.Rules {
|
||||||
if middleware.Path(r.URL.Path).Matches(rule.Url) {
|
if middleware.Path(r.URL.Path).Matches(rule.Url) {
|
||||||
|
|
||||||
// buffering write
|
// Buffering write.
|
||||||
cts := []string{"text/html"} // TODO (brk0v): right now only html pages
|
cts := []string{"text/html"} // TODO (brk0v): only html pages
|
||||||
size := 1 << 18 // default 256 KB
|
size := 1 << 18 // default 256 KB
|
||||||
if rule.Size != 0 {
|
if rule.Size != 0 {
|
||||||
size = rule.Size
|
size = rule.Size
|
||||||
|
@ -150,23 +188,23 @@ func (s Sed) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
st, err := s.Next.ServeHTTP(bw, r)
|
st, err := s.Next.ServeHTTP(bw, r)
|
||||||
body, bufErr := bw.Apply(r)
|
body, bufErr := bw.Apply(r)
|
||||||
|
|
||||||
// not buffered
|
// Not buffered or unknow encoding.
|
||||||
if bufErr == ErrNotBuffered {
|
if bufErr == ErrNotBuffered || bufErr == ErrUnknownEncoding {
|
||||||
return st, err
|
return st, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// rest are errors
|
// Rest are errors.
|
||||||
if bufErr != nil {
|
if bufErr != nil {
|
||||||
return http.StatusInternalServerError, err
|
return http.StatusInternalServerError, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// send headers and return (304 and others)
|
// Send headers and return immediately (304 and others).
|
||||||
if body == nil {
|
if body == nil {
|
||||||
w.WriteHeader(bw.Code)
|
w.WriteHeader(bw.Code)
|
||||||
return st, err
|
return st, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// replace data
|
// Replace data.
|
||||||
var oldnew []string
|
var oldnew []string
|
||||||
for _, pattern := range rule.Patterns {
|
for _, pattern := range rule.Patterns {
|
||||||
oldnew = append(oldnew, pattern.Find)
|
oldnew = append(oldnew, pattern.Find)
|
||||||
|
@ -175,16 +213,15 @@ func (s Sed) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
replacer := strings.NewReplacer(oldnew...)
|
replacer := strings.NewReplacer(oldnew...)
|
||||||
data := replacer.Replace(string(body))
|
data := replacer.Replace(string(body))
|
||||||
|
|
||||||
// update Content-Length if we have Content-Length
|
// Update Content-Length if we have Content-Length.
|
||||||
if bw.Header().Get("Content-Length") != "" {
|
if bw.Header().Get("Content-Length") != "" {
|
||||||
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// send data
|
// Send data.
|
||||||
w.WriteHeader(bw.Code)
|
w.WriteHeader(bw.Code)
|
||||||
w.Write([]byte(data))
|
w.Write([]byte(data))
|
||||||
return st, err
|
return st, err
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue