diff --git a/data.go b/data.go index c8ed164..cfc3b99 100644 --- a/data.go +++ b/data.go @@ -36,15 +36,18 @@ const ( optScaleUp = "scaleUp" ) -// urlError reports a malformed URL error. -type urlError struct { +type requestError struct { message string - url *url.URL + status int } -func (e urlError) Error() string { - return fmt.Sprintf("malformed URL %q: %s", e.url, e.message) +func (e requestError) Error() string { return e.message } +func (e requestError) StatusCode() int { return e.status } + +func urlError(msg string, u *url.URL) error { + return requestError{fmt.Sprintf("malformed URL %q: %s", u, msg), http.StatusBadRequest} } +func permissionError(msg string) error { return requestError{msg, http.StatusForbidden} } // Options specifies transformations to be performed on the requested image. type Options struct { @@ -270,13 +273,13 @@ func NewRequest(r *http.Request, baseURL *url.URL) (*Request, error) { // first segment should be options parts := strings.SplitN(path, "/", 2) if len(parts) != 2 { - return nil, urlError{"too few path segments", r.URL} + return nil, urlError("too few path segments", r.URL) } var err error req.URL, err = parseURL(parts[1]) if err != nil { - return nil, urlError{fmt.Sprintf("unable to parse remote URL: %v", err), r.URL} + return nil, urlError(fmt.Sprintf("unable to parse remote URL: %v", err), r.URL) } req.Options = ParseOptions(parts[0]) @@ -287,11 +290,11 @@ func NewRequest(r *http.Request, baseURL *url.URL) (*Request, error) { } if !req.URL.IsAbs() { - return nil, urlError{"must provide absolute remote URL", r.URL} + return nil, urlError("must provide absolute remote URL", r.URL) } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { - return nil, urlError{"remote URL must have http or https scheme", r.URL} + return nil, urlError("remote URL must have http or https scheme", r.URL) } // query string is always part of the remote URL diff --git a/imageproxy.go b/imageproxy.go index 850fd32..c18840c 100644 --- a/imageproxy.go +++ b/imageproxy.go @@ -115,9 +115,7 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (p *Proxy) serveImage(w http.ResponseWriter, r *http.Request) { req, err := NewRequest(r, p.DefaultBaseURL) if err != nil { - msg := fmt.Sprintf("invalid request URL: %v", err) - glog.Error(msg) - http.Error(w, msg, http.StatusBadRequest) + p.writeError(w, err) return } @@ -125,16 +123,13 @@ func (p *Proxy) serveImage(w http.ResponseWriter, r *http.Request) { req.Options.ScaleUp = p.ScaleUp if err := p.allowed(req); err != nil { - glog.Error(err) - http.Error(w, err.Error(), http.StatusForbidden) + p.writeError(w, err) return } resp, err := p.Client.Get(req.String()) if err != nil { - msg := fmt.Sprintf("error fetching remote image: %v", err) - glog.Error(msg) - http.Error(w, msg, http.StatusInternalServerError) + p.writeError(w, err) return } defer resp.Body.Close() @@ -154,6 +149,20 @@ func (p *Proxy) serveImage(w http.ResponseWriter, r *http.Request) { io.Copy(w, resp.Body) } +// writerError writes err to the http response. +func (p *Proxy) writeError(w http.ResponseWriter, err error) { + type statusCoder interface { + StatusCode() int + } + + glog.Error(err) + code := http.StatusBadGateway + if err, ok := err.(statusCoder); ok { + code = err.StatusCode() + } + http.Error(w, err.Error(), code) +} + // copyHeader copies header values from src to dst, adding to any existing // values with the same header name. If keys is not empty, only those header // keys will be copied. @@ -176,7 +185,7 @@ func copyHeader(dst, src http.Header, keys ...string) { // allowed. func (p *Proxy) allowed(r *Request) error { if len(p.Referrers) > 0 && !validReferrer(p.Referrers, r.Original) { - return fmt.Errorf("request does not contain an allowed referrer: %v", r) + return permissionError(fmt.Sprintf("request does not contain an allowed referrer: %v", r)) } if len(p.Whitelist) == 0 && len(p.SignatureKey) == 0 { @@ -191,7 +200,7 @@ func (p *Proxy) allowed(r *Request) error { return nil } - return fmt.Errorf("request does not contain an allowed host or valid signature: %v", r) + return permissionError(fmt.Sprintf("request does not contain an allowed host or valid signature: %v", r)) } // validHost returns whether the host in u matches one of hosts.