mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-23 22:27:38 -05:00
PoC: on-demand TLS
Implements "on-demand TLS" as I call it, which means obtaining TLS certificates on-the-fly during TLS handshakes if a certificate for the requested hostname is not already available. Only the first request for a new hostname will experience higher latency; subsequent requests will get the new certificates right out of memory. Code still needs lots of cleanup but the feature is basically working.
This commit is contained in:
parent
b4cab78bec
commit
47079c3d24
4 changed files with 218 additions and 45 deletions
|
@ -193,6 +193,7 @@ func startServers(groupings bindingGroup) error {
|
|||
}
|
||||
s.HTTP2 = HTTP2 // TODO: This setting is temporary
|
||||
s.ReqCallback = letsencrypt.RequestCallback // ensures we can solve ACME challenges while running
|
||||
s.SNICallback = letsencrypt.GetCertificateDuringHandshake // TLS on demand -- awesome!
|
||||
|
||||
var ln server.ListenerFile
|
||||
if IsRestart() {
|
||||
|
|
99
caddy/letsencrypt/handshake.go
Normal file
99
caddy/letsencrypt/handshake.go
Normal file
|
@ -0,0 +1,99 @@
|
|||
package letsencrypt
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mholt/caddy/server"
|
||||
)
|
||||
|
||||
// GetCertificateDuringHandshake is a function that gets a certificate during a TLS handshake.
|
||||
// It first checks an in-memory cache in case the cert was requested before, then tries to load
|
||||
// a certificate in the storage folder from disk. If it can't find an existing certificate, it
|
||||
// will try to obtain one using ACME, which will then be stored on disk and cached in memory.
|
||||
//
|
||||
// This function is safe for use by multiple concurrent goroutines.
|
||||
func GetCertificateDuringHandshake(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
// Utility function to help us load a cert from disk and put it in the cache if successful
|
||||
loadCertFromDisk := func(domain string) *tls.Certificate {
|
||||
cert, err := tls.LoadX509KeyPair(storage.SiteCertFile(domain), storage.SiteKeyFile(domain))
|
||||
if err == nil {
|
||||
certCacheMu.Lock()
|
||||
if len(certCache) < 10000 { // limit size of cache to prevent a ridiculous, unusual kind of attack
|
||||
certCache[domain] = &cert
|
||||
}
|
||||
certCacheMu.Unlock()
|
||||
return &cert
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// First check our in-memory cache to see if we've already loaded it
|
||||
certCacheMu.RLock()
|
||||
cert := server.GetCertificateFromCache(clientHello, certCache)
|
||||
certCacheMu.RUnlock()
|
||||
if cert != nil {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// Then check to see if we already have one on disk; if we do, add it to cache and use it
|
||||
name := strings.ToLower(clientHello.ServerName)
|
||||
cert = loadCertFromDisk(name)
|
||||
if cert != nil {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// Only option left is to get one from LE, but the name has to qualify first
|
||||
if !HostQualifies(name) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// By this point, we need to obtain one from the CA. We must protect this process
|
||||
// from happening concurrently, so synchronize.
|
||||
obtainCertWaitGroupsMutex.Lock()
|
||||
wg, ok := obtainCertWaitGroups[name]
|
||||
if ok {
|
||||
// lucky us -- another goroutine is already obtaining the certificate.
|
||||
// wait for it to finish obtaining the cert and then we'll use it.
|
||||
obtainCertWaitGroupsMutex.Unlock()
|
||||
wg.Wait()
|
||||
return GetCertificateDuringHandshake(clientHello)
|
||||
}
|
||||
|
||||
// looks like it's up to us to do all the work and obtain the cert
|
||||
wg = new(sync.WaitGroup)
|
||||
wg.Add(1)
|
||||
obtainCertWaitGroups[name] = wg
|
||||
obtainCertWaitGroupsMutex.Unlock()
|
||||
|
||||
// Unblock waiters and delete waitgroup when we return
|
||||
defer func() {
|
||||
obtainCertWaitGroupsMutex.Lock()
|
||||
wg.Done()
|
||||
delete(obtainCertWaitGroups, name)
|
||||
obtainCertWaitGroupsMutex.Unlock()
|
||||
}()
|
||||
|
||||
// obtain cert
|
||||
client, err := newClientPort(DefaultEmail, AlternatePort)
|
||||
if err != nil {
|
||||
return nil, errors.New("error creating client: " + err.Error())
|
||||
}
|
||||
err = clientObtain(client, []string{name}, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// load certificate into memory and return it
|
||||
return loadCertFromDisk(name), nil
|
||||
}
|
||||
|
||||
// obtainCertWaitGroups is used to coordinate obtaining certs for each hostname.
|
||||
var obtainCertWaitGroups = make(map[string]*sync.WaitGroup)
|
||||
var obtainCertWaitGroupsMutex sync.Mutex
|
||||
|
||||
// certCache stores certificates that have been obtained in memory.
|
||||
var certCache = make(map[string]*tls.Certificate)
|
||||
var certCacheMu sync.RWMutex
|
|
@ -6,6 +6,7 @@ package letsencrypt
|
|||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -82,6 +83,13 @@ func Activate(configs []server.Config) ([]server.Config, error) {
|
|||
// keep certificates renewed and OCSP stapling updated
|
||||
go maintainAssets(configs, stopChan)
|
||||
|
||||
// TODO - experimental dynamic TLS!
|
||||
for i := range configs {
|
||||
if configs[i].Host == "" && configs[i].Port == "443" {
|
||||
configs[i].TLS.Enabled = true
|
||||
}
|
||||
}
|
||||
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
|
@ -127,41 +135,9 @@ func ObtainCerts(configs []server.Config, altPort string) error {
|
|||
continue
|
||||
}
|
||||
|
||||
Obtain:
|
||||
certificate, failures := client.ObtainCertificate([]string{cfg.Host}, true, nil)
|
||||
if len(failures) == 0 {
|
||||
// Success - immediately save the certificate resource
|
||||
err := saveCertResource(certificate)
|
||||
err := clientObtain(client, []string{cfg.Host}, altPort == "")
|
||||
if err != nil {
|
||||
return errors.New("error saving assets for " + cfg.Host + ": " + err.Error())
|
||||
}
|
||||
} else {
|
||||
// Error - either try to fix it or report them it to the user and abort
|
||||
var errMsg string // we'll combine all the failures into a single error message
|
||||
var promptedForAgreement bool // only prompt user for agreement at most once
|
||||
|
||||
for errDomain, obtainErr := range failures {
|
||||
// TODO: Double-check, will obtainErr ever be nil?
|
||||
if tosErr, ok := obtainErr.(acme.TOSError); ok {
|
||||
// Terms of Service agreement error; we can probably deal with this
|
||||
if !Agreed && !promptedForAgreement && altPort == "" { // don't prompt if server is already running
|
||||
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
|
||||
promptedForAgreement = true
|
||||
}
|
||||
if Agreed || altPort != "" {
|
||||
err := client.AgreeToTOS()
|
||||
if err != nil {
|
||||
return errors.New("error agreeing to updated terms: " + err.Error())
|
||||
}
|
||||
goto Obtain
|
||||
}
|
||||
}
|
||||
|
||||
// If user did not agree or it was any other kind of error, just append to the list of errors
|
||||
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
|
||||
}
|
||||
|
||||
return errors.New(errMsg)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -447,6 +423,49 @@ func redirPlaintextHost(cfg server.Config) server.Config {
|
|||
}
|
||||
}
|
||||
|
||||
// clientObtain uses client to obtain a single certificate for domains in names. If
|
||||
// the user is present to provide an email address, pass in true for allowPrompt,
|
||||
// otherwise pass in false. If err == nil, the certificate (and key) will be saved
|
||||
// to disk in the storage folder.
|
||||
func clientObtain(client *acme.Client, names []string, allowPrompt bool) error {
|
||||
certificate, failures := client.ObtainCertificate(names, true, nil)
|
||||
if len(failures) > 0 {
|
||||
// Error - either try to fix it or report them it to the user and abort
|
||||
var errMsg string // we'll combine all the failures into a single error message
|
||||
var promptedForAgreement bool // only prompt user for agreement at most once
|
||||
|
||||
for errDomain, obtainErr := range failures {
|
||||
// TODO: Double-check, will obtainErr ever be nil?
|
||||
if tosErr, ok := obtainErr.(acme.TOSError); ok {
|
||||
// Terms of Service agreement error; we can probably deal with this
|
||||
if !Agreed && !promptedForAgreement && allowPrompt { // don't prompt if server is already running
|
||||
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
|
||||
promptedForAgreement = true
|
||||
}
|
||||
if Agreed || !allowPrompt {
|
||||
err := client.AgreeToTOS()
|
||||
if err != nil {
|
||||
return errors.New("error agreeing to updated terms: " + err.Error())
|
||||
}
|
||||
return clientObtain(client, names, allowPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// If user did not agree or it was any other kind of error, just append to the list of errors
|
||||
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
|
||||
}
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
// Success - immediately save the certificate resource
|
||||
err := saveCertResource(certificate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error saving assets for %v: %v", names, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Revoke revokes the certificate for host via ACME protocol.
|
||||
func Revoke(host string) error {
|
||||
if !existingCertAndKey(host) {
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
@ -33,6 +34,7 @@ type Server struct {
|
|||
startChan chan struct{} // used to block until server is finished starting
|
||||
connTimeout time.Duration // the maximum duration of a graceful shutdown
|
||||
ReqCallback OptionalCallback // if non-nil, is executed at the beginning of every request
|
||||
SNICallback func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
}
|
||||
|
||||
// ListenerFile represents a listener.
|
||||
|
@ -206,17 +208,39 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
|
|||
|
||||
// Here we diverge from the stdlib a bit by loading multiple certs/key pairs
|
||||
// then we map the server names to their certs
|
||||
var err error
|
||||
config.Certificates = make([]tls.Certificate, len(tlsConfigs))
|
||||
for i, tlsConfig := range tlsConfigs {
|
||||
config.Certificates[i], err = tls.LoadX509KeyPair(tlsConfig.Certificate, tlsConfig.Key)
|
||||
config.Certificates[i].OCSPStaple = tlsConfig.OCSPStaple
|
||||
for _, tlsConfig := range tlsConfigs {
|
||||
if tlsConfig.Certificate == "" || tlsConfig.Key == "" {
|
||||
continue
|
||||
}
|
||||
cert, err := tls.LoadX509KeyPair(tlsConfig.Certificate, tlsConfig.Key)
|
||||
if err != nil {
|
||||
defer close(s.startChan)
|
||||
return err
|
||||
return fmt.Errorf("loading certificate and key pair: %v", err)
|
||||
}
|
||||
cert.OCSPStaple = tlsConfig.OCSPStaple
|
||||
config.Certificates = append(config.Certificates, cert)
|
||||
}
|
||||
if len(config.Certificates) > 0 {
|
||||
config.BuildNameToCertificate()
|
||||
}
|
||||
|
||||
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
// TODO: When Caddy starts, if it is to issue certs dynamically, we need
|
||||
// terms agreement and an email address. make sure this is enforced at server
|
||||
// start if the Caddyfile enables dynamic certificate issuance!
|
||||
|
||||
// Check NameToCertificate like the std lib does in "getCertificate" (unexported, bah)
|
||||
cert := GetCertificateFromCache(clientHello, config.NameToCertificate)
|
||||
if cert != nil {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
if s.SNICallback != nil {
|
||||
return s.SNICallback(clientHello)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Customize our TLS configuration
|
||||
config.MinVersion = tlsConfigs[0].ProtocolMinVersion
|
||||
|
@ -225,7 +249,7 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
|
|||
config.PreferServerCipherSuites = tlsConfigs[0].PreferServerCipherSuites
|
||||
|
||||
// TLS client authentication, if user enabled it
|
||||
err = setupClientAuth(tlsConfigs, config)
|
||||
err := setupClientAuth(tlsConfigs, config)
|
||||
if err != nil {
|
||||
defer close(s.startChan)
|
||||
return err
|
||||
|
@ -242,6 +266,36 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
|
|||
return s.Server.Serve(ln)
|
||||
}
|
||||
|
||||
// Borrowed from the Go standard library, crypto/tls pacakge, common.go.
|
||||
// It has been modified to fit this program.
|
||||
// Original license:
|
||||
//
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
func GetCertificateFromCache(clientHello *tls.ClientHelloInfo, cache map[string]*tls.Certificate) *tls.Certificate {
|
||||
name := strings.ToLower(clientHello.ServerName)
|
||||
for len(name) > 0 && name[len(name)-1] == '.' {
|
||||
name = name[:len(name)-1]
|
||||
}
|
||||
|
||||
// exact match? great! use it
|
||||
if cert, ok := cache[name]; ok {
|
||||
return cert
|
||||
}
|
||||
|
||||
// try replacing labels in the name with wildcards until we get a match.
|
||||
labels := strings.Split(name, ".")
|
||||
for i := range labels {
|
||||
labels[i] = "*"
|
||||
candidate := strings.Join(labels, ".")
|
||||
if cert, ok := cache[candidate]; ok {
|
||||
return cert
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the server. It blocks until the server is
|
||||
// totally stopped. On POSIX systems, it will wait for
|
||||
// connections to close (up to a max timeout of a few
|
||||
|
|
Loading…
Reference in a new issue