diff --git a/middleware/extensionless/extensionless.go b/middleware/extensionless/extensionless.go index ef629d9d..0be4c1c2 100644 --- a/middleware/extensionless/extensionless.go +++ b/middleware/extensionless/extensionless.go @@ -7,11 +7,19 @@ package extensionless import ( "net/http" "os" - "strings" + "path" "github.com/mholt/caddy/middleware" ) +// Extensionless is an http.Handler that can assume an extension from clean URLs. +// It tries extensions in the order listed in Extensions. +type Extensionless struct { + Next http.HandlerFunc + Root string + Extensions []string +} + // New creates a new instance of middleware that assumes extensions // so the site can use cleaner, extensionless URLs func New(c middleware.Controller) (middleware.Middleware, error) { @@ -31,17 +39,9 @@ func New(c middleware.Controller) (middleware.Middleware, error) { }, nil } -// Extensionless is an http.Handler that can assume an extension from clean URLs. -// It tries extensions in the order listed in Extensions. -type Extensionless struct { - Next http.HandlerFunc - Extensions []string - Root string -} - // ServeHTTP implements the http.Handler interface. func (e Extensionless) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if !hasExt(r) { + if path.Ext(r.URL.Path) == "" { for _, ext := range e.Extensions { if resourceExists(e.Root, r.URL.Path+ext) { r.URL.Path = r.URL.Path + ext @@ -65,9 +65,7 @@ func parse(c middleware.Controller) ([]string, error) { extensions = append(extensions, c.Val()) // Tack on any other extensions that may have been listed - for c.NextArg() { - extensions = append(extensions, c.Val()) - } + extensions = append(extensions, c.RemainingArgs()...) } return extensions, nil @@ -81,15 +79,3 @@ func resourceExists(root, path string) bool { // but we don't handle any other kinds of errors anyway return err == nil } - -// hasExt returns true if the HTTP request r has an extension, -// false otherwise. -func hasExt(r *http.Request) bool { - if r.URL.Path[len(r.URL.Path)-1] == '/' { - // directory - return true - } - lastSep := strings.LastIndex(r.URL.Path, "/") - lastDot := strings.LastIndex(r.URL.Path, ".") - return lastDot > lastSep -}