0
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2025-04-01 02:42:35 -05:00

Got startupshutdown test working on *nix

This commit is contained in:
Dipen Patel 2015-10-28 19:48:59 -04:00
commit 027bf49945
22 changed files with 1579 additions and 237 deletions

42
app/app_test.go Normal file
View file

@ -0,0 +1,42 @@
package app_test
import (
"runtime"
"testing"
"github.com/mholt/caddy/app"
)
func TestSetCPU(t *testing.T) {
currentCPU := runtime.GOMAXPROCS(-1)
maxCPU := runtime.NumCPU()
for i, test := range []struct {
input string
output int
shouldErr bool
}{
{"1", 1, false},
{"-1", currentCPU, true},
{"0", currentCPU, true},
{"100%", maxCPU, false},
{"50%", int(0.5 * float32(maxCPU)), false},
{"110%", currentCPU, true},
{"-10%", currentCPU, true},
{"invalid input", currentCPU, true},
{"invalid input%", currentCPU, true},
{"9999", maxCPU, false}, // over available CPU
} {
err := app.SetCPU(test.input)
if test.shouldErr && err == nil {
t.Errorf("Test %d: Expected error, but there wasn't any", i)
}
if !test.shouldErr && err != nil {
t.Errorf("Test %d: Expected no error, but there was one: %v", i, err)
}
if actual, expected := runtime.GOMAXPROCS(-1), test.output; actual != expected {
t.Errorf("Test %d: GOMAXPROCS was %d but expected %d", i, actual, expected)
}
// teardown
runtime.GOMAXPROCS(currentCPU)
}
}

View file

@ -2,10 +2,10 @@ version: "{build}"
os: Windows Server 2012 R2
clone_folder: c:\go\src\github.com\mholt\caddy
clone_folder: c:\gopath\src\github.com\mholt\caddy
environment:
GOPATH: c:\go
GOPATH: c:\gopath
install:
- go get golang.org/x/tools/cmd/vet

View file

@ -5,6 +5,7 @@ import (
"io"
"log"
"net"
"sync"
"github.com/mholt/caddy/app"
"github.com/mholt/caddy/config/parse"
@ -40,32 +41,63 @@ func Load(filename string, input io.Reader) (Group, error) {
return Default()
}
// Each server block represents one or more servers/addresses.
// Each server block represents similar hosts/addresses.
// Iterate each server block and make a config for each one,
// executing the directives that were parsed.
for _, sb := range serverBlocks {
sharedConfig, err := serverBlockToConfig(filename, sb)
if err != nil {
return nil, err
}
for i, sb := range serverBlocks {
onces := makeOnces()
storages := makeStorages()
for j, addr := range sb.Addresses {
config := server.Config{
Host: addr.Host,
Port: addr.Port,
Root: Root,
Middleware: make(map[string][]middleware.Middleware),
ConfigFile: filename,
AppName: app.Name,
AppVersion: app.Version,
}
// It is crucial that directives are executed in the proper order.
for _, dir := range directiveOrder {
// Execute directive if it is in the server block
if tokens, ok := sb.Tokens[dir.name]; ok {
// Each setup function gets a controller, which is the
// server config and the dispenser containing only
// this directive's tokens.
controller := &setup.Controller{
Config: &config,
Dispenser: parse.NewDispenserTokens(filename, tokens),
OncePerServerBlock: func(f func() error) error {
var err error
onces[dir.name].Do(func() {
err = f()
})
return err
},
ServerBlockIndex: i,
ServerBlockHostIndex: j,
ServerBlockHosts: sb.HostList(),
ServerBlockStorage: storages[dir.name],
}
midware, err := dir.setup(controller)
if err != nil {
return nil, err
}
if midware != nil {
// TODO: For now, we only support the default path scope /
config.Middleware["/"] = append(config.Middleware["/"], midware)
}
storages[dir.name] = controller.ServerBlockStorage // persist for this server block
}
}
// Now share the config with as many hosts as share the server block
for i, addr := range sb.Addresses {
config := sharedConfig
config.Host = addr.Host
config.Port = addr.Port
if config.Port == "" {
config.Port = Port
}
if config.Port == "http" {
config.TLS.Enabled = false
log.Printf("Warning: TLS disabled for %s://%s. To force TLS over the plaintext HTTP port, "+
"specify port 80 explicitly (https://%s:80).", config.Port, config.Host, config.Host)
}
if i == 0 {
sharedConfig.Startup = []func() error{}
sharedConfig.Shutdown = []func() error{}
}
configs = append(configs, config)
}
}
@ -73,48 +105,36 @@ func Load(filename string, input io.Reader) (Group, error) {
// restore logging settings
log.SetFlags(flags)
// Group by address/virtualhosts
return arrangeBindings(configs)
}
// serverBlockToConfig makes a config for the server block
// by executing the tokens that were parsed. The returned
// config is shared among all hosts/addresses for the server
// block, so Host and Port information is not filled out
// here.
func serverBlockToConfig(filename string, sb parse.ServerBlock) (server.Config, error) {
sharedConfig := server.Config{
Root: Root,
Middleware: make(map[string][]middleware.Middleware),
ConfigFile: filename,
AppName: app.Name,
AppVersion: app.Version,
}
// It is crucial that directives are executed in the proper order.
// makeOnces makes a map of directive name to sync.Once
// instance. This is intended to be called once per server
// block when setting up configs so that Setup functions
// for each directive can perform a task just once per
// server block, even if there are multiple hosts on the block.
//
// We need one Once per directive, otherwise the first
// directive to use it would exclude other directives from
// using it at all, which would be a bug.
func makeOnces() map[string]*sync.Once {
onces := make(map[string]*sync.Once)
for _, dir := range directiveOrder {
// Execute directive if it is in the server block
if tokens, ok := sb.Tokens[dir.name]; ok {
// Each setup function gets a controller, which is the
// server config and the dispenser containing only
// this directive's tokens.
controller := &setup.Controller{
Config: &sharedConfig,
Dispenser: parse.NewDispenserTokens(filename, tokens),
}
midware, err := dir.setup(controller)
if err != nil {
return sharedConfig, err
}
if midware != nil {
// TODO: For now, we only support the default path scope /
sharedConfig.Middleware["/"] = append(sharedConfig.Middleware["/"], midware)
}
}
onces[dir.name] = new(sync.Once)
}
return onces
}
return sharedConfig, nil
// makeStorages makes a map of directive name to interface{}
// so that directives' setup functions can persist state
// between different hosts on the same server block during the
// setup phase.
func makeStorages() map[string]interface{} {
storages := make(map[string]interface{})
for _, dir := range directiveOrder {
storages[dir.name] = nil
}
return storages
}
// arrangeBindings groups configurations by their bind address. For example,
@ -125,8 +145,8 @@ func serverBlockToConfig(filename string, sb parse.ServerBlock) (server.Config,
// bind address to list of configs that would become VirtualHosts on that
// server. Use the keys of the returned map to create listeners, and use
// the associated values to set up the virtualhosts.
func arrangeBindings(allConfigs []server.Config) (Group, error) {
addresses := make(Group)
func arrangeBindings(allConfigs []server.Config) (map[*net.TCPAddr][]server.Config, error) {
addresses := make(map[*net.TCPAddr][]server.Config)
// Group configs by bind address
for _, conf := range allConfigs {
@ -234,8 +254,9 @@ func validDirective(d string) bool {
return false
}
// NewDefault creates a default configuration using the default
// root, host, and port.
// NewDefault makes a default configuration, which
// is empty except for root, host, and port,
// which are essentials for serving the cwd.
func NewDefault() server.Config {
return server.Config{
Root: Root,
@ -244,9 +265,8 @@ func NewDefault() server.Config {
}
}
// Default makes a default configuration which
// is empty except for root, host, and port,
// which are essentials for serving the cwd.
// Default obtains a default config and arranges
// bindings so it's ready to use.
func Default() (Group, error) {
return arrangeBindings([]server.Config{NewDefault()})
}

View file

@ -1,11 +1,27 @@
package config
import (
"reflect"
"sync"
"testing"
"github.com/mholt/caddy/server"
)
func TestNewDefault(t *testing.T) {
config := NewDefault()
if actual, expected := config.Root, DefaultRoot; actual != expected {
t.Errorf("Root was %s but expected %s", actual, expected)
}
if actual, expected := config.Host, DefaultHost; actual != expected {
t.Errorf("Host was %s but expected %s", actual, expected)
}
if actual, expected := config.Port, DefaultPort; actual != expected {
t.Errorf("Port was %s but expected %s", actual, expected)
}
}
func TestResolveAddr(t *testing.T) {
// NOTE: If tests fail due to comparing to string "127.0.0.1",
// it's possible that system env resolves with IPv6, or ::1.
@ -62,3 +78,61 @@ func TestResolveAddr(t *testing.T) {
}
}
}
func TestMakeOnces(t *testing.T) {
directives := []directive{
{"dummy", nil},
{"dummy2", nil},
}
directiveOrder = directives
onces := makeOnces()
if len(onces) != len(directives) {
t.Errorf("onces had len %d , expected %d", len(onces), len(directives))
}
expected := map[string]*sync.Once{
"dummy": new(sync.Once),
"dummy2": new(sync.Once),
}
if !reflect.DeepEqual(onces, expected) {
t.Errorf("onces was %v, expected %v", onces, expected)
}
}
func TestMakeStorages(t *testing.T) {
directives := []directive{
{"dummy", nil},
{"dummy2", nil},
}
directiveOrder = directives
storages := makeStorages()
if len(storages) != len(directives) {
t.Errorf("storages had len %d , expected %d", len(storages), len(directives))
}
expected := map[string]interface{}{
"dummy": nil,
"dummy2": nil,
}
if !reflect.DeepEqual(storages, expected) {
t.Errorf("storages was %v, expected %v", storages, expected)
}
}
func TestValidDirective(t *testing.T) {
directives := []directive{
{"dummy", nil},
{"dummy2", nil},
}
directiveOrder = directives
for i, test := range []struct {
directive string
valid bool
}{
{"dummy", true},
{"dummy2", true},
{"dummy3", false},
} {
if actual, expected := validDirective(test.directive), test.valid; actual != expected {
t.Errorf("Test %d: valid was %t, expected %t", i, actual, expected)
}
}
}

View file

@ -6,7 +6,7 @@ import "io"
// ServerBlocks parses the input just enough to organize tokens,
// in order, by server block. No further parsing is performed.
// Server blocks are returned in the order in which they appear.
func ServerBlocks(filename string, input io.Reader) ([]ServerBlock, error) {
func ServerBlocks(filename string, input io.Reader) ([]serverBlock, error) {
p := parser{Dispenser: NewDispenser(filename, input)}
blocks, err := p.parseAll()
return blocks, err

View file

@ -9,12 +9,12 @@ import (
type parser struct {
Dispenser
block ServerBlock // current server block being parsed
block serverBlock // current server block being parsed
eof bool // if we encounter a valid EOF in a hard place
}
func (p *parser) parseAll() ([]ServerBlock, error) {
var blocks []ServerBlock
func (p *parser) parseAll() ([]serverBlock, error) {
var blocks []serverBlock
for p.Next() {
err := p.parseOne()
@ -30,7 +30,7 @@ func (p *parser) parseAll() ([]ServerBlock, error) {
}
func (p *parser) parseOne() error {
p.block = ServerBlock{Tokens: make(map[string][]token)}
p.block = serverBlock{Tokens: make(map[string][]token)}
err := p.begin()
if err != nil {
@ -87,7 +87,7 @@ func (p *parser) addresses() error {
break
}
if tkn != "" {
if tkn != "" { // empty token possible if user typed "" in Caddyfile
// Trailing comma indicates another address will follow, which
// may possibly be on the next line
if tkn[len(tkn)-1] == ',' {
@ -102,7 +102,7 @@ func (p *parser) addresses() error {
if err != nil {
return err
}
p.block.Addresses = append(p.block.Addresses, Address{host, port})
p.block.Addresses = append(p.block.Addresses, address{host, port})
}
// Advance token and possibly break out of loop or return error
@ -301,15 +301,26 @@ func standardAddress(str string) (host, port string, err error) {
}
type (
// ServerBlock associates tokens with a list of addresses
// serverBlock associates tokens with a list of addresses
// and groups tokens by directive name.
ServerBlock struct {
Addresses []Address
serverBlock struct {
Addresses []address
Tokens map[string][]token
}
// Address represents a host and port.
Address struct {
address struct {
Host, Port string
}
)
// HostList converts the list of addresses (hosts)
// that are associated with this server block into
// a slice of strings. Each string is a host:port
// combination.
func (sb serverBlock) HostList() []string {
sbHosts := make([]string, len(sb.Addresses))
for j, addr := range sb.Addresses {
sbHosts[j] = net.JoinHostPort(addr.Host, addr.Port)
}
return sbHosts
}

View file

@ -59,7 +59,7 @@ func TestStandardAddress(t *testing.T) {
func TestParseOneAndImport(t *testing.T) {
setupParseTests()
testParseOne := func(input string) (ServerBlock, error) {
testParseOne := func(input string) (serverBlock, error) {
p := testParser(input)
p.Next() // parseOne doesn't call Next() to start, so we must
err := p.parseOne()
@ -69,22 +69,22 @@ func TestParseOneAndImport(t *testing.T) {
for i, test := range []struct {
input string
shouldErr bool
addresses []Address
addresses []address
tokens map[string]int // map of directive name to number of tokens expected
}{
{`localhost`, false, []Address{
{`localhost`, false, []address{
{"localhost", ""},
}, map[string]int{}},
{`localhost
dir1`, false, []Address{
dir1`, false, []address{
{"localhost", ""},
}, map[string]int{
"dir1": 1,
}},
{`localhost:1234
dir1 foo bar`, false, []Address{
dir1 foo bar`, false, []address{
{"localhost", "1234"},
}, map[string]int{
"dir1": 3,
@ -92,7 +92,7 @@ func TestParseOneAndImport(t *testing.T) {
{`localhost {
dir1
}`, false, []Address{
}`, false, []address{
{"localhost", ""},
}, map[string]int{
"dir1": 1,
@ -101,7 +101,7 @@ func TestParseOneAndImport(t *testing.T) {
{`localhost:1234 {
dir1 foo bar
dir2
}`, false, []Address{
}`, false, []address{
{"localhost", "1234"},
}, map[string]int{
"dir1": 3,
@ -109,7 +109,7 @@ func TestParseOneAndImport(t *testing.T) {
}},
{`http://localhost https://localhost
dir1 foo bar`, false, []Address{
dir1 foo bar`, false, []address{
{"localhost", "http"},
{"localhost", "https"},
}, map[string]int{
@ -118,7 +118,7 @@ func TestParseOneAndImport(t *testing.T) {
{`http://localhost https://localhost {
dir1 foo bar
}`, false, []Address{
}`, false, []address{
{"localhost", "http"},
{"localhost", "https"},
}, map[string]int{
@ -127,7 +127,7 @@ func TestParseOneAndImport(t *testing.T) {
{`http://localhost, https://localhost {
dir1 foo bar
}`, false, []Address{
}`, false, []address{
{"localhost", "http"},
{"localhost", "https"},
}, map[string]int{
@ -135,13 +135,13 @@ func TestParseOneAndImport(t *testing.T) {
}},
{`http://localhost, {
}`, true, []Address{
}`, true, []address{
{"localhost", "http"},
}, map[string]int{}},
{`host1:80, http://host2.com
dir1 foo bar
dir2 baz`, false, []Address{
dir2 baz`, false, []address{
{"host1", "80"},
{"host2.com", "http"},
}, map[string]int{
@ -151,7 +151,7 @@ func TestParseOneAndImport(t *testing.T) {
{`http://host1.com,
http://host2.com,
https://host3.com`, false, []Address{
https://host3.com`, false, []address{
{"host1.com", "http"},
{"host2.com", "http"},
{"host3.com", "https"},
@ -161,7 +161,7 @@ func TestParseOneAndImport(t *testing.T) {
dir1 foo {
bar baz
}
dir2`, false, []Address{
dir2`, false, []address{
{"host1.com", "1234"},
{"host2.com", "https"},
}, map[string]int{
@ -175,7 +175,7 @@ func TestParseOneAndImport(t *testing.T) {
}
dir2 {
foo bar
}`, false, []Address{
}`, false, []address{
{"127.0.0.1", ""},
}, map[string]int{
"dir1": 5,
@ -183,13 +183,13 @@ func TestParseOneAndImport(t *testing.T) {
}},
{`127.0.0.1
unknown_directive`, true, []Address{
unknown_directive`, true, []address{
{"127.0.0.1", ""},
}, map[string]int{}},
{`localhost
dir1 {
foo`, true, []Address{
foo`, true, []address{
{"localhost", ""},
}, map[string]int{
"dir1": 3,
@ -197,7 +197,15 @@ func TestParseOneAndImport(t *testing.T) {
{`localhost
dir1 {
}`, false, []Address{
}`, false, []address{
{"localhost", ""},
}, map[string]int{
"dir1": 3,
}},
{`localhost
dir1 {
} }`, true, []address{
{"localhost", ""},
}, map[string]int{
"dir1": 3,
@ -209,18 +217,18 @@ func TestParseOneAndImport(t *testing.T) {
foo
}
}
dir2 foo bar`, false, []Address{
dir2 foo bar`, false, []address{
{"localhost", ""},
}, map[string]int{
"dir1": 7,
"dir2": 3,
}},
{``, false, []Address{}, map[string]int{}},
{``, false, []address{}, map[string]int{}},
{`localhost
dir1 arg1
import import_test1.txt`, false, []Address{
import import_test1.txt`, false, []address{
{"localhost", ""},
}, map[string]int{
"dir1": 2,
@ -228,16 +236,20 @@ func TestParseOneAndImport(t *testing.T) {
"dir3": 1,
}},
{`import import_test2.txt`, false, []Address{
{`import import_test2.txt`, false, []address{
{"host1", ""},
}, map[string]int{
"dir1": 1,
"dir2": 2,
}},
{``, false, []Address{}, map[string]int{}},
{`import import_test1.txt import_test2.txt`, true, []address{}, map[string]int{}},
{`""`, false, []Address{}, map[string]int{}},
{`import not_found.txt`, true, []address{}, map[string]int{}},
{`""`, false, []address{}, map[string]int{}},
{``, false, []address{}, map[string]int{}},
} {
result, err := testParseOne(test.input)
@ -282,43 +294,43 @@ func TestParseOneAndImport(t *testing.T) {
func TestParseAll(t *testing.T) {
setupParseTests()
testParseAll := func(input string) ([]ServerBlock, error) {
p := testParser(input)
return p.parseAll()
}
for i, test := range []struct {
input string
shouldErr bool
numBlocks int
addresses [][]address // addresses per server block, in order
}{
{`localhost`, false, 1},
{`localhost`, false, [][]address{
{{"localhost", ""}},
}},
{`localhost {
dir1
}`, false, 1},
{`localhost:1234`, false, [][]address{
[]address{{"localhost", "1234"}},
}},
{`http://localhost https://localhost
dir1 foo bar`, false, 1},
{`localhost:1234 {
}
localhost:2015 {
}`, false, [][]address{
[]address{{"localhost", "1234"}},
[]address{{"localhost", "2015"}},
}},
{`http://localhost, https://localhost {
dir1 foo bar
}`, false, 1},
{`localhost:1234, http://host2`, false, [][]address{
[]address{{"localhost", "1234"}, {"host2", "http"}},
}},
{`http://host1.com,
http://host2.com,
https://host3.com`, false, 1},
{`localhost:1234, http://host2,`, true, [][]address{}},
{`host1 {
}
host2 {
}`, false, 2},
{`""`, false, 0},
{``, false, 0},
{`http://host1.com, http://host2.com {
}
https://host3.com, https://host4.com {
}`, false, [][]address{
[]address{{"host1.com", "http"}, {"host2.com", "http"}},
[]address{{"host3.com", "https"}, {"host4.com", "https"}},
}},
} {
results, err := testParseAll(test.input)
p := testParser(test.input)
blocks, err := p.parseAll()
if test.shouldErr && err == nil {
t.Errorf("Test %d: Expected an error, but didn't get one", i)
@ -327,11 +339,28 @@ func TestParseAll(t *testing.T) {
t.Errorf("Test %d: Expected no error, but got: %v", i, err)
}
if len(results) != test.numBlocks {
if len(blocks) != len(test.addresses) {
t.Errorf("Test %d: Expected %d server blocks, got %d",
i, test.numBlocks, len(results))
i, len(test.addresses), len(blocks))
continue
}
for j, block := range blocks {
if len(block.Addresses) != len(test.addresses[j]) {
t.Errorf("Test %d: Expected %d addresses in block %d, got %d",
i, len(test.addresses[j]), j, len(block.Addresses))
continue
}
for k, addr := range block.Addresses {
if addr.Host != test.addresses[j][k].Host {
t.Errorf("Test %d, block %d, address %d: Expected host to be '%s', but was '%s'",
i, j, k, test.addresses[j][k].Host, addr.Host)
}
if addr.Port != test.addresses[j][k].Port {
t.Errorf("Test %d, block %d, address %d: Expected port to be '%s', but was '%s'",
i, j, k, test.addresses[j][k].Port, addr.Port)
}
}
}
}
}

View file

@ -11,23 +11,56 @@ import (
)
// Controller is given to the setup function of middlewares which
// gives them access to be able to read tokens and set config.
// gives them access to be able to read tokens and set config. Each
// virtualhost gets their own server config and dispenser.
type Controller struct {
*server.Config
parse.Dispenser
// OncePerServerBlock is a function that executes f
// exactly once per server block, no matter how many
// hosts are associated with it. If it is the first
// time, the function f is executed immediately
// (not deferred) and may return an error which is
// returned by OncePerServerBlock.
OncePerServerBlock func(f func() error) error
// ServerBlockIndex is the 0-based index of the
// server block as it appeared in the input.
ServerBlockIndex int
// ServerBlockHostIndex is the 0-based index of this
// host as it appeared in the input at the head of the
// server block.
ServerBlockHostIndex int
// ServerBlockHosts is a list of hosts that are
// associated with this server block. All these
// hosts, consequently, share the same tokens.
ServerBlockHosts []string
// ServerBlockStorage is used by a directive's
// setup function to persist state between all
// the hosts on a server block.
ServerBlockStorage interface{}
}
// NewTestController creates a new *Controller for
// the input specified, with a filename of "Testfile"
// the input specified, with a filename of "Testfile".
// The Config is bare, consisting only of a Root of cwd.
//
// Used primarily for testing but needs to be exported so
// add-ons can use this as a convenience.
// add-ons can use this as a convenience. Does not initialize
// the server-block-related fields.
func NewTestController(input string) *Controller {
return &Controller{
Config: &server.Config{
Root: ".",
},
Dispenser: parse.NewDispenser("Testfile", strings.NewReader(input)),
OncePerServerBlock: func(f func() error) error {
return f()
},
}
}

View file

@ -31,7 +31,7 @@ func FastCGI(c *Controller) (middleware.Middleware, error) {
SoftwareName: c.AppName,
SoftwareVersion: c.AppVersion,
ServerName: c.Host,
ServerPort: c.Port, // BUG: This is not known until the server blocks are split up...
ServerPort: c.Port,
}
}, nil
}

View file

@ -20,6 +20,8 @@ func Shutdown(c *Controller) (middleware.Middleware, error) {
// using c to parse the line. It appends the callback function
// to the list of callback functions passed in by reference.
func registerCallback(c *Controller, list *[]func() error) error {
var funcs []func() error
for c.Next() {
args := c.RemainingArgs()
if len(args) == 0 {
@ -50,8 +52,11 @@ func registerCallback(c *Controller, list *[]func() error) error {
return cmd.Run()
}
*list = append(*list, fn)
funcs = append(funcs, fn)
}
return nil
return c.OncePerServerBlock(func() error {
*list = append(*list, funcs...)
return nil
})
}

View file

@ -4,11 +4,9 @@ import (
"os"
"os/exec"
"path/filepath"
"strings"
"strconv"
"testing"
"github.com/mholt/caddy/config/parse"
"github.com/mholt/caddy/server"
"time"
)
// The Startup function's tests are symmetrical to Shutdown tests,
@ -16,9 +14,7 @@ import (
// same functionality
func TestStartup(t *testing.T) {
var startupFuncs []func() error
tempDirPath, err := getTempDirPath() // function defined in caddy/config/setup/root_test.go
tempDirPath, err := getTempDirPath()
if err != nil {
t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err)
}
@ -28,30 +24,28 @@ func TestStartup(t *testing.T) {
exec.Command("rm", "-r", osSenitiveTestDir).Start() // removes osSenitiveTestDir from the OS's temp directory, if the osSenitiveTestDir already exists
startupCommand := ` startup mkdir ` + osSenitiveTestDir + `
startup mkdir ` + osSenitiveTestDir + ` &
startup highly_unlikely_to_exist_command
`
c := &Controller{
Config: &server.Config{Startup: startupFuncs},
Dispenser: parse.NewDispenser("", strings.NewReader(startupCommand)),
}
_, err = Startup(c)
if err != nil {
t.Errorf("Expected no errors, got: %v", err)
}
tests := []struct {
input string
shouldExecutionErr bool
shouldRemoveErr bool
}{
{false, false}, // expected struct booleans for blocking commands
{false, true}, // expected struct booleans for non-blocking commands
{true, true}, // expected struct booleans for non-existant commands
// test case #0 tests proper functionality blocking commands
{"startup mkdir " + osSenitiveTestDir, false, false},
// test case #1 tests proper functionality of non-blocking commands
{"startup mkdir " + osSenitiveTestDir, false, true},
// test case #2 tests handling of non-existant commands
{"startup " + strconv.Itoa(int(time.Now().UnixNano())), true, true},
}
for i, test := range tests {
err = c.Startup[i]()
c := NewTestController(test.input)
_, err = Startup(c)
if err != nil {
t.Errorf("Expected no errors, got: %v", err)
}
err = c.Startup[0]()
if err != nil && !test.shouldExecutionErr {
t.Errorf("Test %d recieved an error of:\n%v", i, err)
}

View file

@ -2,6 +2,7 @@ package setup
import (
"crypto/tls"
"log"
"strings"
"github.com/mholt/caddy/middleware"
@ -10,6 +11,12 @@ import (
func TLS(c *Controller) (middleware.Middleware, error) {
c.TLS.Enabled = true
if c.Port == "http" {
c.TLS.Enabled = false
log.Printf("Warning: TLS disabled for %s://%s. To force TLS over the plaintext HTTP port, "+
"specify port 80 explicitly (https://%s:80).", c.Port, c.Host, c.Host)
}
for c.Next() {
if !c.NextArg() {
return nil, c.ArgErr()

View file

@ -54,6 +54,25 @@ func TestWebSocketParse(t *testing.T) {
Path: "/api4",
Command: "cat",
}}},
{`websocket /api5 "cmd arg1 arg2 arg3"`, false, []websocket.Config{{
Path: "/api5",
Command: "cmd",
Arguments: []string{"arg1", "arg2", "arg3"},
}}},
// accept respawn
{`websocket /api6 cat {
respawn
}`, false, []websocket.Config{{
Path: "/api6",
Command: "cat",
}}},
// invalid configuration
{`websocket /api7 cat {
invalid
}`, true, []websocket.Config{}},
}
for i, test := range tests {
c := NewTestController(test.inputWebSocketConfig)

View file

@ -49,7 +49,7 @@ func main() {
log.Fatal(err)
}
// Load address configurations from highest priority input
// Load config from file
addresses, err := loadConfigs()
if err != nil {
log.Fatal(err)
@ -123,10 +123,9 @@ func isLocalhost(s string) bool {
// loadConfigs loads configuration from a file or stdin (piped).
// The configurations are grouped by bind address.
// Configuration is obtained from one of three sources, tried
// Configuration is obtained from one of four sources, tried
// in this order: 1. -conf flag, 2. stdin, 3. command line argument 4. Caddyfile.
// If none of those are available, a default configuration is
// loaded.
// If none of those are available, a default configuration is loaded.
func loadConfigs() (config.Group, error) {
// -conf flag
if conf != "" {

View file

@ -3,25 +3,29 @@ package middleware
import (
"errors"
"runtime"
"strings"
"unicode"
"github.com/flynn/go-shlex"
)
var runtimeGoos = runtime.GOOS
// SplitCommandAndArgs takes a command string and parses it
// shell-style into the command and its separate arguments.
func SplitCommandAndArgs(command string) (cmd string, args []string, err error) {
var parts []string
// ensures that shlex does not strip Windows paths of their path separator
if runtime.GOOS == "windows" {
command = strings.Replace(command, "\\", "\\\\", -1)
if runtimeGoos == "windows" {
parts = parseWindowsCommand(command) // parse it Windows-style
} else {
parts, err = parseUnixCommand(command) // parse it Unix-style
if err != nil {
err = errors.New("error parsing command: " + err.Error())
return
}
}
parts, err := shlex.Split(command)
if err != nil {
err = errors.New("error parsing command: " + err.Error())
return
} else if len(parts) == 0 {
if len(parts) == 0 {
err = errors.New("no command contained in '" + command + "'")
return
}
@ -33,3 +37,84 @@ func SplitCommandAndArgs(command string) (cmd string, args []string, err error)
return
}
// parseUnixCommand parses a unix style command line and returns the
// command and its arguments or an error
func parseUnixCommand(cmd string) ([]string, error) {
return shlex.Split(cmd)
}
// parseWindowsCommand parses windows command lines and
// returns the command and the arguments as an array. It
// should be able to parse commonly used command lines.
// Only basic syntax is supported:
// - spaces in double quotes are not token delimiters
// - double quotes are escaped by either backspace or another double quote
// - except for the above case backspaces are path separators (not special)
//
// Many sources point out that escaping quotes using backslash can be unsafe.
// Use two double quotes when possible. (Source: http://stackoverflow.com/a/31413730/2616179 )
//
// This function has to be used on Windows instead
// of the shlex package because this function treats backslash
// characters properly.
func parseWindowsCommand(cmd string) []string {
const backslash = '\\'
const quote = '"'
var parts []string
var part string
var inQuotes bool
var lastRune rune
for i, ch := range cmd {
if i != 0 {
lastRune = rune(cmd[i-1])
}
if ch == backslash {
// put it in the part - for now we don't know if it's an
// escaping char or path separator
part += string(ch)
continue
}
if ch == quote {
if lastRune == backslash {
// remove the backslash from the part and add the escaped quote instead
part = part[:len(part)-1]
part += string(ch)
continue
}
if lastRune == quote {
// revert the last change of the inQuotes state
// it was an escaping quote
inQuotes = !inQuotes
part += string(ch)
continue
}
// normal escaping quotes
inQuotes = !inQuotes
continue
}
if unicode.IsSpace(ch) && !inQuotes && len(part) > 0 {
parts = append(parts, part)
part = ""
continue
}
part += string(ch)
}
if len(part) > 0 {
parts = append(parts, part)
part = ""
}
return parts
}

View file

@ -2,11 +2,176 @@ package middleware
import (
"fmt"
"runtime"
"strings"
"testing"
)
func TestParseUnixCommand(t *testing.T) {
tests := []struct {
input string
expected []string
}{
// 0 - emtpy command
{
input: ``,
expected: []string{},
},
// 1 - command without arguments
{
input: `command`,
expected: []string{`command`},
},
// 2 - command with single argument
{
input: `command arg1`,
expected: []string{`command`, `arg1`},
},
// 3 - command with multiple arguments
{
input: `command arg1 arg2`,
expected: []string{`command`, `arg1`, `arg2`},
},
// 4 - command with single argument with space character - in quotes
{
input: `command "arg1 arg1"`,
expected: []string{`command`, `arg1 arg1`},
},
// 5 - command with multiple spaces and tab character
{
input: "command arg1 arg2\targ3",
expected: []string{`command`, `arg1`, `arg2`, `arg3`},
},
// 6 - command with single argument with space character - escaped with backspace
{
input: `command arg1\ arg2`,
expected: []string{`command`, `arg1 arg2`},
},
// 7 - single quotes should escape special chars
{
input: `command 'arg1\ arg2'`,
expected: []string{`command`, `arg1\ arg2`},
},
}
for i, test := range tests {
errorPrefix := fmt.Sprintf("Test [%d]: ", i)
errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
actual, _ := parseUnixCommand(test.input)
if len(actual) != len(test.expected) {
t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual)
continue
}
for j := 0; j < len(actual); j++ {
if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart {
t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j)
}
}
}
}
func TestParseWindowsCommand(t *testing.T) {
tests := []struct {
input string
expected []string
}{
{ // 0 - empty command - do not fail
input: ``,
expected: []string{},
},
{ // 1 - cmd without args
input: `cmd`,
expected: []string{`cmd`},
},
{ // 2 - multiple args
input: `cmd arg1 arg2`,
expected: []string{`cmd`, `arg1`, `arg2`},
},
{ // 3 - multiple args with space
input: `cmd "combined arg" arg2`,
expected: []string{`cmd`, `combined arg`, `arg2`},
},
{ // 4 - path without spaces
input: `mkdir C:\Windows\foo\bar`,
expected: []string{`mkdir`, `C:\Windows\foo\bar`},
},
{ // 5 - command with space in quotes
input: `"command here"`,
expected: []string{`command here`},
},
{ // 6 - argument with escaped quotes (two quotes)
input: `cmd ""arg""`,
expected: []string{`cmd`, `"arg"`},
},
{ // 7 - argument with escaped quotes (backslash)
input: `cmd \"arg\"`,
expected: []string{`cmd`, `"arg"`},
},
{ // 8 - two quotes (escaped) inside an inQuote element
input: `cmd "a ""quoted value"`,
expected: []string{`cmd`, `a "quoted value`},
},
// TODO - see how many quotes are dislayed if we use "", """, """""""
{ // 9 - two quotes outside an inQuote element
input: `cmd a ""quoted value`,
expected: []string{`cmd`, `a`, `"quoted`, `value`},
},
{ // 10 - path with space in quotes
input: `mkdir "C:\directory name\foobar"`,
expected: []string{`mkdir`, `C:\directory name\foobar`},
},
{ // 11 - space without quotes
input: `mkdir C:\ space`,
expected: []string{`mkdir`, `C:\`, `space`},
},
{ // 12 - space in quotes
input: `mkdir "C:\ space"`,
expected: []string{`mkdir`, `C:\ space`},
},
{ // 13 - UNC
input: `mkdir \\?\C:\Users`,
expected: []string{`mkdir`, `\\?\C:\Users`},
},
{ // 14 - UNC with space
input: `mkdir "\\?\C:\Program Files"`,
expected: []string{`mkdir`, `\\?\C:\Program Files`},
},
{ // 15 - unclosed quotes - treat as if the path ends with quote
input: `mkdir "c:\Program files`,
expected: []string{`mkdir`, `c:\Program files`},
},
{ // 16 - quotes used inside the argument
input: `mkdir "c:\P"rogra"m f"iles`,
expected: []string{`mkdir`, `c:\Program files`},
},
}
for i, test := range tests {
errorPrefix := fmt.Sprintf("Test [%d]: ", i)
errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
actual := parseWindowsCommand(test.input)
if len(actual) != len(test.expected) {
t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual)
continue
}
for j := 0; j < len(actual); j++ {
if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart {
t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j)
}
}
}
}
func TestSplitCommandAndArgs(t *testing.T) {
// force linux parsing. It's more robust and covers error cases
runtimeGoos = "linux"
defer func() {
runtimeGoos = runtime.GOOS
}()
var parseErrorContent = "error parsing command:"
var noCommandErrContent = "no command contained in"
@ -16,84 +181,42 @@ func TestSplitCommandAndArgs(t *testing.T) {
expectedArgs []string
expectedErrContent string
}{
// Test case 0 - emtpy command
// 0 - emtpy command
{
input: ``,
expectedCommand: ``,
expectedArgs: nil,
expectedErrContent: noCommandErrContent,
},
// Test case 1 - command without arguments
// 1 - command without arguments
{
input: `command`,
expectedCommand: `command`,
expectedArgs: nil,
expectedErrContent: ``,
},
// Test case 2 - command with single argument
// 2 - command with single argument
{
input: `command arg1`,
expectedCommand: `command`,
expectedArgs: []string{`arg1`},
expectedErrContent: ``,
},
// Test case 3 - command with multiple arguments
// 3 - command with multiple arguments
{
input: `command arg1 arg2`,
expectedCommand: `command`,
expectedArgs: []string{`arg1`, `arg2`},
expectedErrContent: ``,
},
// Test case 4 - command with single argument with space character - in quotes
{
input: `command "arg1 arg1"`,
expectedCommand: `command`,
expectedArgs: []string{`arg1 arg1`},
expectedErrContent: ``,
},
// Test case 4 - command with single argument with space character - escaped
{
input: `command arg1\ arg1`,
expectedCommand: `command`,
expectedArgs: []string{`arg1 arg1`},
expectedErrContent: ``,
},
// Test case 6 - command with escaped quote character
{
input: `command "arg1 \" arg1"`,
expectedCommand: `command`,
expectedArgs: []string{`arg1 " arg1`},
expectedErrContent: ``,
},
// Test case 7 - command with escaped backslash
{
input: `command '\arg1'`,
expectedCommand: `command`,
expectedArgs: []string{`\arg1`},
expectedErrContent: ``,
},
// Test case 8 - command with comments
{
input: `command arg1 #comment1 comment2`,
expectedCommand: `command`,
expectedArgs: []string{`arg1`},
expectedErrContent: "",
},
// Test case 9 - command with multiple spaces and tab character
{
input: "command arg1 arg2\targ3",
expectedCommand: `command`,
expectedArgs: []string{`arg1`, `arg2`, "arg3"},
expectedErrContent: "",
},
// Test case 10 - command with unclosed quotes
// 4 - command with unclosed quotes
{
input: `command "arg1 arg2`,
expectedCommand: "",
expectedArgs: nil,
expectedErrContent: parseErrorContent,
},
// Test case 11 - command with unclosed quotes
// 5 - command with unclosed quotes
{
input: `command 'arg1 arg2"`,
expectedCommand: "",
@ -120,19 +243,49 @@ func TestSplitCommandAndArgs(t *testing.T) {
// test if command matches
if test.expectedCommand != actualCommand {
t.Errorf("Expected command: [%s], actual: [%s]."+errorSuffix, test.expectedCommand, actualCommand)
t.Errorf(errorPrefix+"Expected command: [%s], actual: [%s]."+errorSuffix, test.expectedCommand, actualCommand)
}
// test if arguments match
if len(test.expectedArgs) != len(actualArgs) {
t.Errorf("Wrong number of arguments! Expected [%v], actual [%v]."+errorSuffix, test.expectedArgs, actualArgs)
}
for j, actualArg := range actualArgs {
expectedArg := test.expectedArgs[j]
if actualArg != expectedArg {
t.Errorf(errorPrefix+"Argument at position [%d] differ! Expected [%s], actual [%s]"+errorSuffix, j, expectedArg, actualArg)
t.Errorf(errorPrefix+"Wrong number of arguments! Expected [%v], actual [%v]."+errorSuffix, test.expectedArgs, actualArgs)
} else {
// test args only if the count matches.
for j, actualArg := range actualArgs {
expectedArg := test.expectedArgs[j]
if actualArg != expectedArg {
t.Errorf(errorPrefix+"Argument at position [%d] differ! Expected [%s], actual [%s]"+errorSuffix, j, expectedArg, actualArg)
}
}
}
}
}
func ExampleSplitCommandAndArgs() {
var commandLine string
var command string
var args []string
// just for the test - change GOOS and reset it at the end of the test
runtimeGoos = "windows"
defer func() {
runtimeGoos = runtime.GOOS
}()
commandLine = `mkdir /P "C:\Program Files"`
command, args, _ = SplitCommandAndArgs(commandLine)
fmt.Printf("Windows: %s: %s [%s]\n", commandLine, command, strings.Join(args, ","))
// set GOOS to linux
runtimeGoos = "linux"
commandLine = `mkdir -p /path/with\ space`
command, args, _ = SplitCommandAndArgs(commandLine)
fmt.Printf("Linux: %s: %s [%s]\n", commandLine, command, strings.Join(args, ","))
// Output:
// Windows: mkdir /P "C:\Program Files": mkdir [/P,C:\Program Files]
// Linux: mkdir -p /path/with\ space: mkdir [-p,/path/with space]
}

View file

@ -97,6 +97,10 @@ func (c Context) URI() string {
func (c Context) Host() (string, error) {
host, _, err := net.SplitHostPort(c.Req.Host)
if err != nil {
if !strings.Contains(c.Req.Host, ":") {
// common with sites served on the default port 80
return c.Req.Host, nil
}
return "", err
}
return host, nil

545
middleware/context_test.go Normal file
View file

@ -0,0 +1,545 @@
package middleware
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
func TestInclude(t *testing.T) {
context := getContextOrFail(t)
inputFilename := "test_file"
absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
defer func() {
err := os.Remove(absInFilePath)
if err != nil && !os.IsNotExist(err) {
t.Fatalf("Failed to clean test file!")
}
}()
tests := []struct {
fileContent string
expectedContent string
shouldErr bool
expectedErrorContent string
}{
// Test 0 - all good
{
fileContent: `str1 {{ .Root }} str2`,
expectedContent: fmt.Sprintf("str1 %s str2", context.Root),
shouldErr: false,
expectedErrorContent: "",
},
// Test 1 - failure on template.Parse
{
fileContent: `str1 {{ .Root } str2`,
expectedContent: "",
shouldErr: true,
expectedErrorContent: `unexpected "}" in operand`,
},
// Test 3 - failure on template.Execute
{
fileContent: `str1 {{ .InvalidField }} str2`,
expectedContent: "",
shouldErr: true,
expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`,
},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
// WriteFile truncates the contentt
err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
if err != nil {
t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
}
content, err := context.Include(inputFilename)
if err != nil {
if !test.shouldErr {
t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error())
}
if !strings.Contains(err.Error(), test.expectedErrorContent) {
t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error())
}
}
if err == nil && test.shouldErr {
t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename)
}
if content != test.expectedContent {
t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
}
}
}
func TestIncludeNotExisting(t *testing.T) {
context := getContextOrFail(t)
_, err := context.Include("not_existing")
if err == nil {
t.Errorf("Expected error but found nil!")
}
}
func TestCookie(t *testing.T) {
tests := []struct {
cookie *http.Cookie
cookieName string
expectedValue string
}{
// Test 0 - happy path
{
cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
cookieName: "cookieName",
expectedValue: "cookieValue",
},
// Test 1 - try to get a non-existing cookie
{
cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
cookieName: "notExisting",
expectedValue: "",
},
// Test 2 - partial name match
{
cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"},
cookieName: "cook",
expectedValue: "",
},
// Test 3 - cookie with optional fields
{
cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120},
cookieName: "cookie",
expectedValue: "cookieValue",
},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
// reinitialize the context for each test
context := getContextOrFail(t)
context.Req.AddCookie(test.cookie)
actualCookieVal := context.Cookie(test.cookieName)
if actualCookieVal != test.expectedValue {
t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName)
}
}
}
func TestCookieMultipleCookies(t *testing.T) {
context := getContextOrFail(t)
cookieNameBase, cookieValueBase := "cookieName", "cookieValue"
// make sure that there's no state and multiple requests for different cookies return the correct result
for i := 0; i < 10; i++ {
context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)})
}
for i := 0; i < 10; i++ {
expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i)
actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i))
if actualCookieVal != expectedCookieVal {
t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal)
}
}
}
func TestHeader(t *testing.T) {
context := getContextOrFail(t)
headerKey, headerVal := "Header1", "HeaderVal1"
context.Req.Header.Add(headerKey, headerVal)
actualHeaderVal := context.Header(headerKey)
if actualHeaderVal != headerVal {
t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
}
missingHeaderVal := context.Header("not-existing")
if missingHeaderVal != "" {
t.Errorf("Expected empty header value, found %s", missingHeaderVal)
}
}
func TestIP(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
inputRemoteAddr string
expectedIP string
}{
// Test 0 - ipv4 with port
{"1.1.1.1:1111", "1.1.1.1"},
// Test 1 - ipv4 without port
{"1.1.1.1", "1.1.1.1"},
// Test 2 - ipv6 with port
{"[::1]:11", "::1"},
// Test 3 - ipv6 without port and brackets
{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
// Test 4 - ipv6 with zone and port
{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
context.Req.RemoteAddr = test.inputRemoteAddr
actualIP := context.IP()
if actualIP != test.expectedIP {
t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
}
}
}
func TestURL(t *testing.T) {
context := getContextOrFail(t)
inputURL := "http://localhost"
context.Req.RequestURI = inputURL
if inputURL != context.URI() {
t.Errorf("Expected url %s, found %s", inputURL, context.URI())
}
}
func TestHost(t *testing.T) {
tests := []struct {
input string
expectedHost string
shouldErr bool
}{
{
input: "localhost:123",
expectedHost: "localhost",
shouldErr: false,
},
{
input: "localhost",
expectedHost: "localhost",
shouldErr: false,
},
{
input: "[::]",
expectedHost: "",
shouldErr: true,
},
}
for _, test := range tests {
testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
}
}
func TestPort(t *testing.T) {
tests := []struct {
input string
expectedPort string
shouldErr bool
}{
{
input: "localhost:123",
expectedPort: "123",
shouldErr: false,
},
{
input: "localhost",
expectedPort: "",
shouldErr: true, // missing port in address
},
{
input: ":8080",
expectedPort: "8080",
shouldErr: false,
},
}
for _, test := range tests {
testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
}
}
func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
context := getContextOrFail(t)
context.Req.Host = input
var actualResult, testedObject string
var err error
if isTestingHost {
actualResult, err = context.Host()
testedObject = "host"
} else {
actualResult, err = context.Port()
testedObject = "port"
}
if shouldErr && err == nil {
t.Errorf("Expected error, found nil!")
return
}
if !shouldErr && err != nil {
t.Errorf("Expected no error, found %s", err)
return
}
if actualResult != expectedResult {
t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
}
}
func TestMethod(t *testing.T) {
context := getContextOrFail(t)
method := "POST"
context.Req.Method = method
if method != context.Method() {
t.Errorf("Expected method %s, found %s", method, context.Method())
}
}
func TestPathMatches(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
urlStr string
pattern string
shouldMatch bool
}{
// Test 0
{
urlStr: "http://localhost/",
pattern: "",
shouldMatch: true,
},
// Test 1
{
urlStr: "http://localhost",
pattern: "",
shouldMatch: true,
},
// Test 1
{
urlStr: "http://localhost/",
pattern: "/",
shouldMatch: true,
},
// Test 3
{
urlStr: "http://localhost/?param=val",
pattern: "/",
shouldMatch: true,
},
// Test 4
{
urlStr: "http://localhost/dir1/dir2",
pattern: "/dir2",
shouldMatch: false,
},
// Test 5
{
urlStr: "http://localhost/dir1/dir2",
pattern: "/dir1",
shouldMatch: true,
},
// Test 6
{
urlStr: "http://localhost:444/dir1/dir2",
pattern: "/dir1",
shouldMatch: true,
},
// Test 7
{
urlStr: "http://localhost/dir1/dir2",
pattern: "*/dir2",
shouldMatch: false,
},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
var err error
context.Req.URL, err = url.Parse(test.urlStr)
if err != nil {
t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
}
matches := context.PathMatches(test.pattern)
if matches != test.shouldMatch {
t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
}
}
}
func TestTruncate(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
inputString string
inputLength int
expected string
}{
// Test 0 - small length
{
inputString: "string",
inputLength: 1,
expected: "s",
},
// Test 1 - exact length
{
inputString: "string",
inputLength: 6,
expected: "string",
},
// Test 2 - bigger length
{
inputString: "string",
inputLength: 10,
expected: "string",
},
}
for i, test := range tests {
actual := context.Truncate(test.inputString, test.inputLength)
if actual != test.expected {
t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength)
}
}
}
func TestStripHTML(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
input string
expected string
}{
// Test 0 - no tags
{
input: `h1`,
expected: `h1`,
},
// Test 1 - happy path
{
input: `<h1>h1</h1>`,
expected: `h1`,
},
// Test 2 - tag in quotes
{
input: `<h1">">h1</h1>`,
expected: `h1`,
},
// Test 3 - multiple tags
{
input: `<h1><b>h1</b></h1>`,
expected: `h1`,
},
// Test 4 - tags not closed
{
input: `<h1`,
expected: `<h1`,
},
// Test 5 - false start
{
input: `<h1<b>hi`,
expected: `<h1hi`,
},
}
for i, test := range tests {
actual := context.StripHTML(test.input)
if actual != test.expected {
t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripHTML(%s)", test.expected, actual, test.input)
}
}
}
func TestStripExt(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
input string
expected string
}{
// Test 0 - empty input
{
input: "",
expected: "",
},
// Test 1 - relative file with ext
{
input: "file.ext",
expected: "file",
},
// Test 2 - relative file without ext
{
input: "file",
expected: "file",
},
// Test 3 - absolute file without ext
{
input: "/file",
expected: "/file",
},
// Test 4 - absolute file with ext
{
input: "/file.ext",
expected: "/file",
},
// Test 5 - with ext but ends with /
{
input: "/dir.ext/",
expected: "/dir.ext/",
},
// Test 6 - file with ext under dir with ext
{
input: "/dir.ext/file.ext",
expected: "/dir.ext/file",
},
}
for i, test := range tests {
actual := context.StripExt(test.input)
if actual != test.expected {
t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripExt(%q)", test.expected, actual, test.input)
}
}
}
func initTestContext() (Context, error) {
body := bytes.NewBufferString("request body")
request, err := http.NewRequest("GET", "https://localhost", body)
if err != nil {
return Context{}, err
}
return Context{Root: http.Dir(os.TempDir()), Req: request}, nil
}
func getContextOrFail(t *testing.T) Context {
context, err := initTestContext()
if err != nil {
t.Fatalf("Failed to prepare test context")
}
return context
}
func getTestPrefix(testN int) string {
return fmt.Sprintf("Test [%d]: ", testN)
}

View file

@ -62,8 +62,8 @@ func (fh *fileHandler) serveFile(w http.ResponseWriter, r *http.Request, name st
}
defer f.Close()
d, err1 := f.Stat()
if err1 != nil {
d, err := f.Stat()
if err != nil {
if os.IsNotExist(err) {
return http.StatusNotFound, nil
} else if os.IsPermission(err) {

View file

@ -0,0 +1,325 @@
package middleware
import (
"errors"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
var testDir = filepath.Join(os.TempDir(), "caddy_testdir")
var customErr = errors.New("Custom Error")
// testFiles is a map with relative paths to test files as keys and file content as values.
// The map represents the following structure:
// - $TEMP/caddy_testdir/
// '-- file1.html
// '-- dirwithindex/
// '---- index.html
// '-- dir/
// '---- file2.html
// '---- hidden.html
var testFiles = map[string]string{
"file1.html": "<h1>file1.html</h1>",
filepath.Join("dirwithindex", "index.html"): "<h1>dirwithindex/index.html</h1>",
filepath.Join("dir", "file2.html"): "<h1>dir/file2.html</h1>",
filepath.Join("dir", "hidden.html"): "<h1>dir/hidden.html</h1>",
}
// TestServeHTTP covers positive scenarios when serving files.
func TestServeHTTP(t *testing.T) {
beforeServeHttpTest(t)
defer afterServeHttpTest(t)
fileserver := FileServer(http.Dir(testDir), []string{"hidden.html"})
movedPermanently := "Moved Permanently"
tests := []struct {
url string
expectedStatus int
expectedBodyContent string
}{
// Test 0 - access withoutt any path
{
url: "https://foo",
expectedStatus: http.StatusNotFound,
},
// Test 1 - access root (without index.html)
{
url: "https://foo/",
expectedStatus: http.StatusNotFound,
},
// Test 2 - access existing file
{
url: "https://foo/file1.html",
expectedStatus: http.StatusOK,
expectedBodyContent: testFiles["file1.html"],
},
// Test 3 - access folder with index file with trailing slash
{
url: "https://foo/dirwithindex/",
expectedStatus: http.StatusOK,
expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")],
},
// Test 4 - access folder with index file without trailing slash
{
url: "https://foo/dirwithindex",
expectedStatus: http.StatusMovedPermanently,
expectedBodyContent: movedPermanently,
},
// Test 5 - access folder without index file
{
url: "https://foo/dir/",
expectedStatus: http.StatusNotFound,
},
// Test 6 - access folder withtout trailing slash
{
url: "https://foo/dir",
expectedStatus: http.StatusMovedPermanently,
expectedBodyContent: movedPermanently,
},
// Test 6 - access file with trailing slash
{
url: "https://foo/file1.html/",
expectedStatus: http.StatusMovedPermanently,
expectedBodyContent: movedPermanently,
},
// Test 7 - access not existing path
{
url: "https://foo/not_existing",
expectedStatus: http.StatusNotFound,
},
// Test 8 - access a file, marked as hidden
{
url: "https://foo/dir/hidden.html",
expectedStatus: http.StatusNotFound,
},
// Test 9 - access a index file directly
{
url: "https://foo/dirwithindex/index.html",
expectedStatus: http.StatusOK,
expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")],
},
// Test 10 - send a request with query params
{
url: "https://foo/dir?param1=val",
expectedStatus: http.StatusMovedPermanently,
expectedBodyContent: movedPermanently,
},
}
for i, test := range tests {
responseRecorder := httptest.NewRecorder()
request, err := http.NewRequest("GET", test.url, strings.NewReader(""))
status, err := fileserver.ServeHTTP(responseRecorder, request)
// check if error matches expectations
if err != nil {
t.Errorf(getTestPrefix(i)+"Serving file at %s failed. Error was: %v", test.url, err)
}
// check status code
if test.expectedStatus != status {
t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
}
// check body content
if !strings.Contains(responseRecorder.Body.String(), test.expectedBodyContent) {
t.Errorf(getTestPrefix(i)+"Expected body to contain %q, found %q", test.expectedBodyContent, responseRecorder.Body.String())
}
}
}
// beforeServeHttpTest creates a test directory with the structure, defined in the variable testFiles
func beforeServeHttpTest(t *testing.T) {
// make the root test dir
err := os.Mkdir(testDir, os.ModePerm)
if err != nil {
if !os.IsExist(err) {
t.Fatalf("Failed to create test dir. Error was: %v", err)
return
}
}
for relFile, fileContent := range testFiles {
absFile := filepath.Join(testDir, relFile)
// make sure the parent directories exist
parentDir := filepath.Dir(absFile)
_, err = os.Stat(parentDir)
if err != nil {
os.MkdirAll(parentDir, os.ModePerm)
}
// now create the test files
f, err := os.Create(absFile)
if err != nil {
t.Fatalf("Failed to create test file %s. Error was: %v", absFile, err)
return
}
// and fill them with content
_, err = f.WriteString(fileContent)
if err != nil {
t.Fatalf("Failed to write to %s. Error was: %v", absFile, err)
return
}
f.Close()
}
}
// afterServeHttpTest removes the test dir and all its content
func afterServeHttpTest(t *testing.T) {
// cleans up everything under the test dir. No need to clean the individual files.
err := os.RemoveAll(testDir)
if err != nil {
t.Fatalf("Failed to clean up test dir %s. Error was: %v", testDir, err)
}
}
// failingFS implements the http.FileSystem interface. The Open method always returns the error, assigned to err
type failingFS struct {
err error // the error to return when Open is called
fileImpl http.File // inject the file implementation
}
// Open returns the assigned failingFile and error
func (f failingFS) Open(path string) (http.File, error) {
return f.fileImpl, f.err
}
// failingFile implements http.File but returns a predefined error on every Stat() method call.
type failingFile struct {
http.File
err error
}
// Stat returns nil FileInfo and the provided error on every call
func (ff failingFile) Stat() (os.FileInfo, error) {
return nil, ff.err
}
// Close is noop and returns no error
func (ff failingFile) Close() error {
return nil
}
// TestServeHTTPFailingFS tests error cases where the Open function fails with various errors.
func TestServeHTTPFailingFS(t *testing.T) {
tests := []struct {
fsErr error
expectedStatus int
expectedErr error
expectedHeaders map[string]string
}{
{
fsErr: os.ErrNotExist,
expectedStatus: http.StatusNotFound,
expectedErr: nil,
},
{
fsErr: os.ErrPermission,
expectedStatus: http.StatusForbidden,
expectedErr: os.ErrPermission,
},
{
fsErr: customErr,
expectedStatus: http.StatusServiceUnavailable,
expectedErr: customErr,
expectedHeaders: map[string]string{"Retry-After": "5"},
},
}
for i, test := range tests {
// initialize a file server with the failing FileSystem
fileserver := FileServer(failingFS{err: test.fsErr}, nil)
// prepare the request and response
request, err := http.NewRequest("GET", "https://foo/", nil)
if err != nil {
t.Fatalf("Failed to build request. Error was: %v", err)
}
responseRecorder := httptest.NewRecorder()
status, actualErr := fileserver.ServeHTTP(responseRecorder, request)
// check the status
if status != test.expectedStatus {
t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
}
// check the error
if actualErr != test.expectedErr {
t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr)
}
// check the headers - a special case for server under load
if test.expectedHeaders != nil && len(test.expectedHeaders) > 0 {
for expectedKey, expectedVal := range test.expectedHeaders {
actualVal := responseRecorder.Header().Get(expectedKey)
if expectedVal != actualVal {
t.Errorf(getTestPrefix(i)+"Expected header %s: %s, found %s", expectedKey, expectedVal, actualVal)
}
}
}
}
}
// TestServeHTTPFailingStat tests error cases where the initial Open function succeeds, but the Stat method on the opened file fails.
func TestServeHTTPFailingStat(t *testing.T) {
tests := []struct {
statErr error
expectedStatus int
expectedErr error
}{
{
statErr: os.ErrNotExist,
expectedStatus: http.StatusNotFound,
expectedErr: nil,
},
{
statErr: os.ErrPermission,
expectedStatus: http.StatusForbidden,
expectedErr: os.ErrPermission,
},
{
statErr: customErr,
expectedStatus: http.StatusInternalServerError,
expectedErr: customErr,
},
}
for i, test := range tests {
// initialize a file server. The FileSystem will not fail, but calls to the Stat method of the returned File object will
fileserver := FileServer(failingFS{err: nil, fileImpl: failingFile{err: test.statErr}}, nil)
// prepare the request and response
request, err := http.NewRequest("GET", "https://foo/", nil)
if err != nil {
t.Fatalf("Failed to build request. Error was: %v", err)
}
responseRecorder := httptest.NewRecorder()
status, actualErr := fileserver.ServeHTTP(responseRecorder, request)
// check the status
if status != test.expectedStatus {
t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
}
// check the error
if actualErr != test.expectedErr {
t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr)
}
}
}

View file

@ -172,7 +172,7 @@ func reader(conn *websocket.Conn, stdout io.ReadCloser, stdin io.WriteCloser) {
conn.SetReadDeadline(time.Now().Add(pongWait))
conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
tickerChan := make(chan bool)
defer func() { tickerChan <- true }() // make sure to close the ticker when we are done.
defer close(tickerChan) // make sure to close the ticker when we are done.
go ticker(conn, tickerChan)
for {
@ -213,10 +213,7 @@ func reader(conn *websocket.Conn, stdout io.ReadCloser, stdin io.WriteCloser) {
// between the server and client to keep it alive with ping messages.
func ticker(conn *websocket.Conn, c chan bool) {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
close(c)
}()
defer ticker.Stop()
for { // blocking loop with select to wait for stimulation.
select {

View file

@ -86,7 +86,7 @@ func (s *Server) Serve() error {
go func(vh virtualHost) {
// Wait for signal
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt, os.Kill)
signal.Notify(interrupt, os.Interrupt, os.Kill) // TODO: syscall.SIGQUIT? (Ctrl+\, Unix-only)
<-interrupt
// Run callbacks