diff --git a/caddy/caddy.go b/caddy/caddy.go index 4c6f2237b..734e984d1 100644 --- a/caddy/caddy.go +++ b/caddy/caddy.go @@ -1,4 +1,5 @@ -// Package caddy implements the Caddy web server as a service. +// Package caddy implements the Caddy web server as a service +// in your own Go programs. // // To use this package, follow a few simple steps: // @@ -190,7 +191,8 @@ func startServers(groupings bindingGroup) error { if err != nil { return err } - s.HTTP2 = HTTP2 // TODO: This setting is temporary + s.HTTP2 = HTTP2 // TODO: This setting is temporary + s.ReqCallback = letsencrypt.RequestCallback // ensures we can solve ACME challenges while running var ln server.ListenerFile if IsRestart() { diff --git a/caddy/caddyfile/json.go b/caddy/caddyfile/json.go index fb04b5565..e1213c27d 100644 --- a/caddy/caddyfile/json.go +++ b/caddy/caddyfile/json.go @@ -28,7 +28,7 @@ func ToJSON(caddyfile []byte) ([]byte, error) { // Fill up host list for _, host := range sb.HostList() { - block.Hosts = append(block.Hosts, strings.TrimSuffix(host, ":")) + block.Hosts = append(block.Hosts, standardizeScheme(host)) } // Extract directives deterministically by sorting them @@ -62,7 +62,6 @@ func ToJSON(caddyfile []byte) ([]byte, error) { // but only one line at a time, to be used at the top-level of // a server block only (where the first token on each line is a // directive) - not to be used at any other nesting level. -// goes to end of line func constructLine(d *parse.Dispenser) []interface{} { var args []interface{} @@ -80,8 +79,8 @@ func constructLine(d *parse.Dispenser) []interface{} { } // constructBlock recursively processes tokens into a -// JSON-encodable structure. -// goes to end of block +// JSON-encodable structure. To be used in a directive's +// block. Goes to end of block. func constructBlock(d *parse.Dispenser) [][]interface{} { block := [][]interface{}{} @@ -110,15 +109,10 @@ func FromJSON(jsonBytes []byte) ([]byte, error) { result += "\n\n" } for i, host := range sb.Hosts { - if hostname, port, err := net.SplitHostPort(host); err == nil { - if port == "http" || port == "https" { - host = port + "://" + hostname - } - } if i > 0 { result += ", " } - result += strings.TrimSuffix(host, ":") + result += standardizeScheme(host) } result += jsonToText(sb.Body, 1) } @@ -170,6 +164,17 @@ func jsonToText(scope interface{}, depth int) string { return result } +// standardizeScheme turns an address like host:https into https://host, +// or "host:" into "host". +func standardizeScheme(addr string) string { + if hostname, port, err := net.SplitHostPort(addr); err == nil { + if port == "http" || port == "https" { + addr = port + "://" + hostname + } + } + return strings.TrimSuffix(addr, ":") +} + // Caddyfile encapsulates a slice of ServerBlocks. type Caddyfile []ServerBlock diff --git a/caddy/caddyfile/json_test.go b/caddy/caddyfile/json_test.go index 024792638..2e44ae2a2 100644 --- a/caddy/caddyfile/json_test.go +++ b/caddy/caddyfile/json_test.go @@ -63,7 +63,7 @@ baz" { // 8 caddyfile: `http://host, https://host { }`, - json: `[{"hosts":["host:http","host:https"],"body":[]}]`, // hosts in JSON are always host:port format (if port is specified), for consistency + json: `[{"hosts":["http://host","https://host"],"body":[]}]`, // hosts in JSON are always host:port format (if port is specified), for consistency }, { // 9 caddyfile: `host { @@ -124,3 +124,38 @@ func TestFromJSON(t *testing.T) { } } } + +func TestStandardizeAddress(t *testing.T) { + // host:https should be converted to https://host + output, err := ToJSON([]byte(`host:https`)) + if err != nil { + t.Fatal(err) + } + if expected, actual := `[{"hosts":["https://host"],"body":[]}]`, string(output); expected != actual { + t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual) + } + + output, err = FromJSON([]byte(`[{"hosts":["https://host"],"body":[]}]`)) + if err != nil { + t.Fatal(err) + } + if expected, actual := "https://host {\n}", string(output); expected != actual { + t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual) + } + + // host: should be converted to just host + output, err = ToJSON([]byte(`host:`)) + if err != nil { + t.Fatal(err) + } + if expected, actual := `[{"hosts":["host"],"body":[]}]`, string(output); expected != actual { + t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual) + } + output, err = FromJSON([]byte(`[{"hosts":["host:"],"body":[]}]`)) + if err != nil { + t.Fatal(err) + } + if expected, actual := "host {\n}", string(output); expected != actual { + t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual) + } +} diff --git a/caddy/config.go b/caddy/config.go index e29826751..3ff63b481 100644 --- a/caddy/config.go +++ b/caddy/config.go @@ -21,25 +21,22 @@ const ( DefaultConfigFile = "Caddyfile" ) -// loadConfigs reads input (named filename) and parses it, returning the -// server configurations in the order they appeared in the input. As part -// of this, it activates Let's Encrypt for the configs that are produced. -// Thus, the returned configs are already optimally configured optimally -// for HTTPS. -func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { +// loadConfigsUpToIncludingTLS loads the configs from input with name filename and returns them, +// the parsed server blocks, the index of the last directive it processed, and an error (if any). +func loadConfigsUpToIncludingTLS(filename string, input io.Reader) ([]server.Config, []parse.ServerBlock, int, error) { var configs []server.Config // Each server block represents similar hosts/addresses, since they // were grouped together in the Caddyfile. serverBlocks, err := parse.ServerBlocks(filename, input, true) if err != nil { - return nil, err + return nil, nil, 0, err } if len(serverBlocks) == 0 { newInput := DefaultInput() serverBlocks, err = parse.ServerBlocks(newInput.Path(), bytes.NewReader(newInput.Body()), true) if err != nil { - return nil, err + return nil, nil, 0, err } } @@ -56,6 +53,7 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { config := server.Config{ Host: addr.Host, Port: addr.Port, + Scheme: addr.Scheme, Root: Root, Middleware: make(map[string][]middleware.Middleware), ConfigFile: filename, @@ -88,7 +86,7 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { // execute setup function and append middleware handler, if any midware, err := dir.setup(controller) if err != nil { - return nil, err + return nil, nil, lastDirectiveIndex, err } if midware != nil { // TODO: For now, we only support the default path scope / @@ -109,22 +107,31 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { } } + return configs, serverBlocks, lastDirectiveIndex, nil +} + +// loadConfigs reads input (named filename) and parses it, returning the +// server configurations in the order they appeared in the input. As part +// of this, it activates Let's Encrypt for the configs that are produced. +// Thus, the returned configs are already optimally configured for HTTPS. +func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { + configs, serverBlocks, lastDirectiveIndex, err := loadConfigsUpToIncludingTLS(filename, input) + if err != nil { + return nil, err + } + // Now we have all the configs, but they have only been set up to the // point of tls. We need to activate Let's Encrypt before setting up // the rest of the middlewares so they have correct information regarding - // TLS configuration, if necessary. (this call is append-only, so our - // iterations below shouldn't be affected) + // TLS configuration, if necessary. (this only appends, so our iterations + // over server blocks below shouldn't be affected) if !IsRestart() && !Quiet { fmt.Print("Activating privacy features...") } configs, err = letsencrypt.Activate(configs) if err != nil { - if !Quiet { - fmt.Println() - } return nil, err - } - if !IsRestart() && !Quiet { + } else if !IsRestart() && !Quiet { fmt.Println(" done.") } @@ -277,44 +284,17 @@ func arrangeBindings(allConfigs []server.Config) (bindingGroup, error) { // but execution may continue. The second error, if not nil, is a real // problem and the server should not be started. // -// This function handles edge cases gracefully. If a port name like -// "http" or "https" is unknown to the system, this function will -// change them to 80 or 443 respectively. If a hostname fails to -// resolve, that host can still be served but will be listening on -// the wildcard host instead. This function takes care of this for you. +// This function does not handle edge cases like port "http" or "https" if +// they are not known to the system. It does, however, serve on the wildcard +// host if resolving the address of the specific hostname fails. func resolveAddr(conf server.Config) (resolvAddr *net.TCPAddr, warnErr, fatalErr error) { - bindHost := conf.BindHost - - // TODO: Do we even need the port? Maybe we just need to look up the host. - resolvAddr, warnErr = net.ResolveTCPAddr("tcp", net.JoinHostPort(bindHost, conf.Port)) + resolvAddr, warnErr = net.ResolveTCPAddr("tcp", net.JoinHostPort(conf.BindHost, conf.Port)) if warnErr != nil { - // Most likely the host lookup failed or the port is unknown - tryPort := conf.Port - - switch errVal := warnErr.(type) { - case *net.AddrError: - if errVal.Err == "unknown port" { - // some odd Linux machines don't support these port names; see issue #136 - switch conf.Port { - case "http": - tryPort = "80" - case "https": - tryPort = "443" - } - } - resolvAddr, fatalErr = net.ResolveTCPAddr("tcp", net.JoinHostPort(bindHost, tryPort)) - if fatalErr != nil { - return - } - default: - // the hostname probably couldn't be resolved, just bind to wildcard then - resolvAddr, fatalErr = net.ResolveTCPAddr("tcp", net.JoinHostPort("0.0.0.0", tryPort)) - if fatalErr != nil { - return - } + // the hostname probably couldn't be resolved, just bind to wildcard then + resolvAddr, fatalErr = net.ResolveTCPAddr("tcp", net.JoinHostPort("", conf.Port)) + if fatalErr != nil { + return } - - return } return @@ -334,12 +314,12 @@ func validDirective(d string) bool { // DefaultInput returns the default Caddyfile input // to use when it is otherwise empty or missing. // It uses the default host and port (depends on -// host, e.g. localhost is 2015, otherwise https) and +// host, e.g. localhost is 2015, otherwise 443) and // root. func DefaultInput() CaddyfileInput { port := Port - if letsencrypt.HostQualifies(Host) { - port = "https" + if letsencrypt.HostQualifies(Host) && port == DefaultPort { + port = "443" } return CaddyfileInput{ Contents: []byte(fmt.Sprintf("%s:%s\nroot %s", Host, port, Root)), diff --git a/caddy/config_test.go b/caddy/config_test.go index 3e70a9311..f5f0db6c2 100644 --- a/caddy/config_test.go +++ b/caddy/config_test.go @@ -13,10 +13,10 @@ func TestDefaultInput(t *testing.T) { t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) } - // next few tests simulate user providing -host flag + // next few tests simulate user providing -host and/or -port flags Host = "not-localhost.com" - if actual, expected := string(DefaultInput().Body()), "not-localhost.com:https\nroot ."; actual != expected { + if actual, expected := string(DefaultInput().Body()), "not-localhost.com:443\nroot ."; actual != expected { t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) } @@ -29,6 +29,18 @@ func TestDefaultInput(t *testing.T) { if actual, expected := string(DefaultInput().Body()), "127.0.1.1:2015\nroot ."; actual != expected { t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) } + + Host = "not-localhost.com" + Port = "1234" + if actual, expected := string(DefaultInput().Body()), "not-localhost.com:1234\nroot ."; actual != expected { + t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) + } + + Host = DefaultHost + Port = "1234" + if actual, expected := string(DefaultInput().Body()), ":1234\nroot ."; actual != expected { + t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) + } } func TestResolveAddr(t *testing.T) { @@ -51,14 +63,14 @@ func TestResolveAddr(t *testing.T) { {server.Config{Host: "localhost", Port: "80"}, false, false, "", 80}, {server.Config{BindHost: "localhost", Port: "1234"}, false, false, "127.0.0.1", 1234}, {server.Config{BindHost: "127.0.0.1", Port: "1234"}, false, false, "127.0.0.1", 1234}, - {server.Config{BindHost: "should-not-resolve", Port: "1234"}, true, false, "0.0.0.0", 1234}, + {server.Config{BindHost: "should-not-resolve", Port: "1234"}, true, false, "", 1234}, {server.Config{BindHost: "localhost", Port: "http"}, false, false, "127.0.0.1", 80}, {server.Config{BindHost: "localhost", Port: "https"}, false, false, "127.0.0.1", 443}, {server.Config{BindHost: "", Port: "1234"}, false, false, "", 1234}, {server.Config{BindHost: "localhost", Port: "abcd"}, false, true, "", 0}, {server.Config{BindHost: "127.0.0.1", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234}, {server.Config{BindHost: "localhost", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234}, - {server.Config{BindHost: "should-not-resolve", Host: "localhost", Port: "1234"}, true, false, "0.0.0.0", 1234}, + {server.Config{BindHost: "should-not-resolve", Host: "localhost", Port: "1234"}, true, false, "", 1234}, } { actualAddr, warnErr, fatalErr := resolveAddr(test.config) diff --git a/caddy/letsencrypt/crypto_test.go b/caddy/letsencrypt/crypto_test.go index 7f791a6c3..ca81efd68 100644 --- a/caddy/letsencrypt/crypto_test.go +++ b/caddy/letsencrypt/crypto_test.go @@ -40,12 +40,12 @@ func TestSaveAndLoadRSAPrivateKey(t *testing.T) { } } -// rsaPrivateKeyBytes returns the bytes of DER-encoded key. -func rsaPrivateKeyBytes(key *rsa.PrivateKey) []byte { - return x509.MarshalPKCS1PrivateKey(key) -} - // rsaPrivateKeysSame compares the bytes of a and b and returns true if they are the same. func rsaPrivateKeysSame(a, b *rsa.PrivateKey) bool { return bytes.Equal(rsaPrivateKeyBytes(a), rsaPrivateKeyBytes(b)) } + +// rsaPrivateKeyBytes returns the bytes of DER-encoded key. +func rsaPrivateKeyBytes(key *rsa.PrivateKey) []byte { + return x509.MarshalPKCS1PrivateKey(key) +} diff --git a/caddy/letsencrypt/handler.go b/caddy/letsencrypt/handler.go index 6c9f962dd..e147e00c8 100644 --- a/caddy/letsencrypt/handler.go +++ b/caddy/letsencrypt/handler.go @@ -2,30 +2,21 @@ package letsencrypt import ( "crypto/tls" + "log" "net" "net/http" "net/http/httputil" "net/url" "strings" - - "github.com/mholt/caddy/middleware" ) const challengeBasePath = "/.well-known/acme-challenge" -// Handler is a Caddy middleware that can proxy ACME challenge -// requests to the real ACME client endpoint. This is necessary -// to renew certificates while the server is running. -type Handler struct { - Next middleware.Handler - //ChallengeActive int32 // (TODO) use sync/atomic to set/get this flag safely and efficiently -} - -// ServeHTTP is basically a no-op unless an ACME challenge is active on this host -// and the request path matches the expected path exactly. -func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - // Proxy challenge requests to ACME client - // TODO: Only do this if a challenge is active? +// RequestCallback proxies challenge requests to ACME client if the +// request path starts with challengeBasePath. It returns true if it +// handled the request and no more needs to be done; it returns false +// if this call was a no-op and the request still needs handling. +func RequestCallback(w http.ResponseWriter, r *http.Request) bool { if strings.HasPrefix(r.URL.Path, challengeBasePath) { scheme := "http" if r.TLS != nil { @@ -37,9 +28,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) hostname = r.URL.Host } - upstream, err := url.Parse(scheme + "://" + hostname + ":" + alternatePort) + upstream, err := url.Parse(scheme + "://" + hostname + ":" + AlternatePort) if err != nil { - return http.StatusInternalServerError, err + w.WriteHeader(http.StatusInternalServerError) + log.Printf("[ERROR] letsencrypt handler: %v", err) + return true } proxy := httputil.NewSingleHostReverseProxy(upstream) @@ -48,8 +41,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) } proxy.ServeHTTP(w, r) - return 0, nil + return true } - return h.Next.ServeHTTP(w, r) + return false } diff --git a/caddy/letsencrypt/handler_test.go b/caddy/letsencrypt/handler_test.go new file mode 100644 index 000000000..ac6f48001 --- /dev/null +++ b/caddy/letsencrypt/handler_test.go @@ -0,0 +1,63 @@ +package letsencrypt + +import ( + "net" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRequestCallbackNoOp(t *testing.T) { + // try base paths that aren't handled by this handler + for _, url := range []string{ + "http://localhost/", + "http://localhost/foo.html", + "http://localhost/.git", + "http://localhost/.well-known/", + "http://localhost/.well-known/acme-challenging", + } { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("Could not craft request, got error: %v", err) + } + rw := httptest.NewRecorder() + if RequestCallback(rw, req) { + t.Errorf("Got true with this URL, but shouldn't have: %s", url) + } + } +} + +func TestRequestCallbackSuccess(t *testing.T) { + expectedPath := challengeBasePath + "/asdf" + + // Set up fake acme handler backend to make sure proxying succeeds + var proxySuccess bool + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxySuccess = true + if r.URL.Path != expectedPath { + t.Errorf("Expected path '%s' but got '%s' instead", expectedPath, r.URL.Path) + } + })) + + // Custom listener that uses the port we expect + ln, err := net.Listen("tcp", "127.0.0.1:"+AlternatePort) + if err != nil { + t.Fatalf("Unable to start test server listener: %v", err) + } + ts.Listener = ln + + // Start our engines and run the test + ts.Start() + defer ts.Close() + req, err := http.NewRequest("GET", "http://127.0.0.1:"+AlternatePort+expectedPath, nil) + if err != nil { + t.Fatalf("Could not craft request, got error: %v", err) + } + rw := httptest.NewRecorder() + + RequestCallback(rw, req) + + if !proxySuccess { + t.Fatal("Expected request to be proxied, but it wasn't") + } +} diff --git a/caddy/letsencrypt/letsencrypt.go b/caddy/letsencrypt/letsencrypt.go index f6b55a5ff..cc7aa9d8c 100644 --- a/caddy/letsencrypt/letsencrypt.go +++ b/caddy/letsencrypt/letsencrypt.go @@ -7,11 +7,14 @@ import ( "encoding/json" "errors" "io/ioutil" + "net" "net/http" "os" "strings" "time" + "golang.org/x/crypto/ocsp" + "github.com/mholt/caddy/caddy/setup" "github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware/redirect" @@ -20,10 +23,9 @@ import ( ) // Activate sets up TLS for each server config in configs -// as needed. It only skips the config if the cert and key -// are already provided, if plaintext http is explicitly -// specified as the port, TLS is explicitly disabled, or -// the host looks like a loopback or wildcard address. +// as needed; this consists of acquiring and maintaining +// certificates and keys for qualifying configs and enabling +// OCSP stapling for all TLS-enabled configs. // // This function may prompt the user to provide an email // address if none is available through other means. It @@ -46,113 +48,35 @@ func Activate(configs []server.Config) ([]server.Config, error) { // just in case previous caller forgot... Deactivate() - // reset cached ocsp statuses from any previous activations - ocspStatus = make(map[*[]byte]int) + // reset cached ocsp from any previous activations + ocspCache = make(map[*[]byte]*ocsp.Response) - // Identify and configure any eligible hosts for which - // we already have certs and keys in storage from last time. - configLen := len(configs) // avoid infinite loop since this loop appends plaintext to the slice - for i := 0; i < configLen; i++ { - if existingCertAndKey(configs[i].Host) && configQualifies(configs, i) { - configs = autoConfigure(configs, i) - } - } + // pre-screen each config and earmark the ones that qualify for managed TLS + MarkQualified(configs) - // Group configs by email address; only configs that are eligible - // for TLS management are included. We group by email so that we - // can request certificates in batches with the same client. - // Note: The return value is a map, and iteration over a map is - // not ordered. I don't think it will be a problem, but if an - // ordering problem arises, look at this carefully. - groupedConfigs, err := groupConfigsByEmail(configs) + // place certificates and keys on disk + err := ObtainCerts(configs, "") if err != nil { return configs, err } - // obtain certificates for configs that need one, and reconfigure each - // config to use the certificates - for leEmail, cfgIndexes := range groupedConfigs { - // make client to service this email address with CA server - client, err := newClient(leEmail) - if err != nil { - return configs, errors.New("error creating client: " + err.Error()) - } + // update TLS configurations + EnableTLS(configs) - // little bit of housekeeping; gather the hostnames into a slice - var hosts []string - for _, idx := range cfgIndexes { - // don't allow duplicates (happens when serving same host on multiple ports!) - var duplicate bool - for _, otherHost := range hosts { - if configs[idx].Host == otherHost { - duplicate = true - break - } - } - if !duplicate { - hosts = append(hosts, configs[idx].Host) - } - } - - // client is ready, so let's get free, trusted SSL certificates! - Obtain: - certificates, failures := client.ObtainCertificates(hosts, true) - if len(failures) > 0 { - // Build an error string to return, using all the failures in the list. - var errMsg string - - // If an error is because of updated SA, only prompt user for agreement once - var promptedForAgreement bool - - for domain, obtainErr := range failures { - // If the failure was simply because the terms have changed, re-prompt and re-try - if tosErr, ok := obtainErr.(acme.TOSError); ok { - if !Agreed && !promptedForAgreement { - Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL - promptedForAgreement = true - } - if Agreed { - err := client.AgreeToTOS() - if err != nil { - return configs, errors.New("error agreeing to updated terms: " + err.Error()) - } - goto Obtain - } - } - - // If user did not agree or it was any other kind of error, just append to the list of errors - errMsg += "[" + domain + "] failed to get certificate: " + obtainErr.Error() + "\n" - } - - // Save the certs we did obtain, though, before leaving - if err := saveCertsAndKeys(certificates); err == nil { - if len(certificates) > 0 { - var certList []string - for _, cert := range certificates { - certList = append(certList, cert.Domain) - } - errMsg += "Saved certificates for: " + strings.Join(certList, ", ") + "\n" - } - } else { - errMsg += "Unable to save obtained certificates: " + err.Error() + "\n" - } - - return configs, errors.New(errMsg) - } - - // ... that's it. save the certs, keys, and metadata files to disk - err = saveCertsAndKeys(certificates) - if err != nil { - return configs, errors.New("error saving assets: " + err.Error()) - } - - // it all comes down to this: turning on TLS with all the new certs - for _, idx := range cfgIndexes { - configs = autoConfigure(configs, idx) - } + // enable OCSP stapling (this affects all TLS-enabled configs) + err = StapleOCSP(configs) + if err != nil { + return configs, err } - // renew all certificates that need renewal + // set up redirects + configs = MakePlaintextRedirects(configs) + + // renew all relevant certificates that need renewal. this is important + // to do right away for a couple reasons, mainly because each restart, + // the renewal ticker is reset, so if restarts happen more often than + // the ticker interval, renewals would never happen. but doing + // it right away at start guarantees that renewals aren't missed. renewCertificates(configs, false) // keep certificates renewed and OCSP stapling updated @@ -176,23 +100,191 @@ func Deactivate() (err error) { return } -// configQualifies returns true if the config at cfgIndex (within allConfigs) -// qualifes for automatic LE activation. It does NOT check to see if a cert -// and key already exist for the config. -func configQualifies(allConfigs []server.Config, cfgIndex int) bool { - cfg := allConfigs[cfgIndex] +// MarkQualified scans each config and, if it qualifies for managed +// TLS, it sets the Marked field of the TLSConfig to true. +func MarkQualified(configs []server.Config) { + for i := 0; i < len(configs); i++ { + if ConfigQualifies(configs[i]) { + configs[i].TLS.Managed = true + } + } +} + +// ObtainCerts obtains certificates for all these configs as long as a certificate does not +// already exist on disk. It does not modify the configs at all; it only obtains and stores +// certificates and keys to the disk. +// +// TODO: Right now by potentially prompting about ToS error, we assume this function is only +// called at startup, but that is not always the case because it could be during a restart. +func ObtainCerts(configs []server.Config, optPort string) error { + groupedConfigs := groupConfigsByEmail(configs) + + for email, group := range groupedConfigs { + client, err := newClientPort(email, optPort) + if err != nil { + return errors.New("error creating client: " + err.Error()) + } + + for _, cfg := range group { + if existingCertAndKey(cfg.Host) { + continue + } + + Obtain: + certificate, failures := client.ObtainCertificate([]string{cfg.Host}, true, nil) + if len(failures) == 0 { + // Success - immediately save the certificate resource + err := saveCertResource(certificate) + if err != nil { + return errors.New("error saving assets for " + cfg.Host + ": " + err.Error()) + } + } else { + // Error - either try to fix it or report them it to the user and abort + var errMsg string // we'll combine all the failures into a single error message + var promptedForAgreement bool // only prompt user for agreement at most once + + for errDomain, obtainErr := range failures { + // TODO: Double-check, will obtainErr ever be nil? + if tosErr, ok := obtainErr.(acme.TOSError); ok { + // Terms of Service agreement error; we can probably deal with this + if !Agreed && !promptedForAgreement { + Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL + promptedForAgreement = true + } + if Agreed { + err := client.AgreeToTOS() + if err != nil { + return errors.New("error agreeing to updated terms: " + err.Error()) + } + goto Obtain + } + } + + // If user did not agree or it was any other kind of error, just append to the list of errors + errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n" + } + + return errors.New(errMsg) + } + } + } + + return nil +} + +// groupConfigsByEmail groups configs by the email address to be used by its +// ACME client. It only includes configs that are marked as fully managed. +// This is the function that may prompt for an email address. +func groupConfigsByEmail(configs []server.Config) map[string][]server.Config { + initMap := make(map[string][]server.Config) + for _, cfg := range configs { + if !cfg.TLS.Managed { + continue + } + leEmail := getEmail(cfg) + initMap[leEmail] = append(initMap[leEmail], cfg) + } + return initMap +} + +// EnableTLS configures each config to use TLS according to default settings. +// It will only change configs that are marked as managed, and assumes that +// certificates and keys are already on disk. +func EnableTLS(configs []server.Config) { + for i := 0; i < len(configs); i++ { + if !configs[i].TLS.Managed { + continue + } + configs[i].TLS.Enabled = true + configs[i].TLS.Certificate = storage.SiteCertFile(configs[i].Host) + configs[i].TLS.Key = storage.SiteKeyFile(configs[i].Host) + setup.SetDefaultTLSParams(&configs[i]) + } +} + +// StapleOCSP staples OCSP responses to each config according to their certificate. +// This should work for any TLS-enabled config, not just Let's Encrypt ones. +func StapleOCSP(configs []server.Config) error { + for i := 0; i < len(configs); i++ { + if configs[i].TLS.Certificate == "" { + continue + } + + bundleBytes, err := ioutil.ReadFile(configs[i].TLS.Certificate) + if err != nil { + return errors.New("load certificate to staple ocsp: " + err.Error()) + } + + ocspBytes, ocspResp, err := acme.GetOCSPForCert(bundleBytes) + if err == nil { + // TODO: We ignore the error if it exists because some certificates + // may not have an issuer URL which we should ignore anyway, and + // sometimes we get syntax errors in the responses. To reproduce this + // behavior, start Caddy with an empty Caddyfile and -log stderr. Then + // add a host to the Caddyfile which requires a new LE certificate. + // Reload Caddy's config with SIGUSR1, and see the log report that it + // obtains the certificate, but then an error: + // getting ocsp: asn1: syntax error: sequence truncated + // But retrying the reload again sometimes solves the problem. It's flaky... + ocspCache[&bundleBytes] = ocspResp + if ocspResp.Status == ocsp.Good { + configs[i].TLS.OCSPStaple = ocspBytes + } + } + } + return nil +} + +// hostHasOtherPort returns true if there is another config in the list with the same +// hostname that has port otherPort, or false otherwise. All the configs are checked +// against the hostname of allConfigs[thisConfigIdx]. +func hostHasOtherPort(allConfigs []server.Config, thisConfigIdx int, otherPort string) bool { + for i, otherCfg := range allConfigs { + if i == thisConfigIdx { + continue // has to be a config OTHER than the one we're comparing against + } + if otherCfg.Host == allConfigs[thisConfigIdx].Host && otherCfg.Port == otherPort { + return true + } + } + return false +} + +// MakePlaintextRedirects sets up redirects from port 80 to the relevant HTTPS +// hosts. You must pass in all configs, not just configs that qualify, since +// we must know whether the same host already exists on port 80, and those would +// not be in a list of configs that qualify for automatic HTTPS. This function will +// only set up redirects for configs that qualify. It returns the updated list of +// all configs. +func MakePlaintextRedirects(allConfigs []server.Config) []server.Config { + for i, cfg := range allConfigs { + if cfg.TLS.Managed && + !hostHasOtherPort(allConfigs, i, "80") && + (cfg.Port == "443" || !hostHasOtherPort(allConfigs, i, "443")) { + allConfigs = append(allConfigs, redirPlaintextHost(cfg)) + } + } + return allConfigs +} + +// ConfigQualifies returns true if cfg qualifies for +// fully managed TLS. It does NOT check to see if a +// cert and key already exist for the config. If the +// config does qualify, you should set cfg.TLS.Managed +// to true and use that instead, because the process of +// setting up the config may make it look like it +// doesn't qualify even though it originally did. +func ConfigQualifies(cfg server.Config) bool { return cfg.TLS.Certificate == "" && // user could provide their own cert and key cfg.TLS.Key == "" && // user can force-disable automatic HTTPS for this host - cfg.Port != "http" && + cfg.Scheme != "http" && + cfg.Port != "80" && cfg.TLS.LetsEncryptEmail != "off" && - // obviously we get can't certs for loopback or internal hosts - HostQualifies(cfg.Host) && - - // make sure another HTTPS version of this config doesn't exist in the list already - !otherHostHasScheme(allConfigs, cfgIndex, "https") + // we get can't certs for some kinds of hostnames + HostQualifies(cfg.Host) } // HostQualifies returns true if the hostname alone @@ -201,36 +293,16 @@ func configQualifies(allConfigs []server.Config, cfgIndex int) bool { // not eligible because we cannot obtain certificates // for those names. func HostQualifies(hostname string) bool { - return hostname != "localhost" && - strings.TrimSpace(hostname) != "" && - hostname != "0.0.0.0" && - hostname != "[::]" && // before parsing - hostname != "::" && // after parsing - hostname != "[::1]" && // before parsing - hostname != "::1" && // after parsing - !strings.HasPrefix(hostname, "127.") // to use boulder on your own machine, add fake domain to hosts file - // not excluding 10.* and 192.168.* hosts for possibility of running internal Boulder instance -} + return hostname != "localhost" && // localhost is ineligible -// groupConfigsByEmail groups configs by user email address. The returned map is -// a map of email address to the configs that are serviced under that account. -// If an email address is not available for an eligible config, the user will be -// prompted to provide one. The returned map contains pointers to the original -// server config values. -func groupConfigsByEmail(configs []server.Config) (map[string][]int, error) { - initMap := make(map[string][]int) - for i := 0; i < len(configs); i++ { - // filter out configs that we already have certs for and - // that we won't be obtaining certs for - this way we won't - // bother the user for an email address unnecessarily and - // we don't obtain new certs for a host we already have certs for. - if existingCertAndKey(configs[i].Host) || !configQualifies(configs, i) { - continue - } - leEmail := getEmail(configs[i]) - initMap[leEmail] = append(initMap[leEmail], i) - } - return initMap, nil + // hostname must not be empty + strings.TrimSpace(hostname) != "" && + + // cannot be an IP address, see + // https://community.letsencrypt.org/t/certificate-for-static-ip/84/2?u=mholt + // (also trim [] from either end, since that special case can sneak through + // for IPv6 addresses using the -host flag and with empty/no Caddyfile) + net.ParseIP(strings.Trim(hostname, "[]")) == nil } // existingCertAndKey returns true if the host has a certificate @@ -268,10 +340,13 @@ func newClientPort(leEmail, port string) (*acme.Client, error) { } // The client facilitates our communication with the CA server. - client, err := acme.NewClient(CAUrl, &leUser, rsaKeySizeToUse, port) + client, err := acme.NewClient(CAUrl, &leUser, rsaKeySizeToUse) if err != nil { return nil, err } + client.SetHTTPAddress(":" + port) + client.SetTLSAddress(":" + port) + client.ExcludeChallenges([]string{"tls-sni-01", "dns-01"}) // We can only guarantee http-01 at this time // If not registered, the user must register an account with the CA // and agree to terms @@ -305,144 +380,47 @@ func newClientPort(leEmail, port string) (*acme.Client, error) { return client, nil } -// obtainCertificates obtains certificates from the CA server for -// the configurations in serverConfigs using client. -func obtainCertificates(client *acme.Client, serverConfigs []server.Config) ([]acme.CertificateResource, map[string]error) { - var hosts []string - for _, cfg := range serverConfigs { - hosts = append(hosts, cfg.Host) - } - return client.ObtainCertificates(hosts, true) -} - -// saveCertificates saves each certificate resource to disk. This +// saveCertResource saves the certificate resource to disk. This // includes the certificate file itself, the private key, and the // metadata file. -func saveCertsAndKeys(certificates []acme.CertificateResource) error { - for _, cert := range certificates { - err := os.MkdirAll(storage.Site(cert.Domain), 0700) - if err != nil { - return err - } - - // Save cert - err = ioutil.WriteFile(storage.SiteCertFile(cert.Domain), cert.Certificate, 0600) - if err != nil { - return err - } - - // Save private key - err = ioutil.WriteFile(storage.SiteKeyFile(cert.Domain), cert.PrivateKey, 0600) - if err != nil { - return err - } - - // Save cert metadata - jsonBytes, err := json.MarshalIndent(&cert, "", "\t") - if err != nil { - return err - } - err = ioutil.WriteFile(storage.SiteMetaFile(cert.Domain), jsonBytes, 0600) - if err != nil { - return err - } +func saveCertResource(cert acme.CertificateResource) error { + err := os.MkdirAll(storage.Site(cert.Domain), 0700) + if err != nil { + return err } + + // Save cert + err = ioutil.WriteFile(storage.SiteCertFile(cert.Domain), cert.Certificate, 0600) + if err != nil { + return err + } + + // Save private key + err = ioutil.WriteFile(storage.SiteKeyFile(cert.Domain), cert.PrivateKey, 0600) + if err != nil { + return err + } + + // Save cert metadata + jsonBytes, err := json.MarshalIndent(&cert, "", "\t") + if err != nil { + return err + } + err = ioutil.WriteFile(storage.SiteMetaFile(cert.Domain), jsonBytes, 0600) + if err != nil { + return err + } + return nil } -// autoConfigure enables TLS on allConfigs[cfgIndex] and appends, if necessary, -// a new config to allConfigs that redirects plaintext HTTP to its new HTTPS -// counterpart. It expects the certificate and key to already be in storage. It -// returns the new list of allConfigs, since it may append a new config. This -// function assumes that allConfigs[cfgIndex] is already set up for HTTPS. -func autoConfigure(allConfigs []server.Config, cfgIndex int) []server.Config { - cfg := &allConfigs[cfgIndex] - - bundleBytes, err := ioutil.ReadFile(storage.SiteCertFile(cfg.Host)) - // TODO: Handle these errors better - if err == nil { - ocsp, status, err := acme.GetOCSPForCert(bundleBytes) - ocspStatus[&bundleBytes] = status - if err == nil && status == acme.OCSPGood { - cfg.TLS.OCSPStaple = ocsp - } - } - cfg.TLS.Certificate = storage.SiteCertFile(cfg.Host) - cfg.TLS.Key = storage.SiteKeyFile(cfg.Host) - cfg.TLS.Enabled = true - // Ensure all defaults are set for the TLS config - setup.SetDefaultTLSParams(cfg) - - if cfg.Port == "" { - cfg.Port = "https" - } - - // Set up http->https redirect as long as there isn't already a http counterpart - // in the configs and this isn't, for some reason, already on port 80. - // Also, the port 80 variant of this config is necessary for proxying challenge requests. - if !otherHostHasScheme(allConfigs, cfgIndex, "http") && - cfg.Port != "80" && cfg.Port != "http" { // (would not be http port with current program flow, but just in case) - allConfigs = append(allConfigs, redirPlaintextHost(*cfg)) - } - - // To support renewals, we need handlers at ports 80 and 443, - // depending on the challenge type that is used to complete renewal. - for i, c := range allConfigs { - if c.Address() == cfg.Host+":80" || - c.Address() == cfg.Host+":443" || - c.Address() == cfg.Host+":http" || - c.Address() == cfg.Host+":https" { - - // Each virtualhost must have their own handlers, or the chaining gets messed up when middlewares are compiled! - handler := new(Handler) - mid := func(next middleware.Handler) middleware.Handler { - handler.Next = next - return handler - } - // TODO: Currently, acmeHandlers are not referenced, but we need to add a way to toggle - // their proxy functionality -- or maybe not. Gotta figure this out for sure. - acmeHandlers[c.Address()] = handler - - allConfigs[i].Middleware["/"] = append(allConfigs[i].Middleware["/"], mid) - } - } - - return allConfigs -} - -// otherHostHasScheme tells you whether there is ANOTHER config in allConfigs -// for the same host but with the port equal to scheme as allConfigs[cfgIndex]. -// This function considers "443" and "https" to be the same scheme, as well as -// "http" and "80". It does not tell you whether there is ANY config with scheme, -// only if there's a different one with it. -func otherHostHasScheme(allConfigs []server.Config, cfgIndex int, scheme string) bool { - if scheme == "80" { - scheme = "http" - } else if scheme == "443" { - scheme = "https" - } - for i, otherCfg := range allConfigs { - if i == cfgIndex { - continue // has to be a config OTHER than the one we're comparing against - } - if otherCfg.Host == allConfigs[cfgIndex].Host { - if (otherCfg.Port == scheme) || - (scheme == "https" && otherCfg.Port == "443") || - (scheme == "http" && otherCfg.Port == "80") { - return true - } - } - } - return false -} - // redirPlaintextHost returns a new plaintext HTTP configuration for // a virtualHost that simply redirects to cfg, which is assumed to // be the HTTPS configuration. The returned configuration is set -// to listen on the "http" port (port 80). +// to listen on port 80. func redirPlaintextHost(cfg server.Config) server.Config { toURL := "https://" + cfg.Host - if cfg.Port != "https" && cfg.Port != "http" { + if cfg.Port != "443" && cfg.Port != "80" { toURL += ":" + cfg.Port } @@ -460,7 +438,7 @@ func redirPlaintextHost(cfg server.Config) server.Config { return server.Config{ Host: cfg.Host, BindHost: cfg.BindHost, - Port: "http", + Port: "80", Middleware: map[string][]middleware.Middleware{ "/": []middleware.Middleware{redirMidware}, }, @@ -515,17 +493,17 @@ var ( // Some essential values related to the Let's Encrypt process const ( - // alternatePort is the port on which the acme client will open a + // AlternatePort is the port on which the acme client will open a // listener and solve the CA's challenges. If this alternate port // is used instead of the default port (80 or 443), then the // default port for the challenge must be forwarded to this one. - alternatePort = "5033" + AlternatePort = "5033" - // How often to check certificates for renewal. - renewInterval = 24 * time.Hour + // RenewInterval is how often to check certificates for renewal. + RenewInterval = 24 * time.Hour - // How often to update OCSP stapling. - ocspInterval = 1 * time.Hour + // OCSPInterval is how often to check if OCSP stapling needs updating. + OCSPInterval = 1 * time.Hour ) // KeySize represents the length of a key in bits. @@ -533,22 +511,22 @@ type KeySize int // Key sizes are used to determine the strength of a key. const ( - ECC_224 KeySize = 224 - ECC_256 = 256 - RSA_2048 = 2048 - RSA_4096 = 4096 + Ecc224 KeySize = 224 + Ecc256 = 256 + Rsa2048 = 2048 + Rsa4096 = 4096 ) // rsaKeySizeToUse is the size to use for new RSA keys. // This shouldn't need to change except for in tests; // the size can be drastically reduced for speed. -var rsaKeySizeToUse = RSA_2048 +var rsaKeySizeToUse = Rsa2048 // stopChan is used to signal the maintenance goroutine // to terminate. var stopChan chan struct{} -// ocspStatus maps certificate bundle to OCSP status at start. +// ocspCache maps certificate bundle to OCSP response. // It is used during regular OCSP checks to see if the OCSP -// status has changed. -var ocspStatus = make(map[*[]byte]int) +// response needs to be updated. +var ocspCache = make(map[*[]byte]*ocsp.Response) diff --git a/caddy/letsencrypt/letsencrypt_test.go b/caddy/letsencrypt/letsencrypt_test.go index 0547b9291..606e08a9f 100644 --- a/caddy/letsencrypt/letsencrypt_test.go +++ b/caddy/letsencrypt/letsencrypt_test.go @@ -1,11 +1,14 @@ package letsencrypt import ( + "io/ioutil" "net/http" + "os" "testing" "github.com/mholt/caddy/middleware/redirect" "github.com/mholt/caddy/server" + "github.com/xenolf/lego/acme" ) func TestHostQualifies(t *testing.T) { @@ -23,9 +26,11 @@ func TestHostQualifies(t *testing.T) { {"", false}, {" ", false}, {"0.0.0.0", false}, - {"192.168.1.3", true}, - {"10.0.2.1", true}, + {"192.168.1.3", false}, + {"10.0.2.1", false}, + {"169.112.53.4", false}, {"foobar.com", true}, + {"sub.foobar.com", true}, } { if HostQualifies(test.host) && !test.expect { t.Errorf("Test %d: Expected '%s' to NOT qualify, but it did", i, test.host) @@ -36,11 +41,37 @@ func TestHostQualifies(t *testing.T) { } } +func TestConfigQualifies(t *testing.T) { + for i, test := range []struct { + cfg server.Config + expect bool + }{ + {server.Config{Host: "localhost"}, false}, + {server.Config{Host: "example.com"}, true}, + {server.Config{Host: "example.com", TLS: server.TLSConfig{Certificate: "cert.pem"}}, false}, + {server.Config{Host: "example.com", TLS: server.TLSConfig{Key: "key.pem"}}, false}, + {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, false}, + {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, true}, + {server.Config{Host: "example.com", Scheme: "http"}, false}, + {server.Config{Host: "example.com", Port: "80"}, false}, + {server.Config{Host: "example.com", Port: "1234"}, true}, + {server.Config{Host: "example.com", Scheme: "https"}, true}, + {server.Config{Host: "example.com", Port: "80", Scheme: "https"}, false}, + } { + if test.expect && !ConfigQualifies(test.cfg) { + t.Errorf("Test %d: Expected config to qualify, but it did NOT: %#v", i, test.cfg) + } + if !test.expect && ConfigQualifies(test.cfg) { + t.Errorf("Test %d: Expected config to NOT qualify, but it did: %#v", i, test.cfg) + } + } +} + func TestRedirPlaintextHost(t *testing.T) { cfg := redirPlaintextHost(server.Config{ Host: "example.com", BindHost: "93.184.216.34", - Port: "http", + Port: "1234", }) // Check host and port @@ -50,7 +81,7 @@ func TestRedirPlaintextHost(t *testing.T) { if actual, expected := cfg.BindHost, "93.184.216.34"; actual != expected { t.Errorf("Expected redir config to have bindhost %s but got %s", expected, actual) } - if actual, expected := cfg.Port, "http"; actual != expected { + if actual, expected := cfg.Port, "80"; actual != expected { t.Errorf("Expected redir config to have port '%s' but got '%s'", expected, actual) } @@ -74,10 +105,239 @@ func TestRedirPlaintextHost(t *testing.T) { if actual, expected := handler.Rules[0].FromPath, "/"; actual != expected { t.Errorf("Expected redirect rule to be for path '%s' but is actually for '%s'", expected, actual) } - if actual, expected := handler.Rules[0].To, "https://example.com{uri}"; actual != expected { + if actual, expected := handler.Rules[0].To, "https://example.com:1234{uri}"; actual != expected { t.Errorf("Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual) } if actual, expected := handler.Rules[0].Code, http.StatusMovedPermanently; actual != expected { t.Errorf("Expected redirect rule to have code %d but was %d", expected, actual) } + + // browsers can interpret default ports with scheme, so make sure the port + // doesn't get added in explicitly for default ports. + cfg = redirPlaintextHost(server.Config{Host: "example.com", Port: "443"}) + handler, ok = cfg.Middleware["/"][0](nil).(redirect.Redirect) + if actual, expected := handler.Rules[0].To, "https://example.com{uri}"; actual != expected { + t.Errorf("(Default Port) Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual) + } +} + +func TestSaveCertResource(t *testing.T) { + storage = Storage("./le_test_save") + defer func() { + err := os.RemoveAll(string(storage)) + if err != nil { + t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err) + } + }() + + domain := "example.com" + certContents := "certificate" + keyContents := "private key" + metaContents := `{ + "domain": "example.com", + "certUrl": "https://example.com/cert", + "certStableUrl": "https://example.com/cert/stable" +}` + + cert := acme.CertificateResource{ + Domain: domain, + CertURL: "https://example.com/cert", + CertStableURL: "https://example.com/cert/stable", + PrivateKey: []byte(keyContents), + Certificate: []byte(certContents), + } + + err := saveCertResource(cert) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + certFile, err := ioutil.ReadFile(storage.SiteCertFile(domain)) + if err != nil { + t.Errorf("Expected no error reading certificate file, got: %v", err) + } + if string(certFile) != certContents { + t.Errorf("Expected certificate file to contain '%s', got '%s'", certContents, string(certFile)) + } + + keyFile, err := ioutil.ReadFile(storage.SiteKeyFile(domain)) + if err != nil { + t.Errorf("Expected no error reading private key file, got: %v", err) + } + if string(keyFile) != keyContents { + t.Errorf("Expected private key file to contain '%s', got '%s'", keyContents, string(keyFile)) + } + + metaFile, err := ioutil.ReadFile(storage.SiteMetaFile(domain)) + if err != nil { + t.Errorf("Expected no error reading meta file, got: %v", err) + } + if string(metaFile) != metaContents { + t.Errorf("Expected meta file to contain '%s', got '%s'", metaContents, string(metaFile)) + } +} + +func TestExistingCertAndKey(t *testing.T) { + storage = Storage("./le_test_existing") + defer func() { + err := os.RemoveAll(string(storage)) + if err != nil { + t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err) + } + }() + + domain := "example.com" + + if existingCertAndKey(domain) { + t.Errorf("Did NOT expect %v to have existing cert or key, but it did", domain) + } + + err := saveCertResource(acme.CertificateResource{ + Domain: domain, + PrivateKey: []byte("key"), + Certificate: []byte("cert"), + }) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if !existingCertAndKey(domain) { + t.Errorf("Expected %v to have existing cert and key, but it did NOT", domain) + } +} + +func TestHostHasOtherPort(t *testing.T) { + configs := []server.Config{ + server.Config{Host: "example.com", Port: "80"}, + server.Config{Host: "sub1.example.com", Port: "80"}, + server.Config{Host: "sub1.example.com", Port: "443"}, + } + + if hostHasOtherPort(configs, 0, "80") { + t.Errorf(`Expected hostHasOtherPort(configs, 0, "80") to be false, but got true`) + } + if hostHasOtherPort(configs, 0, "443") { + t.Errorf(`Expected hostHasOtherPort(configs, 0, "443") to be false, but got true`) + } + if !hostHasOtherPort(configs, 1, "443") { + t.Errorf(`Expected hostHasOtherPort(configs, 1, "443") to be true, but got false`) + } +} + +func TestMakePlaintextRedirects(t *testing.T) { + configs := []server.Config{ + // Happy path = standard redirect from 80 to 443 + server.Config{Host: "example.com", TLS: server.TLSConfig{Managed: true}}, + + // Host on port 80 already defined; don't change it (no redirect) + server.Config{Host: "sub1.example.com", Port: "80", Scheme: "http"}, + server.Config{Host: "sub1.example.com", TLS: server.TLSConfig{Managed: true}}, + + // Redirect from port 80 to port 5000 in this case + server.Config{Host: "sub2.example.com", Port: "5000", TLS: server.TLSConfig{Managed: true}}, + + // Can redirect from 80 to either 443 or 5001, but choose 443 + server.Config{Host: "sub3.example.com", Port: "443", TLS: server.TLSConfig{Managed: true}}, + server.Config{Host: "sub3.example.com", Port: "5001", Scheme: "https", TLS: server.TLSConfig{Managed: true}}, + } + + result := MakePlaintextRedirects(configs) + expectedRedirCount := 3 + + if len(result) != len(configs)+expectedRedirCount { + t.Errorf("Expected %d redirect(s) to be added, but got %d", + expectedRedirCount, len(result)-len(configs)) + } +} + +func TestEnableTLS(t *testing.T) { + configs := []server.Config{ + server.Config{TLS: server.TLSConfig{Managed: true}}, + server.Config{}, // not managed - no changes! + } + + EnableTLS(configs) + + if !configs[0].TLS.Enabled { + t.Errorf("Expected config 0 to have TLS.Enabled == true, but it was false") + } + if configs[0].TLS.Certificate == "" { + t.Errorf("Expected config 0 to have TLS.Certificate set, but it was empty") + } + if configs[0].TLS.Key == "" { + t.Errorf("Expected config 0 to have TLS.Key set, but it was empty") + } + + if configs[1].TLS.Enabled { + t.Errorf("Expected config 1 to have TLS.Enabled == false, but it was true") + } + if configs[1].TLS.Certificate != "" { + t.Errorf("Expected config 1 to have TLS.Certificate empty, but it was: %s", configs[1].TLS.Certificate) + } + if configs[1].TLS.Key != "" { + t.Errorf("Expected config 1 to have TLS.Key empty, but it was: %s", configs[1].TLS.Key) + } +} + +func TestGroupConfigsByEmail(t *testing.T) { + if groupConfigsByEmail([]server.Config{}) == nil { + t.Errorf("With empty input, returned map was nil, but expected non-nil map") + } + + configs := []server.Config{ + server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, + server.Config{Host: "sub1.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}}, + server.Config{Host: "sub2.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, + server.Config{Host: "sub3.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}}, + server.Config{Host: "sub4.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, + server.Config{Host: "sub5.example.com", TLS: server.TLSConfig{LetsEncryptEmail: ""}}, // not managed + } + DefaultEmail = "test@example.com" // bypass prompt during tests... + + groups := groupConfigsByEmail(configs) + + if groups == nil { + t.Fatalf("Returned map was nil, but expected values") + } + + if len(groups) != 2 { + t.Errorf("Expected 2 groups, got %d: %#v", len(groups), groups) + } + if len(groups["foo@bar"]) != 2 { + t.Errorf("Expected 2 configs for foo@bar, got %d: %#v", len(groups["foobar"]), groups["foobar"]) + } + if len(groups[DefaultEmail]) != 3 { + t.Errorf("Expected 3 configs for %s, got %d: %#v", DefaultEmail, len(groups["foobar"]), groups["foobar"]) + } +} + +func TestMarkQualified(t *testing.T) { + // TODO: TestConfigQualifies and this test share the same config list... + configs := []server.Config{ + {Host: "localhost"}, + {Host: "example.com"}, + {Host: "example.com", TLS: server.TLSConfig{Certificate: "cert.pem"}}, + {Host: "example.com", TLS: server.TLSConfig{Key: "key.pem"}}, + {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, + {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, + {Host: "example.com", Scheme: "http"}, + {Host: "example.com", Port: "80"}, + {Host: "example.com", Port: "1234"}, + {Host: "example.com", Scheme: "https"}, + {Host: "example.com", Port: "80", Scheme: "https"}, + } + expectedManagedCount := 4 + + MarkQualified(configs) + + count := 0 + for _, cfg := range configs { + if cfg.TLS.Managed { + count++ + } + } + + if count != expectedManagedCount { + t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count) + } } diff --git a/caddy/letsencrypt/maintain.go b/caddy/letsencrypt/maintain.go index c6701b40a..5a59dc23a 100644 --- a/caddy/letsencrypt/maintain.go +++ b/caddy/letsencrypt/maintain.go @@ -27,8 +27,8 @@ var OnChange func() error // which you'll close when maintenance should stop, to allow this // goroutine to clean up after itself and unblock. func maintainAssets(configs []server.Config, stopChan chan struct{}) { - renewalTicker := time.NewTicker(renewInterval) - ocspTicker := time.NewTicker(ocspInterval) + renewalTicker := time.NewTicker(RenewInterval) + ocspTicker := time.NewTicker(OCSPInterval) for { select { @@ -47,15 +47,29 @@ func maintainAssets(configs []server.Config, stopChan chan struct{}) { } } case <-ocspTicker.C: - for bundle, oldStatus := range ocspStatus { - _, newStatus, err := acme.GetOCSPForCert(*bundle) - if err == nil && newStatus != oldStatus && OnChange != nil { - log.Printf("[INFO] OCSP status changed from %v to %v", oldStatus, newStatus) - err := OnChange() + for bundle, oldResp := range ocspCache { + // start checking OCSP staple about halfway through validity period for good measure + refreshTime := oldResp.ThisUpdate.Add(oldResp.NextUpdate.Sub(oldResp.ThisUpdate) / 2) + + // only check for updated OCSP validity window if refreshTime is in the past + if time.Now().After(refreshTime) { + _, newResp, err := acme.GetOCSPForCert(*bundle) if err != nil { - log.Printf("[ERROR] OnChange after OCSP update: %v", err) + log.Printf("[ERROR] Checking OCSP for bundle: %v", err) + continue + } + + // we're not looking for different status, just a more future expiration + if newResp.NextUpdate != oldResp.NextUpdate { + if OnChange != nil { + log.Printf("[INFO] Updating OCSP stapling to extend validity period to %v", newResp.NextUpdate) + err := OnChange() + if err != nil { + log.Printf("[ERROR] OnChange after OCSP trigger: %v", err) + } + break + } } - break } } case <-stopChan: @@ -102,12 +116,12 @@ func renewCertificates(configs []server.Config, useCustomPort bool) (int, []erro // Directly convert it to days for the following checks. daysLeft := int(expTime.Sub(time.Now().UTC()).Hours() / 24) - // Renew with two weeks or less remaining. - if daysLeft <= 14 { + // Renew if getting close to expiration. + if daysLeft <= renewDaysBefore { log.Printf("[INFO] Certificate for %s has %d days remaining; attempting renewal", cfg.Host, daysLeft) var client *acme.Client if useCustomPort { - client, err = newClientPort("", alternatePort) // email not used for renewal + client, err = newClientPort("", AlternatePort) // email not used for renewal } else { client, err = newClient("") } @@ -134,7 +148,7 @@ func renewCertificates(configs []server.Config, useCustomPort bool) (int, []erro // Renew certificate Renew: - newCertMeta, err := client.RenewCertificate(certMeta, true, true) + newCertMeta, err := client.RenewCertificate(certMeta, true) if err != nil { if _, ok := err.(acme.TOSError); ok { err := client.AgreeToTOS() @@ -145,24 +159,22 @@ func renewCertificates(configs []server.Config, useCustomPort bool) (int, []erro } time.Sleep(10 * time.Second) - newCertMeta, err = client.RenewCertificate(certMeta, true, true) + newCertMeta, err = client.RenewCertificate(certMeta, true) if err != nil { errs = append(errs, err) continue } } - saveCertsAndKeys([]acme.CertificateResource{newCertMeta}) + saveCertResource(newCertMeta) n++ - } else if daysLeft <= 30 { - // Warn on 30 days remaining. TODO: Just do this once... - log.Printf("[WARNING] Certificate for %s has %d days remaining; will automatically renew when 14 days remain\n", cfg.Host, daysLeft) + } else if daysLeft <= renewDaysBefore+7 && daysLeft >= renewDaysBefore+6 { + log.Printf("[WARNING] Certificate for %s has %d days remaining; will automatically renew when %d days remain\n", cfg.Host, daysLeft, renewDaysBefore) } } return n, errs } -// acmeHandlers is a map of host to ACME handler. These -// are used to proxy ACME requests to the ACME client. -var acmeHandlers = make(map[string]*Handler) +// renewDaysBefore is how many days before expiration to renew certificates. +const renewDaysBefore = 14 diff --git a/caddy/letsencrypt/storage_test.go b/caddy/letsencrypt/storage_test.go index 95fb71833..545c46b64 100644 --- a/caddy/letsencrypt/storage_test.go +++ b/caddy/letsencrypt/storage_test.go @@ -6,44 +6,44 @@ import ( ) func TestStorage(t *testing.T) { - storage = Storage("./letsencrypt") + storage = Storage("./le_test") - if expected, actual := filepath.Join("letsencrypt", "sites"), storage.Sites(); actual != expected { + if expected, actual := filepath.Join("le_test", "sites"), storage.Sites(); actual != expected { t.Errorf("Expected Sites() to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("letsencrypt", "sites", "test.com"), storage.Site("test.com"); actual != expected { + if expected, actual := filepath.Join("le_test", "sites", "test.com"), storage.Site("test.com"); actual != expected { t.Errorf("Expected Site() to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("letsencrypt", "sites", "test.com", "test.com.crt"), storage.SiteCertFile("test.com"); actual != expected { + if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.crt"), storage.SiteCertFile("test.com"); actual != expected { t.Errorf("Expected SiteCertFile() to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("letsencrypt", "sites", "test.com", "test.com.key"), storage.SiteKeyFile("test.com"); actual != expected { + if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.key"), storage.SiteKeyFile("test.com"); actual != expected { t.Errorf("Expected SiteKeyFile() to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("letsencrypt", "sites", "test.com", "test.com.json"), storage.SiteMetaFile("test.com"); actual != expected { + if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.json"), storage.SiteMetaFile("test.com"); actual != expected { t.Errorf("Expected SiteMetaFile() to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("letsencrypt", "users"), storage.Users(); actual != expected { + if expected, actual := filepath.Join("le_test", "users"), storage.Users(); actual != expected { t.Errorf("Expected Users() to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("letsencrypt", "users", "me@example.com"), storage.User("me@example.com"); actual != expected { + if expected, actual := filepath.Join("le_test", "users", "me@example.com"), storage.User("me@example.com"); actual != expected { t.Errorf("Expected User() to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("letsencrypt", "users", "me@example.com", "me.json"), storage.UserRegFile("me@example.com"); actual != expected { + if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.json"), storage.UserRegFile("me@example.com"); actual != expected { t.Errorf("Expected UserRegFile() to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("letsencrypt", "users", "me@example.com", "me.key"), storage.UserKeyFile("me@example.com"); actual != expected { + if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.key"), storage.UserKeyFile("me@example.com"); actual != expected { t.Errorf("Expected UserKeyFile() to return '%s' but got '%s'", expected, actual) } // Test with empty emails - if expected, actual := filepath.Join("letsencrypt", "users", emptyEmail), storage.User(emptyEmail); actual != expected { + if expected, actual := filepath.Join("le_test", "users", emptyEmail), storage.User(emptyEmail); actual != expected { t.Errorf("Expected User(\"\") to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("letsencrypt", "users", emptyEmail, emptyEmail+".json"), storage.UserRegFile(""); actual != expected { + if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".json"), storage.UserRegFile(""); actual != expected { t.Errorf("Expected UserRegFile(\"\") to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("letsencrypt", "users", emptyEmail, emptyEmail+".key"), storage.UserKeyFile(""); actual != expected { + if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".key"), storage.UserKeyFile(""); actual != expected { t.Errorf("Expected UserKeyFile(\"\") to return '%s' but got '%s'", expected, actual) } } diff --git a/caddy/letsencrypt/user.go b/caddy/letsencrypt/user.go index 7fae3bb41..dbcc0f493 100644 --- a/caddy/letsencrypt/user.go +++ b/caddy/letsencrypt/user.go @@ -144,7 +144,7 @@ func getEmail(cfg server.Config) string { // Alas, we must bother the user and ask for an email address; // if they proceed they also agree to the SA. reader := bufio.NewReader(stdin) - fmt.Println("Your sites will be served over HTTPS automatically using Let's Encrypt.") + fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.") fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:") fmt.Println(" " + saURL) // TODO: Show current SA link fmt.Println("Please enter your email address so you can recover your account if needed.") diff --git a/caddy/letsencrypt/user_test.go b/caddy/letsencrypt/user_test.go index d074856af..1f9c9e4f2 100644 --- a/caddy/letsencrypt/user_test.go +++ b/caddy/letsencrypt/user_test.go @@ -125,6 +125,11 @@ func TestGetUserAlreadyExists(t *testing.T) { } func TestGetEmail(t *testing.T) { + // let's not clutter up the output + origStdout := os.Stdout + os.Stdout = nil + defer func() { os.Stdout = origStdout }() + storage = Storage("./testdata") defer os.RemoveAll(string(storage)) DefaultEmail = "test2@foo.com" diff --git a/caddy/parse/parse.go b/caddy/parse/parse.go index 1f9137d9b..faef36c28 100644 --- a/caddy/parse/parse.go +++ b/caddy/parse/parse.go @@ -8,7 +8,7 @@ import "io" // If checkDirectives is true, only valid directives will be allowed // otherwise we consider it a parse error. Server blocks are returned // in the order in which they appear. -func ServerBlocks(filename string, input io.Reader, checkDirectives bool) ([]serverBlock, error) { +func ServerBlocks(filename string, input io.Reader, checkDirectives bool) ([]ServerBlock, error) { p := parser{Dispenser: NewDispenser(filename, input)} p.checkDirectives = checkDirectives blocks, err := p.parseAll() diff --git a/caddy/parse/parsing.go b/caddy/parse/parsing.go index 6ef908b0b..1b7587b32 100644 --- a/caddy/parse/parsing.go +++ b/caddy/parse/parsing.go @@ -1,6 +1,7 @@ package parse import ( + "fmt" "net" "os" "path/filepath" @@ -9,13 +10,13 @@ import ( type parser struct { Dispenser - block serverBlock // current server block being parsed + block ServerBlock // current server block being parsed eof bool // if we encounter a valid EOF in a hard place checkDirectives bool // if true, directives must be known } -func (p *parser) parseAll() ([]serverBlock, error) { - var blocks []serverBlock +func (p *parser) parseAll() ([]ServerBlock, error) { + var blocks []ServerBlock for p.Next() { err := p.parseOne() @@ -31,7 +32,7 @@ func (p *parser) parseAll() ([]serverBlock, error) { } func (p *parser) parseOne() error { - p.block = serverBlock{Tokens: make(map[string][]token)} + p.block = ServerBlock{Tokens: make(map[string][]token)} err := p.begin() if err != nil { @@ -99,11 +100,11 @@ func (p *parser) addresses() error { } // Parse and save this address - host, port, err := standardAddress(tkn) + addr, err := standardAddress(tkn) if err != nil { return err } - p.block.Addresses = append(p.block.Addresses, address{host, port}) + p.block.Addresses = append(p.block.Addresses, addr) } // Advance token and possibly break out of loop or return error @@ -303,39 +304,57 @@ func (p *parser) closeCurlyBrace() error { return nil } -// standardAddress turns the accepted host and port patterns -// into a format accepted by net.Dial. -func standardAddress(str string) (host, port string, err error) { - var schemePort, splitPort string +// standardAddress parses an address string into a structured format with separate +// scheme, host, and port portions, as well as the original input string. +func standardAddress(str string) (address, error) { + var scheme string + var err error + // first check for scheme and strip it off + input := str if strings.HasPrefix(str, "https://") { - schemePort = "https" + scheme = "https" str = str[8:] } else if strings.HasPrefix(str, "http://") { - schemePort = "http" + scheme = "http" str = str[7:] } - host, splitPort, err = net.SplitHostPort(str) + // separate host and port + host, port, err := net.SplitHostPort(str) if err != nil { - host, splitPort, err = net.SplitHostPort(str + ":") // tack on empty port - } - if err != nil { - // ¯\_(ツ)_/¯ - host = str + host, port, err = net.SplitHostPort(str + ":") + // no error check here; return err at end of function } - if splitPort != "" { - port = splitPort - } else { - port = schemePort + // see if we can set port based off scheme + if port == "" { + if scheme == "http" { + port = "80" + } else if scheme == "https" { + port = "443" + } } - return + // repeated or conflicting scheme is confusing, so error + if scheme != "" && (port == "http" || port == "https") { + return address{}, fmt.Errorf("[%s] scheme specified twice in address", str) + } + + // standardize http and https ports to their respective port numbers + if port == "http" { + scheme = "http" + port = "80" + } else if port == "https" { + scheme = "https" + port = "443" + } + + return address{Original: input, Scheme: scheme, Host: host, Port: port}, err } // replaceEnvVars replaces environment variables that appear in the token -// and understands both the Unix $SYNTAX and Windows %SYNTAX%. +// and understands both the $UNIX and %WINDOWS% syntaxes. func replaceEnvVars(s string) string { s = replaceEnvReferences(s, "{%", "%}") s = replaceEnvReferences(s, "{$", "}") @@ -360,26 +379,26 @@ func replaceEnvReferences(s, refStart, refEnd string) string { } type ( - // serverBlock associates tokens with a list of addresses + // ServerBlock associates tokens with a list of addresses // and groups tokens by directive name. - serverBlock struct { + ServerBlock struct { Addresses []address Tokens map[string][]token } address struct { - Host, Port string + Original, Scheme, Host, Port string } ) -// HostList converts the list of addresses (hosts) -// that are associated with this server block into -// a slice of strings. Each string is a host:port -// combination. -func (sb serverBlock) HostList() []string { +// HostList converts the list of addresses that are +// associated with this server block into a slice of +// strings, where each address is as it was originally +// read from the input. +func (sb ServerBlock) HostList() []string { sbHosts := make([]string, len(sb.Addresses)) for j, addr := range sb.Addresses { - sbHosts[j] = net.JoinHostPort(addr.Host, addr.Port) + sbHosts[j] = addr.Original } return sbHosts } diff --git a/caddy/parse/parsing_test.go b/caddy/parse/parsing_test.go index bda6b29bc..8533b2a38 100644 --- a/caddy/parse/parsing_test.go +++ b/caddy/parse/parsing_test.go @@ -8,51 +8,55 @@ import ( func TestStandardAddress(t *testing.T) { for i, test := range []struct { - input string - host, port string - shouldErr bool + input string + scheme, host, port string + shouldErr bool }{ - {`localhost`, "localhost", "", false}, - {`localhost:1234`, "localhost", "1234", false}, - {`localhost:`, "localhost", "", false}, - {`0.0.0.0`, "0.0.0.0", "", false}, - {`127.0.0.1:1234`, "127.0.0.1", "1234", false}, - {`:1234`, "", "1234", false}, - {`[::1]`, "::1", "", false}, - {`[::1]:1234`, "::1", "1234", false}, - {`:`, "", "", false}, - {`localhost:http`, "localhost", "http", false}, - {`localhost:https`, "localhost", "https", false}, - {`:http`, "", "http", false}, - {`:https`, "", "https", false}, - {`http://localhost`, "localhost", "http", false}, - {`https://localhost`, "localhost", "https", false}, - {`http://127.0.0.1`, "127.0.0.1", "http", false}, - {`https://127.0.0.1`, "127.0.0.1", "https", false}, - {`http://[::1]`, "::1", "http", false}, - {`http://localhost:1234`, "localhost", "1234", false}, - {`https://127.0.0.1:1234`, "127.0.0.1", "1234", false}, - {`http://[::1]:1234`, "::1", "1234", false}, - {``, "", "", false}, - {`::1`, "::1", "", true}, - {`localhost::`, "localhost::", "", true}, - {`#$%@`, "#$%@", "", true}, + {`localhost`, "", "localhost", "", false}, + {`localhost:1234`, "", "localhost", "1234", false}, + {`localhost:`, "", "localhost", "", false}, + {`0.0.0.0`, "", "0.0.0.0", "", false}, + {`127.0.0.1:1234`, "", "127.0.0.1", "1234", false}, + {`:1234`, "", "", "1234", false}, + {`[::1]`, "", "::1", "", false}, + {`[::1]:1234`, "", "::1", "1234", false}, + {`:`, "", "", "", false}, + {`localhost:http`, "http", "localhost", "80", false}, + {`localhost:https`, "https", "localhost", "443", false}, + {`:http`, "http", "", "80", false}, + {`:https`, "https", "", "443", false}, + {`http://localhost:https`, "", "", "", true}, // conflict + {`http://localhost:http`, "", "", "", true}, // repeated scheme + {`http://localhost`, "http", "localhost", "80", false}, + {`https://localhost`, "https", "localhost", "443", false}, + {`http://127.0.0.1`, "http", "127.0.0.1", "80", false}, + {`https://127.0.0.1`, "https", "127.0.0.1", "443", false}, + {`http://[::1]`, "http", "::1", "80", false}, + {`http://localhost:1234`, "http", "localhost", "1234", false}, + {`https://127.0.0.1:1234`, "https", "127.0.0.1", "1234", false}, + {`http://[::1]:1234`, "http", "::1", "1234", false}, + {``, "", "", "", false}, + {`::1`, "", "::1", "", true}, + {`localhost::`, "", "localhost::", "", true}, + {`#$%@`, "", "#$%@", "", true}, } { - host, port, err := standardAddress(test.input) + actual, err := standardAddress(test.input) if err != nil && !test.shouldErr { - t.Errorf("Test %d: Expected no error, but had error: %v", i, err) + t.Errorf("Test %d (%s): Expected no error, but had error: %v", i, test.input, err) } if err == nil && test.shouldErr { - t.Errorf("Test %d: Expected error, but had none", i) + t.Errorf("Test %d (%s): Expected error, but had none", i, test.input) } - if host != test.host { - t.Errorf("Test %d: Expected host '%s', got '%s'", i, test.host, host) + if actual.Scheme != test.scheme { + t.Errorf("Test %d (%s): Expected scheme '%s', got '%s'", i, test.input, test.scheme, actual.Scheme) } - - if port != test.port { - t.Errorf("Test %d: Expected port '%s', got '%s'", i, test.port, port) + if actual.Host != test.host { + t.Errorf("Test %d (%s): Expected host '%s', got '%s'", i, test.input, test.host, actual.Host) + } + if actual.Port != test.port { + t.Errorf("Test %d (%s): Expected port '%s', got '%s'", i, test.input, test.port, actual.Port) } } } @@ -60,7 +64,7 @@ func TestStandardAddress(t *testing.T) { func TestParseOneAndImport(t *testing.T) { setupParseTests() - testParseOne := func(input string) (serverBlock, error) { + testParseOne := func(input string) (ServerBlock, error) { p := testParser(input) p.Next() // parseOne doesn't call Next() to start, so we must err := p.parseOne() @@ -74,19 +78,19 @@ func TestParseOneAndImport(t *testing.T) { tokens map[string]int // map of directive name to number of tokens expected }{ {`localhost`, false, []address{ - {"localhost", ""}, + {"localhost", "", "localhost", ""}, }, map[string]int{}}, {`localhost dir1`, false, []address{ - {"localhost", ""}, + {"localhost", "", "localhost", ""}, }, map[string]int{ "dir1": 1, }}, {`localhost:1234 dir1 foo bar`, false, []address{ - {"localhost", "1234"}, + {"localhost:1234", "", "localhost", "1234"}, }, map[string]int{ "dir1": 3, }}, @@ -94,7 +98,7 @@ func TestParseOneAndImport(t *testing.T) { {`localhost { dir1 }`, false, []address{ - {"localhost", ""}, + {"localhost", "", "localhost", ""}, }, map[string]int{ "dir1": 1, }}, @@ -103,7 +107,7 @@ func TestParseOneAndImport(t *testing.T) { dir1 foo bar dir2 }`, false, []address{ - {"localhost", "1234"}, + {"localhost:1234", "", "localhost", "1234"}, }, map[string]int{ "dir1": 3, "dir2": 1, @@ -111,8 +115,8 @@ func TestParseOneAndImport(t *testing.T) { {`http://localhost https://localhost dir1 foo bar`, false, []address{ - {"localhost", "http"}, - {"localhost", "https"}, + {"http://localhost", "http", "localhost", "80"}, + {"https://localhost", "https", "localhost", "443"}, }, map[string]int{ "dir1": 3, }}, @@ -120,8 +124,8 @@ func TestParseOneAndImport(t *testing.T) { {`http://localhost https://localhost { dir1 foo bar }`, false, []address{ - {"localhost", "http"}, - {"localhost", "https"}, + {"http://localhost", "http", "localhost", "80"}, + {"https://localhost", "https", "localhost", "443"}, }, map[string]int{ "dir1": 3, }}, @@ -129,22 +133,22 @@ func TestParseOneAndImport(t *testing.T) { {`http://localhost, https://localhost { dir1 foo bar }`, false, []address{ - {"localhost", "http"}, - {"localhost", "https"}, + {"http://localhost", "http", "localhost", "80"}, + {"https://localhost", "https", "localhost", "443"}, }, map[string]int{ "dir1": 3, }}, {`http://localhost, { }`, true, []address{ - {"localhost", "http"}, + {"http://localhost", "http", "localhost", "80"}, }, map[string]int{}}, {`host1:80, http://host2.com dir1 foo bar dir2 baz`, false, []address{ - {"host1", "80"}, - {"host2.com", "http"}, + {"host1:80", "", "host1", "80"}, + {"http://host2.com", "http", "host2.com", "80"}, }, map[string]int{ "dir1": 3, "dir2": 2, @@ -153,9 +157,9 @@ func TestParseOneAndImport(t *testing.T) { {`http://host1.com, http://host2.com, https://host3.com`, false, []address{ - {"host1.com", "http"}, - {"host2.com", "http"}, - {"host3.com", "https"}, + {"http://host1.com", "http", "host1.com", "80"}, + {"http://host2.com", "http", "host2.com", "80"}, + {"https://host3.com", "https", "host3.com", "443"}, }, map[string]int{}}, {`http://host1.com:1234, https://host2.com @@ -163,8 +167,8 @@ func TestParseOneAndImport(t *testing.T) { bar baz } dir2`, false, []address{ - {"host1.com", "1234"}, - {"host2.com", "https"}, + {"http://host1.com:1234", "http", "host1.com", "1234"}, + {"https://host2.com", "https", "host2.com", "443"}, }, map[string]int{ "dir1": 6, "dir2": 1, @@ -177,7 +181,7 @@ func TestParseOneAndImport(t *testing.T) { dir2 { foo bar }`, false, []address{ - {"127.0.0.1", ""}, + {"127.0.0.1", "", "127.0.0.1", ""}, }, map[string]int{ "dir1": 5, "dir2": 5, @@ -185,13 +189,13 @@ func TestParseOneAndImport(t *testing.T) { {`127.0.0.1 unknown_directive`, true, []address{ - {"127.0.0.1", ""}, + {"127.0.0.1", "", "127.0.0.1", ""}, }, map[string]int{}}, {`localhost dir1 { foo`, true, []address{ - {"localhost", ""}, + {"localhost", "", "localhost", ""}, }, map[string]int{ "dir1": 3, }}, @@ -199,7 +203,7 @@ func TestParseOneAndImport(t *testing.T) { {`localhost dir1 { }`, false, []address{ - {"localhost", ""}, + {"localhost", "", "localhost", ""}, }, map[string]int{ "dir1": 3, }}, @@ -207,7 +211,7 @@ func TestParseOneAndImport(t *testing.T) { {`localhost dir1 { } }`, true, []address{ - {"localhost", ""}, + {"localhost", "", "localhost", ""}, }, map[string]int{ "dir1": 3, }}, @@ -219,7 +223,7 @@ func TestParseOneAndImport(t *testing.T) { } } dir2 foo bar`, false, []address{ - {"localhost", ""}, + {"localhost", "", "localhost", ""}, }, map[string]int{ "dir1": 7, "dir2": 3, @@ -230,7 +234,7 @@ func TestParseOneAndImport(t *testing.T) { {`localhost dir1 arg1 import import_test1.txt`, false, []address{ - {"localhost", ""}, + {"localhost", "", "localhost", ""}, }, map[string]int{ "dir1": 2, "dir2": 3, @@ -238,7 +242,7 @@ func TestParseOneAndImport(t *testing.T) { }}, {`import import_test2.txt`, false, []address{ - {"host1", ""}, + {"host1", "", "host1", ""}, }, map[string]int{ "dir1": 1, "dir2": 2, @@ -301,23 +305,23 @@ func TestParseAll(t *testing.T) { addresses [][]address // addresses per server block, in order }{ {`localhost`, false, [][]address{ - {{"localhost", ""}}, + {{"localhost", "", "localhost", ""}}, }}, {`localhost:1234`, false, [][]address{ - []address{{"localhost", "1234"}}, + []address{{"localhost:1234", "", "localhost", "1234"}}, }}, {`localhost:1234 { } localhost:2015 { }`, false, [][]address{ - []address{{"localhost", "1234"}}, - []address{{"localhost", "2015"}}, + []address{{"localhost:1234", "", "localhost", "1234"}}, + []address{{"localhost:2015", "", "localhost", "2015"}}, }}, {`localhost:1234, http://host2`, false, [][]address{ - []address{{"localhost", "1234"}, {"host2", "http"}}, + []address{{"localhost:1234", "", "localhost", "1234"}, {"http://host2", "http", "host2", "80"}}, }}, {`localhost:1234, http://host2,`, true, [][]address{}}, @@ -326,15 +330,15 @@ func TestParseAll(t *testing.T) { } https://host3.com, https://host4.com { }`, false, [][]address{ - []address{{"host1.com", "http"}, {"host2.com", "http"}}, - []address{{"host3.com", "https"}, {"host4.com", "https"}}, + []address{{"http://host1.com", "http", "host1.com", "80"}, {"http://host2.com", "http", "host2.com", "80"}}, + []address{{"https://host3.com", "https", "host3.com", "443"}, {"https://host4.com", "https", "host4.com", "443"}}, }}, {`import import_glob*.txt`, false, [][]address{ - []address{{"glob0.host0", ""}}, - []address{{"glob0.host1", ""}}, - []address{{"glob1.host0", ""}}, - []address{{"glob2.host0", ""}}, + []address{{"glob0.host0", "", "glob0.host0", ""}}, + []address{{"glob0.host1", "", "glob0.host1", ""}}, + []address{{"glob1.host0", "", "glob1.host0", ""}}, + []address{{"glob2.host0", "", "glob2.host0", ""}}, }}, } { p := testParser(test.input) @@ -446,6 +450,13 @@ func TestEnvironmentReplacement(t *testing.T) { if actual, expected := blocks[0].Addresses[0].Port, ""; expected != actual { t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) } + + // in quoted field + p = testParser(":1234\ndir1 \"Test {$FOOBAR} test\"") + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Tokens["dir1"][1].text, "Test foobar test"; expected != actual { + t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual) + } } func setupParseTests() { diff --git a/caddy/restart.go b/caddy/restart.go index c921e44cd..0c3cc218b 100644 --- a/caddy/restart.go +++ b/caddy/restart.go @@ -3,11 +3,17 @@ package caddy import ( + "bytes" "encoding/gob" + "errors" "io/ioutil" "log" "os" "os/exec" + "path" + + "github.com/mholt/caddy/caddy/letsencrypt" + "github.com/mholt/caddy/server" ) func init() { @@ -33,6 +39,12 @@ func Restart(newCaddyfile Input) error { caddyfileMu.Unlock() } + // Get certificates for any new hosts in the new Caddyfile without causing downtime + err := getCertsForNewCaddyfile(newCaddyfile) + if err != nil { + return errors.New("TLS preload: " + err.Error()) + } + if len(os.Args) == 0 { // this should never happen, but... os.Args = []string{""} } @@ -61,7 +73,7 @@ func Restart(newCaddyfile Input) error { // Pass along relevant file descriptors to child process; ordering // is very important since we rely on these being in certain positions. - extraFiles := []*os.File{sigwpipe} + extraFiles := []*os.File{sigwpipe} // fd 3 // Add file descriptors of all the sockets serversMu.Lock() @@ -110,3 +122,44 @@ func Restart(newCaddyfile Input) error { // Looks like child is successful; we can exit gracefully. return Stop() } + +func getCertsForNewCaddyfile(newCaddyfile Input) error { + // parse the new caddyfile only up to (and including) TLS + // so we can know what we need to get certs for. + configs, _, _, err := loadConfigsUpToIncludingTLS(path.Base(newCaddyfile.Path()), bytes.NewReader(newCaddyfile.Body())) + if err != nil { + return errors.New("loading Caddyfile: " + err.Error()) + } + + // first mark the configs that are qualified for managed TLS + letsencrypt.MarkQualified(configs) + + // we must make sure port is set before we group by bind address + letsencrypt.EnableTLS(configs) + + // we only need to issue certs for hosts where we already have an active listener + groupings, err := arrangeBindings(configs) + if err != nil { + return errors.New("arranging bindings: " + err.Error()) + } + var configsToSetup []server.Config + serversMu.Lock() +GroupLoop: + for _, group := range groupings { + for _, server := range servers { + if server.Addr == group.BindAddr.String() { + configsToSetup = append(configsToSetup, group.Configs...) + continue GroupLoop + } + } + } + serversMu.Unlock() + + // place certs on the disk + err = letsencrypt.ObtainCerts(configsToSetup, letsencrypt.AlternatePort) + if err != nil { + return errors.New("obtaining certs: " + err.Error()) + } + + return nil +} diff --git a/caddy/setup/startupshutdown_test.go b/caddy/setup/startupshutdown_test.go index a6bdf1b78..16fa973c3 100644 --- a/caddy/setup/startupshutdown_test.go +++ b/caddy/setup/startupshutdown_test.go @@ -2,7 +2,6 @@ package setup import ( "os" - "os/exec" "path/filepath" "strconv" "testing" @@ -13,16 +12,19 @@ import ( // because the Startup and Shutdown functions share virtually the // same functionality func TestStartup(t *testing.T) { - tempDirPath, err := getTempDirPath() if err != nil { t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err) } - testDir := filepath.Join(tempDirPath, "temp_dir_for_testing_startupshutdown.go") + testDir := filepath.Join(tempDirPath, "temp_dir_for_testing_startupshutdown") + defer func() { + // clean up after non-blocking startup function quits + time.Sleep(500 * time.Millisecond) + os.RemoveAll(testDir) + }() osSenitiveTestDir := filepath.FromSlash(testDir) - - exec.Command("rm", "-r", osSenitiveTestDir).Run() // removes osSenitiveTestDir from the OS's temp directory, if the osSenitiveTestDir already exists + os.RemoveAll(osSenitiveTestDir) // start with a clean slate tests := []struct { input string @@ -53,6 +55,5 @@ func TestStartup(t *testing.T) { if err != nil && !test.shouldRemoveErr { t.Errorf("Test %d recieved an error of:\n%v", i, err) } - } } diff --git a/caddy/setup/tls.go b/caddy/setup/tls.go index 79954de48..0ca5f521c 100644 --- a/caddy/setup/tls.go +++ b/caddy/setup/tls.go @@ -11,12 +11,12 @@ import ( // TLS sets up the TLS configuration (but does not activate Let's Encrypt; that is handled elsewhere). func TLS(c *Controller) (middleware.Middleware, error) { - if c.Port == "http" { + if c.Scheme == "http" && c.Port != "80" { c.TLS.Enabled = false log.Printf("[WARNING] TLS disabled for %s://%s. To force TLS over the plaintext HTTP port, "+ - "specify port 80 explicitly (https://%s:80).", c.Port, c.Host, c.Host) + "specify port 80 explicitly (https://%s:80).", c.Scheme, c.Address(), c.Host) } else { - c.TLS.Enabled = true // they had a tls directive, so assume it's on unless we confirm otherwise later + c.TLS.Enabled = true } for c.Next() { @@ -32,18 +32,9 @@ func TLS(c *Controller) (middleware.Middleware, error) { case 2: c.TLS.Certificate = args[0] c.TLS.Key = args[1] - - // manual HTTPS configuration without port specified should be - // served on the HTTPS port; that is what user would expect, and - // makes it consistent with how the letsencrypt package works. - if c.Port == "" { - c.Port = "https" - } - default: - return nil, c.ArgErr() } - // Optional block + // Optional block with extra parameters for c.NextBlock() { switch c.Val() { case "protocols": @@ -74,6 +65,9 @@ func TLS(c *Controller) (middleware.Middleware, error) { if len(c.TLS.ClientCerts) == 0 { return nil, c.ArgErr() } + // TODO: Allow this? It's a bad idea to allow HTTP. If we do this, make sure invoking tls at all (even manually) also sets up a redirect if possible? + // case "allow_http": + // c.TLS.DisableHTTPRedir = true default: return nil, c.Errf("Unknown keyword '%s'", c.Val()) } @@ -85,8 +79,9 @@ func TLS(c *Controller) (middleware.Middleware, error) { return nil, nil } -// SetDefaultTLSParams sets the default TLS cipher suites, protocol versions and server preferences -// of a server.Config if they were not previously set. +// SetDefaultTLSParams sets the default TLS cipher suites, protocol versions, +// and server preferences of a server.Config if they were not previously set +// (it does not overwrite; only fills in missing values). func SetDefaultTLSParams(c *server.Config) { // If no ciphers provided, use all that Caddy supports for the protocol if len(c.TLS.Ciphers) == 0 { @@ -106,6 +101,11 @@ func SetDefaultTLSParams(c *server.Config) { // Prefer server cipher suites c.TLS.PreferServerCipherSuites = true + + // Default TLS port is 443; only use if port is not manually specified + if c.Port == "" { + c.Port = "443" + } } // Map of supported protocols diff --git a/caddy/setup/tls_test.go b/caddy/setup/tls_test.go index e2d2e0155..3077da15f 100644 --- a/caddy/setup/tls_test.go +++ b/caddy/setup/tls_test.go @@ -64,11 +64,12 @@ func TestTLSParseBasic(t *testing.T) { } func TestTLSParseIncompleteParams(t *testing.T) { + // This doesn't do anything useful but is allowed in case the user wants to be explicit + // about TLS being enabled... c := NewTestController(`tls`) - _, err := TLS(c) - if err == nil { - t.Errorf("Expected errors (first check), but no error returned") + if err != nil { + t.Errorf("Expected no error, but got %v", err) } } @@ -93,10 +94,39 @@ func TestTLSParseWithOptionalParams(t *testing.T) { } if len(c.TLS.Ciphers)-1 != 3 { - t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)) + t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1) } } +func TestTLSDefaultWithOptionalParams(t *testing.T) { + params := `tls { + ciphers RSA-3DES-EDE-CBC-SHA + }` + c := NewTestController(params) + + _, err := TLS(c) + if err != nil { + t.Errorf("Expected no errors, got: %v", err) + } + if len(c.TLS.Ciphers)-1 != 1 { + t.Errorf("Expected 1 ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1) + } +} + +// TODO: If we allow this... but probably not a good idea. +// func TestTLSDisableHTTPRedirect(t *testing.T) { +// c := NewTestController(`tls { +// allow_http +// }`) +// _, err := TLS(c) +// if err != nil { +// t.Errorf("Expected no error, but got %v", err) +// } +// if !c.TLS.DisableHTTPRedir { +// t.Error("Expected HTTP redirect to be disabled, but it wasn't") +// } +// } + func TestTLSParseWithWrongOptionalParams(t *testing.T) { // Test protocols wrong params params := `tls cert.crt cert.key { diff --git a/main.go b/main.go index f0055d422..acff133a3 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "github.com/mholt/caddy/caddy" "github.com/mholt/caddy/caddy/letsencrypt" + "github.com/xenolf/lego/acme" ) var ( @@ -53,6 +54,7 @@ func main() { caddy.AppName = appName caddy.AppVersion = appVersion + acme.UserAgent = appName + "/" + appVersion // set up process log before anything bad happens switch logfile { diff --git a/middleware/gzip/response_filter.go b/middleware/gzip/response_filter.go index b561649e5..3039eb9e6 100644 --- a/middleware/gzip/response_filter.go +++ b/middleware/gzip/response_filter.go @@ -40,8 +40,8 @@ func NewResponseFilterWriter(filters []ResponseFilter, gz *gzipResponseWriter) * return &ResponseFilterWriter{filters: filters, gzipResponseWriter: gz} } -// Write wraps underlying WriteHeader method and compresses if filters -// are satisfied. +// WriteHeader wraps underlying WriteHeader method and +// compresses if filters are satisfied. func (r *ResponseFilterWriter) WriteHeader(code int) { // Determine if compression should be used or not. r.shouldCompress = true diff --git a/middleware/gzip/response_filter_test.go b/middleware/gzip/response_filter_test.go index 75f726922..2878336c3 100644 --- a/middleware/gzip/response_filter_test.go +++ b/middleware/gzip/response_filter_test.go @@ -11,7 +11,7 @@ import ( ) func TestLengthFilter(t *testing.T) { - var filters []ResponseFilter = []ResponseFilter{ + var filters = []ResponseFilter{ LengthFilter(100), LengthFilter(1000), LengthFilter(0), diff --git a/middleware/rewrite/condition.go b/middleware/rewrite/condition.go index 220b489fc..1431afc9c 100644 --- a/middleware/rewrite/condition.go +++ b/middleware/rewrite/condition.go @@ -9,8 +9,8 @@ import ( "github.com/mholt/caddy/middleware" ) +// Operators const ( - // Operators Is = "is" Not = "not" Has = "has" diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go index ad3b4adfb..7d9793b2b 100644 --- a/middleware/rewrite/rewrite.go +++ b/middleware/rewrite/rewrite.go @@ -13,12 +13,12 @@ import ( "github.com/mholt/caddy/middleware" ) -// RewriteResult is the result of a rewrite -type RewriteResult int +// Result is the result of a rewrite +type Result int const ( // RewriteIgnored is returned when rewrite is not done on request. - RewriteIgnored RewriteResult = iota + RewriteIgnored Result = iota // RewriteDone is returned when rewrite is done on request. RewriteDone // RewriteStatus is returned when rewrite is not needed and status code should be set @@ -55,7 +55,7 @@ outer: // Rule describes an internal location rewrite rule. type Rule interface { // Rewrite rewrites the internal location of the current request. - Rewrite(http.FileSystem, *http.Request) RewriteResult + Rewrite(http.FileSystem, *http.Request) Result } // SimpleRule is a simple rewrite rule. @@ -69,7 +69,7 @@ func NewSimpleRule(from, to string) SimpleRule { } // Rewrite rewrites the internal location of the current request. -func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) RewriteResult { +func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result { if s.From == r.URL.Path { // take note of this rewrite for internal use by fastcgi // all we need is the URI, not full URL @@ -102,7 +102,7 @@ type ComplexRule struct { *regexp.Regexp } -// NewRegexpRule creates a new RegexpRule. It returns an error if regexp +// NewComplexRule creates a new RegexpRule. It returns an error if regexp // pattern (pattern) or extensions (ext) are invalid. func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) { // validate regexp if present @@ -136,7 +136,7 @@ func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If } // Rewrite rewrites the internal location of the current request. -func (r *ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re RewriteResult) { +func (r *ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re Result) { rPath := req.URL.Path replacer := newReplacer(req) diff --git a/middleware/rewrite/to.go b/middleware/rewrite/to.go index de07b7fbe..7a38349ff 100644 --- a/middleware/rewrite/to.go +++ b/middleware/rewrite/to.go @@ -13,7 +13,7 @@ import ( // To attempts rewrite. It attempts to rewrite to first valid path // or the last path if none of the paths are valid. // Returns true if rewrite is successful and false otherwise. -func To(fs http.FileSystem, r *http.Request, to string, replacer middleware.Replacer) RewriteResult { +func To(fs http.FileSystem, r *http.Request, to string, replacer middleware.Replacer) Result { tos := strings.Fields(to) // try each rewrite paths diff --git a/server/config.go b/server/config.go index f9ec05fc3..11d69e142 100644 --- a/server/config.go +++ b/server/config.go @@ -17,6 +17,9 @@ type Config struct { // The port to listen on Port string + // The protocol (http/https) to serve with this config; only set if user explicitly specifies it + Scheme string + // The directory from which to serve files Root string @@ -62,10 +65,12 @@ func (c Config) Address() string { // TLSConfig describes how TLS should be configured and used. type TLSConfig struct { - Enabled bool - Certificate string - Key string - LetsEncryptEmail string + Enabled bool + Certificate string + Key string + LetsEncryptEmail string + Managed bool // will be set to true if config qualifies for automatic, managed TLS + //DisableHTTPRedir bool // TODO: not a good idea - should we really allow it? OCSPStaple []byte Ciphers []uint16 ProtocolMinVersion uint16 diff --git a/server/config_test.go b/server/config_test.go index d94f3581e..8787e467b 100644 --- a/server/config_test.go +++ b/server/config_test.go @@ -18,8 +18,8 @@ func TestConfigAddress(t *testing.T) { t.Errorf("Expected '%s' but got '%s'", expected, actual) } - cfg = Config{Host: "::1", Port: "https"} - if actual, expected := cfg.Address(), "[::1]:https"; expected != actual { + cfg = Config{Host: "::1", Port: "443"} + if actual, expected := cfg.Address(), "[::1]:443"; expected != actual { t.Errorf("Expected '%s' but got '%s'", expected, actual) } } diff --git a/server/server.go b/server/server.go index 96979aad2..7e828957b 100644 --- a/server/server.go +++ b/server/server.go @@ -33,6 +33,7 @@ type Server struct { httpWg sync.WaitGroup // used to wait on outstanding connections startChan chan struct{} // used to block until server is finished starting connTimeout time.Duration // the maximum duration of a graceful shutdown + ReqCallback OptionalCallback // if non-nil, is executed at the beginning of every request } // ListenerFile represents a listener. @@ -41,6 +42,11 @@ type ListenerFile interface { File() (*os.File, error) } +// OptionalCallback is a function that may or may not handle a request. +// It returns whether or not it handled the request. If it handled the +// request, it is presumed that no further request handling should occur. +type OptionalCallback func(http.ResponseWriter, *http.Request) bool + // New creates a new Server which will bind to addr and serve // the sites/hosts configured in configs. Its listener will // gracefully close when the server is stopped which will take @@ -309,6 +315,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } }() + w.Header().Set("Server", "Caddy") + + // Execute the optional request callback if it exists + if s.ReqCallback != nil && s.ReqCallback(w, r) { + return + } + host, _, err := net.SplitHostPort(r.Host) if err != nil { host = r.Host // oh well @@ -324,8 +337,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if vh, ok := s.vhosts[host]; ok { - w.Header().Set("Server", "Caddy") - status, _ := vh.stack.ServeHTTP(w, r) // Fallback error response in case error handling wasn't chained in