diff --git a/appveyor.yml b/appveyor.yml index a436177c9..eddfcaa7f 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -9,7 +9,6 @@ environment: install: - go get golang.org/x/tools/cmd/vet - - echo %PATH% - echo %GOPATH% - go version - go env diff --git a/caddy/config_test.go b/caddy/config_test.go index 477da2071..3df0b5cc9 100644 --- a/caddy/config_test.go +++ b/caddy/config_test.go @@ -1,11 +1,27 @@ package caddy import ( + "reflect" + "sync" "testing" "github.com/mholt/caddy/server" ) +func TestNewDefault(t *testing.T) { + config := NewDefault() + + if actual, expected := config.Root, DefaultRoot; actual != expected { + t.Errorf("Root was %s but expected %s", actual, expected) + } + if actual, expected := config.Host, DefaultHost; actual != expected { + t.Errorf("Host was %s but expected %s", actual, expected) + } + if actual, expected := config.Port, DefaultPort; actual != expected { + t.Errorf("Port was %s but expected %s", actual, expected) + } +} + func TestResolveAddr(t *testing.T) { // NOTE: If tests fail due to comparing to string "127.0.0.1", // it's possible that system env resolves with IPv6, or ::1. @@ -62,3 +78,61 @@ func TestResolveAddr(t *testing.T) { } } } + +func TestMakeOnces(t *testing.T) { + directives := []directive{ + {"dummy", nil}, + {"dummy2", nil}, + } + directiveOrder = directives + onces := makeOnces() + if len(onces) != len(directives) { + t.Errorf("onces had len %d , expected %d", len(onces), len(directives)) + } + expected := map[string]*sync.Once{ + "dummy": new(sync.Once), + "dummy2": new(sync.Once), + } + if !reflect.DeepEqual(onces, expected) { + t.Errorf("onces was %v, expected %v", onces, expected) + } +} + +func TestMakeStorages(t *testing.T) { + directives := []directive{ + {"dummy", nil}, + {"dummy2", nil}, + } + directiveOrder = directives + storages := makeStorages() + if len(storages) != len(directives) { + t.Errorf("storages had len %d , expected %d", len(storages), len(directives)) + } + expected := map[string]interface{}{ + "dummy": nil, + "dummy2": nil, + } + if !reflect.DeepEqual(storages, expected) { + t.Errorf("storages was %v, expected %v", storages, expected) + } +} + +func TestValidDirective(t *testing.T) { + directives := []directive{ + {"dummy", nil}, + {"dummy2", nil}, + } + directiveOrder = directives + for i, test := range []struct { + directive string + valid bool + }{ + {"dummy", true}, + {"dummy2", true}, + {"dummy3", false}, + } { + if actual, expected := validDirective(test.directive), test.valid; actual != expected { + t.Errorf("Test %d: valid was %t, expected %t", i, actual, expected) + } + } +} diff --git a/caddy/setup/controller.go b/caddy/setup/controller.go index 02b366cd8..e31207263 100644 --- a/caddy/setup/controller.go +++ b/caddy/setup/controller.go @@ -58,6 +58,9 @@ func NewTestController(input string) *Controller { Root: ".", }, Dispenser: parse.NewDispenser("Testfile", strings.NewReader(input)), + OncePerServerBlock: func(f func() error) error { + return f() + }, } } diff --git a/caddy/setup/startupshutdown_test.go b/caddy/setup/startupshutdown_test.go new file mode 100644 index 000000000..cf07a7e8c --- /dev/null +++ b/caddy/setup/startupshutdown_test.go @@ -0,0 +1,58 @@ +package setup + +import ( + "os" + "os/exec" + "path/filepath" + "strconv" + "testing" + "time" +) + +// The Startup function's tests are symmetrical to Shutdown tests, +// because the Startup and Shutdown functions share virtually the +// same functionality +func TestStartup(t *testing.T) { + + tempDirPath, err := getTempDirPath() + if err != nil { + t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err) + } + + testDir := filepath.Join(tempDirPath, "temp_dir_for_testing_startupshutdown.go") + osSenitiveTestDir := filepath.FromSlash(testDir) + + exec.Command("rm", "-r", osSenitiveTestDir).Run() // removes osSenitiveTestDir from the OS's temp directory, if the osSenitiveTestDir already exists + + tests := []struct { + input string + shouldExecutionErr bool + shouldRemoveErr bool + }{ + // test case #0 tests proper functionality blocking commands + {"startup mkdir " + osSenitiveTestDir, false, false}, + + // test case #1 tests proper functionality of non-blocking commands + {"startup mkdir " + osSenitiveTestDir + " &", false, true}, + + // test case #2 tests handling of non-existant commands + {"startup " + strconv.Itoa(int(time.Now().UnixNano())), true, true}, + } + + for i, test := range tests { + c := NewTestController(test.input) + _, err = Startup(c) + if err != nil { + t.Errorf("Expected no errors, got: %v", err) + } + err = c.Startup[0]() + if err != nil && !test.shouldExecutionErr { + t.Errorf("Test %d recieved an error of:\n%v", i, err) + } + err = os.Remove(osSenitiveTestDir) + if err != nil && !test.shouldRemoveErr { + t.Errorf("Test %d recieved an error of:\n%v", i, err) + } + + } +} diff --git a/caddy/setup/templates.go b/caddy/setup/templates.go index 51d78d5ce..f8d7e98bd 100644 --- a/caddy/setup/templates.go +++ b/caddy/setup/templates.go @@ -32,18 +32,48 @@ func templatesParse(c *Controller) ([]templates.Rule, error) { for c.Next() { var rule templates.Rule - if c.NextArg() { + rule.Path = defaultTemplatePath + rule.Extensions = defaultTemplateExtensions + + args := c.RemainingArgs() + + switch len(args) { + case 0: + // Optional block + for c.NextBlock() { + switch c.Val() { + case "path": + args := c.RemainingArgs() + if len(args) != 1 { + return nil, c.ArgErr() + } + rule.Path = args[0] + + case "ext": + args := c.RemainingArgs() + if len(args) == 0 { + return nil, c.ArgErr() + } + rule.Extensions = args + + case "between": + args := c.RemainingArgs() + if len(args) != 2 { + return nil, c.ArgErr() + } + rule.Delims[0] = args[0] + rule.Delims[1] = args[1] + } + } + default: // First argument would be the path - rule.Path = c.Val() + rule.Path = args[0] // Any remaining arguments are extensions - rule.Extensions = c.RemainingArgs() + rule.Extensions = args[1:] if len(rule.Extensions) == 0 { rule.Extensions = defaultTemplateExtensions } - } else { - rule.Path = defaultTemplatePath - rule.Extensions = defaultTemplateExtensions } for _, ext := range rule.Extensions { @@ -52,7 +82,6 @@ func templatesParse(c *Controller) ([]templates.Rule, error) { rules = append(rules, rule) } - return rules, nil } diff --git a/caddy/setup/templates_test.go b/caddy/setup/templates_test.go index e97ab3799..b1cfb29ce 100644 --- a/caddy/setup/templates_test.go +++ b/caddy/setup/templates_test.go @@ -2,8 +2,9 @@ package setup import ( "fmt" - "github.com/mholt/caddy/middleware/templates" "testing" + + "github.com/mholt/caddy/middleware/templates" ) func TestTemplates(t *testing.T) { @@ -40,7 +41,11 @@ func TestTemplates(t *testing.T) { if fmt.Sprint(myHandler.Rules[0].IndexFiles) != fmt.Sprint(indexFiles) { t.Errorf("Expected %v to be the Default Index files", indexFiles) } + if myHandler.Rules[0].Delims != [2]string{} { + t.Errorf("Expected %v to be the Default Delims", [2]string{}) + } } + func TestTemplatesParse(t *testing.T) { tests := []struct { inputTemplateConfig string @@ -50,19 +55,32 @@ func TestTemplatesParse(t *testing.T) { {`templates /api1`, false, []templates.Rule{{ Path: "/api1", Extensions: defaultTemplateExtensions, + Delims: [2]string{}, }}}, {`templates /api2 .txt .htm`, false, []templates.Rule{{ Path: "/api2", Extensions: []string{".txt", ".htm"}, + Delims: [2]string{}, }}}, - {`templates /api3 .htm .html + {`templates /api3 .htm .html templates /api4 .txt .tpl `, false, []templates.Rule{{ Path: "/api3", Extensions: []string{".htm", ".html"}, + Delims: [2]string{}, }, { Path: "/api4", Extensions: []string{".txt", ".tpl"}, + Delims: [2]string{}, + }}}, + {`templates { + path /api5 + ext .html + between {% %} + }`, false, []templates.Rule{{ + Path: "/api5", + Extensions: []string{".html"}, + Delims: [2]string{"{%", "%}"}, }}}, } for i, test := range tests { diff --git a/caddy/setup/tls.go b/caddy/setup/tls.go index 8a7269506..1345c11c2 100644 --- a/caddy/setup/tls.go +++ b/caddy/setup/tls.go @@ -16,6 +16,12 @@ func TLS(c *Controller) (middleware.Middleware, error) { "specify port 80 explicitly (https://%s:80).", c.Port, c.Host, c.Host) } + if c.Port == "http" { + c.TLS.Enabled = false + log.Printf("Warning: TLS disabled for %s://%s. To force TLS over the plaintext HTTP port, "+ + "specify port 80 explicitly (https://%s:80).", c.Port, c.Host, c.Host) + } + for c.Next() { args := c.RemainingArgs() switch len(args) { diff --git a/caddy/setup/websocket_test.go b/caddy/setup/websocket_test.go index 750f2a1d8..ae3513602 100644 --- a/caddy/setup/websocket_test.go +++ b/caddy/setup/websocket_test.go @@ -54,6 +54,25 @@ func TestWebSocketParse(t *testing.T) { Path: "/api4", Command: "cat", }}}, + + {`websocket /api5 "cmd arg1 arg2 arg3"`, false, []websocket.Config{{ + Path: "/api5", + Command: "cmd", + Arguments: []string{"arg1", "arg2", "arg3"}, + }}}, + + // accept respawn + {`websocket /api6 cat { + respawn + }`, false, []websocket.Config{{ + Path: "/api6", + Command: "cat", + }}}, + + // invalid configuration + {`websocket /api7 cat { + invalid + }`, true, []websocket.Config{}}, } for i, test := range tests { c := NewTestController(test.inputWebSocketConfig) diff --git a/main_test.go b/main_test.go new file mode 100644 index 000000000..4e61afa81 --- /dev/null +++ b/main_test.go @@ -0,0 +1,40 @@ +package main + +import ( + "runtime" + "testing" +) + +func TestSetCPU(t *testing.T) { + currentCPU := runtime.GOMAXPROCS(-1) + maxCPU := runtime.NumCPU() + for i, test := range []struct { + input string + output int + shouldErr bool + }{ + {"1", 1, false}, + {"-1", currentCPU, true}, + {"0", currentCPU, true}, + {"100%", maxCPU, false}, + {"50%", int(0.5 * float32(maxCPU)), false}, + {"110%", currentCPU, true}, + {"-10%", currentCPU, true}, + {"invalid input", currentCPU, true}, + {"invalid input%", currentCPU, true}, + {"9999", maxCPU, false}, // over available CPU + } { + err := setCPU(test.input) + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error, but there wasn't any", i) + } + if !test.shouldErr && err != nil { + t.Errorf("Test %d: Expected no error, but there was one: %v", i, err) + } + if actual, expected := runtime.GOMAXPROCS(-1), test.output; actual != expected { + t.Errorf("Test %d: GOMAXPROCS was %d but expected %d", i, actual, expected) + } + // teardown + runtime.GOMAXPROCS(currentCPU) + } +} diff --git a/middleware/commands.go b/middleware/commands.go index 6fb4a72e4..5c241161e 100644 --- a/middleware/commands.go +++ b/middleware/commands.go @@ -2,18 +2,30 @@ package middleware import ( "errors" + "runtime" + "unicode" "github.com/flynn/go-shlex" ) +var runtimeGoos = runtime.GOOS + // SplitCommandAndArgs takes a command string and parses it // shell-style into the command and its separate arguments. func SplitCommandAndArgs(command string) (cmd string, args []string, err error) { - parts, err := shlex.Split(command) - if err != nil { - err = errors.New("error parsing command: " + err.Error()) - return - } else if len(parts) == 0 { + var parts []string + + if runtimeGoos == "windows" { + parts = parseWindowsCommand(command) // parse it Windows-style + } else { + parts, err = parseUnixCommand(command) // parse it Unix-style + if err != nil { + err = errors.New("error parsing command: " + err.Error()) + return + } + } + + if len(parts) == 0 { err = errors.New("no command contained in '" + command + "'") return } @@ -25,3 +37,84 @@ func SplitCommandAndArgs(command string) (cmd string, args []string, err error) return } + +// parseUnixCommand parses a unix style command line and returns the +// command and its arguments or an error +func parseUnixCommand(cmd string) ([]string, error) { + return shlex.Split(cmd) +} + +// parseWindowsCommand parses windows command lines and +// returns the command and the arguments as an array. It +// should be able to parse commonly used command lines. +// Only basic syntax is supported: +// - spaces in double quotes are not token delimiters +// - double quotes are escaped by either backspace or another double quote +// - except for the above case backspaces are path separators (not special) +// +// Many sources point out that escaping quotes using backslash can be unsafe. +// Use two double quotes when possible. (Source: http://stackoverflow.com/a/31413730/2616179 ) +// +// This function has to be used on Windows instead +// of the shlex package because this function treats backslash +// characters properly. +func parseWindowsCommand(cmd string) []string { + const backslash = '\\' + const quote = '"' + + var parts []string + var part string + var inQuotes bool + var lastRune rune + + for i, ch := range cmd { + + if i != 0 { + lastRune = rune(cmd[i-1]) + } + + if ch == backslash { + // put it in the part - for now we don't know if it's an + // escaping char or path separator + part += string(ch) + continue + } + + if ch == quote { + if lastRune == backslash { + // remove the backslash from the part and add the escaped quote instead + part = part[:len(part)-1] + part += string(ch) + continue + } + + if lastRune == quote { + // revert the last change of the inQuotes state + // it was an escaping quote + inQuotes = !inQuotes + part += string(ch) + continue + } + + // normal escaping quotes + inQuotes = !inQuotes + continue + + } + + if unicode.IsSpace(ch) && !inQuotes && len(part) > 0 { + parts = append(parts, part) + part = "" + continue + } + + part += string(ch) + } + + if len(part) > 0 { + parts = append(parts, part) + part = "" + } + + return parts +} diff --git a/middleware/commands_test.go b/middleware/commands_test.go index 3a5b33342..3001e65a5 100644 --- a/middleware/commands_test.go +++ b/middleware/commands_test.go @@ -2,11 +2,176 @@ package middleware import ( "fmt" + "runtime" "strings" "testing" ) +func TestParseUnixCommand(t *testing.T) { + tests := []struct { + input string + expected []string + }{ + // 0 - emtpy command + { + input: ``, + expected: []string{}, + }, + // 1 - command without arguments + { + input: `command`, + expected: []string{`command`}, + }, + // 2 - command with single argument + { + input: `command arg1`, + expected: []string{`command`, `arg1`}, + }, + // 3 - command with multiple arguments + { + input: `command arg1 arg2`, + expected: []string{`command`, `arg1`, `arg2`}, + }, + // 4 - command with single argument with space character - in quotes + { + input: `command "arg1 arg1"`, + expected: []string{`command`, `arg1 arg1`}, + }, + // 5 - command with multiple spaces and tab character + { + input: "command arg1 arg2\targ3", + expected: []string{`command`, `arg1`, `arg2`, `arg3`}, + }, + // 6 - command with single argument with space character - escaped with backspace + { + input: `command arg1\ arg2`, + expected: []string{`command`, `arg1 arg2`}, + }, + // 7 - single quotes should escape special chars + { + input: `command 'arg1\ arg2'`, + expected: []string{`command`, `arg1\ arg2`}, + }, + } + + for i, test := range tests { + errorPrefix := fmt.Sprintf("Test [%d]: ", i) + errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input) + actual, _ := parseUnixCommand(test.input) + if len(actual) != len(test.expected) { + t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual) + continue + } + for j := 0; j < len(actual); j++ { + if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart { + t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j) + } + } + } +} + +func TestParseWindowsCommand(t *testing.T) { + tests := []struct { + input string + expected []string + }{ + { // 0 - empty command - do not fail + input: ``, + expected: []string{}, + }, + { // 1 - cmd without args + input: `cmd`, + expected: []string{`cmd`}, + }, + { // 2 - multiple args + input: `cmd arg1 arg2`, + expected: []string{`cmd`, `arg1`, `arg2`}, + }, + { // 3 - multiple args with space + input: `cmd "combined arg" arg2`, + expected: []string{`cmd`, `combined arg`, `arg2`}, + }, + { // 4 - path without spaces + input: `mkdir C:\Windows\foo\bar`, + expected: []string{`mkdir`, `C:\Windows\foo\bar`}, + }, + { // 5 - command with space in quotes + input: `"command here"`, + expected: []string{`command here`}, + }, + { // 6 - argument with escaped quotes (two quotes) + input: `cmd ""arg""`, + expected: []string{`cmd`, `"arg"`}, + }, + { // 7 - argument with escaped quotes (backslash) + input: `cmd \"arg\"`, + expected: []string{`cmd`, `"arg"`}, + }, + { // 8 - two quotes (escaped) inside an inQuote element + input: `cmd "a ""quoted value"`, + expected: []string{`cmd`, `a "quoted value`}, + }, + // TODO - see how many quotes are dislayed if we use "", """, """"""" + { // 9 - two quotes outside an inQuote element + input: `cmd a ""quoted value`, + expected: []string{`cmd`, `a`, `"quoted`, `value`}, + }, + { // 10 - path with space in quotes + input: `mkdir "C:\directory name\foobar"`, + expected: []string{`mkdir`, `C:\directory name\foobar`}, + }, + { // 11 - space without quotes + input: `mkdir C:\ space`, + expected: []string{`mkdir`, `C:\`, `space`}, + }, + { // 12 - space in quotes + input: `mkdir "C:\ space"`, + expected: []string{`mkdir`, `C:\ space`}, + }, + { // 13 - UNC + input: `mkdir \\?\C:\Users`, + expected: []string{`mkdir`, `\\?\C:\Users`}, + }, + { // 14 - UNC with space + input: `mkdir "\\?\C:\Program Files"`, + expected: []string{`mkdir`, `\\?\C:\Program Files`}, + }, + + { // 15 - unclosed quotes - treat as if the path ends with quote + input: `mkdir "c:\Program files`, + expected: []string{`mkdir`, `c:\Program files`}, + }, + { // 16 - quotes used inside the argument + input: `mkdir "c:\P"rogra"m f"iles`, + expected: []string{`mkdir`, `c:\Program files`}, + }, + } + + for i, test := range tests { + errorPrefix := fmt.Sprintf("Test [%d]: ", i) + errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input) + + actual := parseWindowsCommand(test.input) + if len(actual) != len(test.expected) { + t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual) + continue + } + for j := 0; j < len(actual); j++ { + if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart { + t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j) + } + } + } +} + func TestSplitCommandAndArgs(t *testing.T) { + + // force linux parsing. It's more robust and covers error cases + runtimeGoos = "linux" + defer func() { + runtimeGoos = runtime.GOOS + }() + var parseErrorContent = "error parsing command:" var noCommandErrContent = "no command contained in" @@ -16,84 +181,42 @@ func TestSplitCommandAndArgs(t *testing.T) { expectedArgs []string expectedErrContent string }{ - // Test case 0 - emtpy command + // 0 - emtpy command { input: ``, expectedCommand: ``, expectedArgs: nil, expectedErrContent: noCommandErrContent, }, - // Test case 1 - command without arguments + // 1 - command without arguments { input: `command`, expectedCommand: `command`, expectedArgs: nil, expectedErrContent: ``, }, - // Test case 2 - command with single argument + // 2 - command with single argument { input: `command arg1`, expectedCommand: `command`, expectedArgs: []string{`arg1`}, expectedErrContent: ``, }, - // Test case 3 - command with multiple arguments + // 3 - command with multiple arguments { input: `command arg1 arg2`, expectedCommand: `command`, expectedArgs: []string{`arg1`, `arg2`}, expectedErrContent: ``, }, - // Test case 4 - command with single argument with space character - in quotes - { - input: `command "arg1 arg1"`, - expectedCommand: `command`, - expectedArgs: []string{`arg1 arg1`}, - expectedErrContent: ``, - }, - // Test case 4 - command with single argument with space character - escaped - { - input: `command arg1\ arg1`, - expectedCommand: `command`, - expectedArgs: []string{`arg1 arg1`}, - expectedErrContent: ``, - }, - // Test case 6 - command with escaped quote character - { - input: `command "arg1 \" arg1"`, - expectedCommand: `command`, - expectedArgs: []string{`arg1 " arg1`}, - expectedErrContent: ``, - }, - // Test case 7 - command with escaped backslash - { - input: `command '\arg1'`, - expectedCommand: `command`, - expectedArgs: []string{`\arg1`}, - expectedErrContent: ``, - }, - // Test case 8 - command with comments - { - input: `command arg1 #comment1 comment2`, - expectedCommand: `command`, - expectedArgs: []string{`arg1`}, - expectedErrContent: "", - }, - // Test case 9 - command with multiple spaces and tab character - { - input: "command arg1 arg2\targ3", - expectedCommand: `command`, - expectedArgs: []string{`arg1`, `arg2`, "arg3"}, - expectedErrContent: "", - }, - // Test case 10 - command with unclosed quotes + // 4 - command with unclosed quotes { input: `command "arg1 arg2`, expectedCommand: "", expectedArgs: nil, expectedErrContent: parseErrorContent, }, - // Test case 11 - command with unclosed quotes + // 5 - command with unclosed quotes { input: `command 'arg1 arg2"`, expectedCommand: "", @@ -120,19 +243,49 @@ func TestSplitCommandAndArgs(t *testing.T) { // test if command matches if test.expectedCommand != actualCommand { - t.Errorf("Expected command: [%s], actual: [%s]."+errorSuffix, test.expectedCommand, actualCommand) + t.Errorf(errorPrefix+"Expected command: [%s], actual: [%s]."+errorSuffix, test.expectedCommand, actualCommand) } // test if arguments match if len(test.expectedArgs) != len(actualArgs) { - t.Errorf("Wrong number of arguments! Expected [%v], actual [%v]."+errorSuffix, test.expectedArgs, actualArgs) - } - - for j, actualArg := range actualArgs { - expectedArg := test.expectedArgs[j] - if actualArg != expectedArg { - t.Errorf(errorPrefix+"Argument at position [%d] differ! Expected [%s], actual [%s]"+errorSuffix, j, expectedArg, actualArg) + t.Errorf(errorPrefix+"Wrong number of arguments! Expected [%v], actual [%v]."+errorSuffix, test.expectedArgs, actualArgs) + } else { + // test args only if the count matches. + for j, actualArg := range actualArgs { + expectedArg := test.expectedArgs[j] + if actualArg != expectedArg { + t.Errorf(errorPrefix+"Argument at position [%d] differ! Expected [%s], actual [%s]"+errorSuffix, j, expectedArg, actualArg) + } } } } } + +func ExampleSplitCommandAndArgs() { + var commandLine string + var command string + var args []string + + // just for the test - change GOOS and reset it at the end of the test + runtimeGoos = "windows" + defer func() { + runtimeGoos = runtime.GOOS + }() + + commandLine = `mkdir /P "C:\Program Files"` + command, args, _ = SplitCommandAndArgs(commandLine) + + fmt.Printf("Windows: %s: %s [%s]\n", commandLine, command, strings.Join(args, ",")) + + // set GOOS to linux + runtimeGoos = "linux" + + commandLine = `mkdir -p /path/with\ space` + command, args, _ = SplitCommandAndArgs(commandLine) + + fmt.Printf("Linux: %s: %s [%s]\n", commandLine, command, strings.Join(args, ",")) + + // Output: + // Windows: mkdir /P "C:\Program Files": mkdir [/P,C:\Program Files] + // Linux: mkdir -p /path/with\ space: mkdir [-p,/path/with space] +} diff --git a/middleware/context.go b/middleware/context.go index 6c45d0337..b00d163eb 100644 --- a/middleware/context.go +++ b/middleware/context.go @@ -97,6 +97,10 @@ func (c Context) URI() string { func (c Context) Host() (string, error) { host, _, err := net.SplitHostPort(c.Req.Host) if err != nil { + if !strings.Contains(c.Req.Host, ":") { + // common with sites served on the default port 80 + return c.Req.Host, nil + } return "", err } return host, nil diff --git a/middleware/context_test.go b/middleware/context_test.go new file mode 100644 index 000000000..11d2a4390 --- /dev/null +++ b/middleware/context_test.go @@ -0,0 +1,545 @@ +package middleware + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestInclude(t *testing.T) { + context := getContextOrFail(t) + + inputFilename := "test_file" + absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename) + defer func() { + err := os.Remove(absInFilePath) + if err != nil && !os.IsNotExist(err) { + t.Fatalf("Failed to clean test file!") + } + }() + + tests := []struct { + fileContent string + expectedContent string + shouldErr bool + expectedErrorContent string + }{ + // Test 0 - all good + { + fileContent: `str1 {{ .Root }} str2`, + expectedContent: fmt.Sprintf("str1 %s str2", context.Root), + shouldErr: false, + expectedErrorContent: "", + }, + // Test 1 - failure on template.Parse + { + fileContent: `str1 {{ .Root } str2`, + expectedContent: "", + shouldErr: true, + expectedErrorContent: `unexpected "}" in operand`, + }, + // Test 3 - failure on template.Execute + { + fileContent: `str1 {{ .InvalidField }} str2`, + expectedContent: "", + shouldErr: true, + expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`, + }, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + + // WriteFile truncates the contentt + err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm) + if err != nil { + t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err) + } + + content, err := context.Include(inputFilename) + if err != nil { + if !test.shouldErr { + t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error()) + } + if !strings.Contains(err.Error(), test.expectedErrorContent) { + t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error()) + } + } + + if err == nil && test.shouldErr { + t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename) + } + + if content != test.expectedContent { + t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename) + } + } +} + +func TestIncludeNotExisting(t *testing.T) { + context := getContextOrFail(t) + + _, err := context.Include("not_existing") + if err == nil { + t.Errorf("Expected error but found nil!") + } +} + +func TestCookie(t *testing.T) { + + tests := []struct { + cookie *http.Cookie + cookieName string + expectedValue string + }{ + // Test 0 - happy path + { + cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"}, + cookieName: "cookieName", + expectedValue: "cookieValue", + }, + // Test 1 - try to get a non-existing cookie + { + cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"}, + cookieName: "notExisting", + expectedValue: "", + }, + // Test 2 - partial name match + { + cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"}, + cookieName: "cook", + expectedValue: "", + }, + // Test 3 - cookie with optional fields + { + cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120}, + cookieName: "cookie", + expectedValue: "cookieValue", + }, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + + // reinitialize the context for each test + context := getContextOrFail(t) + + context.Req.AddCookie(test.cookie) + + actualCookieVal := context.Cookie(test.cookieName) + + if actualCookieVal != test.expectedValue { + t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName) + } + } +} + +func TestCookieMultipleCookies(t *testing.T) { + context := getContextOrFail(t) + + cookieNameBase, cookieValueBase := "cookieName", "cookieValue" + + // make sure that there's no state and multiple requests for different cookies return the correct result + for i := 0; i < 10; i++ { + context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)}) + } + + for i := 0; i < 10; i++ { + expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i) + actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i)) + if actualCookieVal != expectedCookieVal { + t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal) + } + } +} + +func TestHeader(t *testing.T) { + context := getContextOrFail(t) + + headerKey, headerVal := "Header1", "HeaderVal1" + context.Req.Header.Add(headerKey, headerVal) + + actualHeaderVal := context.Header(headerKey) + if actualHeaderVal != headerVal { + t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal) + } + + missingHeaderVal := context.Header("not-existing") + if missingHeaderVal != "" { + t.Errorf("Expected empty header value, found %s", missingHeaderVal) + } +} + +func TestIP(t *testing.T) { + context := getContextOrFail(t) + + tests := []struct { + inputRemoteAddr string + expectedIP string + }{ + // Test 0 - ipv4 with port + {"1.1.1.1:1111", "1.1.1.1"}, + // Test 1 - ipv4 without port + {"1.1.1.1", "1.1.1.1"}, + // Test 2 - ipv6 with port + {"[::1]:11", "::1"}, + // Test 3 - ipv6 without port and brackets + {"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"}, + // Test 4 - ipv6 with zone and port + {`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`}, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + + context.Req.RemoteAddr = test.inputRemoteAddr + actualIP := context.IP() + + if actualIP != test.expectedIP { + t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP) + } + } +} + +func TestURL(t *testing.T) { + context := getContextOrFail(t) + + inputURL := "http://localhost" + context.Req.RequestURI = inputURL + + if inputURL != context.URI() { + t.Errorf("Expected url %s, found %s", inputURL, context.URI()) + } +} + +func TestHost(t *testing.T) { + tests := []struct { + input string + expectedHost string + shouldErr bool + }{ + { + input: "localhost:123", + expectedHost: "localhost", + shouldErr: false, + }, + { + input: "localhost", + expectedHost: "localhost", + shouldErr: false, + }, + { + input: "[::]", + expectedHost: "", + shouldErr: true, + }, + } + + for _, test := range tests { + testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr) + } +} + +func TestPort(t *testing.T) { + tests := []struct { + input string + expectedPort string + shouldErr bool + }{ + { + input: "localhost:123", + expectedPort: "123", + shouldErr: false, + }, + { + input: "localhost", + expectedPort: "", + shouldErr: true, // missing port in address + }, + { + input: ":8080", + expectedPort: "8080", + shouldErr: false, + }, + } + + for _, test := range tests { + testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr) + } +} + +func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) { + context := getContextOrFail(t) + + context.Req.Host = input + var actualResult, testedObject string + var err error + + if isTestingHost { + actualResult, err = context.Host() + testedObject = "host" + } else { + actualResult, err = context.Port() + testedObject = "port" + } + + if shouldErr && err == nil { + t.Errorf("Expected error, found nil!") + return + } + + if !shouldErr && err != nil { + t.Errorf("Expected no error, found %s", err) + return + } + + if actualResult != expectedResult { + t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult) + } +} + +func TestMethod(t *testing.T) { + context := getContextOrFail(t) + + method := "POST" + context.Req.Method = method + + if method != context.Method() { + t.Errorf("Expected method %s, found %s", method, context.Method()) + } + +} + +func TestPathMatches(t *testing.T) { + context := getContextOrFail(t) + + tests := []struct { + urlStr string + pattern string + shouldMatch bool + }{ + // Test 0 + { + urlStr: "http://localhost/", + pattern: "", + shouldMatch: true, + }, + // Test 1 + { + urlStr: "http://localhost", + pattern: "", + shouldMatch: true, + }, + // Test 1 + { + urlStr: "http://localhost/", + pattern: "/", + shouldMatch: true, + }, + // Test 3 + { + urlStr: "http://localhost/?param=val", + pattern: "/", + shouldMatch: true, + }, + // Test 4 + { + urlStr: "http://localhost/dir1/dir2", + pattern: "/dir2", + shouldMatch: false, + }, + // Test 5 + { + urlStr: "http://localhost/dir1/dir2", + pattern: "/dir1", + shouldMatch: true, + }, + // Test 6 + { + urlStr: "http://localhost:444/dir1/dir2", + pattern: "/dir1", + shouldMatch: true, + }, + // Test 7 + { + urlStr: "http://localhost/dir1/dir2", + pattern: "*/dir2", + shouldMatch: false, + }, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + var err error + context.Req.URL, err = url.Parse(test.urlStr) + if err != nil { + t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err) + } + + matches := context.PathMatches(test.pattern) + if matches != test.shouldMatch { + t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches) + } + } +} + +func TestTruncate(t *testing.T) { + context := getContextOrFail(t) + tests := []struct { + inputString string + inputLength int + expected string + }{ + // Test 0 - small length + { + inputString: "string", + inputLength: 1, + expected: "s", + }, + // Test 1 - exact length + { + inputString: "string", + inputLength: 6, + expected: "string", + }, + // Test 2 - bigger length + { + inputString: "string", + inputLength: 10, + expected: "string", + }, + } + + for i, test := range tests { + actual := context.Truncate(test.inputString, test.inputLength) + if actual != test.expected { + t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength) + } + } +} + +func TestStripHTML(t *testing.T) { + context := getContextOrFail(t) + tests := []struct { + input string + expected string + }{ + // Test 0 - no tags + { + input: `h1`, + expected: `h1`, + }, + // Test 1 - happy path + { + input: `

h1

`, + expected: `h1`, + }, + // Test 2 - tag in quotes + { + input: `">h1`, + expected: `h1`, + }, + // Test 3 - multiple tags + { + input: `

h1

`, + expected: `h1`, + }, + // Test 4 - tags not closed + { + input: `hi`, + expected: `file1.html", + filepath.Join("dirwithindex", "index.html"): "

dirwithindex/index.html

", + filepath.Join("dir", "file2.html"): "

dir/file2.html

", + filepath.Join("dir", "hidden.html"): "

dir/hidden.html

", +} + +// TestServeHTTP covers positive scenarios when serving files. +func TestServeHTTP(t *testing.T) { + + beforeServeHttpTest(t) + defer afterServeHttpTest(t) + + fileserver := FileServer(http.Dir(testDir), []string{"hidden.html"}) + + movedPermanently := "Moved Permanently" + + tests := []struct { + url string + + expectedStatus int + expectedBodyContent string + }{ + // Test 0 - access withoutt any path + { + url: "https://foo", + expectedStatus: http.StatusNotFound, + }, + // Test 1 - access root (without index.html) + { + url: "https://foo/", + expectedStatus: http.StatusNotFound, + }, + // Test 2 - access existing file + { + url: "https://foo/file1.html", + expectedStatus: http.StatusOK, + expectedBodyContent: testFiles["file1.html"], + }, + // Test 3 - access folder with index file with trailing slash + { + url: "https://foo/dirwithindex/", + expectedStatus: http.StatusOK, + expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")], + }, + // Test 4 - access folder with index file without trailing slash + { + url: "https://foo/dirwithindex", + expectedStatus: http.StatusMovedPermanently, + expectedBodyContent: movedPermanently, + }, + // Test 5 - access folder without index file + { + url: "https://foo/dir/", + expectedStatus: http.StatusNotFound, + }, + // Test 6 - access folder withtout trailing slash + { + url: "https://foo/dir", + expectedStatus: http.StatusMovedPermanently, + expectedBodyContent: movedPermanently, + }, + // Test 6 - access file with trailing slash + { + url: "https://foo/file1.html/", + expectedStatus: http.StatusMovedPermanently, + expectedBodyContent: movedPermanently, + }, + // Test 7 - access not existing path + { + url: "https://foo/not_existing", + expectedStatus: http.StatusNotFound, + }, + // Test 8 - access a file, marked as hidden + { + url: "https://foo/dir/hidden.html", + expectedStatus: http.StatusNotFound, + }, + // Test 9 - access a index file directly + { + url: "https://foo/dirwithindex/index.html", + expectedStatus: http.StatusOK, + expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")], + }, + // Test 10 - send a request with query params + { + url: "https://foo/dir?param1=val", + expectedStatus: http.StatusMovedPermanently, + expectedBodyContent: movedPermanently, + }, + } + + for i, test := range tests { + responseRecorder := httptest.NewRecorder() + request, err := http.NewRequest("GET", test.url, strings.NewReader("")) + status, err := fileserver.ServeHTTP(responseRecorder, request) + + // check if error matches expectations + if err != nil { + t.Errorf(getTestPrefix(i)+"Serving file at %s failed. Error was: %v", test.url, err) + } + + // check status code + if test.expectedStatus != status { + t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status) + } + + // check body content + if !strings.Contains(responseRecorder.Body.String(), test.expectedBodyContent) { + t.Errorf(getTestPrefix(i)+"Expected body to contain %q, found %q", test.expectedBodyContent, responseRecorder.Body.String()) + } + } + +} + +// beforeServeHttpTest creates a test directory with the structure, defined in the variable testFiles +func beforeServeHttpTest(t *testing.T) { + // make the root test dir + err := os.Mkdir(testDir, os.ModePerm) + if err != nil { + if !os.IsExist(err) { + t.Fatalf("Failed to create test dir. Error was: %v", err) + return + } + } + + for relFile, fileContent := range testFiles { + absFile := filepath.Join(testDir, relFile) + + // make sure the parent directories exist + parentDir := filepath.Dir(absFile) + _, err = os.Stat(parentDir) + if err != nil { + os.MkdirAll(parentDir, os.ModePerm) + } + + // now create the test files + f, err := os.Create(absFile) + if err != nil { + t.Fatalf("Failed to create test file %s. Error was: %v", absFile, err) + return + } + + // and fill them with content + _, err = f.WriteString(fileContent) + if err != nil { + t.Fatalf("Failed to write to %s. Error was: %v", absFile, err) + return + } + f.Close() + } + +} + +// afterServeHttpTest removes the test dir and all its content +func afterServeHttpTest(t *testing.T) { + // cleans up everything under the test dir. No need to clean the individual files. + err := os.RemoveAll(testDir) + if err != nil { + t.Fatalf("Failed to clean up test dir %s. Error was: %v", testDir, err) + } +} + +// failingFS implements the http.FileSystem interface. The Open method always returns the error, assigned to err +type failingFS struct { + err error // the error to return when Open is called + fileImpl http.File // inject the file implementation +} + +// Open returns the assigned failingFile and error +func (f failingFS) Open(path string) (http.File, error) { + return f.fileImpl, f.err +} + +// failingFile implements http.File but returns a predefined error on every Stat() method call. +type failingFile struct { + http.File + err error +} + +// Stat returns nil FileInfo and the provided error on every call +func (ff failingFile) Stat() (os.FileInfo, error) { + return nil, ff.err +} + +// Close is noop and returns no error +func (ff failingFile) Close() error { + return nil +} + +// TestServeHTTPFailingFS tests error cases where the Open function fails with various errors. +func TestServeHTTPFailingFS(t *testing.T) { + + tests := []struct { + fsErr error + expectedStatus int + expectedErr error + expectedHeaders map[string]string + }{ + { + fsErr: os.ErrNotExist, + expectedStatus: http.StatusNotFound, + expectedErr: nil, + }, + { + fsErr: os.ErrPermission, + expectedStatus: http.StatusForbidden, + expectedErr: os.ErrPermission, + }, + { + fsErr: customErr, + expectedStatus: http.StatusServiceUnavailable, + expectedErr: customErr, + expectedHeaders: map[string]string{"Retry-After": "5"}, + }, + } + + for i, test := range tests { + // initialize a file server with the failing FileSystem + fileserver := FileServer(failingFS{err: test.fsErr}, nil) + + // prepare the request and response + request, err := http.NewRequest("GET", "https://foo/", nil) + if err != nil { + t.Fatalf("Failed to build request. Error was: %v", err) + } + responseRecorder := httptest.NewRecorder() + + status, actualErr := fileserver.ServeHTTP(responseRecorder, request) + + // check the status + if status != test.expectedStatus { + t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status) + } + + // check the error + if actualErr != test.expectedErr { + t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr) + } + + // check the headers - a special case for server under load + if test.expectedHeaders != nil && len(test.expectedHeaders) > 0 { + for expectedKey, expectedVal := range test.expectedHeaders { + actualVal := responseRecorder.Header().Get(expectedKey) + if expectedVal != actualVal { + t.Errorf(getTestPrefix(i)+"Expected header %s: %s, found %s", expectedKey, expectedVal, actualVal) + } + } + } + } +} + +// TestServeHTTPFailingStat tests error cases where the initial Open function succeeds, but the Stat method on the opened file fails. +func TestServeHTTPFailingStat(t *testing.T) { + + tests := []struct { + statErr error + expectedStatus int + expectedErr error + }{ + { + statErr: os.ErrNotExist, + expectedStatus: http.StatusNotFound, + expectedErr: nil, + }, + { + statErr: os.ErrPermission, + expectedStatus: http.StatusForbidden, + expectedErr: os.ErrPermission, + }, + { + statErr: customErr, + expectedStatus: http.StatusInternalServerError, + expectedErr: customErr, + }, + } + + for i, test := range tests { + // initialize a file server. The FileSystem will not fail, but calls to the Stat method of the returned File object will + fileserver := FileServer(failingFS{err: nil, fileImpl: failingFile{err: test.statErr}}, nil) + + // prepare the request and response + request, err := http.NewRequest("GET", "https://foo/", nil) + if err != nil { + t.Fatalf("Failed to build request. Error was: %v", err) + } + responseRecorder := httptest.NewRecorder() + + status, actualErr := fileserver.ServeHTTP(responseRecorder, request) + + // check the status + if status != test.expectedStatus { + t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status) + } + + // check the error + if actualErr != test.expectedErr { + t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr) + } + } +} diff --git a/middleware/markdown/generator.go b/middleware/markdown/generator.go index a02cb3b05..9db82bdcf 100644 --- a/middleware/markdown/generator.go +++ b/middleware/markdown/generator.go @@ -17,18 +17,18 @@ import ( // It only generates static files if it is enabled (cfg.StaticDir // must be set). func GenerateStatic(md Markdown, cfg *Config) error { - generated, err := generateLinks(md, cfg) - if err != nil { - return err - } - - // No new file changes, return. - if !generated { - return nil - } - // If static site generation is enabled. if cfg.StaticDir != "" { + generated, err := generateLinks(md, cfg) + if err != nil { + return err + } + + // No new file changes, return. + if !generated { + return nil + } + if err := generateStaticHTML(md, cfg); err != nil { return err } diff --git a/middleware/markdown/markdown.go b/middleware/markdown/markdown.go index bdf142cf2..3b3bc96e0 100644 --- a/middleware/markdown/markdown.go +++ b/middleware/markdown/markdown.go @@ -136,6 +136,7 @@ func (md Markdown) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error // generation, serve the static page if fs.ModTime().Before(fs1.ModTime()) { if html, err := ioutil.ReadFile(filepath); err == nil { + middleware.SetLastModifiedHeader(w, fs1.ModTime()) w.Write(html) return http.StatusOK, nil } @@ -162,6 +163,7 @@ func (md Markdown) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error return http.StatusInternalServerError, err } + middleware.SetLastModifiedHeader(w, fs.ModTime()) w.Write(html) return http.StatusOK, nil } diff --git a/middleware/markdown/markdown_test.go b/middleware/markdown/markdown_test.go index 8cb89fae5..fbbe845d7 100644 --- a/middleware/markdown/markdown_test.go +++ b/middleware/markdown/markdown_test.go @@ -92,7 +92,7 @@ func TestMarkdown(t *testing.T) { expectedBody := ` -Markdown test +Markdown test 1

Header

@@ -102,11 +102,10 @@ Welcome to A Caddy website!

Body

-

go -func getTrue() bool { +

func getTrue() bool {
     return true
 }
-

+
@@ -129,7 +128,7 @@ func getTrue() bool { expectedBody = ` - Markdown test + Markdown test 2 @@ -143,11 +142,10 @@ func getTrue() bool {

Body

-

go -func getTrue() bool { +

func getTrue() bool {
     return true
 }
-

+
` diff --git a/middleware/markdown/process.go b/middleware/markdown/process.go index 0fb48dba1..65f22d66d 100644 --- a/middleware/markdown/process.go +++ b/middleware/markdown/process.go @@ -65,7 +65,8 @@ func (md Markdown) Process(c *Config, requestPath string, b []byte, ctx middlewa } // process markdown - markdown = blackfriday.Markdown(markdown, c.Renderer, 0) + extns := blackfriday.EXTENSION_TABLES | blackfriday.EXTENSION_FENCED_CODE | blackfriday.EXTENSION_STRIKETHROUGH + markdown = blackfriday.Markdown(markdown, c.Renderer, extns) // set it as body for template metadata.Variables["body"] = string(markdown) diff --git a/middleware/markdown/testdata/blog/test.md b/middleware/markdown/testdata/blog/test.md index 3d33ad918..93f07a493 100644 --- a/middleware/markdown/testdata/blog/test.md +++ b/middleware/markdown/testdata/blog/test.md @@ -1,5 +1,5 @@ --- -title: Markdown test +title: Markdown test 1 sitename: A Caddy website --- diff --git a/middleware/markdown/testdata/log/test.md b/middleware/markdown/testdata/log/test.md index 3d33ad918..476ab3015 100644 --- a/middleware/markdown/testdata/log/test.md +++ b/middleware/markdown/testdata/log/test.md @@ -1,5 +1,5 @@ --- -title: Markdown test +title: Markdown test 2 sitename: A Caddy website --- diff --git a/middleware/middleware.go b/middleware/middleware.go index ba7699ce4..b88b24474 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -4,6 +4,7 @@ package middleware import ( "net/http" "path" + "time" ) type ( @@ -78,3 +79,30 @@ func IndexFile(root http.FileSystem, fpath string, indexFiles []string) (string, } return "", false } + +// SetLastModifiedHeader checks if the provided modTime is valid and if it is sets it +// as a Last-Modified header to the ResponseWriter. If the modTime is in the future +// the current time is used instead. +func SetLastModifiedHeader(w http.ResponseWriter, modTime time.Time) { + if modTime.IsZero() || modTime.Equal(time.Unix(0, 0)) { + // the time does not appear to be valid. Don't put it in the response + return + } + + // RFC 2616 - Section 14.29 - Last-Modified: + // An origin server MUST NOT send a Last-Modified date which is later than the + // server's time of message origination. In such cases, where the resource's last + // modification would indicate some time in the future, the server MUST replace + // that date with the message origination date. + now := currentTime() + if modTime.After(now) { + modTime = now + } + + w.Header().Set("Last-Modified", modTime.UTC().Format(http.TimeFormat)) +} + +// currentTime returns time.Now() everytime it's called. It's used for mocking in tests. +var currentTime = func() time.Time { + return time.Now() +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 700beed84..62fa4e250 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -1,8 +1,11 @@ package middleware import ( + "fmt" "net/http" + "net/http/httptest" "testing" + "time" ) func TestIndexfile(t *testing.T) { @@ -42,3 +45,64 @@ func TestIndexfile(t *testing.T) { } } } + +func TestSetLastModified(t *testing.T) { + nowTime := time.Now() + + // ovewrite the function to return reliable time + originalGetCurrentTimeFunc := currentTime + currentTime = func() time.Time { + return nowTime + } + defer func() { + currentTime = originalGetCurrentTimeFunc + }() + + pastTime := nowTime.Truncate(1 * time.Hour) + futureTime := nowTime.Add(1 * time.Hour) + + tests := []struct { + inputModTime time.Time + expectedIsHeaderSet bool + expectedLastModified string + }{ + { + inputModTime: pastTime, + expectedIsHeaderSet: true, + expectedLastModified: pastTime.UTC().Format(http.TimeFormat), + }, + { + inputModTime: nowTime, + expectedIsHeaderSet: true, + expectedLastModified: nowTime.UTC().Format(http.TimeFormat), + }, + { + inputModTime: futureTime, + expectedIsHeaderSet: true, + expectedLastModified: nowTime.UTC().Format(http.TimeFormat), + }, + { + inputModTime: time.Time{}, + expectedIsHeaderSet: false, + }, + } + + for i, test := range tests { + responseRecorder := httptest.NewRecorder() + errorPrefix := fmt.Sprintf("Test [%d]: ", i) + SetLastModifiedHeader(responseRecorder, test.inputModTime) + actualLastModifiedHeader := responseRecorder.Header().Get("Last-Modified") + + if test.expectedIsHeaderSet && actualLastModifiedHeader == "" { + t.Fatalf(errorPrefix + "Expected to find Last-Modified header, but found nothing") + } + + if !test.expectedIsHeaderSet && actualLastModifiedHeader != "" { + t.Fatalf(errorPrefix+"Did not expect to find Last-Modified header, but found one [%s].", actualLastModifiedHeader) + } + + if test.expectedLastModified != actualLastModifiedHeader { + t.Errorf(errorPrefix+"Expected Last-Modified content [%s], found [%s}", test.expectedLastModified, actualLastModifiedHeader) + } + } +} diff --git a/middleware/templates/templates.go b/middleware/templates/templates.go index a699d0026..bc48ac45d 100644 --- a/middleware/templates/templates.go +++ b/middleware/templates/templates.go @@ -33,8 +33,18 @@ func (t Templates) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error // Create execution context ctx := middleware.Context{Root: t.FileSys, Req: r, URL: r.URL} + // New template + templateName := filepath.Base(fpath) + tpl := template.New(templateName) + + // Set delims + if rule.Delims != [2]string{} { + tpl.Delims(rule.Delims[0], rule.Delims[1]) + } + // Build the template - tpl, err := template.ParseFiles(filepath.Join(t.Root, fpath)) + templatePath := filepath.Join(t.Root, fpath) + tpl, err := tpl.ParseFiles(templatePath) if err != nil { if os.IsNotExist(err) { return http.StatusNotFound, nil @@ -50,6 +60,12 @@ func (t Templates) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error if err != nil { return http.StatusInternalServerError, err } + + templateInfo, err := os.Stat(templatePath) + if err == nil { + // add the Last-Modified header if we were able to optain the information + middleware.SetLastModifiedHeader(w, templateInfo.ModTime()) + } buf.WriteTo(w) return http.StatusOK, nil @@ -75,4 +91,5 @@ type Rule struct { Path string Extensions []string IndexFiles []string + Delims [2]string } diff --git a/middleware/templates/templates_test.go b/middleware/templates/templates_test.go index 3ee6072ce..c5a5d24a8 100644 --- a/middleware/templates/templates_test.go +++ b/middleware/templates/templates_test.go @@ -23,6 +23,7 @@ func Test(t *testing.T) { Extensions: []string{".html", ".htm"}, IndexFiles: []string{"index.html", "index.htm"}, Path: "/images", + Delims: [2]string{"{%", "%}"}, }, }, Root: "./testdata", @@ -94,6 +95,30 @@ func Test(t *testing.T) { t.Fatalf("Test: the expected body %v is different from the response one: %v", expectedBody, respBody) } + /* + * Test tmpl on /images/img2.htm + */ + req, err = http.NewRequest("GET", "/images/img2.htm", nil) + if err != nil { + t.Fatalf("Could not create HTTP request: %v", err) + } + + rec = httptest.NewRecorder() + + tmpl.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("Test: Wrong response code: %d, should be %d", rec.Code, http.StatusOK) + } + + respBody = rec.Body.String() + expectedBody = `img{{.Include "header.html"}} +` + + if respBody != expectedBody { + t.Fatalf("Test: the expected body %v is different from the response one: %v", expectedBody, respBody) + } + /* * Test tmplroot on /root.html */ diff --git a/middleware/templates/testdata/images/img.htm b/middleware/templates/testdata/images/img.htm index 865a73809..c90602044 100644 --- a/middleware/templates/testdata/images/img.htm +++ b/middleware/templates/testdata/images/img.htm @@ -1 +1 @@ -img{{.Include "header.html"}} +img{%.Include "header.html"%} diff --git a/middleware/templates/testdata/images/img2.htm b/middleware/templates/testdata/images/img2.htm new file mode 100644 index 000000000..865a73809 --- /dev/null +++ b/middleware/templates/testdata/images/img2.htm @@ -0,0 +1 @@ +img{{.Include "header.html"}} diff --git a/middleware/websocket/websocket.go b/middleware/websocket/websocket.go index f344fe511..76b2bfed8 100644 --- a/middleware/websocket/websocket.go +++ b/middleware/websocket/websocket.go @@ -172,7 +172,7 @@ func reader(conn *websocket.Conn, stdout io.ReadCloser, stdin io.WriteCloser) { conn.SetReadDeadline(time.Now().Add(pongWait)) conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) tickerChan := make(chan bool) - defer func() { tickerChan <- true }() // make sure to close the ticker when we are done. + defer close(tickerChan) // make sure to close the ticker when we are done. go ticker(conn, tickerChan) for { @@ -213,10 +213,7 @@ func reader(conn *websocket.Conn, stdout io.ReadCloser, stdin io.WriteCloser) { // between the server and client to keep it alive with ping messages. func ticker(conn *websocket.Conn, c chan bool) { ticker := time.NewTicker(pingPeriod) - defer func() { - ticker.Stop() - close(c) - }() + defer ticker.Stop() for { // blocking loop with select to wait for stimulation. select {