mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-23 22:27:38 -05:00
PoC: on-demand TLS
Implements "on-demand TLS" as I call it, which means obtaining TLS certificates on-the-fly during TLS handshakes if a certificate for the requested hostname is not already available. Only the first request for a new hostname will experience higher latency; subsequent requests will get the new certificates right out of memory. Code still needs lots of cleanup but the feature is basically working.
This commit is contained in:
parent
b4cab78bec
commit
47079c3d24
4 changed files with 218 additions and 45 deletions
|
@ -191,8 +191,9 @@ func startServers(groupings bindingGroup) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.HTTP2 = HTTP2 // TODO: This setting is temporary
|
s.HTTP2 = HTTP2 // TODO: This setting is temporary
|
||||||
s.ReqCallback = letsencrypt.RequestCallback // ensures we can solve ACME challenges while running
|
s.ReqCallback = letsencrypt.RequestCallback // ensures we can solve ACME challenges while running
|
||||||
|
s.SNICallback = letsencrypt.GetCertificateDuringHandshake // TLS on demand -- awesome!
|
||||||
|
|
||||||
var ln server.ListenerFile
|
var ln server.ListenerFile
|
||||||
if IsRestart() {
|
if IsRestart() {
|
||||||
|
|
99
caddy/letsencrypt/handshake.go
Normal file
99
caddy/letsencrypt/handshake.go
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
package letsencrypt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/mholt/caddy/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetCertificateDuringHandshake is a function that gets a certificate during a TLS handshake.
|
||||||
|
// It first checks an in-memory cache in case the cert was requested before, then tries to load
|
||||||
|
// a certificate in the storage folder from disk. If it can't find an existing certificate, it
|
||||||
|
// will try to obtain one using ACME, which will then be stored on disk and cached in memory.
|
||||||
|
//
|
||||||
|
// This function is safe for use by multiple concurrent goroutines.
|
||||||
|
func GetCertificateDuringHandshake(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
// Utility function to help us load a cert from disk and put it in the cache if successful
|
||||||
|
loadCertFromDisk := func(domain string) *tls.Certificate {
|
||||||
|
cert, err := tls.LoadX509KeyPair(storage.SiteCertFile(domain), storage.SiteKeyFile(domain))
|
||||||
|
if err == nil {
|
||||||
|
certCacheMu.Lock()
|
||||||
|
if len(certCache) < 10000 { // limit size of cache to prevent a ridiculous, unusual kind of attack
|
||||||
|
certCache[domain] = &cert
|
||||||
|
}
|
||||||
|
certCacheMu.Unlock()
|
||||||
|
return &cert
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// First check our in-memory cache to see if we've already loaded it
|
||||||
|
certCacheMu.RLock()
|
||||||
|
cert := server.GetCertificateFromCache(clientHello, certCache)
|
||||||
|
certCacheMu.RUnlock()
|
||||||
|
if cert != nil {
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then check to see if we already have one on disk; if we do, add it to cache and use it
|
||||||
|
name := strings.ToLower(clientHello.ServerName)
|
||||||
|
cert = loadCertFromDisk(name)
|
||||||
|
if cert != nil {
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only option left is to get one from LE, but the name has to qualify first
|
||||||
|
if !HostQualifies(name) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// By this point, we need to obtain one from the CA. We must protect this process
|
||||||
|
// from happening concurrently, so synchronize.
|
||||||
|
obtainCertWaitGroupsMutex.Lock()
|
||||||
|
wg, ok := obtainCertWaitGroups[name]
|
||||||
|
if ok {
|
||||||
|
// lucky us -- another goroutine is already obtaining the certificate.
|
||||||
|
// wait for it to finish obtaining the cert and then we'll use it.
|
||||||
|
obtainCertWaitGroupsMutex.Unlock()
|
||||||
|
wg.Wait()
|
||||||
|
return GetCertificateDuringHandshake(clientHello)
|
||||||
|
}
|
||||||
|
|
||||||
|
// looks like it's up to us to do all the work and obtain the cert
|
||||||
|
wg = new(sync.WaitGroup)
|
||||||
|
wg.Add(1)
|
||||||
|
obtainCertWaitGroups[name] = wg
|
||||||
|
obtainCertWaitGroupsMutex.Unlock()
|
||||||
|
|
||||||
|
// Unblock waiters and delete waitgroup when we return
|
||||||
|
defer func() {
|
||||||
|
obtainCertWaitGroupsMutex.Lock()
|
||||||
|
wg.Done()
|
||||||
|
delete(obtainCertWaitGroups, name)
|
||||||
|
obtainCertWaitGroupsMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// obtain cert
|
||||||
|
client, err := newClientPort(DefaultEmail, AlternatePort)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("error creating client: " + err.Error())
|
||||||
|
}
|
||||||
|
err = clientObtain(client, []string{name}, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// load certificate into memory and return it
|
||||||
|
return loadCertFromDisk(name), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// obtainCertWaitGroups is used to coordinate obtaining certs for each hostname.
|
||||||
|
var obtainCertWaitGroups = make(map[string]*sync.WaitGroup)
|
||||||
|
var obtainCertWaitGroupsMutex sync.Mutex
|
||||||
|
|
||||||
|
// certCache stores certificates that have been obtained in memory.
|
||||||
|
var certCache = make(map[string]*tls.Certificate)
|
||||||
|
var certCacheMu sync.RWMutex
|
|
@ -6,6 +6,7 @@ package letsencrypt
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -82,6 +83,13 @@ func Activate(configs []server.Config) ([]server.Config, error) {
|
||||||
// keep certificates renewed and OCSP stapling updated
|
// keep certificates renewed and OCSP stapling updated
|
||||||
go maintainAssets(configs, stopChan)
|
go maintainAssets(configs, stopChan)
|
||||||
|
|
||||||
|
// TODO - experimental dynamic TLS!
|
||||||
|
for i := range configs {
|
||||||
|
if configs[i].Host == "" && configs[i].Port == "443" {
|
||||||
|
configs[i].TLS.Enabled = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return configs, nil
|
return configs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,41 +135,9 @@ func ObtainCerts(configs []server.Config, altPort string) error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
Obtain:
|
err := clientObtain(client, []string{cfg.Host}, altPort == "")
|
||||||
certificate, failures := client.ObtainCertificate([]string{cfg.Host}, true, nil)
|
if err != nil {
|
||||||
if len(failures) == 0 {
|
return err
|
||||||
// Success - immediately save the certificate resource
|
|
||||||
err := saveCertResource(certificate)
|
|
||||||
if err != nil {
|
|
||||||
return errors.New("error saving assets for " + cfg.Host + ": " + err.Error())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Error - either try to fix it or report them it to the user and abort
|
|
||||||
var errMsg string // we'll combine all the failures into a single error message
|
|
||||||
var promptedForAgreement bool // only prompt user for agreement at most once
|
|
||||||
|
|
||||||
for errDomain, obtainErr := range failures {
|
|
||||||
// TODO: Double-check, will obtainErr ever be nil?
|
|
||||||
if tosErr, ok := obtainErr.(acme.TOSError); ok {
|
|
||||||
// Terms of Service agreement error; we can probably deal with this
|
|
||||||
if !Agreed && !promptedForAgreement && altPort == "" { // don't prompt if server is already running
|
|
||||||
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
|
|
||||||
promptedForAgreement = true
|
|
||||||
}
|
|
||||||
if Agreed || altPort != "" {
|
|
||||||
err := client.AgreeToTOS()
|
|
||||||
if err != nil {
|
|
||||||
return errors.New("error agreeing to updated terms: " + err.Error())
|
|
||||||
}
|
|
||||||
goto Obtain
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If user did not agree or it was any other kind of error, just append to the list of errors
|
|
||||||
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors.New(errMsg)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -447,6 +423,49 @@ func redirPlaintextHost(cfg server.Config) server.Config {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clientObtain uses client to obtain a single certificate for domains in names. If
|
||||||
|
// the user is present to provide an email address, pass in true for allowPrompt,
|
||||||
|
// otherwise pass in false. If err == nil, the certificate (and key) will be saved
|
||||||
|
// to disk in the storage folder.
|
||||||
|
func clientObtain(client *acme.Client, names []string, allowPrompt bool) error {
|
||||||
|
certificate, failures := client.ObtainCertificate(names, true, nil)
|
||||||
|
if len(failures) > 0 {
|
||||||
|
// Error - either try to fix it or report them it to the user and abort
|
||||||
|
var errMsg string // we'll combine all the failures into a single error message
|
||||||
|
var promptedForAgreement bool // only prompt user for agreement at most once
|
||||||
|
|
||||||
|
for errDomain, obtainErr := range failures {
|
||||||
|
// TODO: Double-check, will obtainErr ever be nil?
|
||||||
|
if tosErr, ok := obtainErr.(acme.TOSError); ok {
|
||||||
|
// Terms of Service agreement error; we can probably deal with this
|
||||||
|
if !Agreed && !promptedForAgreement && allowPrompt { // don't prompt if server is already running
|
||||||
|
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
|
||||||
|
promptedForAgreement = true
|
||||||
|
}
|
||||||
|
if Agreed || !allowPrompt {
|
||||||
|
err := client.AgreeToTOS()
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("error agreeing to updated terms: " + err.Error())
|
||||||
|
}
|
||||||
|
return clientObtain(client, names, allowPrompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If user did not agree or it was any other kind of error, just append to the list of errors
|
||||||
|
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
|
||||||
|
}
|
||||||
|
return errors.New(errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success - immediately save the certificate resource
|
||||||
|
err := saveCertResource(certificate)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error saving assets for %v: %v", names, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Revoke revokes the certificate for host via ACME protocol.
|
// Revoke revokes the certificate for host via ACME protocol.
|
||||||
func Revoke(host string) error {
|
func Revoke(host string) error {
|
||||||
if !existingCertAndKey(host) {
|
if !existingCertAndKey(host) {
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -33,6 +34,7 @@ type Server struct {
|
||||||
startChan chan struct{} // used to block until server is finished starting
|
startChan chan struct{} // used to block until server is finished starting
|
||||||
connTimeout time.Duration // the maximum duration of a graceful shutdown
|
connTimeout time.Duration // the maximum duration of a graceful shutdown
|
||||||
ReqCallback OptionalCallback // if non-nil, is executed at the beginning of every request
|
ReqCallback OptionalCallback // if non-nil, is executed at the beginning of every request
|
||||||
|
SNICallback func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenerFile represents a listener.
|
// ListenerFile represents a listener.
|
||||||
|
@ -206,17 +208,39 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
|
||||||
|
|
||||||
// Here we diverge from the stdlib a bit by loading multiple certs/key pairs
|
// Here we diverge from the stdlib a bit by loading multiple certs/key pairs
|
||||||
// then we map the server names to their certs
|
// then we map the server names to their certs
|
||||||
var err error
|
for _, tlsConfig := range tlsConfigs {
|
||||||
config.Certificates = make([]tls.Certificate, len(tlsConfigs))
|
if tlsConfig.Certificate == "" || tlsConfig.Key == "" {
|
||||||
for i, tlsConfig := range tlsConfigs {
|
continue
|
||||||
config.Certificates[i], err = tls.LoadX509KeyPair(tlsConfig.Certificate, tlsConfig.Key)
|
}
|
||||||
config.Certificates[i].OCSPStaple = tlsConfig.OCSPStaple
|
cert, err := tls.LoadX509KeyPair(tlsConfig.Certificate, tlsConfig.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
defer close(s.startChan)
|
defer close(s.startChan)
|
||||||
return err
|
return fmt.Errorf("loading certificate and key pair: %v", err)
|
||||||
}
|
}
|
||||||
|
cert.OCSPStaple = tlsConfig.OCSPStaple
|
||||||
|
config.Certificates = append(config.Certificates, cert)
|
||||||
|
}
|
||||||
|
if len(config.Certificates) > 0 {
|
||||||
|
config.BuildNameToCertificate()
|
||||||
|
}
|
||||||
|
|
||||||
|
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
// TODO: When Caddy starts, if it is to issue certs dynamically, we need
|
||||||
|
// terms agreement and an email address. make sure this is enforced at server
|
||||||
|
// start if the Caddyfile enables dynamic certificate issuance!
|
||||||
|
|
||||||
|
// Check NameToCertificate like the std lib does in "getCertificate" (unexported, bah)
|
||||||
|
cert := GetCertificateFromCache(clientHello, config.NameToCertificate)
|
||||||
|
if cert != nil {
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.SNICallback != nil {
|
||||||
|
return s.SNICallback(clientHello)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
}
|
}
|
||||||
config.BuildNameToCertificate()
|
|
||||||
|
|
||||||
// Customize our TLS configuration
|
// Customize our TLS configuration
|
||||||
config.MinVersion = tlsConfigs[0].ProtocolMinVersion
|
config.MinVersion = tlsConfigs[0].ProtocolMinVersion
|
||||||
|
@ -225,7 +249,7 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
|
||||||
config.PreferServerCipherSuites = tlsConfigs[0].PreferServerCipherSuites
|
config.PreferServerCipherSuites = tlsConfigs[0].PreferServerCipherSuites
|
||||||
|
|
||||||
// TLS client authentication, if user enabled it
|
// TLS client authentication, if user enabled it
|
||||||
err = setupClientAuth(tlsConfigs, config)
|
err := setupClientAuth(tlsConfigs, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
defer close(s.startChan)
|
defer close(s.startChan)
|
||||||
return err
|
return err
|
||||||
|
@ -242,6 +266,36 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
|
||||||
return s.Server.Serve(ln)
|
return s.Server.Serve(ln)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Borrowed from the Go standard library, crypto/tls pacakge, common.go.
|
||||||
|
// It has been modified to fit this program.
|
||||||
|
// Original license:
|
||||||
|
//
|
||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
func GetCertificateFromCache(clientHello *tls.ClientHelloInfo, cache map[string]*tls.Certificate) *tls.Certificate {
|
||||||
|
name := strings.ToLower(clientHello.ServerName)
|
||||||
|
for len(name) > 0 && name[len(name)-1] == '.' {
|
||||||
|
name = name[:len(name)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// exact match? great! use it
|
||||||
|
if cert, ok := cache[name]; ok {
|
||||||
|
return cert
|
||||||
|
}
|
||||||
|
|
||||||
|
// try replacing labels in the name with wildcards until we get a match.
|
||||||
|
labels := strings.Split(name, ".")
|
||||||
|
for i := range labels {
|
||||||
|
labels[i] = "*"
|
||||||
|
candidate := strings.Join(labels, ".")
|
||||||
|
if cert, ok := cache[candidate]; ok {
|
||||||
|
return cert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Stop stops the server. It blocks until the server is
|
// Stop stops the server. It blocks until the server is
|
||||||
// totally stopped. On POSIX systems, it will wait for
|
// totally stopped. On POSIX systems, it will wait for
|
||||||
// connections to close (up to a max timeout of a few
|
// connections to close (up to a max timeout of a few
|
||||||
|
|
Loading…
Reference in a new issue