mirror of
https://github.com/caddyserver/caddy.git
synced 2025-04-01 02:42:35 -05:00
Merge branch 'master' into diagnostics
# Conflicts: # plugins.go # vendor/manifest
This commit is contained in:
commit
269a8b5fce
270 changed files with 28231 additions and 4887 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -16,4 +16,6 @@ Caddyfile
|
|||
|
||||
og_static/
|
||||
|
||||
.vscode/
|
||||
.vscode/
|
||||
|
||||
*.bat
|
|
@ -1,5 +1,5 @@
|
|||
<p align="center">
|
||||
<a href="https://caddyserver.com"><img src="https://cloud.githubusercontent.com/assets/1128849/25305033/12916fce-2731-11e7-86ec-580d4d31cb16.png" alt="Caddy" width="400"></a>
|
||||
<a href="https://caddyserver.com"><img src="https://user-images.githubusercontent.com/1128849/36137292-bebc223a-1051-11e8-9a81-4ea9054c96ac.png" alt="Caddy" width="400"></a>
|
||||
</p>
|
||||
<h3 align="center">Every Site on HTTPS <!-- Serve Confidently --></h3>
|
||||
<p align="center">Caddy is a general-purpose HTTP/2 web server that serves HTTPS by default.</p>
|
||||
|
@ -59,7 +59,7 @@ customize your build in the browser
|
|||
pre-built, vanilla binaries
|
||||
|
||||
## Build
|
||||
To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.8 or newer). Follow these instruction for fast building:
|
||||
To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.9 or newer). Follow these instruction for fast building:
|
||||
|
||||
- Get source `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds`
|
||||
- Now `cd` to `$GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go`
|
||||
|
|
66
caddy.go
66
caddy.go
|
@ -78,8 +78,18 @@ var (
|
|||
mu sync.Mutex
|
||||
)
|
||||
|
||||
func init() {
|
||||
OnProcessExit = append(OnProcessExit, func() {
|
||||
if PidFile != "" {
|
||||
os.Remove(PidFile)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Instance contains the state of servers created as a result of
|
||||
// calling Start and can be used to access or control those servers.
|
||||
// It is literally an instance of a server type. Instance values
|
||||
// should NOT be copied. Use *Instance for safety.
|
||||
type Instance struct {
|
||||
// serverType is the name of the instance's server type
|
||||
serverType string
|
||||
|
@ -90,10 +100,11 @@ type Instance struct {
|
|||
// wg is used to wait for all servers to shut down
|
||||
wg *sync.WaitGroup
|
||||
|
||||
// context is the context created for this instance.
|
||||
// context is the context created for this instance,
|
||||
// used to coordinate the setting up of the server type
|
||||
context Context
|
||||
|
||||
// servers is the list of servers with their listeners.
|
||||
// servers is the list of servers with their listeners
|
||||
servers []ServerListener
|
||||
|
||||
// these callbacks execute when certain events occur
|
||||
|
@ -102,6 +113,18 @@ type Instance struct {
|
|||
onRestart []func() error // before restart commences
|
||||
onShutdown []func() error // stopping, even as part of a restart
|
||||
onFinalShutdown []func() error // stopping, not as part of a restart
|
||||
|
||||
// storing values on an instance is preferable to
|
||||
// global state because these will get garbage-
|
||||
// collected after in-process reloads when the
|
||||
// old instances are destroyed; use StorageMu
|
||||
// to access this value safely
|
||||
Storage map[interface{}]interface{}
|
||||
StorageMu sync.RWMutex
|
||||
}
|
||||
|
||||
func Instances() []*Instance {
|
||||
return instances
|
||||
}
|
||||
|
||||
// Servers returns the ServerListeners in i.
|
||||
|
@ -197,7 +220,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) {
|
|||
}
|
||||
|
||||
// create new instance; if the restart fails, it is simply discarded
|
||||
newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg}
|
||||
newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg, Storage: make(map[interface{}]interface{})}
|
||||
|
||||
// attempt to start new instance
|
||||
err := startWithListenerFds(newCaddyfile, newInst, restartFds)
|
||||
|
@ -456,7 +479,7 @@ func (i *Instance) Caddyfile() Input {
|
|||
//
|
||||
// This function blocks until all the servers are listening.
|
||||
func Start(cdyfile Input) (*Instance, error) {
|
||||
inst := &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup)}
|
||||
inst := &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup), Storage: make(map[interface{}]interface{})}
|
||||
err := startWithListenerFds(cdyfile, inst, nil)
|
||||
if err != nil {
|
||||
return inst, err
|
||||
|
@ -469,11 +492,34 @@ func Start(cdyfile Input) (*Instance, error) {
|
|||
}
|
||||
|
||||
func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]restartTriple) error {
|
||||
// save this instance in the list now so that
|
||||
// plugins can access it if need be, for example
|
||||
// the caddytls package, so it can perform cert
|
||||
// renewals while starting up; we just have to
|
||||
// remove the instance from the list later if
|
||||
// it fails
|
||||
instancesMu.Lock()
|
||||
instances = append(instances, inst)
|
||||
instancesMu.Unlock()
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
instancesMu.Lock()
|
||||
for i, otherInst := range instances {
|
||||
if otherInst == inst {
|
||||
instances = append(instances[:i], instances[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
instancesMu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
if cdyfile == nil {
|
||||
cdyfile = CaddyfileInput{}
|
||||
}
|
||||
|
||||
err := ValidateAndExecuteDirectives(cdyfile, inst, false)
|
||||
err = ValidateAndExecuteDirectives(cdyfile, inst, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -505,10 +551,6 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
|
|||
return err
|
||||
}
|
||||
|
||||
instancesMu.Lock()
|
||||
instances = append(instances, inst)
|
||||
instancesMu.Unlock()
|
||||
|
||||
// run any AfterStartup callbacks if this is not
|
||||
// part of a restart; then show file descriptor notice
|
||||
if restartFds == nil {
|
||||
|
@ -547,7 +589,7 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
|
|||
func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bool) error {
|
||||
// If parsing only inst will be nil, create an instance for this function call only.
|
||||
if justValidate {
|
||||
inst = &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup)}
|
||||
inst = &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup), Storage: make(map[interface{}]interface{})}
|
||||
}
|
||||
|
||||
stypeName := cdyfile.ServerType()
|
||||
|
@ -564,14 +606,14 @@ func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bo
|
|||
return err
|
||||
}
|
||||
|
||||
inst.context = stype.NewContext()
|
||||
inst.context = stype.NewContext(inst)
|
||||
if inst.context == nil {
|
||||
return fmt.Errorf("server type %s produced a nil Context", stypeName)
|
||||
}
|
||||
|
||||
sblocks, err = inst.context.InspectServerBlocks(cdyfile.Path(), sblocks)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error inspecting server blocks: %v", err)
|
||||
}
|
||||
|
||||
diagnostics.Set("num_server_blocks", len(sblocks))
|
||||
|
|
|
@ -148,7 +148,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
|||
case "HEAD":
|
||||
resp, err = fcgiBackend.Head(env)
|
||||
case "GET":
|
||||
resp, err = fcgiBackend.Get(env)
|
||||
resp, err = fcgiBackend.Get(env, r.Body, contentLength)
|
||||
case "OPTIONS":
|
||||
resp, err = fcgiBackend.Options(env)
|
||||
default:
|
||||
|
|
|
@ -460,12 +460,12 @@ func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Res
|
|||
}
|
||||
|
||||
// Get issues a GET request to the fcgi responder.
|
||||
func (c *FCGIClient) Get(p map[string]string) (resp *http.Response, err error) {
|
||||
func (c *FCGIClient) Get(p map[string]string, body io.Reader, l int64) (resp *http.Response, err error) {
|
||||
|
||||
p["REQUEST_METHOD"] = "GET"
|
||||
p["CONTENT_LENGTH"] = "0"
|
||||
p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10)
|
||||
|
||||
return c.Request(p, nil)
|
||||
return c.Request(p, body)
|
||||
}
|
||||
|
||||
// Head issues a HEAD request to the fcgi responder.
|
||||
|
|
|
@ -140,7 +140,8 @@ func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[
|
|||
}
|
||||
resp, err = fcgi.PostForm(fcgiParams, values)
|
||||
} else {
|
||||
resp, err = fcgi.Get(fcgiParams)
|
||||
rd := bytes.NewReader(data)
|
||||
resp, err = fcgi.Get(fcgiParams, rd, int64(rd.Len()))
|
||||
}
|
||||
|
||||
default:
|
||||
|
|
|
@ -27,7 +27,7 @@ func activateHTTPS(cctx caddy.Context) error {
|
|||
operatorPresent := !caddy.Started()
|
||||
|
||||
if !caddy.Quiet && operatorPresent {
|
||||
fmt.Print("Activating privacy features...")
|
||||
fmt.Print("Activating privacy features... ")
|
||||
}
|
||||
|
||||
ctx := cctx.(*httpContext)
|
||||
|
@ -69,7 +69,7 @@ func activateHTTPS(cctx caddy.Context) error {
|
|||
}
|
||||
|
||||
if !caddy.Quiet && operatorPresent {
|
||||
fmt.Println(" done.")
|
||||
fmt.Println("done.")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -160,23 +160,37 @@ func hostHasOtherPort(allConfigs []*SiteConfig, thisConfigIdx int, otherPort str
|
|||
// to listen on HTTPPort. The TLS field of cfg must not be nil.
|
||||
func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
|
||||
redirPort := cfg.Addr.Port
|
||||
if redirPort == DefaultHTTPSPort {
|
||||
redirPort = "" // default port is redundant
|
||||
if redirPort == HTTPSPort {
|
||||
// By default, HTTPSPort should be DefaultHTTPSPort,
|
||||
// which of course doesn't need to be explicitly stated
|
||||
// in the Location header. Even if HTTPSPort is changed
|
||||
// so that it is no longer DefaultHTTPSPort, we shouldn't
|
||||
// append it to the URL in the Location because changing
|
||||
// the HTTPS port is assumed to be an internal-only change
|
||||
// (in other words, we assume port forwarding is going on);
|
||||
// but redirects go back to a presumably-external client.
|
||||
// (If redirect clients are also internal, that is more
|
||||
// advanced, and the user should configure HTTP->HTTPS
|
||||
// redirects themselves.)
|
||||
redirPort = ""
|
||||
}
|
||||
|
||||
redirMiddleware := func(next Handler) Handler {
|
||||
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
// Construct the URL to which to redirect. Note that the Host in a request might
|
||||
// contain a port, but we just need the hostname; we'll set the port if needed.
|
||||
// Construct the URL to which to redirect. Note that the Host in a
|
||||
// request might contain a port, but we just need the hostname from
|
||||
// it; and we'll set the port if needed.
|
||||
toURL := "https://"
|
||||
requestHost, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
requestHost = r.Host // Host did not contain a port; great
|
||||
requestHost = r.Host // Host did not contain a port, so use the whole value
|
||||
}
|
||||
if redirPort == "" {
|
||||
toURL += requestHost
|
||||
} else {
|
||||
toURL += net.JoinHostPort(requestHost, redirPort)
|
||||
}
|
||||
|
||||
toURL += r.URL.RequestURI()
|
||||
|
||||
w.Header().Set("Connection", "close")
|
||||
|
@ -184,9 +198,11 @@ func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
|
|||
return 0, nil
|
||||
})
|
||||
}
|
||||
|
||||
host := cfg.Addr.Host
|
||||
port := HTTPPort
|
||||
addr := net.JoinHostPort(host, port)
|
||||
|
||||
return &SiteConfig{
|
||||
Addr: Address{Original: addr, Host: host, Port: port},
|
||||
ListenHost: cfg.ListenHost,
|
||||
|
|
|
@ -53,7 +53,7 @@ func TestRedirPlaintextHost(t *testing.T) {
|
|||
},
|
||||
{
|
||||
Host: "foohost",
|
||||
Port: "443", // since this is the default HTTPS port, should not be included in Location value
|
||||
Port: HTTPSPort, // since this is the 'default' HTTPS port, should not be included in Location value
|
||||
},
|
||||
{
|
||||
Host: "*.example.com",
|
||||
|
|
|
@ -91,11 +91,13 @@ func hideCaddyfile(cctx caddy.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func newContext() caddy.Context {
|
||||
return &httpContext{keysToSiteConfigs: make(map[string]*SiteConfig)}
|
||||
func newContext(inst *caddy.Instance) caddy.Context {
|
||||
return &httpContext{instance: inst, keysToSiteConfigs: make(map[string]*SiteConfig)}
|
||||
}
|
||||
|
||||
type httpContext struct {
|
||||
instance *caddy.Instance
|
||||
|
||||
// keysToSiteConfigs maps an address at the top of a
|
||||
// server block (a "key") to its SiteConfig. Not all
|
||||
// SiteConfigs will be represented here, only ones
|
||||
|
@ -115,12 +117,14 @@ func (h *httpContext) saveConfig(key string, cfg *SiteConfig) {
|
|||
// executing directives and otherwise prepares the directives to
|
||||
// be parsed and executed.
|
||||
func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
|
||||
siteAddrs := make(map[string]string)
|
||||
|
||||
// For each address in each server block, make a new config
|
||||
for _, sb := range serverBlocks {
|
||||
for _, key := range sb.Keys {
|
||||
key = strings.ToLower(key)
|
||||
if _, dup := h.keysToSiteConfigs[key]; dup {
|
||||
return serverBlocks, fmt.Errorf("duplicate site address: %s", key)
|
||||
return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
|
||||
}
|
||||
addr, err := standardizeAddress(key)
|
||||
if err != nil {
|
||||
|
@ -136,6 +140,23 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
|
|||
addr.Port = Port
|
||||
}
|
||||
|
||||
// Make sure the adjusted site address is distinct
|
||||
addrCopy := addr // make copy so we don't disturb the original, carefully-parsed address struct
|
||||
if addrCopy.Port == "" && Port == DefaultPort {
|
||||
addrCopy.Port = Port
|
||||
}
|
||||
addrStr := strings.ToLower(addrCopy.String())
|
||||
if otherSiteKey, dup := siteAddrs[addrStr]; dup {
|
||||
err := fmt.Errorf("duplicate site address: %s", addrStr)
|
||||
if (addrCopy.Host == Host && Host != DefaultHost) ||
|
||||
(addrCopy.Port == Port && Port != DefaultPort) {
|
||||
err = fmt.Errorf("site defined as %s is a duplicate of %s because of modified "+
|
||||
"default host and/or port values (usually via -host or -port flags)", key, otherSiteKey)
|
||||
}
|
||||
return serverBlocks, err
|
||||
}
|
||||
siteAddrs[addrStr] = key
|
||||
|
||||
// If default HTTP or HTTPS ports have been customized,
|
||||
// make sure the ACME challenge ports match
|
||||
var altHTTPPort, altTLSSNIPort string
|
||||
|
@ -146,15 +167,19 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
|
|||
altTLSSNIPort = HTTPSPort
|
||||
}
|
||||
|
||||
// Make our caddytls.Config, which has a pointer to the
|
||||
// instance's certificate cache and enough information
|
||||
// to use automatic HTTPS when the time comes
|
||||
caddytlsConfig := caddytls.NewConfig(h.instance)
|
||||
caddytlsConfig.Hostname = addr.Host
|
||||
caddytlsConfig.AltHTTPPort = altHTTPPort
|
||||
caddytlsConfig.AltTLSSNIPort = altTLSSNIPort
|
||||
|
||||
// Save the config to our master list, and key it for lookups
|
||||
cfg := &SiteConfig{
|
||||
Addr: addr,
|
||||
Root: Root,
|
||||
TLS: &caddytls.Config{
|
||||
Hostname: addr.Host,
|
||||
AltHTTPPort: altHTTPPort,
|
||||
AltTLSSNIPort: altTLSSNIPort,
|
||||
},
|
||||
Addr: addr,
|
||||
Root: Root,
|
||||
TLS: caddytlsConfig,
|
||||
originCaddyfile: sourceFile,
|
||||
IndexPages: staticfiles.DefaultIndexPages,
|
||||
}
|
||||
|
|
|
@ -137,7 +137,7 @@ func TestAddressString(t *testing.T) {
|
|||
func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
|
||||
Port = "9999"
|
||||
filename := "Testfile"
|
||||
ctx := newContext().(*httpContext)
|
||||
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
|
||||
input := strings.NewReader(`localhost`)
|
||||
sblocks, err := caddyfile.Parse(filename, input, nil)
|
||||
if err != nil {
|
||||
|
@ -153,9 +153,26 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// See discussion on PR #2015
|
||||
func TestInspectServerBlocksWithAdjustedAddress(t *testing.T) {
|
||||
Port = DefaultPort
|
||||
Host = "example.com"
|
||||
filename := "Testfile"
|
||||
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
|
||||
input := strings.NewReader("example.com {\n}\n:2015 {\n}")
|
||||
sblocks, err := caddyfile.Parse(filename, input, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error setting up test, got: %v", err)
|
||||
}
|
||||
_, err = ctx.InspectServerBlocks(filename, sblocks)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected an error because site definitions should overlap, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) {
|
||||
filename := "Testfile"
|
||||
ctx := newContext().(*httpContext)
|
||||
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
|
||||
input := strings.NewReader("localhost {\n}\nLOCALHOST {\n}")
|
||||
sblocks, err := caddyfile.Parse(filename, input, nil)
|
||||
if err != nil {
|
||||
|
@ -207,7 +224,7 @@ func TestDirectivesList(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestContextSaveConfig(t *testing.T) {
|
||||
ctx := newContext().(*httpContext)
|
||||
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
|
||||
ctx.saveConfig("foo", new(SiteConfig))
|
||||
if _, ok := ctx.keysToSiteConfigs["foo"]; !ok {
|
||||
t.Error("Expected config to be saved, but it wasn't")
|
||||
|
@ -226,7 +243,7 @@ func TestContextSaveConfig(t *testing.T) {
|
|||
|
||||
// Test to make sure we are correctly hiding the Caddyfile
|
||||
func TestHideCaddyfile(t *testing.T) {
|
||||
ctx := newContext().(*httpContext)
|
||||
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
|
||||
ctx.saveConfig("test", &SiteConfig{
|
||||
Root: Root,
|
||||
originCaddyfile: "Testfile",
|
||||
|
|
|
@ -392,7 +392,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
|||
if vhost == nil {
|
||||
// check for ACME challenge even if vhost is nil;
|
||||
// could be a new host coming online soon
|
||||
if caddytls.HTTPChallengeHandler(w, r, "localhost", caddytls.DefaultHTTPAlternatePort) {
|
||||
if caddytls.HTTPChallengeHandler(w, r, "localhost") {
|
||||
return 0, nil
|
||||
}
|
||||
// otherwise, log the error and write a message to the client
|
||||
|
@ -408,7 +408,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
|||
|
||||
// we still check for ACME challenge if the vhost exists,
|
||||
// because we must apply its HTTP challenge config settings
|
||||
if s.proxyHTTPChallenge(vhost, w, r) {
|
||||
if caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
|
@ -416,31 +416,25 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
|||
// the URL path, so a request to example.com/foo/blog on the site
|
||||
// defined as example.com/foo appears as /blog instead of /foo/blog.
|
||||
if pathPrefix != "/" {
|
||||
r.URL.Path = strings.TrimPrefix(r.URL.Path, pathPrefix)
|
||||
if !strings.HasPrefix(r.URL.Path, "/") {
|
||||
r.URL.Path = "/" + r.URL.Path
|
||||
}
|
||||
r.URL = trimPathPrefix(r.URL, pathPrefix)
|
||||
}
|
||||
|
||||
return vhost.middlewareChain.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// proxyHTTPChallenge solves the ACME HTTP challenge if r is the HTTP
|
||||
// request for the challenge. If it is, and if the request has been
|
||||
// fulfilled (response written), true is returned; false otherwise.
|
||||
// If you don't have a vhost, just call the challenge handler directly.
|
||||
func (s *Server) proxyHTTPChallenge(vhost *SiteConfig, w http.ResponseWriter, r *http.Request) bool {
|
||||
if vhost.Addr.Port != caddytls.HTTPChallengePort {
|
||||
return false
|
||||
func trimPathPrefix(u *url.URL, prefix string) *url.URL {
|
||||
// We need to use URL.EscapedPath() when trimming the pathPrefix as
|
||||
// URL.Path is ambiguous about / or %2f - see docs. See #1927
|
||||
trimmed := strings.TrimPrefix(u.EscapedPath(), prefix)
|
||||
if !strings.HasPrefix(trimmed, "/") {
|
||||
trimmed = "/" + trimmed
|
||||
}
|
||||
if vhost.TLS != nil && vhost.TLS.Manual {
|
||||
return false
|
||||
trimmedURL, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmed, err)
|
||||
return u
|
||||
}
|
||||
altPort := caddytls.DefaultHTTPAlternatePort
|
||||
if vhost.TLS != nil && vhost.TLS.AltHTTPPort != "" {
|
||||
altPort = vhost.TLS.AltHTTPPort
|
||||
}
|
||||
return caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost, altPort)
|
||||
return trimmedURL
|
||||
}
|
||||
|
||||
// Address returns the address s was assigned to listen on.
|
||||
|
|
|
@ -16,6 +16,7 @@ package httpserver
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
@ -126,6 +127,94 @@ func TestMakeHTTPServerWithTimeouts(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTrimPathPrefix(t *testing.T) {
|
||||
for i, pt := range []struct {
|
||||
path string
|
||||
prefix string
|
||||
expected string
|
||||
shouldFail bool
|
||||
}{
|
||||
{
|
||||
path: "/my/path",
|
||||
prefix: "/my",
|
||||
expected: "/path",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my/%2f/path",
|
||||
prefix: "/my",
|
||||
expected: "/%2f/path",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my/path",
|
||||
prefix: "/my/",
|
||||
expected: "/path",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my///path",
|
||||
prefix: "/my",
|
||||
expected: "/path",
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
path: "/my///path",
|
||||
prefix: "/my",
|
||||
expected: "///path",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my/path///slash",
|
||||
prefix: "/my",
|
||||
expected: "/path///slash",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my/%2f/path/%2f",
|
||||
prefix: "/my",
|
||||
expected: "/%2f/path/%2f",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "/my/%20/path",
|
||||
prefix: "/my",
|
||||
expected: "/%20/path",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "/path",
|
||||
prefix: "",
|
||||
expected: "/path",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "/path/my/",
|
||||
prefix: "/my",
|
||||
expected: "/path/my/",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "",
|
||||
prefix: "/my",
|
||||
expected: "/",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "/apath",
|
||||
prefix: "",
|
||||
expected: "/apath",
|
||||
shouldFail: false,
|
||||
},
|
||||
} {
|
||||
|
||||
u, _ := url.Parse(pt.path)
|
||||
if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.EscapedPath() != want {
|
||||
if !pt.shouldFail {
|
||||
|
||||
t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.EscapedPath())
|
||||
}
|
||||
} else if pt.shouldFail {
|
||||
t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.EscapedPath())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeHTTPServerWithHeaderLimit(t *testing.T) {
|
||||
for name, c := range map[string]struct {
|
||||
group []*SiteConfig
|
||||
|
|
|
@ -16,6 +16,7 @@ package requestid
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -24,12 +25,29 @@ import (
|
|||
|
||||
// Handler is a middleware handler
|
||||
type Handler struct {
|
||||
Next httpserver.Handler
|
||||
Next httpserver.Handler
|
||||
HeaderName string // (optional) header from which to read an existing ID
|
||||
}
|
||||
|
||||
func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
reqid := uuid.New().String()
|
||||
c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid)
|
||||
var reqid uuid.UUID
|
||||
|
||||
uuidFromHeader := r.Header.Get(h.HeaderName)
|
||||
if h.HeaderName != "" && uuidFromHeader != "" {
|
||||
// use the ID in the header field if it exists
|
||||
var err error
|
||||
reqid, err = uuid.Parse(uuidFromHeader)
|
||||
if err != nil {
|
||||
log.Printf("[NOTICE] Parsing request ID from %s header: %v", h.HeaderName, err)
|
||||
reqid = uuid.New()
|
||||
}
|
||||
} else {
|
||||
// otherwise, create a new one
|
||||
reqid = uuid.New()
|
||||
}
|
||||
|
||||
// set the request ID on the context
|
||||
c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid.String())
|
||||
r = r.WithContext(c)
|
||||
|
||||
return h.Next.ServeHTTP(w, r)
|
||||
|
|
|
@ -15,34 +15,53 @@
|
|||
package requestid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestRequestID(t *testing.T) {
|
||||
request, err := http.NewRequest("GET", "http://localhost/", nil)
|
||||
func TestRequestIDHandler(t *testing.T) {
|
||||
handler := Handler{
|
||||
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string)
|
||||
if value == "" {
|
||||
t.Error("Request ID should not be empty")
|
||||
}
|
||||
return 0, nil
|
||||
}),
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "http://localhost/", nil)
|
||||
if err != nil {
|
||||
t.Fatal("Could not create HTTP request:", err)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
reqid := uuid.New().String()
|
||||
|
||||
c := context.WithValue(request.Context(), httpserver.RequestIDCtxKey, reqid)
|
||||
|
||||
request = request.WithContext(c)
|
||||
|
||||
// See caddyhttp/replacer.go
|
||||
value, _ := request.Context().Value(httpserver.RequestIDCtxKey).(string)
|
||||
|
||||
if value == "" {
|
||||
t.Fatal("Request ID should not be empty")
|
||||
}
|
||||
|
||||
if value != reqid {
|
||||
t.Fatal("Request ID does not match")
|
||||
}
|
||||
handler.ServeHTTP(rec, req)
|
||||
}
|
||||
|
||||
func TestRequestIDFromHeader(t *testing.T) {
|
||||
headerName := "X-Request-ID"
|
||||
headerValue := "71a75329-d9f9-4d25-957e-e689a7b68d78"
|
||||
handler := Handler{
|
||||
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string)
|
||||
if value != headerValue {
|
||||
t.Errorf("Request ID should be '%s' but got '%s'", headerValue, value)
|
||||
}
|
||||
return 0, nil
|
||||
}),
|
||||
HeaderName: headerName,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "http://localhost/", nil)
|
||||
if err != nil {
|
||||
t.Fatal("Could not create HTTP request:", err)
|
||||
}
|
||||
req.Header.Set(headerName, headerValue)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
}
|
||||
|
|
|
@ -27,14 +27,19 @@ func init() {
|
|||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
var headerName string
|
||||
|
||||
for c.Next() {
|
||||
if c.NextArg() {
|
||||
return c.ArgErr() //no arg expected.
|
||||
headerName = c.Val()
|
||||
}
|
||||
if c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
}
|
||||
|
||||
httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||
return Handler{Next: next}
|
||||
return Handler{Next: next, HeaderName: headerName}
|
||||
})
|
||||
|
||||
return nil
|
||||
|
|
|
@ -45,7 +45,15 @@ func TestSetup(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSetupWithArg(t *testing.T) {
|
||||
c := caddy.NewTestController("http", `requestid abc`)
|
||||
c := caddy.NewTestController("http", `requestid X-Request-ID`)
|
||||
err := setup(c)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupWithTooManyArgs(t *testing.T) {
|
||||
c := caddy.NewTestController("http", `requestid foo bar`)
|
||||
err := setup(c)
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, got: %v", err)
|
||||
|
|
|
@ -107,6 +107,10 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request) (int, err
|
|||
if d.IsDir() {
|
||||
// ensure there is a trailing slash
|
||||
if urlCopy.Path[len(urlCopy.Path)-1] != '/' {
|
||||
for strings.HasPrefix(urlCopy.Path, "//") {
|
||||
// prevent path-based open redirects
|
||||
urlCopy.Path = strings.TrimPrefix(urlCopy.Path, "/")
|
||||
}
|
||||
urlCopy.Path += "/"
|
||||
http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently)
|
||||
return http.StatusMovedPermanently, nil
|
||||
|
@ -131,6 +135,10 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request) (int, err
|
|||
}
|
||||
|
||||
if redir {
|
||||
for strings.HasPrefix(urlCopy.Path, "//") {
|
||||
// prevent path-based open redirects
|
||||
urlCopy.Path = strings.TrimPrefix(urlCopy.Path, "/")
|
||||
}
|
||||
http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently)
|
||||
return http.StatusMovedPermanently, nil
|
||||
}
|
||||
|
|
|
@ -77,9 +77,9 @@ func TestServeHTTP(t *testing.T) {
|
|||
{
|
||||
url: "https://foo/dirwithindex/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBodyContent: testFiles[webrootDirwithindexIndeHTML],
|
||||
expectedBodyContent: testFiles[webrootDirwithindexIndexHTML],
|
||||
expectedEtag: `"2n9cw"`,
|
||||
expectedContentLength: strconv.Itoa(len(testFiles[webrootDirwithindexIndeHTML])),
|
||||
expectedContentLength: strconv.Itoa(len(testFiles[webrootDirwithindexIndexHTML])),
|
||||
},
|
||||
// Test 4 - access folder with index file without trailing slash
|
||||
{
|
||||
|
@ -235,16 +235,38 @@ func TestServeHTTP(t *testing.T) {
|
|||
expectedBodyContent: movedPermanently,
|
||||
},
|
||||
{
|
||||
// Test 27 - Check etag
|
||||
url: "https://foo/notindex.html",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBodyContent: testFiles[webrootNotIndexHTML],
|
||||
expectedEtag: `"2n9cm"`,
|
||||
expectedContentLength: strconv.Itoa(len(testFiles[webrootNotIndexHTML])),
|
||||
},
|
||||
{
|
||||
// Test 28 - Prevent path-based open redirects (directory)
|
||||
url: "https://foo//example.com%2f..",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
expectedLocation: "https://foo/example.com/../",
|
||||
expectedBodyContent: movedPermanently,
|
||||
},
|
||||
{
|
||||
// Test 29 - Prevent path-based open redirects (file)
|
||||
url: "https://foo//example.com%2f../dirwithindex/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
expectedLocation: "https://foo/example.com/../dirwithindex/",
|
||||
expectedBodyContent: movedPermanently,
|
||||
},
|
||||
{
|
||||
// Test 29 - Prevent path-based open redirects (extra leading slashes)
|
||||
url: "https://foo///example.com%2f..",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
expectedLocation: "https://foo/example.com/../",
|
||||
expectedBodyContent: movedPermanently,
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
// set up response writer and rewuest
|
||||
// set up response writer and request
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
request, err := http.NewRequest("GET", test.url, nil)
|
||||
if err != nil {
|
||||
|
@ -518,7 +540,7 @@ var (
|
|||
webrootNotIndexHTML = filepath.Join(webrootName, "notindex.html")
|
||||
webrootDirFile2HTML = filepath.Join(webrootName, "dir", "file2.html")
|
||||
webrootDirHiddenHTML = filepath.Join(webrootName, "dir", "hidden.html")
|
||||
webrootDirwithindexIndeHTML = filepath.Join(webrootName, "dirwithindex", "index.html")
|
||||
webrootDirwithindexIndexHTML = filepath.Join(webrootName, "dirwithindex", "index.html")
|
||||
webrootSubGzippedHTML = filepath.Join(webrootName, "sub", "gzipped.html")
|
||||
webrootSubGzippedHTMLGz = filepath.Join(webrootName, "sub", "gzipped.html.gz")
|
||||
webrootSubGzippedHTMLBr = filepath.Join(webrootName, "sub", "gzipped.html.br")
|
||||
|
@ -544,7 +566,7 @@ var testFiles = map[string]string{
|
|||
webrootFile1HTML: "<h1>file1.html</h1>",
|
||||
webrootNotIndexHTML: "<h1>notindex.html</h1>",
|
||||
webrootDirFile2HTML: "<h1>dir/file2.html</h1>",
|
||||
webrootDirwithindexIndeHTML: "<h1>dirwithindex/index.html</h1>",
|
||||
webrootDirwithindexIndexHTML: "<h1>dirwithindex/index.html</h1>",
|
||||
webrootDirHiddenHTML: "<h1>dir/hidden.html</h1>",
|
||||
webrootSubGzippedHTML: "<h1>gzipped.html</h1>",
|
||||
webrootSubGzippedHTMLGz: "1.gzipped.html.gz",
|
||||
|
|
|
@ -15,9 +15,11 @@
|
|||
package caddytls
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"strings"
|
||||
|
@ -27,24 +29,104 @@ import (
|
|||
"golang.org/x/crypto/ocsp"
|
||||
)
|
||||
|
||||
// certCache stores certificates in memory,
|
||||
// keying certificates by name. Certificates
|
||||
// should not overlap in the names they serve,
|
||||
// because a name only maps to one certificate.
|
||||
var certCache = make(map[string]Certificate)
|
||||
var certCacheMu sync.RWMutex
|
||||
// certificateCache is to be an instance-wide cache of certs
|
||||
// that site-specific TLS configs can refer to. Using a
|
||||
// central map like this avoids duplication of certs in
|
||||
// memory when the cert is used by multiple sites, and makes
|
||||
// maintenance easier. Because these are not to be global,
|
||||
// the cache will get garbage collected after a config reload
|
||||
// (a new instance will take its place).
|
||||
type certificateCache struct {
|
||||
sync.RWMutex
|
||||
cache map[string]Certificate // keyed by certificate hash
|
||||
}
|
||||
|
||||
// replaceCertificate replaces oldCert with newCert in the cache, and
|
||||
// updates all configs that are pointing to the old certificate to
|
||||
// point to the new one instead. newCert must already be loaded into
|
||||
// the cache (this method does NOT load it into the cache).
|
||||
//
|
||||
// Note that all the names on the old certificate will be deleted
|
||||
// from the name lookup maps of each config, then all the names on
|
||||
// the new certificate will be added to the lookup maps as long as
|
||||
// they do not overwrite any entries.
|
||||
//
|
||||
// The newCert may be modified and its cache entry updated.
|
||||
//
|
||||
// This method is safe for concurrent use.
|
||||
func (certCache *certificateCache) replaceCertificate(oldCert, newCert Certificate) error {
|
||||
certCache.Lock()
|
||||
defer certCache.Unlock()
|
||||
|
||||
// have all the configs that are pointing to the old
|
||||
// certificate point to the new certificate instead
|
||||
for _, cfg := range oldCert.configs {
|
||||
// first delete all the name lookup entries that
|
||||
// pointed to the old certificate
|
||||
for name, certKey := range cfg.Certificates {
|
||||
if certKey == oldCert.Hash {
|
||||
delete(cfg.Certificates, name)
|
||||
}
|
||||
}
|
||||
|
||||
// then add name lookup entries for the names
|
||||
// on the new certificate, but don't overwrite
|
||||
// entries that may already exist, not only as
|
||||
// a courtesy, but importantly: because if we
|
||||
// overwrote a value here, and this config no
|
||||
// longer pointed to a certain certificate in
|
||||
// the cache, that certificate's list of configs
|
||||
// referring to it would be incorrect; so just
|
||||
// insert entries, don't overwrite any
|
||||
for _, name := range newCert.Names {
|
||||
if _, ok := cfg.Certificates[name]; !ok {
|
||||
cfg.Certificates[name] = newCert.Hash
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// since caching a new certificate attaches only the config
|
||||
// that loaded it, the new certificate needs to be given the
|
||||
// list of all the configs that use it, so copy the list
|
||||
// over from the old certificate to the new certificate
|
||||
// in the cache
|
||||
newCert.configs = oldCert.configs
|
||||
certCache.cache[newCert.Hash] = newCert
|
||||
|
||||
// finally, delete the old certificate from the cache
|
||||
delete(certCache.cache, oldCert.Hash)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reloadManagedCertificate reloads the certificate corresponding to the name(s)
|
||||
// on oldCert into the cache, from storage. This also replaces the old certificate
|
||||
// with the new one, so that all configurations that used the old cert now point
|
||||
// to the new cert.
|
||||
func (certCache *certificateCache) reloadManagedCertificate(oldCert Certificate) error {
|
||||
// get the certificate from storage and cache it
|
||||
newCert, err := oldCert.configs[0].CacheManagedCertificate(oldCert.Names[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to reload certificate for %v into cache: %v", oldCert.Names, err)
|
||||
}
|
||||
|
||||
// and replace the old certificate with the new one
|
||||
err = certCache.replaceCertificate(oldCert, newCert)
|
||||
if err != nil {
|
||||
return fmt.Errorf("replacing certificate %v: %v", oldCert.Names, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Certificate is a tls.Certificate with associated metadata tacked on.
|
||||
// Even if the metadata can be obtained by parsing the certificate,
|
||||
// we can be more efficient by extracting the metadata once so it's
|
||||
// just there, ready to use.
|
||||
// we are more efficient by extracting the metadata onto this struct.
|
||||
type Certificate struct {
|
||||
tls.Certificate
|
||||
|
||||
// Names is the list of names this certificate is written for.
|
||||
// The first is the CommonName (if any), the rest are SAN.
|
||||
// This should be the exact list of keys by which this cert
|
||||
// is accessed in the cache, careful to avoid overlap.
|
||||
Names []string
|
||||
|
||||
// NotAfter is when the certificate expires.
|
||||
|
@ -53,59 +135,21 @@ type Certificate struct {
|
|||
// OCSP contains the certificate's parsed OCSP response.
|
||||
OCSP *ocsp.Response
|
||||
|
||||
// Config is the configuration with which the certificate was
|
||||
// loaded or obtained and with which it should be maintained.
|
||||
Config *Config
|
||||
}
|
||||
// The hex-encoded hash of this cert's chain's bytes.
|
||||
Hash string
|
||||
|
||||
// getCertificate gets a certificate that matches name (a server name)
|
||||
// from the in-memory cache. If there is no exact match for name, it
|
||||
// will be checked against names of the form '*.example.com' (wildcard
|
||||
// certificates) according to RFC 6125. If a match is found, matched will
|
||||
// be true. If no matches are found, matched will be false and a default
|
||||
// certificate will be returned with defaulted set to true. If no default
|
||||
// certificate is set, defaulted will be set to false.
|
||||
//
|
||||
// The logic in this function is adapted from the Go standard library,
|
||||
// which is by the Go Authors.
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func getCertificate(name string) (cert Certificate, matched, defaulted bool) {
|
||||
var ok bool
|
||||
|
||||
// Not going to trim trailing dots here since RFC 3546 says,
|
||||
// "The hostname is represented ... without a trailing dot."
|
||||
// Just normalize to lowercase.
|
||||
name = strings.ToLower(name)
|
||||
|
||||
certCacheMu.RLock()
|
||||
defer certCacheMu.RUnlock()
|
||||
|
||||
// exact match? great, let's use it
|
||||
if cert, ok = certCache[name]; ok {
|
||||
matched = true
|
||||
return
|
||||
}
|
||||
|
||||
// try replacing labels in the name with wildcards until we get a match
|
||||
labels := strings.Split(name, ".")
|
||||
for i := range labels {
|
||||
labels[i] = "*"
|
||||
candidate := strings.Join(labels, ".")
|
||||
if cert, ok = certCache[candidate]; ok {
|
||||
matched = true
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// if nothing matches, use the default certificate or bust
|
||||
cert, defaulted = certCache[""]
|
||||
return
|
||||
// configs is the list of configs that use or refer to
|
||||
// The first one is assumed to be the config that is
|
||||
// "in charge" of this certificate (i.e. determines
|
||||
// whether it is managed, how it is managed, etc).
|
||||
// This field will be populated by cacheCertificate.
|
||||
// Only meddle with it if you know what you're doing!
|
||||
configs []*Config
|
||||
}
|
||||
|
||||
// CacheManagedCertificate loads the certificate for domain into the
|
||||
// cache, flagging it as Managed and, if onDemand is true, as "OnDemand"
|
||||
// (meaning that it was obtained or loaded during a TLS handshake).
|
||||
// cache, from the TLS storage for managed certificates. It returns a
|
||||
// copy of the Certificate that was put into the cache.
|
||||
//
|
||||
// This method is safe for concurrent use.
|
||||
func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) {
|
||||
|
@ -117,39 +161,24 @@ func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) {
|
|||
if err != nil {
|
||||
return Certificate{}, err
|
||||
}
|
||||
cert, err := makeCertificate(siteData.Cert, siteData.Key)
|
||||
cert, err := makeCertificateWithOCSP(siteData.Cert, siteData.Key)
|
||||
if err != nil {
|
||||
return cert, err
|
||||
}
|
||||
cert.Config = cfg
|
||||
cacheCertificate(cert)
|
||||
return cert, nil
|
||||
return cfg.cacheCertificate(cert), nil
|
||||
}
|
||||
|
||||
// cacheUnmanagedCertificatePEMFile loads a certificate for host using certFile
|
||||
// and keyFile, which must be in PEM format. It stores the certificate in
|
||||
// memory after evicting any other entries in the cache keyed by the names
|
||||
// on this certificate. In other words, it replaces existing certificates keyed
|
||||
// by the names on this certificate. The Managed and OnDemand flags of the
|
||||
// certificate will be set to false.
|
||||
// the in-memory cache.
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
|
||||
cert, err := makeCertificateFromDisk(certFile, keyFile)
|
||||
func (cfg *Config) cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
|
||||
cert, err := makeCertificateFromDiskWithOCSP(certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// since this is manually managed, this call might be part of a reload after
|
||||
// the owner renewed a certificate; so clear cache of any previous cert first,
|
||||
// otherwise the renewed certificate may never be loaded
|
||||
certCacheMu.Lock()
|
||||
for _, name := range cert.Names {
|
||||
delete(certCache, name)
|
||||
}
|
||||
certCacheMu.Unlock()
|
||||
|
||||
cacheCertificate(cert)
|
||||
cfg.cacheCertificate(cert)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -157,20 +186,20 @@ func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
|
|||
// of the certificate and key, then caches it in memory.
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error {
|
||||
cert, err := makeCertificate(certBytes, keyBytes)
|
||||
func (cfg *Config) cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error {
|
||||
cert, err := makeCertificateWithOCSP(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cacheCertificate(cert)
|
||||
cfg.cacheCertificate(cert)
|
||||
return nil
|
||||
}
|
||||
|
||||
// makeCertificateFromDisk makes a Certificate by loading the
|
||||
// makeCertificateFromDiskWithOCSP makes a Certificate by loading the
|
||||
// certificate and key files. It fills out all the fields in
|
||||
// the certificate except for the Managed and OnDemand flags.
|
||||
// (It is up to the caller to set those.)
|
||||
func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) {
|
||||
// (It is up to the caller to set those.) It staples OCSP.
|
||||
func makeCertificateFromDiskWithOCSP(certFile, keyFile string) (Certificate, error) {
|
||||
certPEMBlock, err := ioutil.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return Certificate{}, err
|
||||
|
@ -179,13 +208,14 @@ func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) {
|
|||
if err != nil {
|
||||
return Certificate{}, err
|
||||
}
|
||||
return makeCertificate(certPEMBlock, keyPEMBlock)
|
||||
return makeCertificateWithOCSP(certPEMBlock, keyPEMBlock)
|
||||
}
|
||||
|
||||
// makeCertificate turns a certificate PEM bundle and a key PEM block into
|
||||
// a Certificate, with OCSP and other relevant metadata tagged with it,
|
||||
// except for the OnDemand and Managed flags. It is up to the caller to
|
||||
// set those properties.
|
||||
// a Certificate with necessary metadata from parsing its bytes filled into
|
||||
// its struct fields for convenience (except for the OnDemand and Managed
|
||||
// flags; it is up to the caller to set those properties!). This function
|
||||
// does NOT staple OCSP.
|
||||
func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
|
||||
var cert Certificate
|
||||
|
||||
|
@ -195,16 +225,26 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
|
|||
return cert, err
|
||||
}
|
||||
|
||||
// Extract relevant metadata and staple OCSP
|
||||
// Extract necessary metadata
|
||||
err = fillCertFromLeaf(&cert, tlsCert)
|
||||
if err != nil {
|
||||
return cert, err
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// makeCertificateWithOCSP is the same as makeCertificate except that it also
|
||||
// staples OCSP to the certificate.
|
||||
func makeCertificateWithOCSP(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
|
||||
cert, err := makeCertificate(certPEMBlock, keyPEMBlock)
|
||||
if err != nil {
|
||||
return cert, err
|
||||
}
|
||||
err = stapleOCSP(&cert, certPEMBlock)
|
||||
if err != nil {
|
||||
log.Printf("[WARNING] Stapling OCSP: %v", err)
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
|
@ -243,65 +283,104 @@ func fillCertFromLeaf(cert *Certificate, tlsCert tls.Certificate) error {
|
|||
return errors.New("certificate has no names")
|
||||
}
|
||||
|
||||
// save the hash of this certificate (chain) and
|
||||
// expiration date, for necessity and efficiency
|
||||
cert.Hash = hashCertificateChain(cert.Certificate.Certificate)
|
||||
cert.NotAfter = leaf.NotAfter
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cacheCertificate adds cert to the in-memory cache. If the cache is
|
||||
// empty, cert will be used as the default certificate. If the cache is
|
||||
// full, random entries are deleted until there is room to map all the
|
||||
// names on the certificate.
|
||||
// hashCertificateChain computes the unique hash of certChain,
|
||||
// which is the chain of DER-encoded bytes. It returns the
|
||||
// hex encoding of the hash.
|
||||
func hashCertificateChain(certChain [][]byte) string {
|
||||
h := sha256.New()
|
||||
for _, certInChain := range certChain {
|
||||
h.Write(certInChain)
|
||||
}
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// managedCertInStorageExpiresSoon returns true if cert (being a
|
||||
// managed certificate) is expiring within RenewDurationBefore.
|
||||
// It returns false if there was an error checking the expiration
|
||||
// of the certificate as found in storage, or if the certificate
|
||||
// in storage is NOT expiring soon. A certificate that is expiring
|
||||
// soon in our cache but is not expiring soon in storage probably
|
||||
// means that another instance renewed the certificate in the
|
||||
// meantime, and it would be a good idea to simply load the cert
|
||||
// into our cache rather than repeating the renewal process again.
|
||||
func managedCertInStorageExpiresSoon(cert Certificate) (bool, error) {
|
||||
if len(cert.configs) == 0 {
|
||||
return false, fmt.Errorf("no configs for certificate")
|
||||
}
|
||||
storage, err := cert.configs[0].StorageFor(cert.configs[0].CAUrl)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
siteData, err := storage.LoadSite(cert.Names[0])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
tlsCert, err := tls.X509KeyPair(siteData.Cert, siteData.Key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
leaf, err := x509.ParseCertificate(tlsCert.Certificate[0])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
timeLeft := leaf.NotAfter.Sub(time.Now().UTC())
|
||||
return timeLeft < RenewDurationBefore, nil
|
||||
}
|
||||
|
||||
// cacheCertificate adds cert to the in-memory cache. If a certificate
|
||||
// with the same hash is already cached, it is NOT overwritten; instead,
|
||||
// cfg is added to the existing certificate's list of configs if not
|
||||
// already in the list. Then all the names on cert are used to add
|
||||
// entries to cfg.Certificates (the config's name lookup map).
|
||||
// Then the certificate is stored/updated in the cache. It returns
|
||||
// a copy of the certificate that ends up being stored in the cache.
|
||||
//
|
||||
// This certificate will be keyed to the names in cert.Names. Any names
|
||||
// already used as a cache key will NOT be replaced by this cert; in
|
||||
// other words, no overlap is allowed, and this certificate will not
|
||||
// service those pre-existing names.
|
||||
// It is VERY important, even for some test cases, that the Hash field
|
||||
// of the cert be set properly.
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func cacheCertificate(cert Certificate) {
|
||||
if cert.Config == nil {
|
||||
cert.Config = new(Config)
|
||||
func (cfg *Config) cacheCertificate(cert Certificate) Certificate {
|
||||
cfg.certCache.Lock()
|
||||
defer cfg.certCache.Unlock()
|
||||
|
||||
// if this certificate already exists in the cache,
|
||||
// use it instead of overwriting it -- very important!
|
||||
if existingCert, ok := cfg.certCache.cache[cert.Hash]; ok {
|
||||
cert = existingCert
|
||||
}
|
||||
certCacheMu.Lock()
|
||||
if _, ok := certCache[""]; !ok {
|
||||
// use as default - must be *appended* to end of list, or bad things happen!
|
||||
cert.Names = append(cert.Names, "")
|
||||
}
|
||||
for len(certCache)+len(cert.Names) > 10000 {
|
||||
// for simplicity, just remove random elements
|
||||
for key := range certCache {
|
||||
if key == "" { // ... but not the default cert
|
||||
continue
|
||||
}
|
||||
delete(certCache, key)
|
||||
|
||||
// attach this config to the certificate so we know which
|
||||
// configs are referencing/using the certificate, but don't
|
||||
// duplicate entries
|
||||
var found bool
|
||||
for _, c := range cert.configs {
|
||||
if c == cfg {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(cert.Names); i++ {
|
||||
name := cert.Names[i]
|
||||
if _, ok := certCache[name]; ok {
|
||||
// do not allow certificates to overlap in the names they serve;
|
||||
// this ambiguity causes problems because it is confusing while
|
||||
// maintaining certificates; see OCSP maintenance code and
|
||||
// https://caddy.community/t/random-ocsp-response-errors-for-random-clients/2473?u=matt.
|
||||
log.Printf("[NOTICE] There is already a certificate loaded for %s, "+
|
||||
"so certificate for %v will not service that name",
|
||||
name, cert.Names)
|
||||
cert.Names = append(cert.Names[:i], cert.Names[i+1:]...)
|
||||
i--
|
||||
continue
|
||||
}
|
||||
certCache[name] = cert
|
||||
if !found {
|
||||
cert.configs = append(cert.configs, cfg)
|
||||
}
|
||||
certCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// uncacheCertificate deletes name's certificate from the
|
||||
// cache. If name is not a key in the certificate cache,
|
||||
// this function does nothing.
|
||||
func uncacheCertificate(name string) {
|
||||
certCacheMu.Lock()
|
||||
delete(certCache, name)
|
||||
certCacheMu.Unlock()
|
||||
// key the certificate by all its names for this config only,
|
||||
// this is how we find the certificate during handshakes
|
||||
// (yes, if certs overlap in the names they serve, one will
|
||||
// overwrite another here, but that's just how it goes)
|
||||
for _, name := range cert.Names {
|
||||
cfg.Certificates[name] = cert.Hash
|
||||
}
|
||||
|
||||
// store the certificate
|
||||
cfg.certCache.cache[cert.Hash] = cert
|
||||
|
||||
return cert
|
||||
}
|
||||
|
|
|
@ -17,57 +17,71 @@ package caddytls
|
|||
import "testing"
|
||||
|
||||
func TestUnexportedGetCertificate(t *testing.T) {
|
||||
defer func() { certCache = make(map[string]Certificate) }()
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
|
||||
// When cache is empty
|
||||
if _, matched, defaulted := getCertificate("example.com"); matched || defaulted {
|
||||
if _, matched, defaulted := cfg.getCertificate("example.com"); matched || defaulted {
|
||||
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
|
||||
}
|
||||
|
||||
// When cache has one certificate in it (also is default)
|
||||
defaultCert := Certificate{Names: []string{"example.com", ""}}
|
||||
certCache[""] = defaultCert
|
||||
certCache["example.com"] = defaultCert
|
||||
if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
|
||||
// When cache has one certificate in it
|
||||
firstCert := Certificate{Names: []string{"example.com"}}
|
||||
certCache.cache["0xdeadbeef"] = firstCert
|
||||
cfg.Certificates["example.com"] = "0xdeadbeef"
|
||||
if cert, matched, defaulted := cfg.getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
|
||||
t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||
}
|
||||
if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" {
|
||||
t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||
if cert, matched, defaulted := cfg.getCertificate("example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
|
||||
t.Errorf("Didn't get a cert for 'example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||
}
|
||||
|
||||
// When retrieving wildcard certificate
|
||||
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}}
|
||||
if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
|
||||
certCache.cache["0xb01dface"] = Certificate{Names: []string{"*.example.com"}}
|
||||
cfg.Certificates["*.example.com"] = "0xb01dface"
|
||||
if cert, matched, defaulted := cfg.getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
|
||||
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||
}
|
||||
|
||||
// When no certificate matches, the default is returned
|
||||
if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted {
|
||||
// When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
|
||||
if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted {
|
||||
t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
|
||||
}
|
||||
|
||||
// When no certificate matches and SNI is NOT provided, a random is returned
|
||||
if cert, matched, defaulted := cfg.getCertificate(""); matched || !defaulted {
|
||||
t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
|
||||
} else if cert.Names[0] != "example.com" {
|
||||
t.Errorf("Expected default cert, got: %v", cert)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheCertificate(t *testing.T) {
|
||||
defer func() { certCache = make(map[string]Certificate) }()
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
|
||||
cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}})
|
||||
if _, ok := certCache["example.com"]; !ok {
|
||||
t.Error("Expected first cert to be cached by key 'example.com', but it wasn't")
|
||||
cfg.cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}, Hash: "foobar"})
|
||||
if len(certCache.cache) != 1 {
|
||||
t.Errorf("Expected length of certificate cache to be 1")
|
||||
}
|
||||
if _, ok := certCache["sub.example.com"]; !ok {
|
||||
t.Error("Expected first cert to be cached by key 'sub.example.com', but it wasn't")
|
||||
if _, ok := certCache.cache["foobar"]; !ok {
|
||||
t.Error("Expected first cert to be cached by key 'foobar', but it wasn't")
|
||||
}
|
||||
if cert, ok := certCache[""]; !ok || cert.Names[2] != "" {
|
||||
t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't")
|
||||
if _, ok := cfg.Certificates["example.com"]; !ok {
|
||||
t.Error("Expected first cert to be keyed by 'example.com', but it wasn't")
|
||||
}
|
||||
if _, ok := cfg.Certificates["sub.example.com"]; !ok {
|
||||
t.Error("Expected first cert to be keyed by 'sub.example.com', but it wasn't")
|
||||
}
|
||||
|
||||
cacheCertificate(Certificate{Names: []string{"example2.com"}})
|
||||
if _, ok := certCache["example2.com"]; !ok {
|
||||
t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't")
|
||||
// different config, but using same cache; and has cert with overlapping name,
|
||||
// but different hash
|
||||
cfg2 := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
cfg2.cacheCertificate(Certificate{Names: []string{"example.com"}, Hash: "barbaz"})
|
||||
if _, ok := certCache.cache["barbaz"]; !ok {
|
||||
t.Error("Expected second cert to be cached by key 'barbaz.com', but it wasn't")
|
||||
}
|
||||
if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" {
|
||||
t.Error("Expected second cert to NOT be cached as default, but it was")
|
||||
if hash, ok := cfg2.Certificates["example.com"]; !ok {
|
||||
t.Error("Expected second cert to be keyed by 'example.com', but it wasn't")
|
||||
} else if hash != "barbaz" {
|
||||
t.Errorf("Expected second cert to map to 'barbaz' but it was %s instead", hash)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ type ACMEClient struct {
|
|||
AllowPrompts bool
|
||||
config *Config
|
||||
acmeClient *acme.Client
|
||||
locker Locker
|
||||
storage Storage
|
||||
}
|
||||
|
||||
// newACMEClient creates a new ACMEClient given an email and whether
|
||||
|
@ -122,10 +122,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
|
|||
AllowPrompts: allowPrompts,
|
||||
config: config,
|
||||
acmeClient: client,
|
||||
locker: &syncLock{
|
||||
nameLocks: make(map[string]*sync.WaitGroup),
|
||||
nameLocksMu: sync.Mutex{},
|
||||
},
|
||||
storage: storage,
|
||||
}
|
||||
|
||||
if config.DNSProvider == "" {
|
||||
|
@ -161,7 +158,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
|
|||
|
||||
// See if TLS challenge needs to be handled by our own facilities
|
||||
if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) {
|
||||
c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSniSolver{})
|
||||
c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache})
|
||||
}
|
||||
|
||||
// Disable any challenges that should not be used
|
||||
|
@ -210,13 +207,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
|
|||
// Callers who have access to a Config value should use the ObtainCert
|
||||
// method on that instead of this lower-level method.
|
||||
func (c *ACMEClient) Obtain(name string) error {
|
||||
// Get access to ACME storage
|
||||
storage, err := c.config.StorageFor(c.config.CAUrl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
waiter, err := c.locker.TryLock(name)
|
||||
waiter, err := c.storage.TryLock(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -226,7 +217,7 @@ func (c *ACMEClient) Obtain(name string) error {
|
|||
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
|
||||
}
|
||||
defer func() {
|
||||
if err := c.locker.Unlock(name); err != nil {
|
||||
if err := c.storage.Unlock(name); err != nil {
|
||||
log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err)
|
||||
}
|
||||
}()
|
||||
|
@ -269,7 +260,7 @@ Attempts:
|
|||
}
|
||||
|
||||
// Success - immediately save the certificate resource
|
||||
err = saveCertResource(storage, certificate)
|
||||
err = saveCertResource(c.storage, certificate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error saving assets for %v: %v", name, err)
|
||||
}
|
||||
|
@ -282,35 +273,30 @@ Attempts:
|
|||
return nil
|
||||
}
|
||||
|
||||
// Renew renews the managed certificate for name. This function is
|
||||
// safe for concurrent use.
|
||||
// Renew renews the managed certificate for name. It puts the renewed
|
||||
// certificate into storage (not the cache). This function is safe for
|
||||
// concurrent use.
|
||||
//
|
||||
// Callers who have access to a Config value should use the RenewCert
|
||||
// method on that instead of this lower-level method.
|
||||
func (c *ACMEClient) Renew(name string) error {
|
||||
// Get access to ACME storage
|
||||
storage, err := c.config.StorageFor(c.config.CAUrl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
waiter, err := c.locker.TryLock(name)
|
||||
waiter, err := c.storage.TryLock(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if waiter != nil {
|
||||
log.Printf("[INFO] Certificate for %s is already being renewed elsewhere and stored; waiting", name)
|
||||
waiter.Wait()
|
||||
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
|
||||
return nil // assume that the worker that renewed the cert succeeded; avoid hammering this path over and over
|
||||
}
|
||||
defer func() {
|
||||
if err := c.locker.Unlock(name); err != nil {
|
||||
if err := c.storage.Unlock(name); err != nil {
|
||||
log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Prepare for renewal (load PEM cert, key, and meta)
|
||||
siteData, err := storage.LoadSite(name)
|
||||
siteData, err := c.storage.LoadSite(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -357,18 +343,13 @@ func (c *ACMEClient) Renew(name string) error {
|
|||
go diagnostics.Increment("acme_certificates_obtained")
|
||||
go diagnostics.Increment("acme_certificates_renewed")
|
||||
|
||||
return saveCertResource(storage, newCertMeta)
|
||||
return saveCertResource(c.storage, newCertMeta)
|
||||
}
|
||||
|
||||
// Revoke revokes the certificate for name and deltes
|
||||
// Revoke revokes the certificate for name and deletes
|
||||
// it from storage.
|
||||
func (c *ACMEClient) Revoke(name string) error {
|
||||
storage, err := c.config.StorageFor(c.config.CAUrl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
siteExists, err := storage.SiteExists(name)
|
||||
siteExists, err := c.storage.SiteExists(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -377,7 +358,7 @@ func (c *ACMEClient) Revoke(name string) error {
|
|||
return errors.New("no certificate and key for " + name)
|
||||
}
|
||||
|
||||
siteData, err := storage.LoadSite(name)
|
||||
siteData, err := c.storage.LoadSite(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -387,7 +368,7 @@ func (c *ACMEClient) Revoke(name string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
err = storage.DeleteSite(name)
|
||||
err = c.storage.DeleteSite(name)
|
||||
if err != nil {
|
||||
return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error())
|
||||
}
|
||||
|
|
|
@ -93,16 +93,17 @@ type Config struct {
|
|||
// an ACME challenge
|
||||
ListenHost string
|
||||
|
||||
// The alternate port (ONLY port, not host)
|
||||
// to use for the ACME HTTP challenge; this
|
||||
// port will be used if we proxy challenges
|
||||
// coming in on port 80 to this alternate port
|
||||
// The alternate port (ONLY port, not host) to
|
||||
// use for the ACME HTTP challenge; if non-empty,
|
||||
// this port will be used instead of
|
||||
// HTTPChallengePort to spin up a listener for
|
||||
// the HTTP challenge
|
||||
AltHTTPPort string
|
||||
|
||||
// The alternate port (ONLY port, not host)
|
||||
// to use for the ACME TLS-SNI challenge.
|
||||
// The system must forward the standard port
|
||||
// for the TLS-SNI challenge to this port.
|
||||
// The system must forward TLSSNIChallengePort
|
||||
// to this port for challenge to succeed
|
||||
AltTLSSNIPort string
|
||||
|
||||
// The string identifier of the DNS provider
|
||||
|
@ -134,7 +135,12 @@ type Config struct {
|
|||
// Protocol Negotiation (ALPN).
|
||||
ALPN []string
|
||||
|
||||
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig()
|
||||
// The map of hostname to certificate hash. This is used to complete
|
||||
// handshakes and serve the right certificate given the SNI.
|
||||
Certificates map[string]string
|
||||
|
||||
certCache *certificateCache // pointer to the Instance's certificate store
|
||||
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig()
|
||||
}
|
||||
|
||||
// OnDemandState contains some state relevant for providing
|
||||
|
@ -155,6 +161,25 @@ type OnDemandState struct {
|
|||
AskURL *url.URL
|
||||
}
|
||||
|
||||
// NewConfig returns a new Config with a pointer to the instance's
|
||||
// certificate cache. You will usually need to set Other fields on
|
||||
// the returned Config for successful practical use.
|
||||
func NewConfig(inst *caddy.Instance) *Config {
|
||||
inst.StorageMu.RLock()
|
||||
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
|
||||
inst.StorageMu.RUnlock()
|
||||
if !ok || certCache == nil {
|
||||
certCache = &certificateCache{cache: make(map[string]Certificate)}
|
||||
inst.StorageMu.Lock()
|
||||
inst.Storage[CertCacheInstStorageKey] = certCache
|
||||
inst.StorageMu.Unlock()
|
||||
}
|
||||
cfg := new(Config)
|
||||
cfg.Certificates = make(map[string]string)
|
||||
cfg.certCache = certCache
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ObtainCert obtains a certificate for name using c, as long
|
||||
// as a certificate does not already exist in storage for that
|
||||
// name. The name must qualify and c must be flagged as Managed.
|
||||
|
@ -330,7 +355,9 @@ func (c *Config) buildStandardTLSConfig() error {
|
|||
|
||||
// MakeTLSConfig makes a tls.Config from configs. The returned
|
||||
// tls.Config is programmed to load the matching caddytls.Config
|
||||
// based on the hostname in SNI, but that's all.
|
||||
// based on the hostname in SNI, but that's all. This is used
|
||||
// to create a single TLS configuration for a listener (a group
|
||||
// of sites).
|
||||
func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
|
||||
if len(configs) == 0 {
|
||||
return nil, nil
|
||||
|
@ -358,15 +385,28 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
|
|||
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
|
||||
}
|
||||
|
||||
// convert each caddytls.Config into a tls.Config
|
||||
// convert this caddytls.Config into a tls.Config
|
||||
if err := cfg.buildStandardTLSConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Key this config by its hostname (overwriting
|
||||
// configs with the same hostname pattern); during
|
||||
// TLS handshakes, configs are loaded based on
|
||||
// the hostname pattern, according to client's SNI.
|
||||
// if an existing config with this hostname was already
|
||||
// configured, then they must be identical (or at least
|
||||
// compatible), otherwise that is a configuration error
|
||||
if otherConfig, ok := configMap[cfg.Hostname]; ok {
|
||||
if err := assertConfigsCompatible(cfg, otherConfig); err != nil {
|
||||
return nil, fmt.Errorf("incompabile TLS configurations for the same SNI "+
|
||||
"name (%s) on the same listener: %v",
|
||||
cfg.Hostname, err)
|
||||
}
|
||||
}
|
||||
|
||||
// key this config by its hostname (overwrites
|
||||
// configs with the same hostname pattern; should
|
||||
// be OK since we already asserted they are roughly
|
||||
// the same); during TLS handshakes, configs are
|
||||
// loaded based on the hostname pattern, according
|
||||
// to client's SNI
|
||||
configMap[cfg.Hostname] = cfg
|
||||
}
|
||||
|
||||
|
@ -383,6 +423,63 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// assertConfigsCompatible returns an error if the two Configs
|
||||
// do not have the same (or roughly compatible) configurations.
|
||||
// If one of the tlsConfig pointers on either Config is nil,
|
||||
// an error will be returned. If both are nil, no error.
|
||||
func assertConfigsCompatible(cfg1, cfg2 *Config) error {
|
||||
c1, c2 := cfg1.tlsConfig, cfg2.tlsConfig
|
||||
|
||||
if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) {
|
||||
return fmt.Errorf("one config is not made")
|
||||
}
|
||||
if c1 == nil && c2 == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(c1.CipherSuites) != len(c2.CipherSuites) {
|
||||
return fmt.Errorf("different number of allowed cipher suites")
|
||||
}
|
||||
for i, ciph := range c1.CipherSuites {
|
||||
if c2.CipherSuites[i] != ciph {
|
||||
return fmt.Errorf("different cipher suites or different order")
|
||||
}
|
||||
}
|
||||
|
||||
if len(c1.CurvePreferences) != len(c2.CurvePreferences) {
|
||||
return fmt.Errorf("different number of allowed cipher suites")
|
||||
}
|
||||
for i, curve := range c1.CurvePreferences {
|
||||
if c2.CurvePreferences[i] != curve {
|
||||
return fmt.Errorf("different curve preferences or different order")
|
||||
}
|
||||
}
|
||||
|
||||
if len(c1.NextProtos) != len(c2.NextProtos) {
|
||||
return fmt.Errorf("different number of ALPN (NextProtos) values")
|
||||
}
|
||||
for i, proto := range c1.NextProtos {
|
||||
if c2.NextProtos[i] != proto {
|
||||
return fmt.Errorf("different ALPN (NextProtos) values or different order")
|
||||
}
|
||||
}
|
||||
|
||||
if c1.PreferServerCipherSuites != c2.PreferServerCipherSuites {
|
||||
return fmt.Errorf("one prefers server cipher suites, the other does not")
|
||||
}
|
||||
if c1.MinVersion != c2.MinVersion {
|
||||
return fmt.Errorf("minimum TLS version mismatch")
|
||||
}
|
||||
if c1.MaxVersion != c2.MaxVersion {
|
||||
return fmt.Errorf("maximum TLS version mismatch")
|
||||
}
|
||||
if c1.ClientAuth != c2.ClientAuth {
|
||||
return fmt.Errorf("client authentication policy mismatch")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigGetter gets a Config keyed by key.
|
||||
type ConfigGetter func(c *caddy.Controller) *Config
|
||||
|
||||
|
@ -522,7 +619,7 @@ var supportedCurvesMap = map[string]tls.CurveID{
|
|||
"P521": tls.CurveP521,
|
||||
}
|
||||
|
||||
// List of all the curves we want to use by default
|
||||
// List of all the curves we want to use by default.
|
||||
//
|
||||
// This list should only include curves which are fast by design (e.g. X25519)
|
||||
// and those for which an optimized assembly implementation exists (e.g. P256).
|
||||
|
@ -548,4 +645,8 @@ const (
|
|||
// be capable of proxying or forwarding the request to this
|
||||
// alternate port.
|
||||
DefaultHTTPAlternatePort = "5033"
|
||||
|
||||
// CertCacheInstStorageKey is the name of the key for
|
||||
// accessing the certificate storage on the *caddy.Instance.
|
||||
CertCacheInstStorageKey = "tls_cert_cache"
|
||||
)
|
||||
|
|
|
@ -237,15 +237,17 @@ func makeSelfSignedCert(config *Config) error {
|
|||
return fmt.Errorf("could not create certificate: %v", err)
|
||||
}
|
||||
|
||||
cacheCertificate(Certificate{
|
||||
chain := [][]byte{derBytes}
|
||||
|
||||
config.cacheCertificate(Certificate{
|
||||
Certificate: tls.Certificate{
|
||||
Certificate: [][]byte{derBytes},
|
||||
Certificate: chain,
|
||||
PrivateKey: privKey,
|
||||
Leaf: cert,
|
||||
},
|
||||
Names: cert.DNSNames,
|
||||
NotAfter: cert.NotAfter,
|
||||
Config: config,
|
||||
Hash: hashCertificateChain(chain),
|
||||
})
|
||||
|
||||
return nil
|
||||
|
|
|
@ -38,9 +38,9 @@ var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme")
|
|||
// Storage instance backed by the local disk. The resulting Storage
|
||||
// instance is guaranteed to be non-nil if there is no error.
|
||||
func NewFileStorage(caURL *url.URL) (Storage, error) {
|
||||
return &FileStorage{
|
||||
Path: filepath.Join(storageBasePath, caURL.Host),
|
||||
}, nil
|
||||
storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)}
|
||||
storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage}
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
// FileStorage facilitates forming file paths derived from a root
|
||||
|
@ -48,6 +48,7 @@ func NewFileStorage(caURL *url.URL) (Storage, error) {
|
|||
// cross-platform way or persisting ACME assets on the file system.
|
||||
type FileStorage struct {
|
||||
Path string
|
||||
Locker
|
||||
}
|
||||
|
||||
// sites gets the directory that stores site certificate and keys.
|
||||
|
|
127
caddytls/filestoragesync.go
Normal file
127
caddytls/filestoragesync.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// be sure to remove lock files when exiting the process!
|
||||
caddy.OnProcessExit = append(caddy.OnProcessExit, func() {
|
||||
fileStorageNameLocksMu.Lock()
|
||||
defer fileStorageNameLocksMu.Unlock()
|
||||
for key, fw := range fileStorageNameLocks {
|
||||
os.Remove(fw.filename)
|
||||
delete(fileStorageNameLocks, key)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// fileStorageLock facilitates ACME-related locking by using
|
||||
// the associated FileStorage, so multiple processes can coordinate
|
||||
// renewals on the certificates on a shared file system.
|
||||
type fileStorageLock struct {
|
||||
caURL string
|
||||
storage *FileStorage
|
||||
}
|
||||
|
||||
// TryLock attempts to get a lock for name, otherwise it returns
|
||||
// a Waiter value to wait until the other process is finished.
|
||||
func (s *fileStorageLock) TryLock(name string) (Waiter, error) {
|
||||
fileStorageNameLocksMu.Lock()
|
||||
defer fileStorageNameLocksMu.Unlock()
|
||||
|
||||
// see if lock already exists within this process
|
||||
fw, ok := fileStorageNameLocks[s.caURL+name]
|
||||
if ok {
|
||||
// lock already created within process, let caller wait on it
|
||||
return fw, nil
|
||||
}
|
||||
|
||||
// attempt to persist lock to disk by creating lock file
|
||||
fw = &fileWaiter{
|
||||
filename: s.storage.siteCertFile(name) + ".lock",
|
||||
wg: new(sync.WaitGroup),
|
||||
}
|
||||
// parent dir must exist
|
||||
if err := os.MkdirAll(s.storage.site(name), 0700); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lf, err := os.OpenFile(fw.filename, os.O_CREATE|os.O_EXCL, 0644)
|
||||
if err != nil {
|
||||
if os.IsExist(err) {
|
||||
// another process has the lock; use it to wait
|
||||
return fw, nil
|
||||
}
|
||||
// otherwise, this was some unexpected error
|
||||
return nil, err
|
||||
}
|
||||
lf.Close()
|
||||
|
||||
// looks like we get the lock
|
||||
fw.wg.Add(1)
|
||||
fileStorageNameLocks[s.caURL+name] = fw
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Unlock unlocks name.
|
||||
func (s *fileStorageLock) Unlock(name string) error {
|
||||
fileStorageNameLocksMu.Lock()
|
||||
defer fileStorageNameLocksMu.Unlock()
|
||||
fw, ok := fileStorageNameLocks[s.caURL+name]
|
||||
if !ok {
|
||||
return fmt.Errorf("FileStorage: no lock to release for %s", name)
|
||||
}
|
||||
os.Remove(fw.filename)
|
||||
fw.wg.Done()
|
||||
delete(fileStorageNameLocks, s.caURL+name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// fileWaiter waits for a file to disappear; it polls
|
||||
// the file system to check for the existence of a file.
|
||||
// It also has a WaitGroup which will be faster than
|
||||
// polling, for when locking need only happen within this
|
||||
// process.
|
||||
type fileWaiter struct {
|
||||
filename string
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
// Wait waits until the lock is released.
|
||||
func (fw *fileWaiter) Wait() {
|
||||
start := time.Now()
|
||||
fw.wg.Wait()
|
||||
for time.Since(start) < 1*time.Hour {
|
||||
_, err := os.Stat(fw.filename)
|
||||
if os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
var fileStorageNameLocks = make(map[string]*fileWaiter) // keyed by CA + name
|
||||
var fileStorageNameLocksMu sync.Mutex
|
||||
|
||||
var _ Locker = &fileStorageLock{}
|
||||
var _ Waiter = &fileWaiter{}
|
|
@ -61,15 +61,15 @@ func (cg configGroup) getConfig(name string) *Config {
|
|||
}
|
||||
}
|
||||
|
||||
// as a fallback, try a config that serves all names
|
||||
// try a config that serves all names (this
|
||||
// is basically the same as a config defined
|
||||
// for "*" -- I think -- but the above loop
|
||||
// doesn't try an empty string)
|
||||
if config, ok := cg[""]; ok {
|
||||
return config
|
||||
}
|
||||
|
||||
// as a last resort, use a random config
|
||||
// (even if the config isn't for that hostname,
|
||||
// it should help us serve clients without SNI
|
||||
// or at least defer TLS alerts to the cert)
|
||||
// no matches, so just serve up a random config
|
||||
for _, config := range cg {
|
||||
return config
|
||||
}
|
||||
|
@ -121,6 +121,86 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
|
|||
return &cert.Certificate, err
|
||||
}
|
||||
|
||||
// getCertificate gets a certificate that matches name (a server name)
|
||||
// from the in-memory cache, according to the lookup table associated with
|
||||
// cfg. The lookup then points to a certificate in the Instance certificate
|
||||
// cache.
|
||||
//
|
||||
// If there is no exact match for name, it will be checked against names of
|
||||
// the form '*.example.com' (wildcard certificates) according to RFC 6125.
|
||||
// If a match is found, matched will be true. If no matches are found, matched
|
||||
// will be false and a "default" certificate will be returned with defaulted
|
||||
// set to true. If defaulted is false, then no certificates were available.
|
||||
//
|
||||
// The logic in this function is adapted from the Go standard library,
|
||||
// which is by the Go Authors.
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defaulted bool) {
|
||||
var certKey string
|
||||
var ok bool
|
||||
|
||||
// Not going to trim trailing dots here since RFC 3546 says,
|
||||
// "The hostname is represented ... without a trailing dot."
|
||||
// Just normalize to lowercase.
|
||||
name = strings.ToLower(name)
|
||||
|
||||
cfg.certCache.RLock()
|
||||
defer cfg.certCache.RUnlock()
|
||||
|
||||
// exact match? great, let's use it
|
||||
if certKey, ok = cfg.Certificates[name]; ok {
|
||||
cert = cfg.certCache.cache[certKey]
|
||||
matched = true
|
||||
return
|
||||
}
|
||||
|
||||
// try replacing labels in the name with wildcards until we get a match
|
||||
labels := strings.Split(name, ".")
|
||||
for i := range labels {
|
||||
labels[i] = "*"
|
||||
candidate := strings.Join(labels, ".")
|
||||
if certKey, ok = cfg.Certificates[candidate]; ok {
|
||||
cert = cfg.certCache.cache[certKey]
|
||||
matched = true
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// check the certCache directly to see if the SNI name is
|
||||
// already the key of the certificate it wants! this is vital
|
||||
// for supporting the TLS-SNI challenge, since the tlsSNISolver
|
||||
// just puts the temporary certificate in the instance cache,
|
||||
// with no regard for configs; this also means that the SNI
|
||||
// can contain the hash of a specific cert (chain) it wants
|
||||
// and we will still be able to serve it up
|
||||
// (this behavior, by the way, could be controversial as to
|
||||
// whether it complies with RFC 6066 about SNI, but I think
|
||||
// it does soooo...)
|
||||
// NOTE/TODO: TLS-SNI challenge is changing, as of Jan. 2018
|
||||
// but what will be different, if it ever returns, is unclear
|
||||
if directCert, ok := cfg.certCache.cache[name]; ok {
|
||||
cert = directCert
|
||||
matched = true
|
||||
return
|
||||
}
|
||||
|
||||
// if nothing matches and SNI was not provided, use a random
|
||||
// certificate; at least there's a chance this older client
|
||||
// can connect, and in the future we won't need this provision
|
||||
// (if SNI is present, it's probably best to just raise a TLS
|
||||
// alert by not serving a certificate)
|
||||
if name == "" {
|
||||
for _, certKey := range cfg.Certificates {
|
||||
defaulted = true
|
||||
cert = cfg.certCache.cache[certKey]
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// getCertDuringHandshake will get a certificate for name. It first tries
|
||||
// the in-memory cache. If no certificate for name is in the cache, the
|
||||
// config most closely corresponding to name will be loaded. If that config
|
||||
|
@ -134,7 +214,7 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
|
|||
// This function is safe for concurrent use.
|
||||
func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
|
||||
// First check our in-memory cache to see if we've already loaded it
|
||||
cert, matched, defaulted := getCertificate(name)
|
||||
cert, matched, defaulted := cfg.getCertificate(name)
|
||||
if matched {
|
||||
return cert, nil
|
||||
}
|
||||
|
@ -277,7 +357,7 @@ func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) {
|
|||
obtainCertWaitChans[name] = wait
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
|
||||
// do the obtain
|
||||
// obtain the certificate
|
||||
log.Printf("[INFO] Obtaining new certificate for %s", name)
|
||||
err := cfg.ObtainCert(name, false)
|
||||
|
||||
|
@ -336,9 +416,9 @@ func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certific
|
|||
// quite common considering not all certs have issuer URLs that support it.
|
||||
log.Printf("[ERROR] Getting OCSP for %s: %v", name, err)
|
||||
}
|
||||
certCacheMu.Lock()
|
||||
certCache[name] = cert
|
||||
certCacheMu.Unlock()
|
||||
cfg.certCache.Lock()
|
||||
cfg.certCache.cache[cert.Hash] = cert
|
||||
cfg.certCache.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -367,29 +447,22 @@ func (cfg *Config) renewDynamicCertificate(name string, currentCert Certificate)
|
|||
obtainCertWaitChans[name] = wait
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
|
||||
// do the renew and reload the certificate
|
||||
// renew and reload the certificate
|
||||
log.Printf("[INFO] Renewing certificate for %s", name)
|
||||
err := cfg.RenewCert(name, false)
|
||||
if err == nil {
|
||||
// immediately flush this certificate from the cache so
|
||||
// the name doesn't overlap when we try to replace it,
|
||||
// which would fail, because overlapping existing cert
|
||||
// names isn't allowed
|
||||
certCacheMu.Lock()
|
||||
for _, certName := range currentCert.Names {
|
||||
delete(certCache, certName)
|
||||
}
|
||||
certCacheMu.Unlock()
|
||||
|
||||
// even though the recursive nature of the dynamic cert loading
|
||||
// would just call this function anyway, we do it here to
|
||||
// make the replacement as atomic as possible. (TODO: similar
|
||||
// to the note in maintain.go, it'd be nice if the clearing of
|
||||
// the cache entries above and this load function were truly
|
||||
// atomic...)
|
||||
_, err := currentCert.Config.CacheManagedCertificate(name)
|
||||
// make the replacement as atomic as possible.
|
||||
newCert, err := currentCert.configs[0].CacheManagedCertificate(name)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] loading renewed certificate: %v", err)
|
||||
log.Printf("[ERROR] loading renewed certificate for %s: %v", name, err)
|
||||
} else {
|
||||
// replace the old certificate with the new one
|
||||
err = cfg.certCache.replaceCertificate(currentCert, newCert)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Replacing certificate for %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -21,9 +21,8 @@ import (
|
|||
)
|
||||
|
||||
func TestGetCertificate(t *testing.T) {
|
||||
defer func() { certCache = make(map[string]Certificate) }()
|
||||
|
||||
cfg := new(Config)
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
|
||||
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
|
||||
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
|
||||
|
@ -38,33 +37,40 @@ func TestGetCertificate(t *testing.T) {
|
|||
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
|
||||
}
|
||||
|
||||
// When cache has one certificate in it (also is default)
|
||||
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
|
||||
certCache[""] = defaultCert
|
||||
certCache["example.com"] = defaultCert
|
||||
// When cache has one certificate in it
|
||||
firstCert := Certificate{Names: []string{"example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
|
||||
cfg.cacheCertificate(firstCert)
|
||||
if cert, err := cfg.GetCertificate(hello); err != nil {
|
||||
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
|
||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
|
||||
}
|
||||
if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
|
||||
if _, err := cfg.GetCertificate(helloNoSNI); err != nil {
|
||||
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
|
||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
|
||||
}
|
||||
|
||||
// When retrieving wildcard certificate
|
||||
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
|
||||
wildcardCert := Certificate{
|
||||
Names: []string{"*.example.com"},
|
||||
Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}},
|
||||
Hash: "(don't overwrite the first one)",
|
||||
}
|
||||
cfg.cacheCertificate(wildcardCert)
|
||||
if cert, err := cfg.GetCertificate(helloSub); err != nil {
|
||||
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
|
||||
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
|
||||
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
|
||||
}
|
||||
|
||||
// When no certificate matches, the default is returned
|
||||
if cert, err := cfg.GetCertificate(helloNoMatch); err != nil {
|
||||
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
|
||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||
t.Errorf("Expected default cert with no matches, got: %v", cert)
|
||||
// When cache is NOT empty but there's no SNI
|
||||
if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
|
||||
t.Errorf("Expected random certificate with no error when no SNI, got err: %v", err)
|
||||
} else if cert == nil || len(cert.Leaf.DNSNames) == 0 {
|
||||
t.Errorf("Expected random cert with no matches, got: %v", cert)
|
||||
}
|
||||
|
||||
// When no certificate matches, raise an alert
|
||||
if _, err := cfg.GetCertificate(helloNoMatch); err == nil {
|
||||
t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,10 +27,11 @@ import (
|
|||
const challengeBasePath = "/.well-known/acme-challenge"
|
||||
|
||||
// HTTPChallengeHandler proxies challenge requests to ACME client if the
|
||||
// request path starts with challengeBasePath. It returns true if it
|
||||
// handled the request and no more needs to be done; it returns false
|
||||
// if this call was a no-op and the request still needs handling.
|
||||
func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, altPort string) bool {
|
||||
// request path starts with challengeBasePath, if the HTTP challenge is not
|
||||
// disabled, and if we are known to be obtaining a certificate for the name.
|
||||
// It returns true if it handled the request and no more needs to be done;
|
||||
// it returns false if this call was a no-op and the request still needs handling.
|
||||
func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost string) bool {
|
||||
if !strings.HasPrefix(r.URL.Path, challengeBasePath) {
|
||||
return false
|
||||
}
|
||||
|
@ -50,7 +51,11 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, al
|
|||
listenHost = "localhost"
|
||||
}
|
||||
|
||||
upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, altPort))
|
||||
// always proxy to the DefaultHTTPAlternatePort because obviously the
|
||||
// ACME challenge request already got into one of our HTTP handlers, so
|
||||
// it means we must have started a HTTP listener on the alternate
|
||||
// port instead; which is only accessible via listenHost
|
||||
upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, DefaultHTTPAlternatePort))
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
log.Printf("[ERROR] ACME proxy handler: %v", err)
|
||||
|
|
|
@ -39,7 +39,7 @@ func TestHTTPChallengeHandlerNoOp(t *testing.T) {
|
|||
t.Fatalf("Could not craft request, got error: %v", err)
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
if HTTPChallengeHandler(rw, req, "", DefaultHTTPAlternatePort) {
|
||||
if HTTPChallengeHandler(rw, req, "") {
|
||||
t.Errorf("Got true with this URL, but shouldn't have: %s", url)
|
||||
}
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ func TestHTTPChallengeHandlerSuccess(t *testing.T) {
|
|||
}
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
HTTPChallengeHandler(rw, req, "", DefaultHTTPAlternatePort)
|
||||
HTTPChallengeHandler(rw, req, "")
|
||||
|
||||
if !proxySuccess {
|
||||
t.Fatal("Expected request to be proxied, but it wasn't")
|
||||
|
|
|
@ -87,103 +87,127 @@ func maintainAssets(stopChan chan struct{}) {
|
|||
// RenewManagedCertificates renews managed certificates,
|
||||
// including ones loaded on-demand.
|
||||
func RenewManagedCertificates(allowPrompts bool) (err error) {
|
||||
var renewQueue, deleteQueue []Certificate
|
||||
visitedNames := make(map[string]struct{})
|
||||
|
||||
certCacheMu.RLock()
|
||||
for name, cert := range certCache {
|
||||
if !cert.Config.Managed || cert.Config.SelfSigned {
|
||||
for _, inst := range caddy.Instances() {
|
||||
inst.StorageMu.RLock()
|
||||
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
|
||||
inst.StorageMu.RUnlock()
|
||||
if !ok || certCache == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// the list of names on this cert should never be empty...
|
||||
if cert.Names == nil || len(cert.Names) == 0 {
|
||||
log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v - removing from cache", name, cert.Names)
|
||||
deleteQueue = append(deleteQueue, cert)
|
||||
continue
|
||||
}
|
||||
// we use the queues for a very important reason: to do any and all
|
||||
// operations that could require an exclusive write lock outside
|
||||
// of the read lock! otherwise we get a deadlock, yikes. in other
|
||||
// words, our first iteration through the certificate cache does NOT
|
||||
// perform any operations--only queues them--so that more fine-grained
|
||||
// write locks may be obtained during the actual operations.
|
||||
var renewQueue, reloadQueue, deleteQueue []Certificate
|
||||
|
||||
// skip names whose certificate we've already renewed
|
||||
if _, ok := visitedNames[name]; ok {
|
||||
continue
|
||||
}
|
||||
for _, name := range cert.Names {
|
||||
visitedNames[name] = struct{}{}
|
||||
}
|
||||
|
||||
// if its time is up or ending soon, we need to try to renew it
|
||||
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
|
||||
if timeLeft < RenewDurationBefore {
|
||||
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
|
||||
|
||||
if cert.Config == nil {
|
||||
log.Printf("[ERROR] %s: No associated TLS config; unable to renew", name)
|
||||
certCache.RLock()
|
||||
for certKey, cert := range certCache.cache {
|
||||
if len(cert.configs) == 0 {
|
||||
// this is bad if this happens, probably a programmer error (oops)
|
||||
log.Printf("[ERROR] No associated TLS config for certificate with names %v; unable to manage", cert.Names)
|
||||
continue
|
||||
}
|
||||
if !cert.configs[0].Managed || cert.configs[0].SelfSigned {
|
||||
continue
|
||||
}
|
||||
|
||||
// queue for renewal when we aren't in a read lock anymore
|
||||
// (the TLS-SNI challenge will need a write lock in order to
|
||||
// present the certificate, so we renew outside of read lock)
|
||||
renewQueue = append(renewQueue, cert)
|
||||
}
|
||||
}
|
||||
certCacheMu.RUnlock()
|
||||
|
||||
// Perform renewals that are queued
|
||||
for _, cert := range renewQueue {
|
||||
// Get the name which we should use to renew this certificate;
|
||||
// we only support managing certificates with one name per cert,
|
||||
// so this should be easy. We can't rely on cert.Config.Hostname
|
||||
// because it may be a wildcard value from the Caddyfile (e.g.
|
||||
// *.something.com) which, as of Jan. 2017, is not supported by ACME.
|
||||
var renewName string
|
||||
for _, name := range cert.Names {
|
||||
if name != "" {
|
||||
renewName = name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// perform renewal
|
||||
err := cert.Config.RenewCert(renewName, allowPrompts)
|
||||
if err != nil {
|
||||
if allowPrompts {
|
||||
// Certificate renewal failed and the operator is present. See a discussion
|
||||
// about this in issue 642. For a while, we only stopped if the certificate
|
||||
// was expired, but in reality, there is no difference between reporting
|
||||
// it now versus later, except that there's somebody present to deal with
|
||||
// it right now.
|
||||
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
|
||||
if timeLeft < RenewDurationBeforeAtStartup {
|
||||
// See issue 1680. Only fail at startup if the certificate is dangerously
|
||||
// close to expiration.
|
||||
return err
|
||||
}
|
||||
}
|
||||
log.Printf("[ERROR] %v", err)
|
||||
if cert.Config.OnDemand {
|
||||
// loaded dynamically, removed dynamically
|
||||
// the list of names on this cert should never be empty... programmer error?
|
||||
if cert.Names == nil || len(cert.Names) == 0 {
|
||||
log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v - removing from cache", certKey, cert.Names)
|
||||
deleteQueue = append(deleteQueue, cert)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
|
||||
// if time is up or expires soon, we need to try to renew it
|
||||
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
|
||||
if timeLeft < RenewDurationBefore {
|
||||
// see if the certificate in storage has already been renewed, possibly by another
|
||||
// instance of Caddy that didn't coordinate with this one; if so, just load it (this
|
||||
// might happen if another instance already renewed it - kinda sloppy but checking disk
|
||||
// first is a simple way to possibly drastically reduce rate limit problems)
|
||||
storedCertExpiring, err := managedCertInStorageExpiresSoon(cert)
|
||||
if err != nil {
|
||||
// hmm, weird, but not a big deal, maybe it was deleted or something
|
||||
log.Printf("[NOTICE] Error while checking if certificate for %v in storage is also expiring soon: %v",
|
||||
cert.Names, err)
|
||||
} else if !storedCertExpiring {
|
||||
// if the certificate is NOT expiring soon and there was no error, then we
|
||||
// are good to just reload the certificate from storage instead of repeating
|
||||
// a likely-unnecessary renewal procedure
|
||||
reloadQueue = append(reloadQueue, cert)
|
||||
continue
|
||||
}
|
||||
|
||||
// the certificate in storage has not been renewed yet, so we will do it
|
||||
// NOTE 1: This is not correct 100% of the time, if multiple Caddy instances
|
||||
// happen to run their maintenance checks at approximately the same times;
|
||||
// both might start renewal at about the same time and do two renewals and one
|
||||
// will overwrite the other. Hence TLS storage plugins. This is sort of a TODO.
|
||||
// NOTE 2: It is super-important to note that the TLS-SNI challenge requires
|
||||
// a write lock on the cache in order to complete its challenge, so it is extra
|
||||
// vital that this renew operation does not happen inside our read lock!
|
||||
renewQueue = append(renewQueue, cert)
|
||||
}
|
||||
}
|
||||
certCache.RUnlock()
|
||||
|
||||
// Reload certificates that merely need to be updated in memory
|
||||
for _, oldCert := range reloadQueue {
|
||||
timeLeft := oldCert.NotAfter.Sub(time.Now().UTC())
|
||||
log.Printf("[INFO] Certificate for %v expires in %v, but is already renewed in storage; reloading stored certificate",
|
||||
oldCert.Names, timeLeft)
|
||||
|
||||
err = certCache.reloadManagedCertificate(oldCert)
|
||||
if err != nil {
|
||||
if allowPrompts {
|
||||
return err // operator is present, so report error immediately
|
||||
}
|
||||
log.Printf("[ERROR] Loading renewed certificate: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Renewal queue
|
||||
for _, oldCert := range renewQueue {
|
||||
timeLeft := oldCert.NotAfter.Sub(time.Now().UTC())
|
||||
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", oldCert.Names, timeLeft)
|
||||
|
||||
// Get the name which we should use to renew this certificate;
|
||||
// we only support managing certificates with one name per cert,
|
||||
// so this should be easy. We can't rely on cert.Config.Hostname
|
||||
// because it may be a wildcard value from the Caddyfile (e.g.
|
||||
// *.something.com) which, as of Jan. 2017, is not supported by ACME.
|
||||
// TODO: ^ ^ ^ (wildcards)
|
||||
renewName := oldCert.Names[0]
|
||||
|
||||
// perform renewal
|
||||
err := oldCert.configs[0].RenewCert(renewName, allowPrompts)
|
||||
if err != nil {
|
||||
if allowPrompts {
|
||||
// Certificate renewal failed and the operator is present. See a discussion
|
||||
// about this in issue 642. For a while, we only stopped if the certificate
|
||||
// was expired, but in reality, there is no difference between reporting
|
||||
// it now versus later, except that there's somebody present to deal with
|
||||
// it right now. Follow-up: See issue 1680. Only fail in this case if the
|
||||
// certificate is dangerously close to expiration.
|
||||
timeLeft := oldCert.NotAfter.Sub(time.Now().UTC())
|
||||
if timeLeft < RenewDurationBeforeAtStartup {
|
||||
return err
|
||||
}
|
||||
}
|
||||
log.Printf("[ERROR] %v", err)
|
||||
if oldCert.configs[0].OnDemand {
|
||||
// loaded dynamically, remove dynamically
|
||||
deleteQueue = append(deleteQueue, oldCert)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// successful renewal, so update in-memory cache by loading
|
||||
// renewed certificate so it will be used with handshakes
|
||||
|
||||
// we must delete all the names this cert services from the cache
|
||||
// so that we can replace the certificate, because replacing names
|
||||
// already in the cache is not allowed, to avoid later conflicts
|
||||
// with renewals.
|
||||
// TODO: It would be nice if this whole operation were idempotent;
|
||||
// i.e. a thread-safe function to replace a certificate in the cache,
|
||||
// see also handshake.go for on-demand maintenance.
|
||||
certCacheMu.Lock()
|
||||
for _, name := range cert.Names {
|
||||
delete(certCache, name)
|
||||
}
|
||||
certCacheMu.Unlock()
|
||||
|
||||
// put the certificate in the cache
|
||||
_, err := cert.Config.CacheManagedCertificate(cert.Names[0])
|
||||
err = certCache.reloadManagedCertificate(oldCert)
|
||||
if err != nil {
|
||||
if allowPrompts {
|
||||
return err // operator is present, so report error immediately
|
||||
|
@ -191,15 +215,22 @@ func RenewManagedCertificates(allowPrompts bool) (err error) {
|
|||
log.Printf("[ERROR] %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply queued deletion changes to the cache
|
||||
for _, cert := range deleteQueue {
|
||||
certCacheMu.Lock()
|
||||
for _, name := range cert.Names {
|
||||
delete(certCache, name)
|
||||
// Deletion queue
|
||||
for _, cert := range deleteQueue {
|
||||
certCache.Lock()
|
||||
// remove any pointers to this certificate from Configs
|
||||
for _, cfg := range cert.configs {
|
||||
for name, certKey := range cfg.Certificates {
|
||||
if certKey == cert.Hash {
|
||||
delete(cfg.Certificates, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
// then delete the certificate from the cache
|
||||
delete(certCache.cache, cert.Hash)
|
||||
certCache.Unlock()
|
||||
}
|
||||
certCacheMu.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -212,91 +243,75 @@ func RenewManagedCertificates(allowPrompts bool) (err error) {
|
|||
// Ryan Sleevi's recommendations for good OCSP support:
|
||||
// https://gist.github.com/sleevi/5efe9ef98961ecfb4da8
|
||||
func UpdateOCSPStaples() {
|
||||
// Create a temporary place to store updates
|
||||
// until we release the potentially long-lived
|
||||
// read lock and use a short-lived write lock.
|
||||
type ocspUpdate struct {
|
||||
rawBytes []byte
|
||||
parsed *ocsp.Response
|
||||
}
|
||||
updated := make(map[string]ocspUpdate)
|
||||
|
||||
// A single SAN certificate maps to multiple names, so we use this
|
||||
// set to make sure we don't waste cycles checking OCSP for the same
|
||||
// certificate multiple times.
|
||||
visited := make(map[string]struct{})
|
||||
|
||||
certCacheMu.RLock()
|
||||
for name, cert := range certCache {
|
||||
// skip this certificate if we've already visited it,
|
||||
// and if not, mark all the names as visited
|
||||
if _, ok := visited[name]; ok {
|
||||
continue
|
||||
}
|
||||
for _, n := range cert.Names {
|
||||
visited[n] = struct{}{}
|
||||
}
|
||||
|
||||
// no point in updating OCSP for expired certificates
|
||||
if time.Now().After(cert.NotAfter) {
|
||||
for _, inst := range caddy.Instances() {
|
||||
inst.StorageMu.RLock()
|
||||
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
|
||||
inst.StorageMu.RUnlock()
|
||||
if !ok || certCache == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var lastNextUpdate time.Time
|
||||
if cert.OCSP != nil {
|
||||
lastNextUpdate = cert.OCSP.NextUpdate
|
||||
if freshOCSP(cert.OCSP) {
|
||||
// no need to update staple if ours is still fresh
|
||||
// Create a temporary place to store updates
|
||||
// until we release the potentially long-lived
|
||||
// read lock and use a short-lived write lock
|
||||
// on the certificate cache.
|
||||
type ocspUpdate struct {
|
||||
rawBytes []byte
|
||||
parsed *ocsp.Response
|
||||
}
|
||||
updated := make(map[string]ocspUpdate)
|
||||
|
||||
certCache.RLock()
|
||||
for certHash, cert := range certCache.cache {
|
||||
// no point in updating OCSP for expired certificates
|
||||
if time.Now().After(cert.NotAfter) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
err := stapleOCSP(&cert, nil)
|
||||
if err != nil {
|
||||
var lastNextUpdate time.Time
|
||||
if cert.OCSP != nil {
|
||||
// if there was no staple before, that's fine; otherwise we should log the error
|
||||
log.Printf("[ERROR] Checking OCSP: %v", err)
|
||||
lastNextUpdate = cert.OCSP.NextUpdate
|
||||
if freshOCSP(cert.OCSP) {
|
||||
continue // no need to update staple if ours is still fresh
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// By this point, we've obtained the latest OCSP response.
|
||||
// If there was no staple before, or if the response is updated, make
|
||||
// sure we apply the update to all names on the certificate.
|
||||
if cert.OCSP != nil && (lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate) {
|
||||
log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
|
||||
cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
|
||||
for _, n := range cert.Names {
|
||||
// BUG: If this certificate has names on it that appear on another
|
||||
// certificate in the cache, AND the other certificate is keyed by
|
||||
// that name in the cache, then this method of 'queueing' the staple
|
||||
// update will cause this certificate's new OCSP to be stapled to
|
||||
// a different certificate! See:
|
||||
// https://caddy.community/t/random-ocsp-response-errors-for-random-clients/2473?u=matt
|
||||
// This problem should be avoided if names on certificates in the
|
||||
// cache don't overlap with regards to the cache keys.
|
||||
// (This is isn't a bug anymore, since we're careful when we add
|
||||
// certificates to the cache by skipping keying when key already exists.)
|
||||
updated[n] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
|
||||
err := stapleOCSP(&cert, nil)
|
||||
if err != nil {
|
||||
if cert.OCSP != nil {
|
||||
// if there was no staple before, that's fine; otherwise we should log the error
|
||||
log.Printf("[ERROR] Checking OCSP: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// By this point, we've obtained the latest OCSP response.
|
||||
// If there was no staple before, or if the response is updated, make
|
||||
// sure we apply the update to all names on the certificate.
|
||||
if cert.OCSP != nil && (lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate) {
|
||||
log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
|
||||
cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
|
||||
updated[certHash] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
|
||||
}
|
||||
}
|
||||
}
|
||||
certCacheMu.RUnlock()
|
||||
certCache.RUnlock()
|
||||
|
||||
// This write lock should be brief since we have all the info we need now.
|
||||
certCacheMu.Lock()
|
||||
for name, update := range updated {
|
||||
cert := certCache[name]
|
||||
cert.OCSP = update.parsed
|
||||
cert.Certificate.OCSPStaple = update.rawBytes
|
||||
certCache[name] = cert
|
||||
// These write locks should be brief since we have all the info we need now.
|
||||
for certKey, update := range updated {
|
||||
certCache.Lock()
|
||||
cert := certCache.cache[certKey]
|
||||
cert.OCSP = update.parsed
|
||||
cert.Certificate.OCSPStaple = update.rawBytes
|
||||
certCache.cache[certKey] = cert
|
||||
certCache.Unlock()
|
||||
}
|
||||
}
|
||||
certCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// DeleteOldStapleFiles deletes cached OCSP staples that have expired.
|
||||
// TODO: Should we do this for certificates too?
|
||||
func DeleteOldStapleFiles() {
|
||||
// TODO: Upgrade caddytls.Storage to support OCSP operations too
|
||||
files, err := ioutil.ReadDir(ocspFolder)
|
||||
if err != nil {
|
||||
// maybe just hasn't been created yet; no big deal
|
||||
|
|
|
@ -38,6 +38,7 @@ func init() {
|
|||
// are specified by the user in the config file. All the automatic HTTPS
|
||||
// stuff comes later outside of this function.
|
||||
func setupTLS(c *caddy.Controller) error {
|
||||
// obtain the configGetter, which loads the config we're, uh, configuring
|
||||
configGetter, ok := configGetters[c.ServerType()]
|
||||
if !ok {
|
||||
return fmt.Errorf("no caddytls.ConfigGetter for %s server type; must call RegisterConfigGetter", c.ServerType())
|
||||
|
@ -47,6 +48,14 @@ func setupTLS(c *caddy.Controller) error {
|
|||
return fmt.Errorf("no caddytls.Config to set up for %s", c.Key)
|
||||
}
|
||||
|
||||
// the certificate cache is tied to the current caddy.Instance; get a pointer to it
|
||||
certCache, ok := c.Get(CertCacheInstStorageKey).(*certificateCache)
|
||||
if !ok || certCache == nil {
|
||||
certCache = &certificateCache{cache: make(map[string]Certificate)}
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
}
|
||||
config.certCache = certCache
|
||||
|
||||
config.Enabled = true
|
||||
|
||||
for c.Next() {
|
||||
|
@ -237,7 +246,7 @@ func setupTLS(c *caddy.Controller) error {
|
|||
|
||||
// load a single certificate and key, if specified
|
||||
if certificateFile != "" && keyFile != "" {
|
||||
err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
|
||||
err := config.cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
|
||||
if err != nil {
|
||||
return c.Errf("Unable to load certificate and key files for '%s': %v", c.Key, err)
|
||||
}
|
||||
|
@ -246,7 +255,7 @@ func setupTLS(c *caddy.Controller) error {
|
|||
|
||||
// load a directory of certificates, if specified
|
||||
if loadDir != "" {
|
||||
err := loadCertsInDir(c, loadDir)
|
||||
err := loadCertsInDir(config, c, loadDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -273,7 +282,7 @@ func setupTLS(c *caddy.Controller) error {
|
|||
// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
|
||||
//
|
||||
// This function may write to the log as it walks the directory tree.
|
||||
func loadCertsInDir(c *caddy.Controller, dir string) error {
|
||||
func loadCertsInDir(cfg *Config, c *caddy.Controller, dir string) error {
|
||||
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
|
||||
|
@ -336,7 +345,7 @@ func loadCertsInDir(c *caddy.Controller, dir string) error {
|
|||
return c.Errf("%s: no private key block found", path)
|
||||
}
|
||||
|
||||
err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
|
||||
err = cfg.cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
|
||||
if err != nil {
|
||||
return c.Errf("%s: failed to load cert and key for '%s': %v", path, c.Key, err)
|
||||
}
|
||||
|
|
|
@ -46,9 +46,12 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
func TestSetupParseBasic(t *testing.T) {
|
||||
cfg := new(Config)
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", `tls `+certFile+` `+keyFile+``)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err != nil {
|
||||
|
@ -124,9 +127,12 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
|
|||
must_staple
|
||||
alpn http/1.1
|
||||
}`
|
||||
cfg := new(Config)
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err != nil {
|
||||
|
@ -158,9 +164,11 @@ func TestSetupDefaultWithOptionalParams(t *testing.T) {
|
|||
params := `tls {
|
||||
ciphers RSA-3DES-EDE-CBC-SHA
|
||||
}`
|
||||
cfg := new(Config)
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err != nil {
|
||||
|
@ -176,9 +184,12 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
|
|||
params := `tls ` + certFile + ` ` + keyFile + ` {
|
||||
protocols ssl tls
|
||||
}`
|
||||
cfg := new(Config)
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err == nil {
|
||||
t.Errorf("Expected errors, but no error returned")
|
||||
|
@ -191,6 +202,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
|
|||
cfg = new(Config)
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c = caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
err = setupTLS(c)
|
||||
if err == nil {
|
||||
t.Error("Expected errors, but no error returned")
|
||||
|
@ -215,6 +227,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
|
|||
cfg = new(Config)
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c = caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
err = setupTLS(c)
|
||||
if err == nil {
|
||||
t.Error("Expected errors, but no error returned")
|
||||
|
@ -226,7 +239,8 @@ func TestSetupParseWithClientAuth(t *testing.T) {
|
|||
params := `tls ` + certFile + ` ` + keyFile + ` {
|
||||
clients
|
||||
}`
|
||||
cfg := new(Config)
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
err := setupTLS(c)
|
||||
|
@ -259,9 +273,11 @@ func TestSetupParseWithClientAuth(t *testing.T) {
|
|||
clients verify_if_given
|
||||
}`, tls.VerifyClientCertIfGiven, true, noCAs},
|
||||
} {
|
||||
cfg := new(Config)
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", caseData.params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
err := setupTLS(c)
|
||||
if caseData.expectedErr {
|
||||
if err == nil {
|
||||
|
@ -311,9 +327,11 @@ func TestSetupParseWithCAUrl(t *testing.T) {
|
|||
ca 1 2
|
||||
}`, true, ""},
|
||||
} {
|
||||
cfg := new(Config)
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", caseData.params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
err := setupTLS(c)
|
||||
if caseData.expectedErr {
|
||||
if err == nil {
|
||||
|
@ -335,9 +353,11 @@ func TestSetupParseWithKeyType(t *testing.T) {
|
|||
params := `tls {
|
||||
key_type p384
|
||||
}`
|
||||
cfg := new(Config)
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err != nil {
|
||||
|
@ -353,9 +373,11 @@ func TestSetupParseWithCurves(t *testing.T) {
|
|||
params := `tls {
|
||||
curves x25519 p256 p384 p521
|
||||
}`
|
||||
cfg := new(Config)
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err != nil {
|
||||
|
@ -380,9 +402,11 @@ func TestSetupParseWithOneTLSProtocol(t *testing.T) {
|
|||
params := `tls {
|
||||
protocols tls1.2
|
||||
}`
|
||||
cfg := new(Config)
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err != nil {
|
||||
|
|
|
@ -107,6 +107,10 @@ type Storage interface {
|
|||
// in StoreUser. The result is an empty string if there are no
|
||||
// persisted users in storage.
|
||||
MostRecentUserEmail() string
|
||||
|
||||
// Locker is necessary because synchronizing certificate maintenance
|
||||
// depends on how storage is implemented.
|
||||
Locker
|
||||
}
|
||||
|
||||
// ErrNotExist is returned by Storage implementations when
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var _ Locker = &syncLock{}
|
||||
|
||||
type syncLock struct {
|
||||
nameLocks map[string]*sync.WaitGroup
|
||||
nameLocksMu sync.Mutex
|
||||
}
|
||||
|
||||
// TryLock attempts to get a lock for name, otherwise it returns
|
||||
// a Waiter value to wait until the other process is finished.
|
||||
func (s *syncLock) TryLock(name string) (Waiter, error) {
|
||||
s.nameLocksMu.Lock()
|
||||
defer s.nameLocksMu.Unlock()
|
||||
wg, ok := s.nameLocks[name]
|
||||
if ok {
|
||||
// lock already obtained, let caller wait on it
|
||||
return wg, nil
|
||||
}
|
||||
// caller gets lock
|
||||
wg = new(sync.WaitGroup)
|
||||
wg.Add(1)
|
||||
s.nameLocks[name] = wg
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Unlock unlocks name.
|
||||
func (s *syncLock) Unlock(name string) error {
|
||||
s.nameLocksMu.Lock()
|
||||
defer s.nameLocksMu.Unlock()
|
||||
wg, ok := s.nameLocks[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("FileStorage: no lock to release for %s", name)
|
||||
}
|
||||
wg.Done()
|
||||
delete(s.nameLocks, name)
|
||||
return nil
|
||||
}
|
|
@ -88,30 +88,38 @@ func Revoke(host string) error {
|
|||
return client.Revoke(host)
|
||||
}
|
||||
|
||||
// tlsSniSolver is a type that can solve tls-sni challenges using
|
||||
// tlsSNISolver is a type that can solve TLS-SNI challenges using
|
||||
// an existing listener and our custom, in-memory certificate cache.
|
||||
type tlsSniSolver struct{}
|
||||
type tlsSNISolver struct {
|
||||
certCache *certificateCache
|
||||
}
|
||||
|
||||
// Present adds the challenge certificate to the cache.
|
||||
func (s tlsSniSolver) Present(domain, token, keyAuth string) error {
|
||||
func (s tlsSNISolver) Present(domain, token, keyAuth string) error {
|
||||
cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cacheCertificate(Certificate{
|
||||
certHash := hashCertificateChain(cert.Certificate)
|
||||
s.certCache.Lock()
|
||||
s.certCache.cache[acmeDomain] = Certificate{
|
||||
Certificate: cert,
|
||||
Names: []string{acmeDomain},
|
||||
})
|
||||
Hash: certHash, // perhaps not necesssary
|
||||
}
|
||||
s.certCache.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanUp removes the challenge certificate from the cache.
|
||||
func (s tlsSniSolver) CleanUp(domain, token, keyAuth string) error {
|
||||
func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error {
|
||||
_, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uncacheCertificate(acmeDomain)
|
||||
s.certCache.Lock()
|
||||
delete(s.certCache.cache, acmeDomain)
|
||||
s.certCache.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -103,6 +103,20 @@ func (c *Controller) Context() Context {
|
|||
return c.instance.context
|
||||
}
|
||||
|
||||
// Get safely gets a value from the Instance's storage.
|
||||
func (c *Controller) Get(key interface{}) interface{} {
|
||||
c.instance.StorageMu.RLock()
|
||||
defer c.instance.StorageMu.RUnlock()
|
||||
return c.instance.Storage[key]
|
||||
}
|
||||
|
||||
// Set safely sets a value on the Instance's storage.
|
||||
func (c *Controller) Set(key, val interface{}) {
|
||||
c.instance.StorageMu.Lock()
|
||||
c.instance.Storage[key] = val
|
||||
c.instance.StorageMu.Unlock()
|
||||
}
|
||||
|
||||
// NewTestController creates a new Controller for
|
||||
// the server type and input specified. The filename
|
||||
// is "Testfile". If the server type is not empty and
|
||||
|
@ -113,12 +127,12 @@ func (c *Controller) Context() Context {
|
|||
// Used only for testing, but exported so plugins can
|
||||
// use this for convenience.
|
||||
func NewTestController(serverType, input string) *Controller {
|
||||
var ctx Context
|
||||
testInst := &Instance{serverType: serverType, Storage: make(map[interface{}]interface{})}
|
||||
if stype, err := getServerType(serverType); err == nil {
|
||||
ctx = stype.NewContext()
|
||||
testInst.context = stype.NewContext(testInst)
|
||||
}
|
||||
return &Controller{
|
||||
instance: &Instance{serverType: serverType, context: ctx},
|
||||
instance: testInst,
|
||||
Dispenser: caddyfile.NewDispenser("Testfile", strings.NewReader(input)),
|
||||
OncePerServerBlock: func(f func() error) error { return f() },
|
||||
}
|
||||
|
|
1
dist/init/linux-systemd/README.md
vendored
1
dist/init/linux-systemd/README.md
vendored
|
@ -91,6 +91,7 @@ Install the systemd service unit configuration file, reload the systemd daemon,
|
|||
and start caddy:
|
||||
|
||||
```bash
|
||||
wget https://raw.githubusercontent.com/mholt/caddy/master/dist/init/linux-systemd/caddy.service
|
||||
sudo cp caddy.service /etc/systemd/system/
|
||||
sudo chown root:root /etc/systemd/system/caddy.service
|
||||
sudo chmod 644 /etc/systemd/system/caddy.service
|
||||
|
|
4
dist/init/linux-systemd/caddy.service
vendored
4
dist/init/linux-systemd/caddy.service
vendored
|
@ -30,8 +30,8 @@ LimitNPROC=512
|
|||
|
||||
; Use private /tmp and /var/tmp, which are discarded after caddy stops.
|
||||
PrivateTmp=true
|
||||
; Use a minimal /dev
|
||||
PrivateDevices=true
|
||||
; Use a minimal /dev (May bring additional security if switched to 'true', but it may not work on Raspberry Pi's or other devices, so it has been disabled in this dist.)
|
||||
PrivateDevices=false
|
||||
; Hide /home, /root, and /run/user. Nobody will steal your SSH-keys.
|
||||
ProtectHome=true
|
||||
; Make /usr, /boot, /etc and possibly some more folders read-only.
|
||||
|
|
41
plugins.go
41
plugins.go
|
@ -19,6 +19,7 @@ import (
|
|||
"log"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
)
|
||||
|
@ -38,7 +39,7 @@ var (
|
|||
|
||||
// eventHooks is a map of hook name to Hook. All hooks plugins
|
||||
// must have a name.
|
||||
eventHooks = make(map[string]EventHook)
|
||||
eventHooks = sync.Map{}
|
||||
|
||||
// parsingCallbacks maps server type to map of directive
|
||||
// to list of callback functions. These aren't really
|
||||
|
@ -98,11 +99,15 @@ func ListPlugins() map[string][]string {
|
|||
p["caddyfile_loaders"] = append(p["caddyfile_loaders"], defaultCaddyfileLoader.name)
|
||||
}
|
||||
|
||||
// event hook plugins
|
||||
if len(eventHooks) > 0 {
|
||||
for name := range eventHooks {
|
||||
p["event_hooks"] = append(p["event_hooks"], name)
|
||||
}
|
||||
// List the event hook plugins
|
||||
hooks := ""
|
||||
eventHooks.Range(func(k, _ interface{}) bool {
|
||||
hooks += " hook." + k.(string) + "\n"
|
||||
return true
|
||||
})
|
||||
if hooks != "" {
|
||||
str += "\nEvent hook plugins:\n"
|
||||
str += hooks
|
||||
}
|
||||
|
||||
// alphabetize the rest of the plugins
|
||||
|
@ -220,7 +225,7 @@ type ServerType struct {
|
|||
// startup phases before this one. It's a way to keep
|
||||
// each set of server instances separate and to reduce
|
||||
// the amount of global state you need.
|
||||
NewContext func() Context
|
||||
NewContext func(inst *Instance) Context
|
||||
}
|
||||
|
||||
// Plugin is a type which holds information about a plugin.
|
||||
|
@ -277,23 +282,23 @@ func RegisterEventHook(name string, hook EventHook) {
|
|||
if name == "" {
|
||||
panic("event hook must have a name")
|
||||
}
|
||||
if _, dup := eventHooks[name]; dup {
|
||||
_, dup := eventHooks.LoadOrStore(name, hook)
|
||||
if dup {
|
||||
panic("hook named " + name + " already registered")
|
||||
}
|
||||
eventHooks[name] = hook
|
||||
}
|
||||
|
||||
// EmitEvent executes the different hooks passing the EventType as an
|
||||
// argument. This is a blocking function. Hook developers should
|
||||
// use 'go' keyword if they don't want to block Caddy.
|
||||
func EmitEvent(event EventName, info interface{}) {
|
||||
for name, hook := range eventHooks {
|
||||
err := hook(event, info)
|
||||
|
||||
eventHooks.Range(func(k, v interface{}) bool {
|
||||
err := v.(EventHook)(event, info)
|
||||
if err != nil {
|
||||
log.Printf("error on '%s' hook: %v", name, err)
|
||||
log.Printf("error on '%s' hook: %v", k.(string), err)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// ParsingCallback is a function that is called after
|
||||
|
@ -412,6 +417,14 @@ func loadCaddyfileInput(serverType string) (Input, error) {
|
|||
return caddyfileToUse, nil
|
||||
}
|
||||
|
||||
// OnProcessExit is a list of functions to run when the process
|
||||
// exits -- they are ONLY for cleanup and should not block,
|
||||
// return errors, or do anything fancy. They will be run with
|
||||
// every signal, even if "shutdown callbacks" are not executed.
|
||||
// This variable must only be modified in the main goroutine
|
||||
// from init() functions.
|
||||
var OnProcessExit []func()
|
||||
|
||||
// caddyfileLoader pairs the name of a loader to the loader.
|
||||
type caddyfileLoader struct {
|
||||
name string
|
||||
|
|
|
@ -44,16 +44,17 @@ func trapSignalsCrossPlatform() {
|
|||
|
||||
if i > 0 {
|
||||
log.Println("[INFO] SIGINT: Force quit")
|
||||
if PidFile != "" {
|
||||
os.Remove(PidFile)
|
||||
for _, f := range OnProcessExit {
|
||||
f() // important cleanup actions only
|
||||
}
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
log.Println("[INFO] SIGINT: Shutting down")
|
||||
|
||||
if PidFile != "" {
|
||||
os.Remove(PidFile)
|
||||
// important cleanup actions before shutdown callbacks
|
||||
for _, f := range OnProcessExit {
|
||||
f()
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
|
|
@ -33,22 +33,22 @@ func trapSignalsPosix() {
|
|||
switch sig {
|
||||
case syscall.SIGQUIT:
|
||||
log.Println("[INFO] SIGQUIT: Quitting process immediately")
|
||||
if PidFile != "" {
|
||||
os.Remove(PidFile)
|
||||
for _, f := range OnProcessExit {
|
||||
f() // only perform important cleanup actions
|
||||
}
|
||||
os.Exit(0)
|
||||
|
||||
case syscall.SIGTERM:
|
||||
log.Println("[INFO] SIGTERM: Shutting down servers then terminating")
|
||||
exitCode := executeShutdownCallbacks("SIGTERM")
|
||||
for _, f := range OnProcessExit {
|
||||
f() // only perform important cleanup actions
|
||||
}
|
||||
err := Stop()
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] SIGTERM stop: %v", err)
|
||||
exitCode = 3
|
||||
}
|
||||
if PidFile != "" {
|
||||
os.Remove(PidFile)
|
||||
}
|
||||
os.Exit(exitCode)
|
||||
|
||||
case syscall.SIGUSR1:
|
||||
|
|
21
vendor/github.com/bifurcation/mint/LICENSE.md
generated
vendored
Normal file
21
vendor/github.com/bifurcation/mint/LICENSE.md
generated
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2016 Richard Barnes
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
99
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
Normal file
99
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
Normal file
|
@ -0,0 +1,99 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mint
|
||||
|
||||
import "strconv"
|
||||
|
||||
type Alert uint8
|
||||
|
||||
const (
|
||||
// alert level
|
||||
AlertLevelWarning = 1
|
||||
AlertLevelError = 2
|
||||
)
|
||||
|
||||
const (
|
||||
AlertCloseNotify Alert = 0
|
||||
AlertUnexpectedMessage Alert = 10
|
||||
AlertBadRecordMAC Alert = 20
|
||||
AlertDecryptionFailed Alert = 21
|
||||
AlertRecordOverflow Alert = 22
|
||||
AlertDecompressionFailure Alert = 30
|
||||
AlertHandshakeFailure Alert = 40
|
||||
AlertBadCertificate Alert = 42
|
||||
AlertUnsupportedCertificate Alert = 43
|
||||
AlertCertificateRevoked Alert = 44
|
||||
AlertCertificateExpired Alert = 45
|
||||
AlertCertificateUnknown Alert = 46
|
||||
AlertIllegalParameter Alert = 47
|
||||
AlertUnknownCA Alert = 48
|
||||
AlertAccessDenied Alert = 49
|
||||
AlertDecodeError Alert = 50
|
||||
AlertDecryptError Alert = 51
|
||||
AlertProtocolVersion Alert = 70
|
||||
AlertInsufficientSecurity Alert = 71
|
||||
AlertInternalError Alert = 80
|
||||
AlertInappropriateFallback Alert = 86
|
||||
AlertUserCanceled Alert = 90
|
||||
AlertNoRenegotiation Alert = 100
|
||||
AlertMissingExtension Alert = 109
|
||||
AlertUnsupportedExtension Alert = 110
|
||||
AlertCertificateUnobtainable Alert = 111
|
||||
AlertUnrecognizedName Alert = 112
|
||||
AlertBadCertificateStatsResponse Alert = 113
|
||||
AlertBadCertificateHashValue Alert = 114
|
||||
AlertUnknownPSKIdentity Alert = 115
|
||||
AlertNoApplicationProtocol Alert = 120
|
||||
AlertWouldBlock Alert = 254
|
||||
AlertNoAlert Alert = 255
|
||||
)
|
||||
|
||||
var alertText = map[Alert]string{
|
||||
AlertCloseNotify: "close notify",
|
||||
AlertUnexpectedMessage: "unexpected message",
|
||||
AlertBadRecordMAC: "bad record MAC",
|
||||
AlertDecryptionFailed: "decryption failed",
|
||||
AlertRecordOverflow: "record overflow",
|
||||
AlertDecompressionFailure: "decompression failure",
|
||||
AlertHandshakeFailure: "handshake failure",
|
||||
AlertBadCertificate: "bad certificate",
|
||||
AlertUnsupportedCertificate: "unsupported certificate",
|
||||
AlertCertificateRevoked: "revoked certificate",
|
||||
AlertCertificateExpired: "expired certificate",
|
||||
AlertCertificateUnknown: "unknown certificate",
|
||||
AlertIllegalParameter: "illegal parameter",
|
||||
AlertUnknownCA: "unknown certificate authority",
|
||||
AlertAccessDenied: "access denied",
|
||||
AlertDecodeError: "error decoding message",
|
||||
AlertDecryptError: "error decrypting message",
|
||||
AlertProtocolVersion: "protocol version not supported",
|
||||
AlertInsufficientSecurity: "insufficient security level",
|
||||
AlertInternalError: "internal error",
|
||||
AlertInappropriateFallback: "inappropriate fallback",
|
||||
AlertUserCanceled: "user canceled",
|
||||
AlertMissingExtension: "missing extension",
|
||||
AlertUnsupportedExtension: "unsupported extension",
|
||||
AlertCertificateUnobtainable: "certificate unobtainable",
|
||||
AlertUnrecognizedName: "unrecognized name",
|
||||
AlertBadCertificateStatsResponse: "bad certificate status response",
|
||||
AlertBadCertificateHashValue: "bad certificate hash value",
|
||||
AlertUnknownPSKIdentity: "unknown PSK identity",
|
||||
AlertNoApplicationProtocol: "no application protocol",
|
||||
AlertNoRenegotiation: "no renegotiation",
|
||||
AlertWouldBlock: "would have blocked",
|
||||
AlertNoAlert: "no alert",
|
||||
}
|
||||
|
||||
func (e Alert) String() string {
|
||||
s, ok := alertText[e]
|
||||
if ok {
|
||||
return s
|
||||
}
|
||||
return "alert(" + strconv.Itoa(int(e)) + ")"
|
||||
}
|
||||
|
||||
func (e Alert) Error() string {
|
||||
return e.String()
|
||||
}
|
42
vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go
generated
vendored
Normal file
42
vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go
generated
vendored
Normal file
|
@ -0,0 +1,42 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
)
|
||||
|
||||
var url string
|
||||
|
||||
func main() {
|
||||
url := flag.String("url", "https://localhost:4430", "URL to send request")
|
||||
flag.Parse()
|
||||
mintdial := func(network, addr string) (net.Conn, error) {
|
||||
return mint.Dial(network, addr, nil)
|
||||
}
|
||||
|
||||
tr := &http.Transport{
|
||||
DialTLS: mintdial,
|
||||
DisableCompression: true,
|
||||
}
|
||||
client := &http.Client{Transport: tr}
|
||||
|
||||
response, err := client.Get(*url)
|
||||
if err != nil {
|
||||
fmt.Println("err:", err)
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
contents, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
fmt.Printf("%s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("%s\n", string(contents))
|
||||
}
|
37
vendor/github.com/bifurcation/mint/bin/mint-client/main.go
generated
vendored
Normal file
37
vendor/github.com/bifurcation/mint/bin/mint-client/main.go
generated
vendored
Normal file
|
@ -0,0 +1,37 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
)
|
||||
|
||||
var addr string
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&addr, "addr", "localhost:4430", "port")
|
||||
flag.Parse()
|
||||
|
||||
conn, err := mint.Dial("tcp", addr, nil)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("TLS handshake failed:", err)
|
||||
return
|
||||
}
|
||||
|
||||
request := "GET / HTTP/1.0\r\n\r\n"
|
||||
conn.Write([]byte(request))
|
||||
|
||||
response := ""
|
||||
buffer := make([]byte, 1024)
|
||||
var read int
|
||||
for err == nil {
|
||||
read, err = conn.Read(buffer)
|
||||
fmt.Println(" ~~ read: ", read)
|
||||
response += string(buffer)
|
||||
}
|
||||
fmt.Println("err:", err)
|
||||
fmt.Println("Received from server:")
|
||||
fmt.Println(response)
|
||||
}
|
226
vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go
generated
vendored
Normal file
226
vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go
generated
vendored
Normal file
|
@ -0,0 +1,226 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
port string
|
||||
serverName string
|
||||
certFile string
|
||||
keyFile string
|
||||
responseFile string
|
||||
h2 bool
|
||||
sendTickets bool
|
||||
)
|
||||
|
||||
type responder []byte
|
||||
|
||||
func (rsp responder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(rsp)
|
||||
}
|
||||
|
||||
// ParsePrivateKeyDER parses a PKCS #1, PKCS #8, or elliptic curve
|
||||
// PEM-encoded private key.
|
||||
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) {
|
||||
keyDER, _ := pem.Decode(keyPEM)
|
||||
if keyDER == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER.Bytes)
|
||||
if err != nil {
|
||||
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER.Bytes)
|
||||
if err != nil {
|
||||
generalKey, err = x509.ParseECPrivateKey(keyDER.Bytes)
|
||||
if err != nil {
|
||||
// We don't include the actual error into
|
||||
// the final error. The reason might be
|
||||
// we don't want to leak any info about
|
||||
// the private key.
|
||||
return nil, fmt.Errorf("No successful private key decoder")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch generalKey.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return generalKey.(*rsa.PrivateKey), nil
|
||||
case *ecdsa.PrivateKey:
|
||||
return generalKey.(*ecdsa.PrivateKey), nil
|
||||
}
|
||||
|
||||
// should never reach here
|
||||
return nil, fmt.Errorf("Should be unreachable")
|
||||
}
|
||||
|
||||
// ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object,
|
||||
// either a raw x509 certificate or a PKCS #7 structure possibly containing
|
||||
// multiple certificates, from the top of certsPEM, which itself may
|
||||
// contain multiple PEM encoded certificate objects.
|
||||
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||
func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) {
|
||||
block, rest := pem.Decode(certsPEM)
|
||||
if block == nil {
|
||||
return nil, rest, nil
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
var certs = []*x509.Certificate{cert}
|
||||
return certs, rest, err
|
||||
}
|
||||
|
||||
// ParseCertificatesPEM parses a sequence of PEM-encoded certificate and returns them,
|
||||
// can handle PEM encoded PKCS #7 structures.
|
||||
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||
func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
|
||||
var certs []*x509.Certificate
|
||||
var err error
|
||||
certsPEM = bytes.TrimSpace(certsPEM)
|
||||
for len(certsPEM) > 0 {
|
||||
var cert []*x509.Certificate
|
||||
cert, certsPEM, err = ParseOneCertificateFromPEM(certsPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if cert == nil {
|
||||
break
|
||||
}
|
||||
|
||||
certs = append(certs, cert...)
|
||||
}
|
||||
if len(certsPEM) > 0 {
|
||||
return nil, fmt.Errorf("Trailing PEM data")
|
||||
}
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&port, "port", "4430", "port")
|
||||
flag.StringVar(&serverName, "host", "example.com", "hostname")
|
||||
flag.StringVar(&certFile, "cert", "", "certificate chain in PEM or DER")
|
||||
flag.StringVar(&keyFile, "key", "", "private key in PEM format")
|
||||
flag.StringVar(&responseFile, "response", "", "file to serve")
|
||||
flag.BoolVar(&h2, "h2", false, "whether to use HTTP/2 (exclusively)")
|
||||
flag.BoolVar(&sendTickets, "tickets", true, "whether to send session tickets")
|
||||
flag.Parse()
|
||||
|
||||
var certChain []*x509.Certificate
|
||||
var priv crypto.Signer
|
||||
var response []byte
|
||||
var err error
|
||||
|
||||
// Load the key and certificate chain
|
||||
if certFile != "" {
|
||||
certs, err := ioutil.ReadFile(certFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
} else {
|
||||
certChain, err = ParseCertificatesPEM(certs)
|
||||
if err != nil {
|
||||
certChain, err = x509.ParseCertificates(certs)
|
||||
if err != nil {
|
||||
log.Fatalf("Error parsing certificates: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if keyFile != "" {
|
||||
keyPEM, err := ioutil.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
} else {
|
||||
priv, err = ParsePrivateKeyPEM(keyPEM)
|
||||
if priv == nil || err != nil {
|
||||
log.Fatalf("Error parsing private key: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
}
|
||||
|
||||
// Load response file
|
||||
if responseFile != "" {
|
||||
log.Printf("Loading response file: %v", responseFile)
|
||||
response, err = ioutil.ReadFile(responseFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
}
|
||||
} else {
|
||||
response = []byte("Welcome to the TLS 1.3 zone!")
|
||||
}
|
||||
handler := responder(response)
|
||||
|
||||
config := mint.Config{
|
||||
SendSessionTickets: true,
|
||||
ServerName: serverName,
|
||||
NextProtos: []string{"http/1.1"},
|
||||
}
|
||||
|
||||
if h2 {
|
||||
config.NextProtos = []string{"h2"}
|
||||
}
|
||||
|
||||
config.SendSessionTickets = sendTickets
|
||||
|
||||
if certChain != nil && priv != nil {
|
||||
log.Printf("Loading cert: %v key: %v", certFile, keyFile)
|
||||
config.Certificates = []*mint.Certificate{
|
||||
{
|
||||
Chain: certChain,
|
||||
PrivateKey: priv,
|
||||
},
|
||||
}
|
||||
}
|
||||
config.Init(false)
|
||||
|
||||
service := "0.0.0.0:" + port
|
||||
srv := &http.Server{Handler: handler}
|
||||
|
||||
log.Printf("Listening on port %v", port)
|
||||
// Need the inner loop here because the h1 server errors on a dropped connection
|
||||
// Need the outer loop here because the h2 server is per-connection
|
||||
for {
|
||||
listener, err := mint.Listen("tcp", service, &config)
|
||||
if err != nil {
|
||||
log.Printf("Listen Error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !h2 {
|
||||
alert := srv.Serve(listener)
|
||||
if alert != mint.AlertNoAlert {
|
||||
log.Printf("Serve Error: %v", err)
|
||||
}
|
||||
} else {
|
||||
srv2 := new(http2.Server)
|
||||
opts := &http2.ServeConnOpts{
|
||||
Handler: handler,
|
||||
BaseConfig: srv,
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("Accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
go srv2.ServeConn(conn, opts)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
65
vendor/github.com/bifurcation/mint/bin/mint-server/main.go
generated
vendored
Normal file
65
vendor/github.com/bifurcation/mint/bin/mint-server/main.go
generated
vendored
Normal file
|
@ -0,0 +1,65 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
)
|
||||
|
||||
var port string
|
||||
|
||||
func main() {
|
||||
var config mint.Config
|
||||
config.SendSessionTickets = true
|
||||
config.ServerName = "localhost"
|
||||
config.Init(false)
|
||||
|
||||
flag.StringVar(&port, "port", "4430", "port")
|
||||
flag.Parse()
|
||||
|
||||
service := "0.0.0.0:" + port
|
||||
listener, err := mint.Listen("tcp", service, &config)
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("server: listen: %s", err)
|
||||
}
|
||||
log.Print("server: listening")
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("server: accept: %s", err)
|
||||
break
|
||||
}
|
||||
defer conn.Close()
|
||||
log.Printf("server: accepted from %s", conn.RemoteAddr())
|
||||
go handleClient(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func handleClient(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 10)
|
||||
for {
|
||||
log.Print("server: conn: waiting")
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err != nil {
|
||||
log.Printf("server: conn: read: %s", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
n, err = conn.Write([]byte("hello world"))
|
||||
log.Printf("server: conn: wrote %d bytes", n)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("server: write: %s", err)
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
log.Println("server: conn: closed")
|
||||
}
|
942
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
Normal file
942
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
Normal file
|
@ -0,0 +1,942 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"hash"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client State Machine
|
||||
//
|
||||
// START <----+
|
||||
// Send ClientHello | | Recv HelloRetryRequest
|
||||
// / v |
|
||||
// | WAIT_SH ---+
|
||||
// Can | | Recv ServerHello
|
||||
// send | V
|
||||
// early | WAIT_EE
|
||||
// data | | Recv EncryptedExtensions
|
||||
// | +--------+--------+
|
||||
// | Using | | Using certificate
|
||||
// | PSK | v
|
||||
// | | WAIT_CERT_CR
|
||||
// | | Recv | | Recv CertificateRequest
|
||||
// | | Certificate | v
|
||||
// | | | WAIT_CERT
|
||||
// | | | | Recv Certificate
|
||||
// | | v v
|
||||
// | | WAIT_CV
|
||||
// | | | Recv CertificateVerify
|
||||
// | +> WAIT_FINISHED <+
|
||||
// | | Recv Finished
|
||||
// \ |
|
||||
// | [Send EndOfEarlyData]
|
||||
// | [Send Certificate [+ CertificateVerify]]
|
||||
// | Send Finished
|
||||
// Can send v
|
||||
// app data --> CONNECTED
|
||||
// after
|
||||
// here
|
||||
//
|
||||
// State Instructions
|
||||
// START Send(CH); [RekeyOut; SendEarlyData]
|
||||
// WAIT_SH Send(CH) || RekeyIn
|
||||
// WAIT_EE {}
|
||||
// WAIT_CERT_CR {}
|
||||
// WAIT_CERT {}
|
||||
// WAIT_CV {}
|
||||
// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut;
|
||||
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
||||
|
||||
type ClientStateStart struct {
|
||||
Caps Capabilities
|
||||
Opts ConnectionOptions
|
||||
Params ConnectionParameters
|
||||
|
||||
cookie []byte
|
||||
firstClientHello *HandshakeMessage
|
||||
helloRetryRequest *HandshakeMessage
|
||||
}
|
||||
|
||||
func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
// key_shares
|
||||
offeredDH := map[NamedGroup][]byte{}
|
||||
ks := KeyShareExtension{
|
||||
HandshakeType: HandshakeTypeClientHello,
|
||||
Shares: make([]KeyShareEntry, len(state.Caps.Groups)),
|
||||
}
|
||||
for i, group := range state.Caps.Groups {
|
||||
pub, priv, err := newKeyShare(group)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
ks.Shares[i].Group = group
|
||||
ks.Shares[i].KeyExchange = pub
|
||||
offeredDH[group] = priv
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "opts: %+v", state.Opts)
|
||||
|
||||
// supported_versions, supported_groups, signature_algorithms, server_name
|
||||
sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}}
|
||||
sni := ServerNameExtension(state.Opts.ServerName)
|
||||
sg := SupportedGroupsExtension{Groups: state.Caps.Groups}
|
||||
sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
|
||||
|
||||
state.Params.ServerName = state.Opts.ServerName
|
||||
|
||||
// Application Layer Protocol Negotiation
|
||||
var alpn *ALPNExtension
|
||||
if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) {
|
||||
alpn = &ALPNExtension{Protocols: state.Opts.NextProtos}
|
||||
}
|
||||
|
||||
// Construct base ClientHello
|
||||
ch := &ClientHelloBody{
|
||||
CipherSuites: state.Caps.CipherSuites,
|
||||
}
|
||||
_, err := prng.Read(ch.Random[:])
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} {
|
||||
err := ch.Extensions.Add(ext)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
// XXX: These optional extensions can't be folded into the above because Go
|
||||
// interface-typed values are never reported as nil
|
||||
if alpn != nil {
|
||||
err := ch.Extensions.Add(alpn)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
if state.cookie != nil {
|
||||
err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Run the external extension handler.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Handle PSK and EarlyData just before transmitting, so that we can
|
||||
// calculate the PSK binder value
|
||||
var psk *PreSharedKeyExtension
|
||||
var ed *EarlyDataExtension
|
||||
var offeredPSK PreSharedKey
|
||||
var earlyHash crypto.Hash
|
||||
var earlySecret []byte
|
||||
var clientEarlyTrafficKeys keySet
|
||||
var clientHello *HandshakeMessage
|
||||
if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok {
|
||||
offeredPSK = key
|
||||
|
||||
// Narrow ciphersuites to ones that match PSK hash
|
||||
params, ok := cipherSuiteMap[key.CipherSuite]
|
||||
if !ok {
|
||||
logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite")
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
compatibleSuites := []CipherSuite{}
|
||||
for _, suite := range ch.CipherSuites {
|
||||
if cipherSuiteMap[suite].Hash == params.Hash {
|
||||
compatibleSuites = append(compatibleSuites, suite)
|
||||
}
|
||||
}
|
||||
ch.CipherSuites = compatibleSuites
|
||||
|
||||
// Signal early data if we're going to do it
|
||||
if len(state.Opts.EarlyData) > 0 {
|
||||
state.Params.ClientSendingEarlyData = true
|
||||
ed = &EarlyDataExtension{}
|
||||
err = ch.Extensions.Add(ed)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "Error adding early data extension: %v", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Signal supported PSK key exchange modes
|
||||
if len(state.Caps.PSKModes) == 0 {
|
||||
logf(logTypeHandshake, "PSK selected, but no PSKModes")
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes}
|
||||
err = ch.Extensions.Add(kem)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
// Add the shim PSK extension to the ClientHello
|
||||
logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity)
|
||||
psk = &PreSharedKeyExtension{
|
||||
HandshakeType: HandshakeTypeClientHello,
|
||||
Identities: []PSKIdentity{
|
||||
{
|
||||
Identity: key.Identity,
|
||||
ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd,
|
||||
},
|
||||
},
|
||||
Binders: []PSKBinderEntry{
|
||||
// Note: Stub to get the length fields right
|
||||
{Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())},
|
||||
},
|
||||
}
|
||||
ch.Extensions.Add(psk)
|
||||
|
||||
// Compute the binder key
|
||||
h0 := params.Hash.New().Sum(nil)
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
|
||||
earlyHash = params.Hash
|
||||
earlySecret = HkdfExtract(params.Hash, zero, key.Key)
|
||||
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
|
||||
|
||||
binderLabel := labelExternalBinder
|
||||
if key.IsResumption {
|
||||
binderLabel = labelResumptionBinder
|
||||
}
|
||||
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
|
||||
logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey)
|
||||
|
||||
// Compute the binder value
|
||||
trunc, err := ch.Truncated()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
truncHash := params.Hash.New()
|
||||
truncHash.Write(trunc)
|
||||
|
||||
binder := computeFinishedData(params, binderKey, truncHash.Sum(nil))
|
||||
|
||||
// Replace the PSK extension
|
||||
psk.Binders[0].Binder = binder
|
||||
ch.Extensions.Add(psk)
|
||||
|
||||
// If we got here, the earlier marshal succeeded (in ch.Truncated()), so
|
||||
// this one should too.
|
||||
clientHello, _ = HandshakeMessageFromBody(ch)
|
||||
|
||||
// Compute early traffic keys
|
||||
h := params.Hash.New()
|
||||
h.Write(clientHello.Marshal())
|
||||
chHash := h.Sum(nil)
|
||||
|
||||
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
|
||||
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
|
||||
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
|
||||
} else if len(state.Opts.EarlyData) > 0 {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK")
|
||||
return nil, nil, AlertInternalError
|
||||
} else {
|
||||
clientHello, err = HandshakeMessageFromBody(ch)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]")
|
||||
nextState := ClientStateWaitSH{
|
||||
Caps: state.Caps,
|
||||
Opts: state.Opts,
|
||||
Params: state.Params,
|
||||
OfferedDH: offeredDH,
|
||||
OfferedPSK: offeredPSK,
|
||||
|
||||
earlySecret: earlySecret,
|
||||
earlyHash: earlyHash,
|
||||
|
||||
firstClientHello: state.firstClientHello,
|
||||
helloRetryRequest: state.helloRetryRequest,
|
||||
clientHello: clientHello,
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
SendHandshakeMessage{clientHello},
|
||||
}
|
||||
if state.Params.ClientSendingEarlyData {
|
||||
toSend = append(toSend, []HandshakeAction{
|
||||
RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys},
|
||||
SendEarlyData{},
|
||||
}...)
|
||||
}
|
||||
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
type ClientStateWaitSH struct {
|
||||
Caps Capabilities
|
||||
Opts ConnectionOptions
|
||||
Params ConnectionParameters
|
||||
OfferedDH map[NamedGroup][]byte
|
||||
OfferedPSK PreSharedKey
|
||||
PSK []byte
|
||||
|
||||
earlySecret []byte
|
||||
earlyHash crypto.Hash
|
||||
|
||||
firstClientHello *HandshakeMessage
|
||||
helloRetryRequest *HandshakeMessage
|
||||
clientHello *HandshakeMessage
|
||||
}
|
||||
|
||||
func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
bodyGeneric, err := hm.ToBody()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
switch body := bodyGeneric.(type) {
|
||||
case *HelloRetryRequestBody:
|
||||
hrr := body
|
||||
|
||||
if state.helloRetryRequest != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
// Check that the version sent by the server is the one we support
|
||||
if hrr.Version != supportedVersion {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version)
|
||||
return nil, nil, AlertProtocolVersion
|
||||
}
|
||||
|
||||
// Check that the server provided a supported ciphersuite
|
||||
supportedCipherSuite := false
|
||||
for _, suite := range state.Caps.CipherSuites {
|
||||
supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite)
|
||||
}
|
||||
if !supportedCipherSuite {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Narrow the supported ciphersuites to the server-provided one
|
||||
state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite}
|
||||
|
||||
// Handle external extensions.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// The only thing we know how to respond to in an HRR is the Cookie
|
||||
// extension, so if there is either no Cookie extension or anything other
|
||||
// than a Cookie extension, we have to fail.
|
||||
serverCookie := new(CookieExtension)
|
||||
foundCookie := hrr.Extensions.Find(serverCookie)
|
||||
if !foundCookie || len(hrr.Extensions) != 1 {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions))
|
||||
return nil, nil, AlertIllegalParameter
|
||||
}
|
||||
|
||||
// Hash the body into a pseudo-message
|
||||
// XXX: Ignoring some errors here
|
||||
params := cipherSuiteMap[hrr.CipherSuite]
|
||||
h := params.Hash.New()
|
||||
h.Write(state.clientHello.Marshal())
|
||||
firstClientHello := &HandshakeMessage{
|
||||
msgType: HandshakeTypeMessageHash,
|
||||
body: h.Sum(nil),
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
|
||||
return ClientStateStart{
|
||||
Caps: state.Caps,
|
||||
Opts: state.Opts,
|
||||
cookie: serverCookie.Cookie,
|
||||
firstClientHello: firstClientHello,
|
||||
helloRetryRequest: hm,
|
||||
}.Next(nil)
|
||||
|
||||
case *ServerHelloBody:
|
||||
sh := body
|
||||
|
||||
// Check that the version sent by the server is the one we support
|
||||
if sh.Version != supportedVersion {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version)
|
||||
return nil, nil, AlertProtocolVersion
|
||||
}
|
||||
|
||||
// Check that the server provided a supported ciphersuite
|
||||
supportedCipherSuite := false
|
||||
for _, suite := range state.Caps.CipherSuites {
|
||||
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
|
||||
}
|
||||
if !supportedCipherSuite {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Handle external extensions.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Do PSK or key agreement depending on extensions
|
||||
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello}
|
||||
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello}
|
||||
|
||||
foundPSK := sh.Extensions.Find(&serverPSK)
|
||||
foundKeyShare := sh.Extensions.Find(&serverKeyShare)
|
||||
|
||||
if foundPSK && (serverPSK.SelectedIdentity == 0) {
|
||||
state.Params.UsingPSK = true
|
||||
}
|
||||
|
||||
var dhSecret []byte
|
||||
if foundKeyShare {
|
||||
sks := serverKeyShare.Shares[0]
|
||||
priv, ok := state.OfferedDH[sks.Group]
|
||||
if !ok {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group")
|
||||
return nil, nil, AlertIllegalParameter
|
||||
}
|
||||
|
||||
state.Params.UsingDH = true
|
||||
dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv)
|
||||
}
|
||||
|
||||
suite := sh.CipherSuite
|
||||
state.Params.CipherSuite = suite
|
||||
|
||||
params, ok := cipherSuiteMap[suite]
|
||||
if !ok {
|
||||
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Start up the handshake hash
|
||||
handshakeHash := params.Hash.New()
|
||||
handshakeHash.Write(state.firstClientHello.Marshal())
|
||||
handshakeHash.Write(state.helloRetryRequest.Marshal())
|
||||
handshakeHash.Write(state.clientHello.Marshal())
|
||||
handshakeHash.Write(hm.Marshal())
|
||||
|
||||
// Compute handshake secrets
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
|
||||
var earlySecret []byte
|
||||
if state.Params.UsingPSK {
|
||||
if params.Hash != state.earlyHash {
|
||||
logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]",
|
||||
state.earlyHash, suite, params.Hash)
|
||||
}
|
||||
|
||||
earlySecret = state.earlySecret
|
||||
} else {
|
||||
earlySecret = HkdfExtract(params.Hash, zero, zero)
|
||||
}
|
||||
|
||||
if dhSecret == nil {
|
||||
dhSecret = zero
|
||||
}
|
||||
|
||||
h0 := params.Hash.New().Sum(nil)
|
||||
h2 := handshakeHash.Sum(nil)
|
||||
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
|
||||
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret)
|
||||
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
|
||||
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
|
||||
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
|
||||
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
|
||||
|
||||
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
|
||||
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
|
||||
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
|
||||
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
|
||||
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
|
||||
|
||||
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
|
||||
nextState := ClientStateWaitEE{
|
||||
Caps: state.Caps,
|
||||
Params: state.Params,
|
||||
cryptoParams: params,
|
||||
handshakeHash: handshakeHash,
|
||||
certificates: state.Caps.Certificates,
|
||||
masterSecret: masterSecret,
|
||||
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
|
||||
}
|
||||
toSend := []HandshakeAction{
|
||||
RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys},
|
||||
}
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType)
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
type ClientStateWaitEE struct {
|
||||
Caps Capabilities
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
certificates []*Certificate
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions {
|
||||
logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
ee := EncryptedExtensionsBody{}
|
||||
_, err := ee.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
// Handle external extensions.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
serverALPN := ALPNExtension{}
|
||||
serverEarlyData := EarlyDataExtension{}
|
||||
|
||||
gotALPN := ee.Extensions.Find(&serverALPN)
|
||||
state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData)
|
||||
|
||||
if gotALPN && len(serverALPN.Protocols) > 0 {
|
||||
state.Params.NextProto = serverALPN.Protocols[0]
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
if state.Params.UsingPSK {
|
||||
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
|
||||
nextState := ClientStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
|
||||
nextState := ClientStateWaitCertCR{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ClientStateWaitCertCR struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
certificates []*Certificate
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
bodyGeneric, err := hm.ToBody()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
switch body := bodyGeneric.(type) {
|
||||
case *CertificateBody:
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]")
|
||||
nextState := ClientStateWaitCV{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
serverCertificate: body,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
|
||||
case *CertificateRequestBody:
|
||||
// A certificate request in the handshake should have a zero-length context
|
||||
if len(body.CertificateRequestContext) > 0 {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err)
|
||||
return nil, nil, AlertIllegalParameter
|
||||
}
|
||||
|
||||
state.Params.UsingClientAuth = true
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]")
|
||||
nextState := ClientStateWaitCert{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
serverCertificateRequest: body,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
type ClientStateWaitCert struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
|
||||
certificates []*Certificate
|
||||
serverCertificateRequest *CertificateRequestBody
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeCertificate {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
cert := &CertificateBody{}
|
||||
_, err := cert.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]")
|
||||
nextState := ClientStateWaitCV{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
serverCertificate: cert,
|
||||
serverCertificateRequest: state.serverCertificateRequest,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ClientStateWaitCV struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
|
||||
certificates []*Certificate
|
||||
serverCertificate *CertificateBody
|
||||
serverCertificateRequest *CertificateRequestBody
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
certVerify := CertificateVerifyBody{}
|
||||
_, err := certVerify.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
hcv := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||
|
||||
serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey
|
||||
if err := certVerify.Verify(serverPublicKey, hcv); err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify")
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
if state.AuthCertificate != nil {
|
||||
err := state.AuthCertificate(state.serverCertificate.CertificateList)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate")
|
||||
return nil, nil, AlertBadCertificate
|
||||
}
|
||||
} else {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate")
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]")
|
||||
nextState := ClientStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
serverCertificateRequest: state.serverCertificateRequest,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ClientStateWaitFinished struct {
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
|
||||
certificates []*Certificate
|
||||
serverCertificateRequest *CertificateRequestBody
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeFinished {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
// Verify server's Finished
|
||||
h3 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
|
||||
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
|
||||
|
||||
serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3)
|
||||
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
|
||||
|
||||
fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)}
|
||||
_, err := fin.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
if !bytes.Equal(fin.VerifyData, serverFinishedData) {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]",
|
||||
fin.VerifyData, serverFinishedData)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Update the handshake hash with the Finished
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal())
|
||||
h4 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4)
|
||||
|
||||
// Compute traffic secrets and keys
|
||||
clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4)
|
||||
serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4)
|
||||
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
|
||||
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
|
||||
|
||||
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret)
|
||||
serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret)
|
||||
|
||||
exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4)
|
||||
logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
|
||||
|
||||
// Assemble client's second flight
|
||||
toSend := []HandshakeAction{}
|
||||
|
||||
if state.Params.UsingEarlyData {
|
||||
// Note: We only send EOED if the server is actually going to use the early
|
||||
// data. Otherwise, it will never see it, and the transcripts will
|
||||
// mismatch.
|
||||
// EOED marshal is infallible
|
||||
eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{})
|
||||
toSend = append(toSend, SendHandshakeMessage{eoedm})
|
||||
state.handshakeHash.Write(eoedm.Marshal())
|
||||
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
|
||||
}
|
||||
|
||||
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
|
||||
toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys})
|
||||
|
||||
if state.Params.UsingClientAuth {
|
||||
// Extract constraints from certicateRequest
|
||||
schemes := SignatureAlgorithmsExtension{}
|
||||
gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes)
|
||||
if !gotSchemes {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
|
||||
return nil, nil, AlertIllegalParameter
|
||||
}
|
||||
|
||||
// Select a certificate
|
||||
cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates)
|
||||
if err != nil {
|
||||
// XXX: Signal this to the application layer?
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
|
||||
|
||||
certificate := &CertificateBody{}
|
||||
certm, err := HandshakeMessageFromBody(certificate)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||
state.handshakeHash.Write(certm.Marshal())
|
||||
} else {
|
||||
// Create and send Certificate, CertificateVerify
|
||||
certificate := &CertificateBody{
|
||||
CertificateList: make([]CertificateEntry, len(cert.Chain)),
|
||||
}
|
||||
for i, entry := range cert.Chain {
|
||||
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
|
||||
}
|
||||
certm, err := HandshakeMessageFromBody(certificate)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||
state.handshakeHash.Write(certm.Marshal())
|
||||
|
||||
hcv := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||
|
||||
certificateVerify := &CertificateVerifyBody{Algorithm: certScheme}
|
||||
logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash)
|
||||
|
||||
err = certificateVerify.Sign(cert.PrivateKey, hcv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
certvm, err := HandshakeMessageFromBody(certificateVerify)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certvm})
|
||||
state.handshakeHash.Write(certvm.Marshal())
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the client's Finished message
|
||||
h5 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
|
||||
|
||||
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
|
||||
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
|
||||
|
||||
fin = &FinishedBody{
|
||||
VerifyDataLen: len(clientFinishedData),
|
||||
VerifyData: clientFinishedData,
|
||||
}
|
||||
finm, err := HandshakeMessageFromBody(fin)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
// Compute the resumption secret
|
||||
state.handshakeHash.Write(finm.Marshal())
|
||||
h6 := state.handshakeHash.Sum(nil)
|
||||
|
||||
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
|
||||
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
|
||||
|
||||
toSend = append(toSend, []HandshakeAction{
|
||||
SendHandshakeMessage{finm},
|
||||
RekeyIn{Label: "application", KeySet: serverTrafficKeys},
|
||||
RekeyOut{Label: "application", KeySet: clientTrafficKeys},
|
||||
}...)
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
|
||||
nextState := StateConnected{
|
||||
Params: state.Params,
|
||||
isClient: true,
|
||||
cryptoParams: state.cryptoParams,
|
||||
resumptionSecret: resumptionSecret,
|
||||
clientTrafficSecret: clientTrafficSecret,
|
||||
serverTrafficSecret: serverTrafficSecret,
|
||||
exporterSecret: exporterSecret,
|
||||
}
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
152
vendor/github.com/bifurcation/mint/common.go
generated
vendored
Normal file
152
vendor/github.com/bifurcation/mint/common.go
generated
vendored
Normal file
|
@ -0,0 +1,152 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
supportedVersion uint16 = 0x7f15 // draft-21
|
||||
|
||||
// Flags for some minor compat issues
|
||||
allowWrongVersionNumber = true
|
||||
allowPKCS1 = true
|
||||
)
|
||||
|
||||
// enum {...} ContentType;
|
||||
type RecordType byte
|
||||
|
||||
const (
|
||||
RecordTypeAlert RecordType = 21
|
||||
RecordTypeHandshake RecordType = 22
|
||||
RecordTypeApplicationData RecordType = 23
|
||||
)
|
||||
|
||||
// enum {...} HandshakeType;
|
||||
type HandshakeType byte
|
||||
|
||||
const (
|
||||
// Omitted: *_RESERVED
|
||||
HandshakeTypeClientHello HandshakeType = 1
|
||||
HandshakeTypeServerHello HandshakeType = 2
|
||||
HandshakeTypeNewSessionTicket HandshakeType = 4
|
||||
HandshakeTypeEndOfEarlyData HandshakeType = 5
|
||||
HandshakeTypeHelloRetryRequest HandshakeType = 6
|
||||
HandshakeTypeEncryptedExtensions HandshakeType = 8
|
||||
HandshakeTypeCertificate HandshakeType = 11
|
||||
HandshakeTypeCertificateRequest HandshakeType = 13
|
||||
HandshakeTypeCertificateVerify HandshakeType = 15
|
||||
HandshakeTypeServerConfiguration HandshakeType = 17
|
||||
HandshakeTypeFinished HandshakeType = 20
|
||||
HandshakeTypeKeyUpdate HandshakeType = 24
|
||||
HandshakeTypeMessageHash HandshakeType = 254
|
||||
)
|
||||
|
||||
// uint8 CipherSuite[2];
|
||||
type CipherSuite uint16
|
||||
|
||||
const (
|
||||
// XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero
|
||||
// value for this type so that we can detect when a field is set.
|
||||
CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000
|
||||
TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301
|
||||
TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302
|
||||
TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303
|
||||
TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304
|
||||
TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305
|
||||
)
|
||||
|
||||
func (c CipherSuite) String() string {
|
||||
switch c {
|
||||
case CIPHER_SUITE_UNKNOWN:
|
||||
return "unknown"
|
||||
case TLS_AES_128_GCM_SHA256:
|
||||
return "TLS_AES_128_GCM_SHA256"
|
||||
case TLS_AES_256_GCM_SHA384:
|
||||
return "TLS_AES_256_GCM_SHA384"
|
||||
case TLS_CHACHA20_POLY1305_SHA256:
|
||||
return "TLS_CHACHA20_POLY1305_SHA256"
|
||||
case TLS_AES_128_CCM_SHA256:
|
||||
return "TLS_AES_128_CCM_SHA256"
|
||||
case TLS_AES_256_CCM_8_SHA256:
|
||||
return "TLS_AES_256_CCM_8_SHA256"
|
||||
}
|
||||
// cannot use %x here, since it calls String(), leading to infinite recursion
|
||||
return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16))
|
||||
}
|
||||
|
||||
// enum {...} SignatureScheme
|
||||
type SignatureScheme uint16
|
||||
|
||||
const (
|
||||
// RSASSA-PKCS1-v1_5 algorithms
|
||||
RSA_PKCS1_SHA1 SignatureScheme = 0x0201
|
||||
RSA_PKCS1_SHA256 SignatureScheme = 0x0401
|
||||
RSA_PKCS1_SHA384 SignatureScheme = 0x0501
|
||||
RSA_PKCS1_SHA512 SignatureScheme = 0x0601
|
||||
// ECDSA algorithms
|
||||
ECDSA_P256_SHA256 SignatureScheme = 0x0403
|
||||
ECDSA_P384_SHA384 SignatureScheme = 0x0503
|
||||
ECDSA_P521_SHA512 SignatureScheme = 0x0603
|
||||
// RSASSA-PSS algorithms
|
||||
RSA_PSS_SHA256 SignatureScheme = 0x0804
|
||||
RSA_PSS_SHA384 SignatureScheme = 0x0805
|
||||
RSA_PSS_SHA512 SignatureScheme = 0x0806
|
||||
// EdDSA algorithms
|
||||
Ed25519 SignatureScheme = 0x0807
|
||||
Ed448 SignatureScheme = 0x0808
|
||||
)
|
||||
|
||||
// enum {...} ExtensionType
|
||||
type ExtensionType uint16
|
||||
|
||||
const (
|
||||
ExtensionTypeServerName ExtensionType = 0
|
||||
ExtensionTypeSupportedGroups ExtensionType = 10
|
||||
ExtensionTypeSignatureAlgorithms ExtensionType = 13
|
||||
ExtensionTypeALPN ExtensionType = 16
|
||||
ExtensionTypeKeyShare ExtensionType = 40
|
||||
ExtensionTypePreSharedKey ExtensionType = 41
|
||||
ExtensionTypeEarlyData ExtensionType = 42
|
||||
ExtensionTypeSupportedVersions ExtensionType = 43
|
||||
ExtensionTypeCookie ExtensionType = 44
|
||||
ExtensionTypePSKKeyExchangeModes ExtensionType = 45
|
||||
ExtensionTypeTicketEarlyDataInfo ExtensionType = 46
|
||||
)
|
||||
|
||||
// enum {...} NamedGroup
|
||||
type NamedGroup uint16
|
||||
|
||||
const (
|
||||
// Elliptic Curve Groups.
|
||||
P256 NamedGroup = 23
|
||||
P384 NamedGroup = 24
|
||||
P521 NamedGroup = 25
|
||||
// ECDH functions.
|
||||
X25519 NamedGroup = 29
|
||||
X448 NamedGroup = 30
|
||||
// Finite field groups.
|
||||
FFDHE2048 NamedGroup = 256
|
||||
FFDHE3072 NamedGroup = 257
|
||||
FFDHE4096 NamedGroup = 258
|
||||
FFDHE6144 NamedGroup = 259
|
||||
FFDHE8192 NamedGroup = 260
|
||||
)
|
||||
|
||||
// enum {...} PskKeyExchangeMode;
|
||||
type PSKKeyExchangeMode uint8
|
||||
|
||||
const (
|
||||
PSKModeKE PSKKeyExchangeMode = 0
|
||||
PSKModeDHEKE PSKKeyExchangeMode = 1
|
||||
)
|
||||
|
||||
// enum {
|
||||
// update_not_requested(0), update_requested(1), (255)
|
||||
// } KeyUpdateRequest;
|
||||
type KeyUpdateRequest uint8
|
||||
|
||||
const (
|
||||
KeyUpdateNotRequested KeyUpdateRequest = 0
|
||||
KeyUpdateRequested KeyUpdateRequest = 1
|
||||
)
|
819
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
Normal file
819
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
Normal file
|
@ -0,0 +1,819 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var WouldBlock = fmt.Errorf("Would have blocked")
|
||||
|
||||
type Certificate struct {
|
||||
Chain []*x509.Certificate
|
||||
PrivateKey crypto.Signer
|
||||
}
|
||||
|
||||
type PreSharedKey struct {
|
||||
CipherSuite CipherSuite
|
||||
IsResumption bool
|
||||
Identity []byte
|
||||
Key []byte
|
||||
NextProto string
|
||||
ReceivedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
TicketAgeAdd uint32
|
||||
}
|
||||
|
||||
type PreSharedKeyCache interface {
|
||||
Get(string) (PreSharedKey, bool)
|
||||
Put(string, PreSharedKey)
|
||||
Size() int
|
||||
}
|
||||
|
||||
type PSKMapCache map[string]PreSharedKey
|
||||
|
||||
// A CookieHandler does two things:
|
||||
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
|
||||
// - validates this byte string echoed by the client in the ClientHello
|
||||
type CookieHandler interface {
|
||||
Generate(*Conn) ([]byte, error)
|
||||
Validate(*Conn, []byte) bool
|
||||
}
|
||||
|
||||
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
|
||||
psk, ok = cache[key]
|
||||
return
|
||||
}
|
||||
|
||||
func (cache *PSKMapCache) Put(key string, psk PreSharedKey) {
|
||||
(*cache)[key] = psk
|
||||
}
|
||||
|
||||
func (cache PSKMapCache) Size() int {
|
||||
return len(cache)
|
||||
}
|
||||
|
||||
// Config is the struct used to pass configuration settings to a TLS client or
|
||||
// server instance. The settings for client and server are pretty different,
|
||||
// but we just throw them all in here.
|
||||
type Config struct {
|
||||
// Client fields
|
||||
ServerName string
|
||||
|
||||
// Server fields
|
||||
SendSessionTickets bool
|
||||
TicketLifetime uint32
|
||||
TicketLen int
|
||||
EarlyDataLifetime uint32
|
||||
AllowEarlyData bool
|
||||
// Require the client to echo a cookie.
|
||||
RequireCookie bool
|
||||
// If cookies are required and no CookieHandler is set, a default cookie handler is used.
|
||||
// The default cookie handler uses 32 random bytes as a cookie.
|
||||
CookieHandler CookieHandler
|
||||
RequireClientAuth bool
|
||||
|
||||
// Shared fields
|
||||
Certificates []*Certificate
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
CipherSuites []CipherSuite
|
||||
Groups []NamedGroup
|
||||
SignatureSchemes []SignatureScheme
|
||||
NextProtos []string
|
||||
PSKs PreSharedKeyCache
|
||||
PSKModes []PSKKeyExchangeMode
|
||||
NonBlocking bool
|
||||
|
||||
// The same config object can be shared among different connections, so it
|
||||
// needs its own mutex
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// Clone returns a shallow clone of c. It is safe to clone a Config that is
|
||||
// being used concurrently by a TLS client or server.
|
||||
func (c *Config) Clone() *Config {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
return &Config{
|
||||
ServerName: c.ServerName,
|
||||
|
||||
SendSessionTickets: c.SendSessionTickets,
|
||||
TicketLifetime: c.TicketLifetime,
|
||||
TicketLen: c.TicketLen,
|
||||
EarlyDataLifetime: c.EarlyDataLifetime,
|
||||
AllowEarlyData: c.AllowEarlyData,
|
||||
RequireCookie: c.RequireCookie,
|
||||
RequireClientAuth: c.RequireClientAuth,
|
||||
|
||||
Certificates: c.Certificates,
|
||||
AuthCertificate: c.AuthCertificate,
|
||||
CipherSuites: c.CipherSuites,
|
||||
Groups: c.Groups,
|
||||
SignatureSchemes: c.SignatureSchemes,
|
||||
NextProtos: c.NextProtos,
|
||||
PSKs: c.PSKs,
|
||||
PSKModes: c.PSKModes,
|
||||
NonBlocking: c.NonBlocking,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) Init(isClient bool) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Set defaults
|
||||
if len(c.CipherSuites) == 0 {
|
||||
c.CipherSuites = defaultSupportedCipherSuites
|
||||
}
|
||||
if len(c.Groups) == 0 {
|
||||
c.Groups = defaultSupportedGroups
|
||||
}
|
||||
if len(c.SignatureSchemes) == 0 {
|
||||
c.SignatureSchemes = defaultSignatureSchemes
|
||||
}
|
||||
if c.TicketLen == 0 {
|
||||
c.TicketLen = defaultTicketLen
|
||||
}
|
||||
if !reflect.ValueOf(c.PSKs).IsValid() {
|
||||
c.PSKs = &PSKMapCache{}
|
||||
}
|
||||
if len(c.PSKModes) == 0 {
|
||||
c.PSKModes = defaultPSKModes
|
||||
}
|
||||
|
||||
// If there is no certificate, generate one
|
||||
if !isClient && len(c.Certificates) == 0 {
|
||||
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
|
||||
priv, err := newSigningKey(RSA_PSS_SHA256)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Certificates = []*Certificate{
|
||||
{
|
||||
Chain: []*x509.Certificate{cert},
|
||||
PrivateKey: priv,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) ValidForServer() bool {
|
||||
return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) ||
|
||||
(len(c.Certificates) > 0 &&
|
||||
len(c.Certificates[0].Chain) > 0 &&
|
||||
c.Certificates[0].PrivateKey != nil)
|
||||
}
|
||||
|
||||
func (c *Config) ValidForClient() bool {
|
||||
return len(c.ServerName) > 0
|
||||
}
|
||||
|
||||
var (
|
||||
defaultSupportedCipherSuites = []CipherSuite{
|
||||
TLS_AES_128_GCM_SHA256,
|
||||
TLS_AES_256_GCM_SHA384,
|
||||
}
|
||||
|
||||
defaultSupportedGroups = []NamedGroup{
|
||||
P256,
|
||||
P384,
|
||||
FFDHE2048,
|
||||
X25519,
|
||||
}
|
||||
|
||||
defaultSignatureSchemes = []SignatureScheme{
|
||||
RSA_PSS_SHA256,
|
||||
RSA_PSS_SHA384,
|
||||
RSA_PSS_SHA512,
|
||||
ECDSA_P256_SHA256,
|
||||
ECDSA_P384_SHA384,
|
||||
ECDSA_P521_SHA512,
|
||||
}
|
||||
|
||||
defaultTicketLen = 16
|
||||
|
||||
defaultPSKModes = []PSKKeyExchangeMode{
|
||||
PSKModeKE,
|
||||
PSKModeDHEKE,
|
||||
}
|
||||
)
|
||||
|
||||
type ConnectionState struct {
|
||||
HandshakeState string // string representation of the handshake state.
|
||||
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
|
||||
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement
|
||||
NextProto string // Selected ALPN proto
|
||||
}
|
||||
|
||||
// Conn implements the net.Conn interface, as with "crypto/tls"
|
||||
// * Read, Write, and Close are provided locally
|
||||
// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn
|
||||
type Conn struct {
|
||||
config *Config
|
||||
conn net.Conn
|
||||
isClient bool
|
||||
|
||||
EarlyData []byte
|
||||
|
||||
state StateConnected
|
||||
hState HandshakeState
|
||||
handshakeMutex sync.Mutex
|
||||
handshakeAlert Alert
|
||||
handshakeComplete bool
|
||||
|
||||
readBuffer []byte
|
||||
in, out *RecordLayer
|
||||
hIn, hOut *HandshakeLayer
|
||||
|
||||
extHandler AppExtensionHandler
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
|
||||
c := &Conn{conn: conn, config: config, isClient: isClient}
|
||||
c.in = NewRecordLayer(c.conn)
|
||||
c.out = NewRecordLayer(c.conn)
|
||||
c.hIn = NewHandshakeLayer(c.in)
|
||||
c.hIn.nonblocking = c.config.NonBlocking
|
||||
c.hOut = NewHandshakeLayer(c.out)
|
||||
return c
|
||||
}
|
||||
|
||||
// Read up
|
||||
func (c *Conn) consumeRecord() error {
|
||||
pt, err := c.in.ReadRecord()
|
||||
if pt == nil {
|
||||
logf(logTypeIO, "extendBuffer returns error %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
switch pt.contentType {
|
||||
case RecordTypeHandshake:
|
||||
logf(logTypeHandshake, "Received post-handshake message")
|
||||
// We do not support fragmentation of post-handshake handshake messages.
|
||||
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
|
||||
start := 0
|
||||
for start < len(pt.fragment) {
|
||||
if len(pt.fragment[start:]) < handshakeHeaderLen {
|
||||
return fmt.Errorf("Post-handshake handshake message too short for header")
|
||||
}
|
||||
|
||||
hm := &HandshakeMessage{}
|
||||
hm.msgType = HandshakeType(pt.fragment[start])
|
||||
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
|
||||
|
||||
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen {
|
||||
return fmt.Errorf("Post-handshake handshake message too short for body")
|
||||
}
|
||||
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen]
|
||||
|
||||
// Advance state machine
|
||||
state, actions, alert := c.state.Next(hm)
|
||||
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
// XXX: If we want to support more advanced cases, e.g., post-handshake
|
||||
// authentication, we'll need to allow transitions other than
|
||||
// Connected -> Connected
|
||||
var connected bool
|
||||
c.state, connected = state.(StateConnected)
|
||||
if !connected {
|
||||
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
start += handshakeHeaderLen + hmLen
|
||||
}
|
||||
case RecordTypeAlert:
|
||||
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||
if len(pt.fragment) != 2 {
|
||||
c.sendAlert(AlertUnexpectedMessage)
|
||||
return io.EOF
|
||||
}
|
||||
if Alert(pt.fragment[1]) == AlertCloseNotify {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
switch pt.fragment[0] {
|
||||
case AlertLevelWarning:
|
||||
// drop on the floor
|
||||
case AlertLevelError:
|
||||
return Alert(pt.fragment[1])
|
||||
default:
|
||||
c.sendAlert(AlertUnexpectedMessage)
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
case RecordTypeApplicationData:
|
||||
c.readBuffer = append(c.readBuffer, pt.fragment...)
|
||||
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Read application data up to the size of buffer. Handshake and alert records
|
||||
// are consumed by the Conn object directly.
|
||||
func (c *Conn) Read(buffer []byte) (int, error) {
|
||||
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
|
||||
if alert := c.Handshake(); alert != AlertNoAlert {
|
||||
return 0, alert
|
||||
}
|
||||
|
||||
if len(buffer) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Lock the input channel
|
||||
c.in.Lock()
|
||||
defer c.in.Unlock()
|
||||
for len(c.readBuffer) == 0 {
|
||||
err := c.consumeRecord()
|
||||
|
||||
// err can be nil if consumeRecord processed a non app-data
|
||||
// record.
|
||||
if err != nil {
|
||||
if c.config.NonBlocking || err != WouldBlock {
|
||||
logf(logTypeIO, "conn.Read returns err=%v", err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var read int
|
||||
n := len(buffer)
|
||||
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
|
||||
if len(c.readBuffer) <= n {
|
||||
buffer = buffer[:len(c.readBuffer)]
|
||||
copy(buffer, c.readBuffer)
|
||||
read = len(c.readBuffer)
|
||||
c.readBuffer = c.readBuffer[:0]
|
||||
} else {
|
||||
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
|
||||
copy(buffer[:n], c.readBuffer[:n])
|
||||
c.readBuffer = c.readBuffer[n:]
|
||||
read = n
|
||||
}
|
||||
|
||||
logf(logTypeVerbose, "Returning %v", string(buffer))
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// Write application data
|
||||
func (c *Conn) Write(buffer []byte) (int, error) {
|
||||
// Lock the output channel
|
||||
c.out.Lock()
|
||||
defer c.out.Unlock()
|
||||
|
||||
// Send full-size fragments
|
||||
var start int
|
||||
sent := 0
|
||||
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
||||
err := c.out.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeApplicationData,
|
||||
fragment: buffer[start : start+maxFragmentLen],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent += maxFragmentLen
|
||||
}
|
||||
|
||||
// Send a final partial fragment if necessary
|
||||
if start < len(buffer) {
|
||||
err := c.out.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeApplicationData,
|
||||
fragment: buffer[start:],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent += len(buffer[start:])
|
||||
}
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
// sendAlert sends a TLS alert message.
|
||||
// c.out.Mutex <= L.
|
||||
func (c *Conn) sendAlert(err Alert) error {
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
var level int
|
||||
switch err {
|
||||
case AlertNoRenegotiation, AlertCloseNotify:
|
||||
level = AlertLevelWarning
|
||||
default:
|
||||
level = AlertLevelError
|
||||
}
|
||||
|
||||
buf := []byte{byte(err), byte(level)}
|
||||
c.out.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeAlert,
|
||||
fragment: buf,
|
||||
})
|
||||
|
||||
// close_notify and end_of_early_data are not actually errors
|
||||
if level == AlertLevelWarning {
|
||||
return &net.OpError{Op: "local error", Err: err}
|
||||
}
|
||||
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *Conn) Close() error {
|
||||
// XXX crypto/tls has an interlock with Write here. Do we need that?
|
||||
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote network address.
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
// SetDeadline sets the read and write deadlines associated with the connection.
|
||||
// A zero value for t means Read and Write will not time out.
|
||||
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline on the underlying connection.
|
||||
// A zero value for t means Read will not time out.
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the write deadline on the underlying connection.
|
||||
// A zero value for t means Write will not time out.
|
||||
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
|
||||
label := "[server]"
|
||||
if c.isClient {
|
||||
label = "[client]"
|
||||
}
|
||||
|
||||
switch action := actionGeneric.(type) {
|
||||
case SendHandshakeMessage:
|
||||
err := c.hOut.WriteMessage(action.Message)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case RekeyIn:
|
||||
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet)
|
||||
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case RekeyOut:
|
||||
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet)
|
||||
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case SendEarlyData:
|
||||
logf(logTypeHandshake, "%s Sending early data...", label)
|
||||
_, err := c.Write(c.EarlyData)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case ReadPastEarlyData:
|
||||
logf(logTypeHandshake, "%s Reading past early data...", label)
|
||||
// Scan past all records that fail to decrypt
|
||||
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
_, ok := err.(DecryptError)
|
||||
|
||||
for ok {
|
||||
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
_, ok = err.(DecryptError)
|
||||
}
|
||||
|
||||
case ReadEarlyData:
|
||||
logf(logTypeHandshake, "%s Reading early data...", label)
|
||||
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
|
||||
|
||||
for t == RecordTypeApplicationData {
|
||||
// Read a record into the buffer. Note that this is safe
|
||||
// in blocking mode because we read the record in in
|
||||
// PeekRecordType.
|
||||
pt, err := c.in.ReadRecord()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
|
||||
c.EarlyData = append(c.EarlyData, pt.fragment...)
|
||||
|
||||
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
|
||||
}
|
||||
logf(logTypeHandshake, "%s Done reading early data", label)
|
||||
|
||||
case StorePSK:
|
||||
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
|
||||
if c.isClient {
|
||||
// Clients look up PSKs based on server name
|
||||
c.config.PSKs.Put(c.config.ServerName, action.PSK)
|
||||
} else {
|
||||
// Servers look them up based on the identity in the extension
|
||||
c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK)
|
||||
}
|
||||
|
||||
default:
|
||||
logf(logTypeHandshake, "%s Unknown actionuction type", label)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
func (c *Conn) HandshakeSetup() Alert {
|
||||
var state HandshakeState
|
||||
var actions []HandshakeAction
|
||||
var alert Alert
|
||||
|
||||
if err := c.config.Init(c.isClient); err != nil {
|
||||
logf(logTypeHandshake, "Error initializing config: %v", err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
// Set things up
|
||||
caps := Capabilities{
|
||||
CipherSuites: c.config.CipherSuites,
|
||||
Groups: c.config.Groups,
|
||||
SignatureSchemes: c.config.SignatureSchemes,
|
||||
PSKs: c.config.PSKs,
|
||||
PSKModes: c.config.PSKModes,
|
||||
AllowEarlyData: c.config.AllowEarlyData,
|
||||
RequireCookie: c.config.RequireCookie,
|
||||
CookieHandler: c.config.CookieHandler,
|
||||
RequireClientAuth: c.config.RequireClientAuth,
|
||||
NextProtos: c.config.NextProtos,
|
||||
Certificates: c.config.Certificates,
|
||||
ExtensionHandler: c.extHandler,
|
||||
}
|
||||
opts := ConnectionOptions{
|
||||
ServerName: c.config.ServerName,
|
||||
NextProtos: c.config.NextProtos,
|
||||
EarlyData: c.EarlyData,
|
||||
}
|
||||
|
||||
if caps.RequireCookie && caps.CookieHandler == nil {
|
||||
caps.CookieHandler = &defaultCookieHandler{}
|
||||
}
|
||||
|
||||
if c.isClient {
|
||||
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error initializing client state: %v", alert)
|
||||
return alert
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
return alert
|
||||
}
|
||||
}
|
||||
} else {
|
||||
state = ServerStateStart{Caps: caps, conn: c}
|
||||
}
|
||||
|
||||
c.hState = state
|
||||
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
// Handshake causes a TLS handshake on the connection. The `isClient` member
|
||||
// determines whether a client or server handshake is performed. If a
|
||||
// handshake has already been performed, then its result will be returned.
|
||||
func (c *Conn) Handshake() Alert {
|
||||
label := "[server]"
|
||||
if c.isClient {
|
||||
label = "[client]"
|
||||
}
|
||||
|
||||
// TODO Lock handshakeMutex
|
||||
// TODO Remove CloseNotify hack
|
||||
if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify {
|
||||
logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert)
|
||||
return c.handshakeAlert
|
||||
}
|
||||
if c.handshakeComplete {
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
var alert Alert
|
||||
if c.hState == nil {
|
||||
logf(logTypeHandshake, "%s First time through handshake, setting up", label)
|
||||
alert = c.HandshakeSetup()
|
||||
if alert != AlertNoAlert {
|
||||
return alert
|
||||
}
|
||||
} else {
|
||||
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState)
|
||||
}
|
||||
|
||||
state := c.hState
|
||||
_, connected := state.(StateConnected)
|
||||
|
||||
var actions []HandshakeAction
|
||||
|
||||
for !connected {
|
||||
// Read a handshake message
|
||||
hm, err := c.hIn.ReadMessage()
|
||||
if err == WouldBlock {
|
||||
logf(logTypeHandshake, "%s Would block reading message: %v", label, err)
|
||||
return AlertWouldBlock
|
||||
}
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading message: %v", label, err)
|
||||
c.sendAlert(AlertCloseNotify)
|
||||
return AlertCloseNotify
|
||||
}
|
||||
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
|
||||
|
||||
// Advance the state machine
|
||||
state, actions, alert = state.Next(hm)
|
||||
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||
return alert
|
||||
}
|
||||
|
||||
for index, action := range actions {
|
||||
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return alert
|
||||
}
|
||||
}
|
||||
|
||||
c.hState = state
|
||||
logf(logTypeHandshake, "state is now %s", c.GetHsState())
|
||||
|
||||
_, connected = state.(StateConnected)
|
||||
}
|
||||
|
||||
c.state = state.(StateConnected)
|
||||
|
||||
// Send NewSessionTicket if acting as server
|
||||
if !c.isClient && c.config.SendSessionTickets {
|
||||
actions, alert := c.state.NewSessionTicket(
|
||||
c.config.TicketLen,
|
||||
c.config.TicketLifetime,
|
||||
c.config.EarlyDataLifetime)
|
||||
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return alert
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.handshakeComplete = true
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
|
||||
if !c.handshakeComplete {
|
||||
return fmt.Errorf("Cannot update keys until after handshake")
|
||||
}
|
||||
|
||||
request := KeyUpdateNotRequested
|
||||
if requestUpdate {
|
||||
request = KeyUpdateRequested
|
||||
}
|
||||
|
||||
// Create the key update and update state
|
||||
actions, alert := c.state.KeyUpdate(request)
|
||||
if alert != AlertNoAlert {
|
||||
c.sendAlert(alert)
|
||||
return fmt.Errorf("Alert while generating key update: %v", alert)
|
||||
}
|
||||
|
||||
// Take actions (send key update and rekey)
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
c.sendAlert(alert)
|
||||
return fmt.Errorf("Alert during key update actions: %v", alert)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) GetHsState() string {
|
||||
return reflect.TypeOf(c.hState).Name()
|
||||
}
|
||||
|
||||
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
_, connected := c.hState.(StateConnected)
|
||||
if !connected {
|
||||
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
|
||||
}
|
||||
|
||||
if c.state.exporterSecret == nil {
|
||||
return nil, fmt.Errorf("Internal error: no exporter secret")
|
||||
}
|
||||
|
||||
h0 := c.state.cryptoParams.Hash.New().Sum(nil)
|
||||
tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0)
|
||||
|
||||
hc := c.state.cryptoParams.Hash.New().Sum(context)
|
||||
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
|
||||
}
|
||||
|
||||
func (c *Conn) State() ConnectionState {
|
||||
state := ConnectionState{
|
||||
HandshakeState: c.GetHsState(),
|
||||
}
|
||||
|
||||
if c.handshakeComplete {
|
||||
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
|
||||
state.NextProto = c.state.Params.NextProto
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error {
|
||||
if c.hState != nil {
|
||||
return fmt.Errorf("Can't set extension handler after setup")
|
||||
}
|
||||
|
||||
c.extHandler = h
|
||||
return nil
|
||||
}
|
654
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
Normal file
654
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
Normal file
|
@ -0,0 +1,654 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
|
||||
// Blank includes to ensure hash support
|
||||
_ "crypto/sha1"
|
||||
_ "crypto/sha256"
|
||||
_ "crypto/sha512"
|
||||
)
|
||||
|
||||
var prng = rand.Reader
|
||||
|
||||
type aeadFactory func(key []byte) (cipher.AEAD, error)
|
||||
|
||||
type CipherSuiteParams struct {
|
||||
Suite CipherSuite
|
||||
Cipher aeadFactory // Cipher factory
|
||||
Hash crypto.Hash // Hash function
|
||||
KeyLen int // Key length in octets
|
||||
IvLen int // IV length in octets
|
||||
}
|
||||
|
||||
type signatureAlgorithm uint8
|
||||
|
||||
const (
|
||||
signatureAlgorithmUnknown = iota
|
||||
signatureAlgorithmRSA_PKCS1
|
||||
signatureAlgorithmRSA_PSS
|
||||
signatureAlgorithmECDSA
|
||||
)
|
||||
|
||||
var (
|
||||
hashMap = map[SignatureScheme]crypto.Hash{
|
||||
RSA_PKCS1_SHA1: crypto.SHA1,
|
||||
RSA_PKCS1_SHA256: crypto.SHA256,
|
||||
RSA_PKCS1_SHA384: crypto.SHA384,
|
||||
RSA_PKCS1_SHA512: crypto.SHA512,
|
||||
ECDSA_P256_SHA256: crypto.SHA256,
|
||||
ECDSA_P384_SHA384: crypto.SHA384,
|
||||
ECDSA_P521_SHA512: crypto.SHA512,
|
||||
RSA_PSS_SHA256: crypto.SHA256,
|
||||
RSA_PSS_SHA384: crypto.SHA384,
|
||||
RSA_PSS_SHA512: crypto.SHA512,
|
||||
}
|
||||
|
||||
sigMap = map[SignatureScheme]signatureAlgorithm{
|
||||
RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1,
|
||||
RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1,
|
||||
RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1,
|
||||
RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1,
|
||||
ECDSA_P256_SHA256: signatureAlgorithmECDSA,
|
||||
ECDSA_P384_SHA384: signatureAlgorithmECDSA,
|
||||
ECDSA_P521_SHA512: signatureAlgorithmECDSA,
|
||||
RSA_PSS_SHA256: signatureAlgorithmRSA_PSS,
|
||||
RSA_PSS_SHA384: signatureAlgorithmRSA_PSS,
|
||||
RSA_PSS_SHA512: signatureAlgorithmRSA_PSS,
|
||||
}
|
||||
|
||||
curveMap = map[SignatureScheme]NamedGroup{
|
||||
ECDSA_P256_SHA256: P256,
|
||||
ECDSA_P384_SHA384: P384,
|
||||
ECDSA_P521_SHA512: P521,
|
||||
}
|
||||
|
||||
newAESGCM = func(key []byte) (cipher.AEAD, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TLS always uses 12-byte nonces
|
||||
return cipher.NewGCMWithNonceSize(block, 12)
|
||||
}
|
||||
|
||||
cipherSuiteMap = map[CipherSuite]CipherSuiteParams{
|
||||
TLS_AES_128_GCM_SHA256: {
|
||||
Suite: TLS_AES_128_GCM_SHA256,
|
||||
Cipher: newAESGCM,
|
||||
Hash: crypto.SHA256,
|
||||
KeyLen: 16,
|
||||
IvLen: 12,
|
||||
},
|
||||
TLS_AES_256_GCM_SHA384: {
|
||||
Suite: TLS_AES_256_GCM_SHA384,
|
||||
Cipher: newAESGCM,
|
||||
Hash: crypto.SHA384,
|
||||
KeyLen: 32,
|
||||
IvLen: 12,
|
||||
},
|
||||
}
|
||||
|
||||
x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{
|
||||
RSA_PKCS1_SHA1: x509.SHA1WithRSA,
|
||||
RSA_PKCS1_SHA256: x509.SHA256WithRSA,
|
||||
RSA_PKCS1_SHA384: x509.SHA384WithRSA,
|
||||
RSA_PKCS1_SHA512: x509.SHA512WithRSA,
|
||||
ECDSA_P256_SHA256: x509.ECDSAWithSHA256,
|
||||
ECDSA_P384_SHA384: x509.ECDSAWithSHA384,
|
||||
ECDSA_P521_SHA512: x509.ECDSAWithSHA512,
|
||||
}
|
||||
|
||||
defaultRSAKeySize = 2048
|
||||
)
|
||||
|
||||
func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) {
|
||||
switch group {
|
||||
case P256:
|
||||
crv = elliptic.P256()
|
||||
case P384:
|
||||
crv = elliptic.P384()
|
||||
case P521:
|
||||
crv = elliptic.P521()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) {
|
||||
switch key.Curve.Params().Name {
|
||||
case elliptic.P256().Params().Name:
|
||||
g = P256
|
||||
case elliptic.P384().Params().Name:
|
||||
g = P384
|
||||
case elliptic.P521().Params().Name:
|
||||
g = P521
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) {
|
||||
size = 0
|
||||
switch group {
|
||||
case X25519:
|
||||
size = 32
|
||||
case P256:
|
||||
size = 65
|
||||
case P384:
|
||||
size = 97
|
||||
case P521:
|
||||
size = 133
|
||||
case FFDHE2048:
|
||||
size = 256
|
||||
case FFDHE3072:
|
||||
size = 384
|
||||
case FFDHE4096:
|
||||
size = 512
|
||||
case FFDHE6144:
|
||||
size = 768
|
||||
case FFDHE8192:
|
||||
size = 1024
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func primeFromNamedGroup(group NamedGroup) (p *big.Int) {
|
||||
switch group {
|
||||
case FFDHE2048:
|
||||
p = finiteFieldPrime2048
|
||||
case FFDHE3072:
|
||||
p = finiteFieldPrime3072
|
||||
case FFDHE4096:
|
||||
p = finiteFieldPrime4096
|
||||
case FFDHE6144:
|
||||
p = finiteFieldPrime6144
|
||||
case FFDHE8192:
|
||||
p = finiteFieldPrime8192
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool {
|
||||
sigType := sigMap[alg]
|
||||
switch key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS
|
||||
case *ecdsa.PrivateKey:
|
||||
return sigType == signatureAlgorithmECDSA
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) {
|
||||
primeLen := len(p.Bytes())
|
||||
for {
|
||||
// g = 2 for all ffdhe groups
|
||||
priv, err = rand.Int(prng, p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pub = big.NewInt(0)
|
||||
pub.Exp(big.NewInt(2), priv, p)
|
||||
|
||||
if len(pub.Bytes()) == primeLen {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) {
|
||||
switch group {
|
||||
case P256, P384, P521:
|
||||
var x, y *big.Int
|
||||
crv := curveFromNamedGroup(group)
|
||||
priv, x, y, err = elliptic.GenerateKey(crv, prng)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pub = elliptic.Marshal(crv, x, y)
|
||||
return
|
||||
|
||||
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
|
||||
p := primeFromNamedGroup(group)
|
||||
x, X, err2 := ffdheKeyShareFromPrime(p)
|
||||
if err2 != nil {
|
||||
err = err2
|
||||
return
|
||||
}
|
||||
|
||||
priv = x.Bytes()
|
||||
pubBytes := X.Bytes()
|
||||
|
||||
numBytes := keyExchangeSizeFromNamedGroup(group)
|
||||
|
||||
pub = make([]byte, numBytes)
|
||||
copy(pub[numBytes-len(pubBytes):], pubBytes)
|
||||
|
||||
return
|
||||
|
||||
case X25519:
|
||||
var private, public [32]byte
|
||||
_, err = prng.Read(private[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
curve25519.ScalarBaseMult(&public, &private)
|
||||
priv = private[:]
|
||||
pub = public[:]
|
||||
return
|
||||
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group)
|
||||
}
|
||||
}
|
||||
|
||||
func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) {
|
||||
switch group {
|
||||
case P256, P384, P521:
|
||||
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
|
||||
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||
}
|
||||
|
||||
crv := curveFromNamedGroup(group)
|
||||
pubX, pubY := elliptic.Unmarshal(crv, pub)
|
||||
x, _ := crv.Params().ScalarMult(pubX, pubY, priv)
|
||||
xBytes := x.Bytes()
|
||||
|
||||
numBytes := len(crv.Params().P.Bytes())
|
||||
|
||||
ret := make([]byte, numBytes)
|
||||
copy(ret[numBytes-len(xBytes):], xBytes)
|
||||
|
||||
return ret, nil
|
||||
|
||||
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
|
||||
numBytes := keyExchangeSizeFromNamedGroup(group)
|
||||
if len(pub) != numBytes {
|
||||
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||
}
|
||||
p := primeFromNamedGroup(group)
|
||||
x := big.NewInt(0).SetBytes(priv)
|
||||
Y := big.NewInt(0).SetBytes(pub)
|
||||
ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes()
|
||||
|
||||
ret := make([]byte, numBytes)
|
||||
copy(ret[numBytes-len(ZBytes):], ZBytes)
|
||||
|
||||
return ret, nil
|
||||
|
||||
case X25519:
|
||||
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
|
||||
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||
}
|
||||
|
||||
var private, public, ret [32]byte
|
||||
copy(private[:], priv)
|
||||
copy(public[:], pub)
|
||||
curve25519.ScalarMult(&ret, &private, &public)
|
||||
|
||||
return ret[:], nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group)
|
||||
}
|
||||
}
|
||||
|
||||
func newSigningKey(sig SignatureScheme) (crypto.Signer, error) {
|
||||
switch sig {
|
||||
case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256,
|
||||
RSA_PKCS1_SHA384, RSA_PKCS1_SHA512,
|
||||
RSA_PSS_SHA256, RSA_PSS_SHA384,
|
||||
RSA_PSS_SHA512:
|
||||
return rsa.GenerateKey(prng, defaultRSAKeySize)
|
||||
case ECDSA_P256_SHA256:
|
||||
return ecdsa.GenerateKey(elliptic.P256(), prng)
|
||||
case ECDSA_P384_SHA384:
|
||||
return ecdsa.GenerateKey(elliptic.P384(), prng)
|
||||
case ECDSA_P521_SHA512:
|
||||
return ecdsa.GenerateKey(elliptic.P521(), prng)
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig)
|
||||
}
|
||||
}
|
||||
|
||||
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
|
||||
sigAlg, ok := x509AlgMap[alg]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
|
||||
}
|
||||
if len(name) == 0 {
|
||||
return nil, fmt.Errorf("tls.selfsigned: No name provided")
|
||||
}
|
||||
|
||||
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(0, 0, 1),
|
||||
SignatureAlgorithm: sigAlg,
|
||||
Subject: pkix.Name{CommonName: name},
|
||||
DNSNames: []string{name},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// It is safe to ignore the error here because we're parsing known-good data
|
||||
cert, _ := x509.ParseCertificate(der)
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// XXX(rlb): Copied from crypto/x509
|
||||
type ecdsaSignature struct {
|
||||
R, S *big.Int
|
||||
}
|
||||
|
||||
func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) {
|
||||
var opts crypto.SignerOpts
|
||||
|
||||
hash := hashMap[alg]
|
||||
if hash == crypto.SHA1 {
|
||||
return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
|
||||
}
|
||||
|
||||
sigType := sigMap[alg]
|
||||
var realInput []byte
|
||||
switch key := privateKey.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
switch {
|
||||
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||
logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size())
|
||||
opts = hash
|
||||
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||
fallthrough
|
||||
case sigType == signatureAlgorithmRSA_PSS:
|
||||
logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size())
|
||||
opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key")
|
||||
}
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput = h.Sum(nil)
|
||||
case *ecdsa.PrivateKey:
|
||||
if sigType != signatureAlgorithmECDSA {
|
||||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key")
|
||||
}
|
||||
|
||||
algGroup := curveMap[alg]
|
||||
keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey))
|
||||
if algGroup != keyGroup {
|
||||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination")
|
||||
}
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput = h.Sum(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type")
|
||||
}
|
||||
|
||||
sig, err := privateKey.Sign(prng, realInput, opts)
|
||||
logf(logTypeCrypto, "signature: %x", sig)
|
||||
return sig, err
|
||||
}
|
||||
|
||||
func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error {
|
||||
hash := hashMap[alg]
|
||||
|
||||
if hash == crypto.SHA1 {
|
||||
return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
|
||||
}
|
||||
|
||||
sigType := sigMap[alg]
|
||||
switch pub := publicKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
switch {
|
||||
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||
logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size())
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput := h.Sum(nil)
|
||||
return rsa.VerifyPKCS1v15(pub, hash, realInput, sig)
|
||||
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||
fallthrough
|
||||
case sigType == signatureAlgorithmRSA_PSS:
|
||||
logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size())
|
||||
opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput := h.Sum(nil)
|
||||
return rsa.VerifyPSS(pub, hash, realInput, sig, opts)
|
||||
default:
|
||||
return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key")
|
||||
}
|
||||
|
||||
case *ecdsa.PublicKey:
|
||||
if sigType != signatureAlgorithmECDSA {
|
||||
return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key")
|
||||
}
|
||||
|
||||
if curveMap[alg] != namedGroupFromECDSAKey(pub) {
|
||||
return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key")
|
||||
}
|
||||
|
||||
ecdsaSig := new(ecdsaSignature)
|
||||
if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
|
||||
return err
|
||||
} else if len(rest) != 0 {
|
||||
return fmt.Errorf("tls.verify: trailing data after ECDSA signature")
|
||||
}
|
||||
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
|
||||
return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values")
|
||||
}
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput := h.Sum(nil)
|
||||
if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) {
|
||||
return fmt.Errorf("tls.verify: ECDSA verification failure")
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("tls.verify: Unsupported key type")
|
||||
}
|
||||
}
|
||||
|
||||
// 0
|
||||
// |
|
||||
// v
|
||||
// PSK -> HKDF-Extract = Early Secret
|
||||
// |
|
||||
// +-----> Derive-Secret(.,
|
||||
// | "ext binder" |
|
||||
// | "res binder",
|
||||
// | "")
|
||||
// | = binder_key
|
||||
// |
|
||||
// +-----> Derive-Secret(., "c e traffic",
|
||||
// | ClientHello)
|
||||
// | = client_early_traffic_secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "e exp master",
|
||||
// | ClientHello)
|
||||
// | = early_exporter_master_secret
|
||||
// v
|
||||
// Derive-Secret(., "derived", "")
|
||||
// |
|
||||
// v
|
||||
// (EC)DHE -> HKDF-Extract = Handshake Secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "c hs traffic",
|
||||
// | ClientHello...ServerHello)
|
||||
// | = client_handshake_traffic_secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "s hs traffic",
|
||||
// | ClientHello...ServerHello)
|
||||
// | = server_handshake_traffic_secret
|
||||
// v
|
||||
// Derive-Secret(., "derived", "")
|
||||
// |
|
||||
// v
|
||||
// 0 -> HKDF-Extract = Master Secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "c ap traffic",
|
||||
// | ClientHello...server Finished)
|
||||
// | = client_application_traffic_secret_0
|
||||
// |
|
||||
// +-----> Derive-Secret(., "s ap traffic",
|
||||
// | ClientHello...server Finished)
|
||||
// | = server_application_traffic_secret_0
|
||||
// |
|
||||
// +-----> Derive-Secret(., "exp master",
|
||||
// | ClientHello...server Finished)
|
||||
// | = exporter_master_secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "res master",
|
||||
// ClientHello...client Finished)
|
||||
// = resumption_master_secret
|
||||
|
||||
// From RFC 5869
|
||||
// PRK = HMAC-Hash(salt, IKM)
|
||||
func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte {
|
||||
salt := saltIn
|
||||
|
||||
// if [salt is] not provided, it is set to a string of HashLen zeros
|
||||
if salt == nil {
|
||||
salt = bytes.Repeat([]byte{0}, hash.Size())
|
||||
}
|
||||
|
||||
h := hmac.New(hash.New, salt)
|
||||
h.Write(input)
|
||||
out := h.Sum(nil)
|
||||
|
||||
logf(logTypeCrypto, "HKDF Extract:\n")
|
||||
logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt)
|
||||
logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input)
|
||||
logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
const (
|
||||
labelExternalBinder = "ext binder"
|
||||
labelResumptionBinder = "res binder"
|
||||
labelEarlyTrafficSecret = "c e traffic"
|
||||
labelEarlyExporterSecret = "e exp master"
|
||||
labelClientHandshakeTrafficSecret = "c hs traffic"
|
||||
labelServerHandshakeTrafficSecret = "s hs traffic"
|
||||
labelClientApplicationTrafficSecret = "c ap traffic"
|
||||
labelServerApplicationTrafficSecret = "s ap traffic"
|
||||
labelExporterSecret = "exp master"
|
||||
labelResumptionSecret = "res master"
|
||||
labelDerived = "derived"
|
||||
labelFinished = "finished"
|
||||
labelResumption = "resumption"
|
||||
)
|
||||
|
||||
// struct HkdfLabel {
|
||||
// uint16 length;
|
||||
// opaque label<9..255>;
|
||||
// opaque hash_value<0..255>;
|
||||
// };
|
||||
func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte {
|
||||
label := "tls13 " + labelIn
|
||||
|
||||
labelLen := len(label)
|
||||
hashLen := len(hashValue)
|
||||
hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen)
|
||||
hkdfLabel[0] = byte(outLen >> 8)
|
||||
hkdfLabel[1] = byte(outLen)
|
||||
hkdfLabel[2] = byte(labelLen)
|
||||
copy(hkdfLabel[3:3+labelLen], []byte(label))
|
||||
hkdfLabel[3+labelLen] = byte(hashLen)
|
||||
copy(hkdfLabel[3+labelLen+1:], hashValue)
|
||||
|
||||
return hkdfLabel
|
||||
}
|
||||
|
||||
func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte {
|
||||
out := []byte{}
|
||||
T := []byte{}
|
||||
i := byte(1)
|
||||
for len(out) < outLen {
|
||||
block := append(T, info...)
|
||||
block = append(block, i)
|
||||
|
||||
h := hmac.New(hash.New, prk)
|
||||
h.Write(block)
|
||||
|
||||
T = h.Sum(nil)
|
||||
out = append(out, T...)
|
||||
i++
|
||||
}
|
||||
return out[:outLen]
|
||||
}
|
||||
|
||||
func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte {
|
||||
info := hkdfEncodeLabel(label, hashValue, outLen)
|
||||
derived := HkdfExpand(hash, secret, info, outLen)
|
||||
|
||||
logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen)
|
||||
logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret)
|
||||
logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue)
|
||||
logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info)
|
||||
logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived)
|
||||
|
||||
return derived
|
||||
}
|
||||
|
||||
func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte {
|
||||
return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size())
|
||||
}
|
||||
|
||||
func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte {
|
||||
macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size())
|
||||
mac := hmac.New(params.Hash.New, macKey)
|
||||
mac.Write(input)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
type keySet struct {
|
||||
cipher aeadFactory
|
||||
key []byte
|
||||
iv []byte
|
||||
}
|
||||
|
||||
func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
|
||||
logf(logTypeCrypto, "making traffic keys: secret=%x", secret)
|
||||
return keySet{
|
||||
cipher: params.Cipher,
|
||||
key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen),
|
||||
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
|
||||
}
|
||||
}
|
586
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
Normal file
586
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
Normal file
|
@ -0,0 +1,586 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/bifurcation/mint/syntax"
|
||||
)
|
||||
|
||||
type ExtensionBody interface {
|
||||
Type() ExtensionType
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal(data []byte) (int, error)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ExtensionType extension_type;
|
||||
// opaque extension_data<0..2^16-1>;
|
||||
// } Extension;
|
||||
type Extension struct {
|
||||
ExtensionType ExtensionType
|
||||
ExtensionData []byte `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (ext Extension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(ext)
|
||||
}
|
||||
|
||||
func (ext *Extension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, ext)
|
||||
}
|
||||
|
||||
type ExtensionList []Extension
|
||||
|
||||
type extensionListInner struct {
|
||||
List []Extension `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (el ExtensionList) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(extensionListInner{el})
|
||||
}
|
||||
|
||||
func (el *ExtensionList) Unmarshal(data []byte) (int, error) {
|
||||
var list extensionListInner
|
||||
read, err := syntax.Unmarshal(data, &list)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
*el = list.List
|
||||
return read, nil
|
||||
}
|
||||
|
||||
func (el *ExtensionList) Add(src ExtensionBody) error {
|
||||
data, err := src.Marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if el == nil {
|
||||
el = new(ExtensionList)
|
||||
}
|
||||
|
||||
// If one already exists with this type, replace it
|
||||
for i := range *el {
|
||||
if (*el)[i].ExtensionType == src.Type() {
|
||||
(*el)[i].ExtensionData = data
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise append
|
||||
*el = append(*el, Extension{
|
||||
ExtensionType: src.Type(),
|
||||
ExtensionData: data,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (el ExtensionList) Find(dst ExtensionBody) bool {
|
||||
for _, ext := range el {
|
||||
if ext.ExtensionType == dst.Type() {
|
||||
_, err := dst.Unmarshal(ext.ExtensionData)
|
||||
return err == nil
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// struct {
|
||||
// NameType name_type;
|
||||
// select (name_type) {
|
||||
// case host_name: HostName;
|
||||
// } name;
|
||||
// } ServerName;
|
||||
//
|
||||
// enum {
|
||||
// host_name(0), (255)
|
||||
// } NameType;
|
||||
//
|
||||
// opaque HostName<1..2^16-1>;
|
||||
//
|
||||
// struct {
|
||||
// ServerName server_name_list<1..2^16-1>
|
||||
// } ServerNameList;
|
||||
//
|
||||
// But we only care about the case where there's a single DNS hostname. We
|
||||
// will never create anything else, and throw if we receive something else
|
||||
//
|
||||
// 2 1 2
|
||||
// | listLen | NameType | nameLen | name |
|
||||
type ServerNameExtension string
|
||||
|
||||
type serverNameInner struct {
|
||||
NameType uint8
|
||||
HostName []byte `tls:"head=2,min=1"`
|
||||
}
|
||||
|
||||
type serverNameListInner struct {
|
||||
ServerNameList []serverNameInner `tls:"head=2,min=1"`
|
||||
}
|
||||
|
||||
func (sni ServerNameExtension) Type() ExtensionType {
|
||||
return ExtensionTypeServerName
|
||||
}
|
||||
|
||||
func (sni ServerNameExtension) Marshal() ([]byte, error) {
|
||||
list := serverNameListInner{
|
||||
ServerNameList: []serverNameInner{{
|
||||
NameType: 0x00, // host_name
|
||||
HostName: []byte(sni),
|
||||
}},
|
||||
}
|
||||
|
||||
return syntax.Marshal(list)
|
||||
}
|
||||
|
||||
func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) {
|
||||
var list serverNameListInner
|
||||
read, err := syntax.Unmarshal(data, &list)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Syntax requires at least one entry
|
||||
// Entries beyond the first are ignored
|
||||
if nameType := list.ServerNameList[0].NameType; nameType != 0x00 {
|
||||
return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType)
|
||||
}
|
||||
|
||||
*sni = ServerNameExtension(list.ServerNameList[0].HostName)
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// NamedGroup group;
|
||||
// opaque key_exchange<1..2^16-1>;
|
||||
// } KeyShareEntry;
|
||||
//
|
||||
// struct {
|
||||
// select (Handshake.msg_type) {
|
||||
// case client_hello:
|
||||
// KeyShareEntry client_shares<0..2^16-1>;
|
||||
//
|
||||
// case hello_retry_request:
|
||||
// NamedGroup selected_group;
|
||||
//
|
||||
// case server_hello:
|
||||
// KeyShareEntry server_share;
|
||||
// };
|
||||
// } KeyShare;
|
||||
type KeyShareEntry struct {
|
||||
Group NamedGroup
|
||||
KeyExchange []byte `tls:"head=2,min=1"`
|
||||
}
|
||||
|
||||
func (kse KeyShareEntry) SizeValid() bool {
|
||||
return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group)
|
||||
}
|
||||
|
||||
type KeyShareExtension struct {
|
||||
HandshakeType HandshakeType
|
||||
SelectedGroup NamedGroup
|
||||
Shares []KeyShareEntry
|
||||
}
|
||||
|
||||
type KeyShareClientHelloInner struct {
|
||||
ClientShares []KeyShareEntry `tls:"head=2,min=0"`
|
||||
}
|
||||
type KeyShareHelloRetryInner struct {
|
||||
SelectedGroup NamedGroup
|
||||
}
|
||||
type KeyShareServerHelloInner struct {
|
||||
ServerShare KeyShareEntry
|
||||
}
|
||||
|
||||
func (ks KeyShareExtension) Type() ExtensionType {
|
||||
return ExtensionTypeKeyShare
|
||||
}
|
||||
|
||||
func (ks KeyShareExtension) Marshal() ([]byte, error) {
|
||||
switch ks.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
for _, share := range ks.Shares {
|
||||
if !share.SizeValid() {
|
||||
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||
}
|
||||
}
|
||||
return syntax.Marshal(KeyShareClientHelloInner{ks.Shares})
|
||||
|
||||
case HandshakeTypeHelloRetryRequest:
|
||||
if len(ks.Shares) > 0 {
|
||||
return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest")
|
||||
}
|
||||
|
||||
return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup})
|
||||
|
||||
case HandshakeTypeServerHello:
|
||||
if len(ks.Shares) != 1 {
|
||||
return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share")
|
||||
}
|
||||
|
||||
if !ks.Shares[0].SizeValid() {
|
||||
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||
}
|
||||
|
||||
return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]})
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) {
|
||||
switch ks.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
var inner KeyShareClientHelloInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for _, share := range inner.ClientShares {
|
||||
if !share.SizeValid() {
|
||||
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||
}
|
||||
}
|
||||
|
||||
ks.Shares = inner.ClientShares
|
||||
return read, nil
|
||||
|
||||
case HandshakeTypeHelloRetryRequest:
|
||||
var inner KeyShareHelloRetryInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
ks.SelectedGroup = inner.SelectedGroup
|
||||
return read, nil
|
||||
|
||||
case HandshakeTypeServerHello:
|
||||
var inner KeyShareServerHelloInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if !inner.ServerShare.SizeValid() {
|
||||
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||
}
|
||||
|
||||
ks.Shares = []KeyShareEntry{inner.ServerShare}
|
||||
return read, nil
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// struct {
|
||||
// NamedGroup named_group_list<2..2^16-1>;
|
||||
// } NamedGroupList;
|
||||
type SupportedGroupsExtension struct {
|
||||
Groups []NamedGroup `tls:"head=2,min=2"`
|
||||
}
|
||||
|
||||
func (sg SupportedGroupsExtension) Type() ExtensionType {
|
||||
return ExtensionTypeSupportedGroups
|
||||
}
|
||||
|
||||
func (sg SupportedGroupsExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(sg)
|
||||
}
|
||||
|
||||
func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, sg)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// SignatureScheme supported_signature_algorithms<2..2^16-2>;
|
||||
// } SignatureSchemeList
|
||||
type SignatureAlgorithmsExtension struct {
|
||||
Algorithms []SignatureScheme `tls:"head=2,min=2"`
|
||||
}
|
||||
|
||||
func (sa SignatureAlgorithmsExtension) Type() ExtensionType {
|
||||
return ExtensionTypeSignatureAlgorithms
|
||||
}
|
||||
|
||||
func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(sa)
|
||||
}
|
||||
|
||||
func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, sa)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque identity<1..2^16-1>;
|
||||
// uint32 obfuscated_ticket_age;
|
||||
// } PskIdentity;
|
||||
//
|
||||
// opaque PskBinderEntry<32..255>;
|
||||
//
|
||||
// struct {
|
||||
// select (Handshake.msg_type) {
|
||||
// case client_hello:
|
||||
// PskIdentity identities<7..2^16-1>;
|
||||
// PskBinderEntry binders<33..2^16-1>;
|
||||
//
|
||||
// case server_hello:
|
||||
// uint16 selected_identity;
|
||||
// };
|
||||
//
|
||||
// } PreSharedKeyExtension;
|
||||
type PSKIdentity struct {
|
||||
Identity []byte `tls:"head=2,min=1"`
|
||||
ObfuscatedTicketAge uint32
|
||||
}
|
||||
|
||||
type PSKBinderEntry struct {
|
||||
Binder []byte `tls:"head=1,min=32"`
|
||||
}
|
||||
|
||||
type PreSharedKeyExtension struct {
|
||||
HandshakeType HandshakeType
|
||||
Identities []PSKIdentity
|
||||
Binders []PSKBinderEntry
|
||||
SelectedIdentity uint16
|
||||
}
|
||||
|
||||
type preSharedKeyClientInner struct {
|
||||
Identities []PSKIdentity `tls:"head=2,min=7"`
|
||||
Binders []PSKBinderEntry `tls:"head=2,min=33"`
|
||||
}
|
||||
|
||||
type preSharedKeyServerInner struct {
|
||||
SelectedIdentity uint16
|
||||
}
|
||||
|
||||
func (psk PreSharedKeyExtension) Type() ExtensionType {
|
||||
return ExtensionTypePreSharedKey
|
||||
}
|
||||
|
||||
func (psk PreSharedKeyExtension) Marshal() ([]byte, error) {
|
||||
switch psk.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
return syntax.Marshal(preSharedKeyClientInner{
|
||||
Identities: psk.Identities,
|
||||
Binders: psk.Binders,
|
||||
})
|
||||
|
||||
case HandshakeTypeServerHello:
|
||||
if len(psk.Identities) > 0 || len(psk.Binders) > 0 {
|
||||
return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index")
|
||||
}
|
||||
return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity})
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported")
|
||||
}
|
||||
}
|
||||
|
||||
func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) {
|
||||
switch psk.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
var inner preSharedKeyClientInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(inner.Identities) != len(inner.Binders) {
|
||||
return 0, fmt.Errorf("Lengths of identities and binders not equal")
|
||||
}
|
||||
|
||||
psk.Identities = inner.Identities
|
||||
psk.Binders = inner.Binders
|
||||
return read, nil
|
||||
|
||||
case HandshakeTypeServerHello:
|
||||
var inner preSharedKeyServerInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
psk.SelectedIdentity = inner.SelectedIdentity
|
||||
return read, nil
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported")
|
||||
}
|
||||
}
|
||||
|
||||
func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) {
|
||||
for i, localID := range psk.Identities {
|
||||
if bytes.Equal(localID.Identity, id) {
|
||||
return psk.Binders[i].Binder, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode;
|
||||
//
|
||||
// struct {
|
||||
// PskKeyExchangeMode ke_modes<1..255>;
|
||||
// } PskKeyExchangeModes;
|
||||
type PSKKeyExchangeModesExtension struct {
|
||||
KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"`
|
||||
}
|
||||
|
||||
func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType {
|
||||
return ExtensionTypePSKKeyExchangeModes
|
||||
}
|
||||
|
||||
func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(pkem)
|
||||
}
|
||||
|
||||
func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, pkem)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// } EarlyDataIndication;
|
||||
|
||||
type EarlyDataExtension struct{}
|
||||
|
||||
func (ed EarlyDataExtension) Type() ExtensionType {
|
||||
return ExtensionTypeEarlyData
|
||||
}
|
||||
|
||||
func (ed EarlyDataExtension) Marshal() ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// uint32 max_early_data_size;
|
||||
// } TicketEarlyDataInfo;
|
||||
|
||||
type TicketEarlyDataInfoExtension struct {
|
||||
MaxEarlyDataSize uint32
|
||||
}
|
||||
|
||||
func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType {
|
||||
return ExtensionTypeTicketEarlyDataInfo
|
||||
}
|
||||
|
||||
func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(tedi)
|
||||
}
|
||||
|
||||
func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, tedi)
|
||||
}
|
||||
|
||||
// opaque ProtocolName<1..2^8-1>;
|
||||
//
|
||||
// struct {
|
||||
// ProtocolName protocol_name_list<2..2^16-1>
|
||||
// } ProtocolNameList;
|
||||
type ALPNExtension struct {
|
||||
Protocols []string
|
||||
}
|
||||
|
||||
type protocolNameInner struct {
|
||||
Name []byte `tls:"head=1,min=1"`
|
||||
}
|
||||
|
||||
type alpnExtensionInner struct {
|
||||
Protocols []protocolNameInner `tls:"head=2,min=2"`
|
||||
}
|
||||
|
||||
func (alpn ALPNExtension) Type() ExtensionType {
|
||||
return ExtensionTypeALPN
|
||||
}
|
||||
|
||||
func (alpn ALPNExtension) Marshal() ([]byte, error) {
|
||||
protocols := make([]protocolNameInner, len(alpn.Protocols))
|
||||
for i, protocol := range alpn.Protocols {
|
||||
protocols[i] = protocolNameInner{[]byte(protocol)}
|
||||
}
|
||||
return syntax.Marshal(alpnExtensionInner{protocols})
|
||||
}
|
||||
|
||||
func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
|
||||
var inner alpnExtensionInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
alpn.Protocols = make([]string, len(inner.Protocols))
|
||||
for i, protocol := range inner.Protocols {
|
||||
alpn.Protocols[i] = string(protocol.Name)
|
||||
}
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion versions<2..254>;
|
||||
// } SupportedVersions;
|
||||
type SupportedVersionsExtension struct {
|
||||
Versions []uint16 `tls:"head=1,min=2,max=254"`
|
||||
}
|
||||
|
||||
func (sv SupportedVersionsExtension) Type() ExtensionType {
|
||||
return ExtensionTypeSupportedVersions
|
||||
}
|
||||
|
||||
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(sv)
|
||||
}
|
||||
|
||||
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, sv)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque cookie<1..2^16-1>;
|
||||
// } Cookie;
|
||||
type CookieExtension struct {
|
||||
Cookie []byte `tls:"head=2,min=1"`
|
||||
}
|
||||
|
||||
func (c CookieExtension) Type() ExtensionType {
|
||||
return ExtensionTypeCookie
|
||||
}
|
||||
|
||||
func (c CookieExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(c)
|
||||
}
|
||||
|
||||
func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, c)
|
||||
}
|
||||
|
||||
// defaultCookieLength is the default length of a cookie
|
||||
const defaultCookieLength = 32
|
||||
|
||||
type defaultCookieHandler struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
var _ CookieHandler = &defaultCookieHandler{}
|
||||
|
||||
// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data
|
||||
func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) {
|
||||
h.data = make([]byte, defaultCookieLength)
|
||||
if _, err := prng.Read(h.data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h.data, nil
|
||||
}
|
||||
|
||||
func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool {
|
||||
return bytes.Equal(h.data, data)
|
||||
}
|
147
vendor/github.com/bifurcation/mint/ffdhe.go
generated
vendored
Normal file
147
vendor/github.com/bifurcation/mint/ffdhe.go
generated
vendored
Normal file
|
@ -0,0 +1,147 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
var (
|
||||
finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B423861285C97FFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex)
|
||||
finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes)
|
||||
|
||||
finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||
"3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex)
|
||||
finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes)
|
||||
|
||||
finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" +
|
||||
"FFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex)
|
||||
finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes)
|
||||
|
||||
finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
|
||||
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
|
||||
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
|
||||
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
|
||||
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
|
||||
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
|
||||
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
|
||||
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
|
||||
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
|
||||
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
|
||||
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
|
||||
"A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex)
|
||||
finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes)
|
||||
|
||||
finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
|
||||
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
|
||||
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
|
||||
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
|
||||
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
|
||||
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
|
||||
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
|
||||
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
|
||||
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
|
||||
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
|
||||
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
|
||||
"A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" +
|
||||
"1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" +
|
||||
"0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" +
|
||||
"CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" +
|
||||
"2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" +
|
||||
"BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" +
|
||||
"51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" +
|
||||
"D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" +
|
||||
"1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" +
|
||||
"FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" +
|
||||
"97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" +
|
||||
"D68C8BB7C5C6424CFFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex)
|
||||
finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes)
|
||||
)
|
98
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
Normal file
98
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
Normal file
|
@ -0,0 +1,98 @@
|
|||
// Read a generic "framed" packet consisting of a header and a
|
||||
// This is used for both TLS Records and TLS Handshake Messages
|
||||
package mint
|
||||
|
||||
type framing interface {
|
||||
headerLen() int
|
||||
defaultReadLen() int
|
||||
frameLen(hdr []byte) (int, error)
|
||||
}
|
||||
|
||||
const (
|
||||
kFrameReaderHdr = 0
|
||||
kFrameReaderBody = 1
|
||||
)
|
||||
|
||||
type frameNextAction func(f *frameReader) error
|
||||
|
||||
type frameReader struct {
|
||||
details framing
|
||||
state uint8
|
||||
header []byte
|
||||
body []byte
|
||||
working []byte
|
||||
writeOffset int
|
||||
remainder []byte
|
||||
}
|
||||
|
||||
func newFrameReader(d framing) *frameReader {
|
||||
hdr := make([]byte, d.headerLen())
|
||||
return &frameReader{
|
||||
d,
|
||||
kFrameReaderHdr,
|
||||
hdr,
|
||||
nil,
|
||||
hdr,
|
||||
0,
|
||||
nil,
|
||||
}
|
||||
}
|
||||
|
||||
func dup(a []byte) []byte {
|
||||
r := make([]byte, len(a))
|
||||
copy(r, a)
|
||||
return r
|
||||
}
|
||||
|
||||
func (f *frameReader) needed() int {
|
||||
tmp := (len(f.working) - f.writeOffset) - len(f.remainder)
|
||||
if tmp < 0 {
|
||||
return 0
|
||||
}
|
||||
return tmp
|
||||
}
|
||||
|
||||
func (f *frameReader) addChunk(in []byte) {
|
||||
// Append to the buffer.
|
||||
logf(logTypeFrameReader, "Appending %v", len(in))
|
||||
f.remainder = append(f.remainder, in...)
|
||||
}
|
||||
|
||||
func (f *frameReader) process() (hdr []byte, body []byte, err error) {
|
||||
for f.needed() == 0 {
|
||||
logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset)
|
||||
// Fill out our working block
|
||||
copied := copy(f.working[f.writeOffset:], f.remainder)
|
||||
f.remainder = f.remainder[copied:]
|
||||
f.writeOffset += copied
|
||||
if f.writeOffset < len(f.working) {
|
||||
logf(logTypeFrameReader, "Read would have blocked 1")
|
||||
return nil, nil, WouldBlock
|
||||
}
|
||||
// Reset the write offset, because we are now full.
|
||||
f.writeOffset = 0
|
||||
|
||||
// We have read a full frame
|
||||
if f.state == kFrameReaderBody {
|
||||
logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder))
|
||||
f.state = kFrameReaderHdr
|
||||
f.working = f.header
|
||||
return dup(f.header), dup(f.body), nil
|
||||
}
|
||||
|
||||
// We have read the header
|
||||
bodyLen, err := f.details.frameLen(f.header)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen)
|
||||
|
||||
f.body = make([]byte, bodyLen)
|
||||
f.working = f.body
|
||||
f.writeOffset = 0
|
||||
f.state = kFrameReaderBody
|
||||
}
|
||||
|
||||
logf(logTypeFrameReader, "Read would have blocked 2")
|
||||
return nil, nil, WouldBlock
|
||||
}
|
253
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
Normal file
253
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
Normal file
|
@ -0,0 +1,253 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
handshakeHeaderLen = 4 // handshake message header length
|
||||
maxHandshakeMessageLen = 1 << 24 // max handshake message length
|
||||
)
|
||||
|
||||
// struct {
|
||||
// HandshakeType msg_type; /* handshake type */
|
||||
// uint24 length; /* bytes in message */
|
||||
// select (HandshakeType) {
|
||||
// ...
|
||||
// } body;
|
||||
// } Handshake;
|
||||
//
|
||||
// We do the select{...} part in a different layer, so we treat the
|
||||
// actual message body as opaque:
|
||||
//
|
||||
// struct {
|
||||
// HandshakeType msg_type;
|
||||
// opaque msg<0..2^24-1>
|
||||
// } Handshake;
|
||||
//
|
||||
// TODO: File a spec bug
|
||||
type HandshakeMessage struct {
|
||||
// Omitted: length
|
||||
msgType HandshakeType
|
||||
body []byte
|
||||
}
|
||||
|
||||
// Note: This could be done with the `syntax` module, using the simplified
|
||||
// syntax as discussed above. However, since this is so simple, there's not
|
||||
// much benefit to doing so.
|
||||
func (hm *HandshakeMessage) Marshal() []byte {
|
||||
if hm == nil {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
msgLen := len(hm.body)
|
||||
data := make([]byte, 4+len(hm.body))
|
||||
data[0] = byte(hm.msgType)
|
||||
data[1] = byte(msgLen >> 16)
|
||||
data[2] = byte(msgLen >> 8)
|
||||
data[3] = byte(msgLen)
|
||||
copy(data[4:], hm.body)
|
||||
return data
|
||||
}
|
||||
|
||||
func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
|
||||
logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body)
|
||||
|
||||
var body HandshakeMessageBody
|
||||
switch hm.msgType {
|
||||
case HandshakeTypeClientHello:
|
||||
body = new(ClientHelloBody)
|
||||
case HandshakeTypeServerHello:
|
||||
body = new(ServerHelloBody)
|
||||
case HandshakeTypeHelloRetryRequest:
|
||||
body = new(HelloRetryRequestBody)
|
||||
case HandshakeTypeEncryptedExtensions:
|
||||
body = new(EncryptedExtensionsBody)
|
||||
case HandshakeTypeCertificate:
|
||||
body = new(CertificateBody)
|
||||
case HandshakeTypeCertificateRequest:
|
||||
body = new(CertificateRequestBody)
|
||||
case HandshakeTypeCertificateVerify:
|
||||
body = new(CertificateVerifyBody)
|
||||
case HandshakeTypeFinished:
|
||||
body = &FinishedBody{VerifyDataLen: len(hm.body)}
|
||||
case HandshakeTypeNewSessionTicket:
|
||||
body = new(NewSessionTicketBody)
|
||||
case HandshakeTypeKeyUpdate:
|
||||
body = new(KeyUpdateBody)
|
||||
case HandshakeTypeEndOfEarlyData:
|
||||
body = new(EndOfEarlyDataBody)
|
||||
default:
|
||||
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
|
||||
}
|
||||
|
||||
_, err := body.Unmarshal(hm.body)
|
||||
return body, err
|
||||
}
|
||||
|
||||
func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
|
||||
data, err := body.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &HandshakeMessage{
|
||||
msgType: body.Type(),
|
||||
body: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type HandshakeLayer struct {
|
||||
nonblocking bool // Should we operate in nonblocking mode
|
||||
conn *RecordLayer // Used for reading/writing records
|
||||
frame *frameReader // The buffered frame reader
|
||||
}
|
||||
|
||||
type handshakeLayerFrameDetails struct{}
|
||||
|
||||
func (d handshakeLayerFrameDetails) headerLen() int {
|
||||
return handshakeHeaderLen
|
||||
}
|
||||
|
||||
func (d handshakeLayerFrameDetails) defaultReadLen() int {
|
||||
return handshakeHeaderLen + maxFragmentLen
|
||||
}
|
||||
|
||||
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||
logf(logTypeIO, "Header=%x", hdr)
|
||||
return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil
|
||||
}
|
||||
|
||||
func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer {
|
||||
h := HandshakeLayer{}
|
||||
h.conn = r
|
||||
h.frame = newFrameReader(&handshakeLayerFrameDetails{})
|
||||
return &h
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) readRecord() error {
|
||||
logf(logTypeIO, "Trying to read record")
|
||||
pt, err := h.conn.ReadRecord()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pt.contentType != RecordTypeHandshake &&
|
||||
pt.contentType != RecordTypeAlert {
|
||||
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
|
||||
}
|
||||
|
||||
if pt.contentType == RecordTypeAlert {
|
||||
logf(logTypeIO, "read alert %v", pt.fragment[1])
|
||||
if len(pt.fragment) < 2 {
|
||||
h.sendAlert(AlertUnexpectedMessage)
|
||||
return io.EOF
|
||||
}
|
||||
return Alert(pt.fragment[1])
|
||||
}
|
||||
|
||||
logf(logTypeIO, "read handshake record of len %v", len(pt.fragment))
|
||||
h.frame.addChunk(pt.fragment)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendAlert sends a TLS alert message.
|
||||
func (h *HandshakeLayer) sendAlert(err Alert) error {
|
||||
tmp := make([]byte, 2)
|
||||
tmp[0] = AlertLevelError
|
||||
tmp[1] = byte(err)
|
||||
h.conn.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeAlert,
|
||||
fragment: tmp},
|
||||
)
|
||||
|
||||
// closeNotify is a special case in that it isn't an error:
|
||||
if err != AlertCloseNotify {
|
||||
return &net.OpError{Op: "local error", Err: err}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
|
||||
var hdr, body []byte
|
||||
var err error
|
||||
|
||||
for {
|
||||
logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder))
|
||||
if h.frame.needed() > 0 {
|
||||
logf(logTypeHandshake, "Trying to read a new record")
|
||||
err = h.readRecord()
|
||||
}
|
||||
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hdr, body, err = h.frame.process()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "read handshake message")
|
||||
|
||||
hm := &HandshakeMessage{}
|
||||
hm.msgType = HandshakeType(hdr[0])
|
||||
|
||||
hm.body = make([]byte, len(body))
|
||||
copy(hm.body, body)
|
||||
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error {
|
||||
return h.WriteMessages([]*HandshakeMessage{hm})
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error {
|
||||
for _, hm := range hms {
|
||||
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
|
||||
}
|
||||
|
||||
// Write out headers and bodies
|
||||
buffer := []byte{}
|
||||
for _, msg := range hms {
|
||||
msgLen := len(msg.body)
|
||||
if msgLen > maxHandshakeMessageLen {
|
||||
return fmt.Errorf("tls.handshakelayer: Message too large to send")
|
||||
}
|
||||
|
||||
buffer = append(buffer, msg.Marshal()...)
|
||||
}
|
||||
|
||||
// Send full-size fragments
|
||||
var start int
|
||||
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
||||
err := h.conn.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeHandshake,
|
||||
fragment: buffer[start : start+maxFragmentLen],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Send a final partial fragment if necessary
|
||||
if start < len(buffer) {
|
||||
err := h.conn.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeHandshake,
|
||||
fragment: buffer[start:],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
450
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
Normal file
450
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
Normal file
|
@ -0,0 +1,450 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/bifurcation/mint/syntax"
|
||||
)
|
||||
|
||||
type HandshakeMessageBody interface {
|
||||
Type() HandshakeType
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal(data []byte) (int, error)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
|
||||
// Random random;
|
||||
// opaque legacy_session_id<0..32>;
|
||||
// CipherSuite cipher_suites<2..2^16-2>;
|
||||
// opaque legacy_compression_methods<1..2^8-1>;
|
||||
// Extension extensions<0..2^16-1>;
|
||||
// } ClientHello;
|
||||
type ClientHelloBody struct {
|
||||
// Omitted: clientVersion
|
||||
// Omitted: legacySessionID
|
||||
// Omitted: legacyCompressionMethods
|
||||
Random [32]byte
|
||||
CipherSuites []CipherSuite
|
||||
Extensions ExtensionList
|
||||
}
|
||||
|
||||
type clientHelloBodyInner struct {
|
||||
LegacyVersion uint16
|
||||
Random [32]byte
|
||||
LegacySessionID []byte `tls:"head=1,max=32"`
|
||||
CipherSuites []CipherSuite `tls:"head=2,min=2"`
|
||||
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
|
||||
Extensions []Extension `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (ch ClientHelloBody) Type() HandshakeType {
|
||||
return HandshakeTypeClientHello
|
||||
}
|
||||
|
||||
func (ch ClientHelloBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(clientHelloBodyInner{
|
||||
LegacyVersion: 0x0303,
|
||||
Random: ch.Random,
|
||||
LegacySessionID: []byte{},
|
||||
CipherSuites: ch.CipherSuites,
|
||||
LegacyCompressionMethods: []byte{0},
|
||||
Extensions: ch.Extensions,
|
||||
})
|
||||
}
|
||||
|
||||
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
|
||||
var inner clientHelloBodyInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// We are strict about these things because we only support 1.3
|
||||
if inner.LegacyVersion != 0x0303 {
|
||||
return 0, fmt.Errorf("tls.clienthello: Incorrect version number")
|
||||
}
|
||||
|
||||
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
|
||||
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
|
||||
}
|
||||
|
||||
ch.Random = inner.Random
|
||||
ch.CipherSuites = inner.CipherSuites
|
||||
ch.Extensions = inner.Extensions
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// TODO: File a spec bug to clarify this
|
||||
func (ch ClientHelloBody) Truncated() ([]byte, error) {
|
||||
if len(ch.Extensions) == 0 {
|
||||
return nil, fmt.Errorf("tls.clienthello.truncate: No extensions")
|
||||
}
|
||||
|
||||
pskExt := ch.Extensions[len(ch.Extensions)-1]
|
||||
if pskExt.ExtensionType != ExtensionTypePreSharedKey {
|
||||
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK")
|
||||
}
|
||||
|
||||
chm, err := HandshakeMessageFromBody(&ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chData := chm.Marshal()
|
||||
|
||||
psk := PreSharedKeyExtension{
|
||||
HandshakeType: HandshakeTypeClientHello,
|
||||
}
|
||||
_, err = psk.Unmarshal(pskExt.ExtensionData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Marshal just the binders so that we know how much to truncate
|
||||
binders := struct {
|
||||
Binders []PSKBinderEntry `tls:"head=2,min=33"`
|
||||
}{Binders: psk.Binders}
|
||||
binderData, _ := syntax.Marshal(binders)
|
||||
binderLen := len(binderData)
|
||||
|
||||
chLen := len(chData)
|
||||
return chData[:chLen-binderLen], nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion server_version;
|
||||
// CipherSuite cipher_suite;
|
||||
// Extension extensions<2..2^16-1>;
|
||||
// } HelloRetryRequest;
|
||||
type HelloRetryRequestBody struct {
|
||||
Version uint16
|
||||
CipherSuite CipherSuite
|
||||
Extensions ExtensionList `tls:"head=2,min=2"`
|
||||
}
|
||||
|
||||
func (hrr HelloRetryRequestBody) Type() HandshakeType {
|
||||
return HandshakeTypeHelloRetryRequest
|
||||
}
|
||||
|
||||
func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(hrr)
|
||||
}
|
||||
|
||||
func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, hrr)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion version;
|
||||
// Random random;
|
||||
// CipherSuite cipher_suite;
|
||||
// Extension extensions<0..2^16-1>;
|
||||
// } ServerHello;
|
||||
type ServerHelloBody struct {
|
||||
Version uint16
|
||||
Random [32]byte
|
||||
CipherSuite CipherSuite
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (sh ServerHelloBody) Type() HandshakeType {
|
||||
return HandshakeTypeServerHello
|
||||
}
|
||||
|
||||
func (sh ServerHelloBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(sh)
|
||||
}
|
||||
|
||||
func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, sh)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque verify_data[verify_data_length];
|
||||
// } Finished;
|
||||
//
|
||||
// verifyDataLen is not a field in the TLS struct, but we add it here so
|
||||
// that calling code can tell us how much data to expect when we marshal /
|
||||
// unmarshal. (We could add this to the marshal/unmarshal methods, but let's
|
||||
// try to keep the signature consistent for now.)
|
||||
//
|
||||
// For similar reasons, we don't use the `syntax` module here, because this
|
||||
// struct doesn't map well to standard TLS presentation language concepts.
|
||||
//
|
||||
// TODO: File a spec bug
|
||||
type FinishedBody struct {
|
||||
VerifyDataLen int
|
||||
VerifyData []byte
|
||||
}
|
||||
|
||||
func (fin FinishedBody) Type() HandshakeType {
|
||||
return HandshakeTypeFinished
|
||||
}
|
||||
|
||||
func (fin FinishedBody) Marshal() ([]byte, error) {
|
||||
if len(fin.VerifyData) != fin.VerifyDataLen {
|
||||
return nil, fmt.Errorf("tls.finished: data length mismatch")
|
||||
}
|
||||
|
||||
body := make([]byte, len(fin.VerifyData))
|
||||
copy(body, fin.VerifyData)
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (fin *FinishedBody) Unmarshal(data []byte) (int, error) {
|
||||
if len(data) < fin.VerifyDataLen {
|
||||
return 0, fmt.Errorf("tls.finished: Malformed finished; too short")
|
||||
}
|
||||
|
||||
fin.VerifyData = make([]byte, fin.VerifyDataLen)
|
||||
copy(fin.VerifyData, data[:fin.VerifyDataLen])
|
||||
return fin.VerifyDataLen, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// Extension extensions<0..2^16-1>;
|
||||
// } EncryptedExtensions;
|
||||
//
|
||||
// Marshal() and Unmarshal() are handled by ExtensionList
|
||||
type EncryptedExtensionsBody struct {
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (ee EncryptedExtensionsBody) Type() HandshakeType {
|
||||
return HandshakeTypeEncryptedExtensions
|
||||
}
|
||||
|
||||
func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(ee)
|
||||
}
|
||||
|
||||
func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, ee)
|
||||
}
|
||||
|
||||
// opaque ASN1Cert<1..2^24-1>;
|
||||
//
|
||||
// struct {
|
||||
// ASN1Cert cert_data;
|
||||
// Extension extensions<0..2^16-1>
|
||||
// } CertificateEntry;
|
||||
//
|
||||
// struct {
|
||||
// opaque certificate_request_context<0..2^8-1>;
|
||||
// CertificateEntry certificate_list<0..2^24-1>;
|
||||
// } Certificate;
|
||||
type CertificateEntry struct {
|
||||
CertData *x509.Certificate
|
||||
Extensions ExtensionList
|
||||
}
|
||||
|
||||
type CertificateBody struct {
|
||||
CertificateRequestContext []byte
|
||||
CertificateList []CertificateEntry
|
||||
}
|
||||
|
||||
type certificateEntryInner struct {
|
||||
CertData []byte `tls:"head=3,min=1"`
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
type certificateBodyInner struct {
|
||||
CertificateRequestContext []byte `tls:"head=1"`
|
||||
CertificateList []certificateEntryInner `tls:"head=3"`
|
||||
}
|
||||
|
||||
func (c CertificateBody) Type() HandshakeType {
|
||||
return HandshakeTypeCertificate
|
||||
}
|
||||
|
||||
func (c CertificateBody) Marshal() ([]byte, error) {
|
||||
inner := certificateBodyInner{
|
||||
CertificateRequestContext: c.CertificateRequestContext,
|
||||
CertificateList: make([]certificateEntryInner, len(c.CertificateList)),
|
||||
}
|
||||
|
||||
for i, entry := range c.CertificateList {
|
||||
inner.CertificateList[i] = certificateEntryInner{
|
||||
CertData: entry.CertData.Raw,
|
||||
Extensions: entry.Extensions,
|
||||
}
|
||||
}
|
||||
|
||||
return syntax.Marshal(inner)
|
||||
}
|
||||
|
||||
func (c *CertificateBody) Unmarshal(data []byte) (int, error) {
|
||||
inner := certificateBodyInner{}
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return read, err
|
||||
}
|
||||
|
||||
c.CertificateRequestContext = inner.CertificateRequestContext
|
||||
c.CertificateList = make([]CertificateEntry, len(inner.CertificateList))
|
||||
|
||||
for i, entry := range inner.CertificateList {
|
||||
c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err)
|
||||
}
|
||||
|
||||
c.CertificateList[i].Extensions = entry.Extensions
|
||||
}
|
||||
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// SignatureScheme algorithm;
|
||||
// opaque signature<0..2^16-1>;
|
||||
// } CertificateVerify;
|
||||
type CertificateVerifyBody struct {
|
||||
Algorithm SignatureScheme
|
||||
Signature []byte `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (cv CertificateVerifyBody) Type() HandshakeType {
|
||||
return HandshakeTypeCertificateVerify
|
||||
}
|
||||
|
||||
func (cv CertificateVerifyBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(cv)
|
||||
}
|
||||
|
||||
func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, cv)
|
||||
}
|
||||
|
||||
func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte {
|
||||
// TODO: Change context for client auth
|
||||
// TODO: Put this in a const
|
||||
const context = "TLS 1.3, server CertificateVerify"
|
||||
sigInput := bytes.Repeat([]byte{0x20}, 64)
|
||||
sigInput = append(sigInput, []byte(context)...)
|
||||
sigInput = append(sigInput, []byte{0}...)
|
||||
sigInput = append(sigInput, data...)
|
||||
return sigInput
|
||||
}
|
||||
|
||||
func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) {
|
||||
sigInput := cv.EncodeSignatureInput(handshakeHash)
|
||||
cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput)
|
||||
logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
|
||||
return
|
||||
}
|
||||
|
||||
func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error {
|
||||
sigInput := cv.EncodeSignatureInput(handshakeHash)
|
||||
logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
|
||||
return verify(cv.Algorithm, publicKey, sigInput, cv.Signature)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque certificate_request_context<0..2^8-1>;
|
||||
// Extension extensions<2..2^16-1>;
|
||||
// } CertificateRequest;
|
||||
type CertificateRequestBody struct {
|
||||
CertificateRequestContext []byte `tls:"head=1"`
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (cr CertificateRequestBody) Type() HandshakeType {
|
||||
return HandshakeTypeCertificateRequest
|
||||
}
|
||||
|
||||
func (cr CertificateRequestBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(cr)
|
||||
}
|
||||
|
||||
func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, cr)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// uint32 ticket_lifetime;
|
||||
// uint32 ticket_age_add;
|
||||
// opaque ticket_nonce<1..255>;
|
||||
// opaque ticket<1..2^16-1>;
|
||||
// Extension extensions<0..2^16-2>;
|
||||
// } NewSessionTicket;
|
||||
type NewSessionTicketBody struct {
|
||||
TicketLifetime uint32
|
||||
TicketAgeAdd uint32
|
||||
TicketNonce []byte `tls:"head=1,min=1"`
|
||||
Ticket []byte `tls:"head=2,min=1"`
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
const ticketNonceLen = 16
|
||||
|
||||
func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) {
|
||||
buf := make([]byte, 4+ticketNonceLen+ticketLen)
|
||||
_, err := prng.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tkt := &NewSessionTicketBody{
|
||||
TicketLifetime: ticketLifetime,
|
||||
TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]),
|
||||
TicketNonce: buf[4 : 4+ticketNonceLen],
|
||||
Ticket: buf[4+ticketNonceLen:],
|
||||
}
|
||||
|
||||
return tkt, err
|
||||
}
|
||||
|
||||
func (tkt NewSessionTicketBody) Type() HandshakeType {
|
||||
return HandshakeTypeNewSessionTicket
|
||||
}
|
||||
|
||||
func (tkt NewSessionTicketBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(tkt)
|
||||
}
|
||||
|
||||
func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, tkt)
|
||||
}
|
||||
|
||||
// enum {
|
||||
// update_not_requested(0), update_requested(1), (255)
|
||||
// } KeyUpdateRequest;
|
||||
//
|
||||
// struct {
|
||||
// KeyUpdateRequest request_update;
|
||||
// } KeyUpdate;
|
||||
type KeyUpdateBody struct {
|
||||
KeyUpdateRequest KeyUpdateRequest
|
||||
}
|
||||
|
||||
func (ku KeyUpdateBody) Type() HandshakeType {
|
||||
return HandshakeTypeKeyUpdate
|
||||
}
|
||||
|
||||
func (ku KeyUpdateBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(ku)
|
||||
}
|
||||
|
||||
func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, ku)
|
||||
}
|
||||
|
||||
// struct {} EndOfEarlyData;
|
||||
type EndOfEarlyDataBody struct{}
|
||||
|
||||
func (eoed EndOfEarlyDataBody) Type() HandshakeType {
|
||||
return HandshakeTypeEndOfEarlyData
|
||||
}
|
||||
|
||||
func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
55
vendor/github.com/bifurcation/mint/log.go
generated
vendored
Normal file
55
vendor/github.com/bifurcation/mint/log.go
generated
vendored
Normal file
|
@ -0,0 +1,55 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// We use this environment variable to control logging. It should be a
|
||||
// comma-separated list of log tags (see below) or "*" to enable all logging.
|
||||
const logConfigVar = "MINT_LOG"
|
||||
|
||||
// Pre-defined log types
|
||||
const (
|
||||
logTypeCrypto = "crypto"
|
||||
logTypeHandshake = "handshake"
|
||||
logTypeNegotiation = "negotiation"
|
||||
logTypeIO = "io"
|
||||
logTypeFrameReader = "frame"
|
||||
logTypeVerbose = "verbose"
|
||||
)
|
||||
|
||||
var (
|
||||
logFunction = log.Printf
|
||||
logAll = false
|
||||
logSettings = map[string]bool{}
|
||||
)
|
||||
|
||||
func init() {
|
||||
parseLogEnv(os.Environ())
|
||||
}
|
||||
|
||||
func parseLogEnv(env []string) {
|
||||
for _, stmt := range env {
|
||||
if strings.HasPrefix(stmt, logConfigVar+"=") {
|
||||
val := stmt[len(logConfigVar)+1:]
|
||||
|
||||
if val == "*" {
|
||||
logAll = true
|
||||
} else {
|
||||
for _, t := range strings.Split(val, ",") {
|
||||
logSettings[t] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func logf(tag string, format string, args ...interface{}) {
|
||||
if logAll || logSettings[tag] {
|
||||
fullFormat := fmt.Sprintf("[%s] %s", tag, format)
|
||||
logFunction(fullFormat, args...)
|
||||
}
|
||||
}
|
217
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
Normal file
217
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
Normal file
|
@ -0,0 +1,217 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func VersionNegotiation(offered, supported []uint16) (bool, uint16) {
|
||||
for _, offeredVersion := range offered {
|
||||
for _, supportedVersion := range supported {
|
||||
logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion)
|
||||
if offeredVersion == supportedVersion {
|
||||
// XXX: Should probably be highest supported version, but for now, we
|
||||
// only support one version, so it doesn't really matter.
|
||||
return true, offeredVersion
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, 0
|
||||
}
|
||||
|
||||
func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) {
|
||||
for _, share := range keyShares {
|
||||
for _, group := range groups {
|
||||
if group != share.Group {
|
||||
continue
|
||||
}
|
||||
|
||||
pub, priv, err := newKeyShare(share.Group)
|
||||
if err != nil {
|
||||
// If we encounter an error, just keep looking
|
||||
continue
|
||||
}
|
||||
|
||||
dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv)
|
||||
if err != nil {
|
||||
// If we encounter an error, just keep looking
|
||||
continue
|
||||
}
|
||||
|
||||
return true, group, pub, dhSecret
|
||||
}
|
||||
}
|
||||
|
||||
return false, 0, nil, nil
|
||||
}
|
||||
|
||||
const (
|
||||
ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds
|
||||
)
|
||||
|
||||
func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) {
|
||||
logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size())
|
||||
for i, id := range identities {
|
||||
identityHex := hex.EncodeToString(id.Identity)
|
||||
|
||||
psk, ok := psks.Get(identityHex)
|
||||
if !ok {
|
||||
logf(logTypeNegotiation, "No PSK for identity %x", identityHex)
|
||||
continue
|
||||
}
|
||||
|
||||
// For resumption, make sure the ticket age is correct
|
||||
if psk.IsResumption {
|
||||
extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd
|
||||
knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond)
|
||||
ticketAgeDelta := knownTicketAge - extTicketAge
|
||||
if knownTicketAge < extTicketAge {
|
||||
ticketAgeDelta = extTicketAge - knownTicketAge
|
||||
}
|
||||
if ticketAgeDelta > ticketAgeTolerance {
|
||||
logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity)
|
||||
logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]",
|
||||
extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance)
|
||||
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity)
|
||||
}
|
||||
}
|
||||
|
||||
params, ok := cipherSuiteMap[psk.CipherSuite]
|
||||
if !ok {
|
||||
err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite)
|
||||
return false, 0, nil, CipherSuiteParams{}, err
|
||||
}
|
||||
|
||||
// Compute binder
|
||||
binderLabel := labelExternalBinder
|
||||
if psk.IsResumption {
|
||||
binderLabel = labelResumptionBinder
|
||||
}
|
||||
|
||||
h0 := params.Hash.New().Sum(nil)
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
earlySecret := HkdfExtract(params.Hash, zero, psk.Key)
|
||||
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
|
||||
|
||||
// context = ClientHello[truncated]
|
||||
// context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated]
|
||||
ctxHash := params.Hash.New()
|
||||
ctxHash.Write(context)
|
||||
|
||||
binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil))
|
||||
if !bytes.Equal(binder, binders[i].Binder) {
|
||||
logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder)
|
||||
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity)
|
||||
}
|
||||
|
||||
logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity)
|
||||
return true, i, &psk, params, nil
|
||||
}
|
||||
|
||||
logf(logTypeNegotiation, "Failed to find a usable PSK")
|
||||
return false, 0, nil, CipherSuiteParams{}, nil
|
||||
}
|
||||
|
||||
func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) {
|
||||
logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes)
|
||||
dhAllowed := false
|
||||
dhRequired := true
|
||||
for _, mode := range modes {
|
||||
dhAllowed = dhAllowed || (mode == PSKModeDHEKE)
|
||||
dhRequired = dhRequired && (mode == PSKModeDHEKE)
|
||||
}
|
||||
|
||||
// Use PSK if we can meet DH requirement and modes were provided
|
||||
usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0)
|
||||
|
||||
// Use DH if allowed
|
||||
usingDH := canDoDH && (dhAllowed || !usingPSK)
|
||||
|
||||
logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK)
|
||||
return usingDH, usingPSK
|
||||
}
|
||||
|
||||
func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) {
|
||||
// Select for server name if provided
|
||||
candidates := certs
|
||||
if serverName != nil {
|
||||
candidatesByName := []*Certificate{}
|
||||
for _, cert := range certs {
|
||||
for _, name := range cert.Chain[0].DNSNames {
|
||||
if len(*serverName) > 0 && name == *serverName {
|
||||
candidatesByName = append(candidatesByName, cert)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(candidatesByName) == 0 {
|
||||
return nil, 0, fmt.Errorf("No certificates available for server name")
|
||||
}
|
||||
|
||||
candidates = candidatesByName
|
||||
}
|
||||
|
||||
// Select for signature scheme
|
||||
for _, cert := range candidates {
|
||||
for _, scheme := range signatureSchemes {
|
||||
if !schemeValidForKey(scheme, cert.PrivateKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
return cert, scheme, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
|
||||
}
|
||||
|
||||
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool {
|
||||
usingEarlyData := gotEarlyData && usingPSK && allowEarlyData
|
||||
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData)
|
||||
return usingEarlyData
|
||||
}
|
||||
|
||||
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {
|
||||
for _, s1 := range offered {
|
||||
if psk != nil {
|
||||
if s1 == psk.CipherSuite {
|
||||
return s1, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for _, s2 := range supported {
|
||||
if s1 == s2 {
|
||||
return s1, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil)
|
||||
}
|
||||
|
||||
func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) {
|
||||
for _, p1 := range offered {
|
||||
if psk != nil {
|
||||
if p1 != psk.NextProto {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, p2 := range supported {
|
||||
if p1 == p2 {
|
||||
return p1, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the client offers ALPN on resumption, it must match the earlier one
|
||||
var err error
|
||||
if psk != nil && psk.IsResumption && (len(offered) > 0) {
|
||||
err = fmt.Errorf("ALPN for PSK not provided")
|
||||
}
|
||||
return "", err
|
||||
}
|
296
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
Normal file
296
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
Normal file
|
@ -0,0 +1,296 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
sequenceNumberLen = 8 // sequence number length
|
||||
recordHeaderLen = 5 // record header length
|
||||
maxFragmentLen = 1 << 14 // max number of bytes in a record
|
||||
)
|
||||
|
||||
type DecryptError string
|
||||
|
||||
func (err DecryptError) Error() string {
|
||||
return string(err)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ContentType type;
|
||||
// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */
|
||||
// uint16 length;
|
||||
// opaque fragment[TLSPlaintext.length];
|
||||
// } TLSPlaintext;
|
||||
type TLSPlaintext struct {
|
||||
// Omitted: record_version (static)
|
||||
// Omitted: length (computed from fragment)
|
||||
contentType RecordType
|
||||
fragment []byte
|
||||
}
|
||||
|
||||
type RecordLayer struct {
|
||||
sync.Mutex
|
||||
|
||||
conn io.ReadWriter // The underlying connection
|
||||
frame *frameReader // The buffered frame reader
|
||||
nextData []byte // The next record to send
|
||||
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
|
||||
cachedError error // Error on the last record read
|
||||
|
||||
ivLength int // Length of the seq and nonce fields
|
||||
seq []byte // Zero-padded sequence number
|
||||
nonce []byte // Buffer for per-record nonces
|
||||
cipher cipher.AEAD // AEAD cipher
|
||||
}
|
||||
|
||||
type recordLayerFrameDetails struct{}
|
||||
|
||||
func (d recordLayerFrameDetails) headerLen() int {
|
||||
return recordHeaderLen
|
||||
}
|
||||
|
||||
func (d recordLayerFrameDetails) defaultReadLen() int {
|
||||
return recordHeaderLen + maxFragmentLen
|
||||
}
|
||||
|
||||
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||
return (int(hdr[3]) << 8) | int(hdr[4]), nil
|
||||
}
|
||||
|
||||
func NewRecordLayer(conn io.ReadWriter) *RecordLayer {
|
||||
r := RecordLayer{}
|
||||
r.conn = conn
|
||||
r.frame = newFrameReader(recordLayerFrameDetails{})
|
||||
r.ivLength = 0
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error {
|
||||
var err error
|
||||
r.cipher, err = cipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.ivLength = len(iv)
|
||||
r.seq = bytes.Repeat([]byte{0}, r.ivLength)
|
||||
r.nonce = make([]byte, r.ivLength)
|
||||
copy(r.nonce, iv)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) incrementSequenceNumber() {
|
||||
if r.ivLength == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- {
|
||||
r.seq[i]++
|
||||
r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i]
|
||||
if r.seq[i] != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Not allowed to let sequence number wrap.
|
||||
// Instead, must renegotiate before it does.
|
||||
// Not likely enough to bother.
|
||||
panic("TLS: sequence number wraparound")
|
||||
}
|
||||
|
||||
func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext {
|
||||
// Expand the fragment to hold contentType, padding, and overhead
|
||||
originalLen := len(pt.fragment)
|
||||
plaintextLen := originalLen + 1 + padLen
|
||||
ciphertextLen := plaintextLen + r.cipher.Overhead()
|
||||
|
||||
// Assemble the revised plaintext
|
||||
out := &TLSPlaintext{
|
||||
contentType: RecordTypeApplicationData,
|
||||
fragment: make([]byte, ciphertextLen),
|
||||
}
|
||||
copy(out.fragment, pt.fragment)
|
||||
out.fragment[originalLen] = byte(pt.contentType)
|
||||
for i := 1; i <= padLen; i++ {
|
||||
out.fragment[originalLen+i] = 0
|
||||
}
|
||||
|
||||
// Encrypt the fragment
|
||||
payload := out.fragment[:plaintextLen]
|
||||
r.cipher.Seal(payload[:0], r.nonce, payload, nil)
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) {
|
||||
if len(pt.fragment) < r.cipher.Overhead() {
|
||||
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead())
|
||||
return nil, 0, DecryptError(msg)
|
||||
}
|
||||
|
||||
decryptLen := len(pt.fragment) - r.cipher.Overhead()
|
||||
out := &TLSPlaintext{
|
||||
contentType: pt.contentType,
|
||||
fragment: make([]byte, decryptLen),
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
_, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil)
|
||||
if err != nil {
|
||||
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
|
||||
}
|
||||
|
||||
// Find the padding boundary
|
||||
padLen := 0
|
||||
for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ {
|
||||
}
|
||||
|
||||
// Transfer the content type
|
||||
newLen := decryptLen - padLen - 1
|
||||
out.contentType = RecordType(out.fragment[newLen])
|
||||
|
||||
// Truncate the message to remove contentType, padding, overhead
|
||||
out.fragment = out.fragment[:newLen]
|
||||
return out, padLen, nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
|
||||
var pt *TLSPlaintext
|
||||
var err error
|
||||
|
||||
for {
|
||||
pt, err = r.nextRecord()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if !block || err != WouldBlock {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return pt.contentType, nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
|
||||
pt, err := r.nextRecord()
|
||||
|
||||
// Consume the cached record if there was one
|
||||
r.cachedRecord = nil
|
||||
r.cachedError = nil
|
||||
|
||||
return pt, err
|
||||
}
|
||||
|
||||
func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
||||
if r.cachedRecord != nil {
|
||||
logf(logTypeIO, "Returning cached record")
|
||||
return r.cachedRecord, r.cachedError
|
||||
}
|
||||
|
||||
// Loop until one of three things happens:
|
||||
//
|
||||
// 1. We get a frame
|
||||
// 2. We try to read off the socket and get nothing, in which case
|
||||
// return WouldBlock
|
||||
// 3. We get an error.
|
||||
err := WouldBlock
|
||||
var header, body []byte
|
||||
|
||||
for err != nil {
|
||||
if r.frame.needed() > 0 {
|
||||
buf := make([]byte, recordHeaderLen+maxFragmentLen)
|
||||
n, err := r.conn.Read(buf)
|
||||
if err != nil {
|
||||
logf(logTypeIO, "Error reading, %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return nil, WouldBlock
|
||||
}
|
||||
|
||||
logf(logTypeIO, "Read %v bytes", n)
|
||||
|
||||
buf = buf[:n]
|
||||
r.frame.addChunk(buf)
|
||||
}
|
||||
|
||||
header, body, err = r.frame.process()
|
||||
// Loop around on WouldBlock to see if some
|
||||
// data is now available.
|
||||
if err != nil && err != WouldBlock {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pt := &TLSPlaintext{}
|
||||
// Validate content type
|
||||
switch RecordType(header[0]) {
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
|
||||
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData:
|
||||
pt.contentType = RecordType(header[0])
|
||||
}
|
||||
|
||||
// Validate version
|
||||
if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) {
|
||||
return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2])
|
||||
}
|
||||
|
||||
// Validate size < max
|
||||
size := (int(header[3]) << 8) + int(header[4])
|
||||
if size > maxFragmentLen+256 {
|
||||
return nil, fmt.Errorf("tls.record: Ciphertext size too big")
|
||||
}
|
||||
|
||||
pt.fragment = make([]byte, size)
|
||||
copy(pt.fragment, body)
|
||||
|
||||
// Attempt to decrypt fragment
|
||||
if r.cipher != nil {
|
||||
pt, _, err = r.decrypt(pt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Check that plaintext length is not too long
|
||||
if len(pt.fragment) > maxFragmentLen {
|
||||
return nil, fmt.Errorf("tls.record: Plaintext size too big")
|
||||
}
|
||||
|
||||
logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment)
|
||||
|
||||
r.cachedRecord = pt
|
||||
r.incrementSequenceNumber()
|
||||
return pt, nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
|
||||
return r.WriteRecordWithPadding(pt, 0)
|
||||
}
|
||||
|
||||
func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
|
||||
if r.cipher != nil {
|
||||
pt = r.encrypt(pt, padLen)
|
||||
} else if padLen > 0 {
|
||||
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
|
||||
}
|
||||
|
||||
if len(pt.fragment) > maxFragmentLen {
|
||||
return fmt.Errorf("tls.record: Record size too big")
|
||||
}
|
||||
|
||||
length := len(pt.fragment)
|
||||
header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)}
|
||||
record := append(header, pt.fragment...)
|
||||
|
||||
logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment)
|
||||
|
||||
r.incrementSequenceNumber()
|
||||
_, err := r.conn.Write(record)
|
||||
return err
|
||||
}
|
898
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
Normal file
898
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
Normal file
|
@ -0,0 +1,898 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"hash"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Server State Machine
|
||||
//
|
||||
// START <-----+
|
||||
// Recv ClientHello | | Send HelloRetryRequest
|
||||
// v |
|
||||
// RECVD_CH ----+
|
||||
// | Select parameters
|
||||
// | Send ServerHello
|
||||
// v
|
||||
// NEGOTIATED
|
||||
// | Send EncryptedExtensions
|
||||
// | [Send CertificateRequest]
|
||||
// Can send | [Send Certificate + CertificateVerify]
|
||||
// app data --> | Send Finished
|
||||
// after +--------+--------+
|
||||
// here No 0-RTT | | 0-RTT
|
||||
// | v
|
||||
// | WAIT_EOED <---+
|
||||
// | Recv | | | Recv
|
||||
// | EndOfEarlyData | | | early data
|
||||
// | | +-----+
|
||||
// +> WAIT_FLIGHT2 <-+
|
||||
// |
|
||||
// +--------+--------+
|
||||
// No auth | | Client auth
|
||||
// | |
|
||||
// | v
|
||||
// | WAIT_CERT
|
||||
// | Recv | | Recv Certificate
|
||||
// | empty | v
|
||||
// | Certificate | WAIT_CV
|
||||
// | | | Recv
|
||||
// | v | CertificateVerify
|
||||
// +-> WAIT_FINISHED <---+
|
||||
// | Recv Finished
|
||||
// v
|
||||
// CONNECTED
|
||||
//
|
||||
// NB: Not using state RECVD_CH
|
||||
//
|
||||
// State Instructions
|
||||
// START {}
|
||||
// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)]
|
||||
// WAIT_EOED RekeyIn;
|
||||
// WAIT_FLIGHT2 {}
|
||||
// WAIT_CERT_CR {}
|
||||
// WAIT_CERT {}
|
||||
// WAIT_CV {}
|
||||
// WAIT_FINISHED RekeyIn; RekeyOut;
|
||||
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
||||
|
||||
type ServerStateStart struct {
|
||||
Caps Capabilities
|
||||
conn *Conn
|
||||
|
||||
cookieSent bool
|
||||
firstClientHello *HandshakeMessage
|
||||
helloRetryRequest *HandshakeMessage
|
||||
}
|
||||
|
||||
func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeClientHello {
|
||||
logf(logTypeHandshake, "[ServerStateStart] unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
ch := &ClientHelloBody{}
|
||||
_, err := ch.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
clientHello := hm
|
||||
connParams := ConnectionParameters{}
|
||||
|
||||
supportedVersions := new(SupportedVersionsExtension)
|
||||
serverName := new(ServerNameExtension)
|
||||
supportedGroups := new(SupportedGroupsExtension)
|
||||
signatureAlgorithms := new(SignatureAlgorithmsExtension)
|
||||
clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello}
|
||||
clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello}
|
||||
clientEarlyData := &EarlyDataExtension{}
|
||||
clientALPN := new(ALPNExtension)
|
||||
clientPSKModes := new(PSKKeyExchangeModesExtension)
|
||||
clientCookie := new(CookieExtension)
|
||||
|
||||
// Handle external extensions.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
gotSupportedVersions := ch.Extensions.Find(supportedVersions)
|
||||
gotServerName := ch.Extensions.Find(serverName)
|
||||
gotSupportedGroups := ch.Extensions.Find(supportedGroups)
|
||||
gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms)
|
||||
gotEarlyData := ch.Extensions.Find(clientEarlyData)
|
||||
ch.Extensions.Find(clientKeyShares)
|
||||
ch.Extensions.Find(clientPSK)
|
||||
ch.Extensions.Find(clientALPN)
|
||||
ch.Extensions.Find(clientPSKModes)
|
||||
ch.Extensions.Find(clientCookie)
|
||||
|
||||
if gotServerName {
|
||||
connParams.ServerName = string(*serverName)
|
||||
}
|
||||
|
||||
// If the client didn't send supportedVersions or doesn't support 1.3,
|
||||
// then we're done here.
|
||||
if !gotSupportedVersions {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions")
|
||||
return nil, nil, AlertProtocolVersion
|
||||
}
|
||||
versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion})
|
||||
if !versionOK {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version")
|
||||
return nil, nil, AlertProtocolVersion
|
||||
}
|
||||
|
||||
if state.Caps.RequireCookie && state.cookieSent && !state.Caps.CookieHandler.Validate(state.conn, clientCookie.Cookie) {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch")
|
||||
return nil, nil, AlertAccessDenied
|
||||
}
|
||||
|
||||
// Figure out if we can do DH
|
||||
canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups)
|
||||
|
||||
// Figure out if we can do PSK
|
||||
canDoPSK := false
|
||||
var selectedPSK int
|
||||
var psk *PreSharedKey
|
||||
var params CipherSuiteParams
|
||||
if len(clientPSK.Identities) > 0 {
|
||||
contextBase := []byte{}
|
||||
if state.helloRetryRequest != nil {
|
||||
chBytes := state.firstClientHello.Marshal()
|
||||
hrrBytes := state.helloRetryRequest.Marshal()
|
||||
contextBase = append(chBytes, hrrBytes...)
|
||||
}
|
||||
|
||||
chTrunc, err := ch.Truncated()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
context := append(contextBase, chTrunc...)
|
||||
|
||||
canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Figure out if we actually should do DH / PSK
|
||||
connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes)
|
||||
|
||||
// Select a ciphersuite
|
||||
connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Send a cookie if required
|
||||
// NB: Need to do this here because it's after ciphersuite selection, which
|
||||
// has to be after PSK selection.
|
||||
// XXX: Doing this statefully for now, could be stateless
|
||||
var cookieData []byte
|
||||
if state.Caps.RequireCookie && !state.cookieSent {
|
||||
var err error
|
||||
cookieData, err = state.Caps.CookieHandler.Generate(state.conn)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
if cookieData != nil {
|
||||
// Ignoring errors because everything here is newly constructed, so there
|
||||
// shouldn't be marshal errors
|
||||
hrr := &HelloRetryRequestBody{
|
||||
Version: supportedVersion,
|
||||
CipherSuite: connParams.CipherSuite,
|
||||
}
|
||||
hrr.Extensions.Add(&CookieExtension{Cookie: cookieData})
|
||||
|
||||
// Run the external extension handler.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
helloRetryRequest, err := HandshakeMessageFromBody(hrr)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
params := cipherSuiteMap[connParams.CipherSuite]
|
||||
h := params.Hash.New()
|
||||
h.Write(clientHello.Marshal())
|
||||
firstClientHello := &HandshakeMessage{
|
||||
msgType: HandshakeTypeMessageHash,
|
||||
body: h.Sum(nil),
|
||||
}
|
||||
|
||||
nextState := ServerStateStart{
|
||||
Caps: state.Caps,
|
||||
conn: state.conn,
|
||||
cookieSent: true,
|
||||
firstClientHello: firstClientHello,
|
||||
helloRetryRequest: helloRetryRequest,
|
||||
}
|
||||
toSend := []HandshakeAction{SendHandshakeMessage{helloRetryRequest}}
|
||||
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]")
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
// If we've got no entropy to make keys from, fail
|
||||
if !connParams.UsingDH && !connParams.UsingPSK {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated")
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
var pskSecret []byte
|
||||
var cert *Certificate
|
||||
var certScheme SignatureScheme
|
||||
if connParams.UsingPSK {
|
||||
pskSecret = psk.Key
|
||||
} else {
|
||||
psk = nil
|
||||
|
||||
// If we're not using a PSK mode, then we need to have certain extensions
|
||||
if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)",
|
||||
gotServerName, gotSupportedGroups, gotSignatureAlgorithms)
|
||||
return nil, nil, AlertMissingExtension
|
||||
}
|
||||
|
||||
// Select a certificate
|
||||
name := string(*serverName)
|
||||
var err error
|
||||
cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err)
|
||||
return nil, nil, AlertAccessDenied
|
||||
}
|
||||
}
|
||||
|
||||
if !connParams.UsingDH {
|
||||
dhSecret = nil
|
||||
}
|
||||
|
||||
// Figure out if we're going to do early data
|
||||
var clientEarlyTrafficSecret []byte
|
||||
connParams.ClientSendingEarlyData = gotEarlyData
|
||||
connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData)
|
||||
if connParams.UsingEarlyData {
|
||||
|
||||
h := params.Hash.New()
|
||||
h.Write(clientHello.Marshal())
|
||||
chHash := h.Sum(nil)
|
||||
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
earlySecret := HkdfExtract(params.Hash, zero, pskSecret)
|
||||
clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
|
||||
}
|
||||
|
||||
// Select a next protocol
|
||||
connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err)
|
||||
return nil, nil, AlertNoApplicationProtocol
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]")
|
||||
return ServerStateNegotiated{
|
||||
Caps: state.Caps,
|
||||
Params: connParams,
|
||||
|
||||
dhGroup: dhGroup,
|
||||
dhPublic: dhPublic,
|
||||
dhSecret: dhSecret,
|
||||
pskSecret: pskSecret,
|
||||
selectedPSK: selectedPSK,
|
||||
cert: cert,
|
||||
certScheme: certScheme,
|
||||
clientEarlyTrafficSecret: clientEarlyTrafficSecret,
|
||||
|
||||
firstClientHello: state.firstClientHello,
|
||||
helloRetryRequest: state.helloRetryRequest,
|
||||
clientHello: clientHello,
|
||||
}.Next(nil)
|
||||
}
|
||||
|
||||
type ServerStateNegotiated struct {
|
||||
Caps Capabilities
|
||||
Params ConnectionParameters
|
||||
|
||||
dhGroup NamedGroup
|
||||
dhPublic []byte
|
||||
dhSecret []byte
|
||||
pskSecret []byte
|
||||
clientEarlyTrafficSecret []byte
|
||||
selectedPSK int
|
||||
cert *Certificate
|
||||
certScheme SignatureScheme
|
||||
|
||||
firstClientHello *HandshakeMessage
|
||||
helloRetryRequest *HandshakeMessage
|
||||
clientHello *HandshakeMessage
|
||||
}
|
||||
|
||||
func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
// Create the ServerHello
|
||||
sh := &ServerHelloBody{
|
||||
Version: supportedVersion,
|
||||
CipherSuite: state.Params.CipherSuite,
|
||||
}
|
||||
_, err := prng.Read(sh.Random[:])
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
if state.Params.UsingDH {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension")
|
||||
err = sh.Extensions.Add(&KeyShareExtension{
|
||||
HandshakeType: HandshakeTypeServerHello,
|
||||
Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}},
|
||||
})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
if state.Params.UsingPSK {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension")
|
||||
err = sh.Extensions.Add(&PreSharedKeyExtension{
|
||||
HandshakeType: HandshakeTypeServerHello,
|
||||
SelectedIdentity: uint16(state.selectedPSK),
|
||||
})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Run the external extension handler.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
serverHello, err := HandshakeMessageFromBody(sh)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
// Look up crypto params
|
||||
params, ok := cipherSuiteMap[sh.CipherSuite]
|
||||
if !ok {
|
||||
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Start up the handshake hash
|
||||
handshakeHash := params.Hash.New()
|
||||
handshakeHash.Write(state.firstClientHello.Marshal())
|
||||
handshakeHash.Write(state.helloRetryRequest.Marshal())
|
||||
handshakeHash.Write(state.clientHello.Marshal())
|
||||
handshakeHash.Write(serverHello.Marshal())
|
||||
|
||||
// Compute handshake secrets
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
|
||||
var earlySecret []byte
|
||||
if state.Params.UsingPSK {
|
||||
earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret)
|
||||
} else {
|
||||
earlySecret = HkdfExtract(params.Hash, zero, zero)
|
||||
}
|
||||
|
||||
if state.dhSecret == nil {
|
||||
state.dhSecret = zero
|
||||
}
|
||||
|
||||
h0 := params.Hash.New().Sum(nil)
|
||||
h2 := handshakeHash.Sum(nil)
|
||||
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
|
||||
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret)
|
||||
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
|
||||
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
|
||||
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
|
||||
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
|
||||
|
||||
logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret)
|
||||
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
|
||||
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
|
||||
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
|
||||
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
|
||||
|
||||
clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret)
|
||||
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
|
||||
|
||||
// Send an EncryptedExtensions message (even if it's empty)
|
||||
eeList := ExtensionList{}
|
||||
if state.Params.NextProto != "" {
|
||||
logf(logTypeHandshake, "[server] sending ALPN extension")
|
||||
err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
if state.Params.UsingEarlyData {
|
||||
logf(logTypeHandshake, "[server] sending EDI extension")
|
||||
err = eeList.Add(&EarlyDataExtension{})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
ee := &EncryptedExtensionsBody{eeList}
|
||||
|
||||
// Run the external extension handler.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
eem, err := HandshakeMessageFromBody(ee)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
handshakeHash.Write(eem.Marshal())
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
SendHandshakeMessage{serverHello},
|
||||
RekeyOut{Label: "handshake", KeySet: serverHandshakeKeys},
|
||||
SendHandshakeMessage{eem},
|
||||
}
|
||||
|
||||
// Authenticate with a certificate if required
|
||||
if !state.Params.UsingPSK {
|
||||
// Send a CertificateRequest message if we want client auth
|
||||
if state.Caps.RequireClientAuth {
|
||||
state.Params.UsingClientAuth = true
|
||||
|
||||
// XXX: We don't support sending any constraints besides a list of
|
||||
// supported signature algorithms
|
||||
cr := &CertificateRequestBody{}
|
||||
schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
|
||||
err := cr.Extensions.Add(schemes)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
crm, err := HandshakeMessageFromBody(cr)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
//TODO state.state.serverCertificateRequest = cr
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{crm})
|
||||
handshakeHash.Write(crm.Marshal())
|
||||
}
|
||||
|
||||
// Create and send Certificate, CertificateVerify
|
||||
certificate := &CertificateBody{
|
||||
CertificateList: make([]CertificateEntry, len(state.cert.Chain)),
|
||||
}
|
||||
for i, entry := range state.cert.Chain {
|
||||
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
|
||||
}
|
||||
certm, err := HandshakeMessageFromBody(certificate)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||
handshakeHash.Write(certm.Marshal())
|
||||
|
||||
certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme}
|
||||
logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash)
|
||||
|
||||
hcv := handshakeHash.Sum(nil)
|
||||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||
|
||||
err = certificateVerify.Sign(state.cert.PrivateKey, hcv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
certvm, err := HandshakeMessageFromBody(certificateVerify)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certvm})
|
||||
handshakeHash.Write(certvm.Marshal())
|
||||
}
|
||||
|
||||
// Compute secrets resulting from the server's first flight
|
||||
h3 := handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
|
||||
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
|
||||
|
||||
serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3)
|
||||
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
|
||||
|
||||
// Assemble the Finished message
|
||||
fin := &FinishedBody{
|
||||
VerifyDataLen: len(serverFinishedData),
|
||||
VerifyData: serverFinishedData,
|
||||
}
|
||||
finm, _ := HandshakeMessageFromBody(fin)
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{finm})
|
||||
handshakeHash.Write(finm.Marshal())
|
||||
|
||||
// Compute traffic secrets
|
||||
h4 := handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4)
|
||||
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4)
|
||||
|
||||
clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4)
|
||||
serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4)
|
||||
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
|
||||
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
|
||||
|
||||
serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret)
|
||||
toSend = append(toSend, RekeyOut{Label: "application", KeySet: serverTrafficKeys})
|
||||
|
||||
exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4)
|
||||
logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
|
||||
|
||||
if state.Params.UsingEarlyData {
|
||||
clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret)
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]")
|
||||
nextState := ServerStateWaitEOED{
|
||||
AuthCertificate: state.Caps.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: params,
|
||||
handshakeHash: handshakeHash,
|
||||
masterSecret: masterSecret,
|
||||
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||
clientTrafficSecret: clientTrafficSecret,
|
||||
serverTrafficSecret: serverTrafficSecret,
|
||||
exporterSecret: exporterSecret,
|
||||
}
|
||||
toSend = append(toSend, []HandshakeAction{
|
||||
RekeyIn{Label: "early", KeySet: clientEarlyTrafficKeys},
|
||||
ReadEarlyData{},
|
||||
}...)
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]")
|
||||
toSend = append(toSend, []HandshakeAction{
|
||||
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
|
||||
ReadPastEarlyData{},
|
||||
}...)
|
||||
waitFlight2 := ServerStateWaitFlight2{
|
||||
AuthCertificate: state.Caps.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: params,
|
||||
handshakeHash: handshakeHash,
|
||||
masterSecret: masterSecret,
|
||||
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||
clientTrafficSecret: clientTrafficSecret,
|
||||
serverTrafficSecret: serverTrafficSecret,
|
||||
exporterSecret: exporterSecret,
|
||||
}
|
||||
nextState, moreToSend, alert := waitFlight2.Next(nil)
|
||||
toSend = append(toSend, moreToSend...)
|
||||
return nextState, toSend, alert
|
||||
}
|
||||
|
||||
type ServerStateWaitEOED struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData {
|
||||
logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
if len(hm.body) > 0 {
|
||||
logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]")
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]")
|
||||
toSend := []HandshakeAction{
|
||||
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
|
||||
}
|
||||
waitFlight2 := ServerStateWaitFlight2{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
nextState, moreToSend, alert := waitFlight2.Next(nil)
|
||||
toSend = append(toSend, moreToSend...)
|
||||
return nextState, toSend, alert
|
||||
}
|
||||
|
||||
type ServerStateWaitFlight2 struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFlight2] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
if state.Params.UsingClientAuth {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]")
|
||||
nextState := ServerStateWaitCert{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]")
|
||||
nextState := ServerStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
handshakeHash: state.handshakeHash,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ServerStateWaitCert struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeCertificate {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
cert := &CertificateBody{}
|
||||
_, err := cert.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
if len(cert.CertificateList) == 0 {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate")
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]")
|
||||
nextState := ServerStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
handshakeHash: state.handshakeHash,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]")
|
||||
nextState := ServerStateWaitCV{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
handshakeHash: state.handshakeHash,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
clientCertificate: cert,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ServerStateWaitCV struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
|
||||
clientCertificate *CertificateBody
|
||||
}
|
||||
|
||||
func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm))
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
certVerify := &CertificateVerifyBody{}
|
||||
_, err := certVerify.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
// Verify client signature over handshake hash
|
||||
hcv := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||
|
||||
clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey
|
||||
if err := certVerify.Verify(clientPublicKey, hcv); err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
if state.AuthCertificate != nil {
|
||||
err := state.AuthCertificate(state.clientCertificate.CertificateList)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate")
|
||||
return nil, nil, AlertBadCertificate
|
||||
}
|
||||
} else {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate")
|
||||
}
|
||||
|
||||
// If it passes, record the certificateVerify in the transcript hash
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]")
|
||||
nextState := ServerStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
handshakeHash: state.handshakeHash,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ServerStateWaitFinished struct {
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeFinished {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()}
|
||||
_, err := fin.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
// Verify client Finished data
|
||||
h5 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
|
||||
|
||||
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
|
||||
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
|
||||
|
||||
if !bytes.Equal(fin.VerifyData, clientFinishedData) {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify")
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Compute the resumption secret
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
h6 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6)
|
||||
|
||||
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
|
||||
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
|
||||
|
||||
// Compute client traffic keys
|
||||
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]")
|
||||
nextState := StateConnected{
|
||||
Params: state.Params,
|
||||
isClient: false,
|
||||
cryptoParams: state.cryptoParams,
|
||||
resumptionSecret: resumptionSecret,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
toSend := []HandshakeAction{
|
||||
RekeyIn{Label: "application", KeySet: clientTrafficKeys},
|
||||
}
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
230
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
Normal file
230
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
Normal file
|
@ -0,0 +1,230 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Marker interface for actions that an implementation should take based on
|
||||
// state transitions.
|
||||
type HandshakeAction interface{}
|
||||
|
||||
type SendHandshakeMessage struct {
|
||||
Message *HandshakeMessage
|
||||
}
|
||||
|
||||
type SendEarlyData struct{}
|
||||
|
||||
type ReadEarlyData struct{}
|
||||
|
||||
type ReadPastEarlyData struct{}
|
||||
|
||||
type RekeyIn struct {
|
||||
Label string
|
||||
KeySet keySet
|
||||
}
|
||||
|
||||
type RekeyOut struct {
|
||||
Label string
|
||||
KeySet keySet
|
||||
}
|
||||
|
||||
type StorePSK struct {
|
||||
PSK PreSharedKey
|
||||
}
|
||||
|
||||
type HandshakeState interface {
|
||||
Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert)
|
||||
}
|
||||
|
||||
type AppExtensionHandler interface {
|
||||
Send(hs HandshakeType, el *ExtensionList) error
|
||||
Receive(hs HandshakeType, el *ExtensionList) error
|
||||
}
|
||||
|
||||
// Capabilities objects represent the capabilities of a TLS client or server,
|
||||
// as an input to TLS negotiation
|
||||
type Capabilities struct {
|
||||
// For both client and server
|
||||
CipherSuites []CipherSuite
|
||||
Groups []NamedGroup
|
||||
SignatureSchemes []SignatureScheme
|
||||
PSKs PreSharedKeyCache
|
||||
Certificates []*Certificate
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
ExtensionHandler AppExtensionHandler
|
||||
|
||||
// For client
|
||||
PSKModes []PSKKeyExchangeMode
|
||||
|
||||
// For server
|
||||
NextProtos []string
|
||||
AllowEarlyData bool
|
||||
RequireCookie bool
|
||||
CookieHandler CookieHandler
|
||||
RequireClientAuth bool
|
||||
}
|
||||
|
||||
// ConnectionOptions objects represent per-connection settings for a client
|
||||
// initiating a connection
|
||||
type ConnectionOptions struct {
|
||||
ServerName string
|
||||
NextProtos []string
|
||||
EarlyData []byte
|
||||
}
|
||||
|
||||
// ConnectionParameters objects represent the parameters negotiated for a
|
||||
// connection.
|
||||
type ConnectionParameters struct {
|
||||
UsingPSK bool
|
||||
UsingDH bool
|
||||
ClientSendingEarlyData bool
|
||||
UsingEarlyData bool
|
||||
UsingClientAuth bool
|
||||
|
||||
CipherSuite CipherSuite
|
||||
ServerName string
|
||||
NextProto string
|
||||
}
|
||||
|
||||
// StateConnected is symmetric between client and server
|
||||
type StateConnected struct {
|
||||
Params ConnectionParameters
|
||||
isClient bool
|
||||
cryptoParams CipherSuiteParams
|
||||
resumptionSecret []byte
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
|
||||
var trafficKeys keySet
|
||||
if state.isClient {
|
||||
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
|
||||
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||
} else {
|
||||
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
|
||||
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||
}
|
||||
|
||||
kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err)
|
||||
return nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
SendHandshakeMessage{kum},
|
||||
RekeyOut{Label: "update", KeySet: trafficKeys},
|
||||
}
|
||||
return toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
|
||||
tkt, err := NewSessionTicket(length, lifetime)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
|
||||
return nil, AlertInternalError
|
||||
}
|
||||
|
||||
err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err)
|
||||
return nil, AlertInternalError
|
||||
}
|
||||
|
||||
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
|
||||
labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size())
|
||||
|
||||
newPSK := PreSharedKey{
|
||||
CipherSuite: state.cryptoParams.Suite,
|
||||
IsResumption: true,
|
||||
Identity: tkt.Ticket,
|
||||
Key: resumptionKey,
|
||||
NextProto: state.Params.NextProto,
|
||||
ReceivedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second),
|
||||
TicketAgeAdd: tkt.TicketAgeAdd,
|
||||
}
|
||||
|
||||
tktm, err := HandshakeMessageFromBody(tkt)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err)
|
||||
return nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
StorePSK{newPSK},
|
||||
SendHandshakeMessage{tktm},
|
||||
}
|
||||
return toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
bodyGeneric, err := hm.ToBody()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
switch body := bodyGeneric.(type) {
|
||||
case *KeyUpdateBody:
|
||||
var trafficKeys keySet
|
||||
if !state.isClient {
|
||||
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
|
||||
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||
} else {
|
||||
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
|
||||
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}}
|
||||
|
||||
// If requested, roll outbound keys and send a KeyUpdate
|
||||
if body.KeyUpdateRequest == KeyUpdateRequested {
|
||||
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested)
|
||||
if alert != AlertNoAlert {
|
||||
return nil, nil, alert
|
||||
}
|
||||
|
||||
toSend = append(toSend, moreToSend...)
|
||||
}
|
||||
|
||||
return state, toSend, AlertNoAlert
|
||||
|
||||
case *NewSessionTicketBody:
|
||||
// XXX: Allow NewSessionTicket in both directions?
|
||||
if !state.isClient {
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
|
||||
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size())
|
||||
|
||||
psk := PreSharedKey{
|
||||
CipherSuite: state.cryptoParams.Suite,
|
||||
IsResumption: true,
|
||||
Identity: body.Ticket,
|
||||
Key: resumptionKey,
|
||||
NextProto: state.Params.NextProto,
|
||||
ReceivedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second),
|
||||
TicketAgeAdd: body.TicketAgeAdd,
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{StorePSK{psk}}
|
||||
return state, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType)
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
243
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
Normal file
243
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
Normal file
|
@ -0,0 +1,243 @@
|
|||
package syntax
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func Unmarshal(data []byte, v interface{}) (int, error) {
|
||||
// Check for well-formedness.
|
||||
// Avoids filling out half a data structure
|
||||
// before discovering a JSON syntax error.
|
||||
d := decodeState{}
|
||||
d.Write(data)
|
||||
return d.unmarshal(v)
|
||||
}
|
||||
|
||||
// These are the options that can be specified in the struct tag. Right now,
|
||||
// all of them apply to variable-length vectors and nothing else
|
||||
type decOpts struct {
|
||||
head uint // length of length in bytes
|
||||
min uint // minimum size in bytes
|
||||
max uint // maximum size in bytes
|
||||
}
|
||||
|
||||
type decodeState struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (d *decodeState) unmarshal(v interface{}) (read int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if _, ok := r.(runtime.Error); ok {
|
||||
panic(r)
|
||||
}
|
||||
if s, ok := r.(string); ok {
|
||||
panic(s)
|
||||
}
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() != reflect.Ptr || rv.IsNil() {
|
||||
return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)")
|
||||
}
|
||||
|
||||
read = d.value(rv)
|
||||
return read, nil
|
||||
}
|
||||
|
||||
func (e *decodeState) value(v reflect.Value) int {
|
||||
return valueDecoder(v)(e, v, decOpts{})
|
||||
}
|
||||
|
||||
type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int
|
||||
|
||||
func valueDecoder(v reflect.Value) decoderFunc {
|
||||
return typeDecoder(v.Type().Elem())
|
||||
}
|
||||
|
||||
func typeDecoder(t reflect.Type) decoderFunc {
|
||||
// Note: Omits the caching / wait-group things that encoding/json uses
|
||||
return newTypeDecoder(t)
|
||||
}
|
||||
|
||||
func newTypeDecoder(t reflect.Type) decoderFunc {
|
||||
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return uintDecoder
|
||||
case reflect.Array:
|
||||
return newArrayDecoder(t)
|
||||
case reflect.Slice:
|
||||
return newSliceDecoder(t)
|
||||
case reflect.Struct:
|
||||
return newStructDecoder(t)
|
||||
default:
|
||||
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||
}
|
||||
}
|
||||
|
||||
///// Specific decoders below
|
||||
|
||||
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
var uintLen int
|
||||
switch v.Elem().Kind() {
|
||||
case reflect.Uint8:
|
||||
uintLen = 1
|
||||
case reflect.Uint16:
|
||||
uintLen = 2
|
||||
case reflect.Uint32:
|
||||
uintLen = 4
|
||||
case reflect.Uint64:
|
||||
uintLen = 8
|
||||
}
|
||||
|
||||
buf := make([]byte, uintLen)
|
||||
n, err := d.Read(buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if n != uintLen {
|
||||
panic(fmt.Errorf("Insufficient data to read uint"))
|
||||
}
|
||||
|
||||
val := uint64(0)
|
||||
for _, b := range buf {
|
||||
val = (val << 8) + uint64(b)
|
||||
}
|
||||
|
||||
v.Elem().SetUint(val)
|
||||
return uintLen
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type arrayDecoder struct {
|
||||
elemDec decoderFunc
|
||||
}
|
||||
|
||||
func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
n := v.Elem().Type().Len()
|
||||
read := 0
|
||||
for i := 0; i < n; i += 1 {
|
||||
read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts)
|
||||
}
|
||||
return read
|
||||
}
|
||||
|
||||
func newArrayDecoder(t reflect.Type) decoderFunc {
|
||||
dec := &arrayDecoder{typeDecoder(t.Elem())}
|
||||
return dec.decode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type sliceDecoder struct {
|
||||
elementType reflect.Type
|
||||
elementDec decoderFunc
|
||||
}
|
||||
|
||||
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
if opts.head == 0 {
|
||||
panic(fmt.Errorf("Cannot decode a slice without a header length"))
|
||||
}
|
||||
|
||||
lengthBytes := make([]byte, opts.head)
|
||||
n, err := d.Read(lengthBytes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if uint(n) != opts.head {
|
||||
panic(fmt.Errorf("Not enough data to read header"))
|
||||
}
|
||||
|
||||
length := uint(0)
|
||||
for _, b := range lengthBytes {
|
||||
length = (length << 8) + uint(b)
|
||||
}
|
||||
|
||||
if opts.max > 0 && length > opts.max {
|
||||
panic(fmt.Errorf("Length of vector exceeds declared max"))
|
||||
}
|
||||
if length < opts.min {
|
||||
panic(fmt.Errorf("Length of vector below declared min"))
|
||||
}
|
||||
|
||||
data := make([]byte, length)
|
||||
n, err = d.Read(data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if uint(n) != length {
|
||||
panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length))
|
||||
}
|
||||
|
||||
elemBuf := &decodeState{}
|
||||
elemBuf.Write(data)
|
||||
elems := []reflect.Value{}
|
||||
read := int(opts.head)
|
||||
for elemBuf.Len() > 0 {
|
||||
elem := reflect.New(sd.elementType)
|
||||
read += sd.elementDec(elemBuf, elem, opts)
|
||||
elems = append(elems, elem)
|
||||
}
|
||||
|
||||
v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems)))
|
||||
for i := 0; i < len(elems); i += 1 {
|
||||
v.Elem().Index(i).Set(elems[i].Elem())
|
||||
}
|
||||
return read
|
||||
}
|
||||
|
||||
func newSliceDecoder(t reflect.Type) decoderFunc {
|
||||
dec := &sliceDecoder{
|
||||
elementType: t.Elem(),
|
||||
elementDec: typeDecoder(t.Elem()),
|
||||
}
|
||||
return dec.decode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type structDecoder struct {
|
||||
fieldOpts []decOpts
|
||||
fieldDecs []decoderFunc
|
||||
}
|
||||
|
||||
func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
read := 0
|
||||
for i := range sd.fieldDecs {
|
||||
read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i])
|
||||
}
|
||||
return read
|
||||
}
|
||||
|
||||
func newStructDecoder(t reflect.Type) decoderFunc {
|
||||
n := t.NumField()
|
||||
sd := structDecoder{
|
||||
fieldOpts: make([]decOpts, n),
|
||||
fieldDecs: make([]decoderFunc, n),
|
||||
}
|
||||
|
||||
for i := 0; i < n; i += 1 {
|
||||
f := t.Field(i)
|
||||
|
||||
tag := f.Tag.Get("tls")
|
||||
tagOpts := parseTag(tag)
|
||||
|
||||
sd.fieldOpts[i] = decOpts{
|
||||
head: tagOpts["head"],
|
||||
max: tagOpts["max"],
|
||||
min: tagOpts["min"],
|
||||
}
|
||||
|
||||
sd.fieldDecs[i] = typeDecoder(f.Type)
|
||||
}
|
||||
|
||||
return sd.decode
|
||||
}
|
187
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
Normal file
187
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
Normal file
|
@ -0,0 +1,187 @@
|
|||
package syntax
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func Marshal(v interface{}) ([]byte, error) {
|
||||
e := &encodeState{}
|
||||
err := e.marshal(v, encOpts{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return e.Bytes(), nil
|
||||
}
|
||||
|
||||
// These are the options that can be specified in the struct tag. Right now,
|
||||
// all of them apply to variable-length vectors and nothing else
|
||||
type encOpts struct {
|
||||
head uint // length of length in bytes
|
||||
min uint // minimum size in bytes
|
||||
max uint // maximum size in bytes
|
||||
}
|
||||
|
||||
type encodeState struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if _, ok := r.(runtime.Error); ok {
|
||||
panic(r)
|
||||
}
|
||||
if s, ok := r.(string); ok {
|
||||
panic(s)
|
||||
}
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
e.reflectValue(reflect.ValueOf(v), opts)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) {
|
||||
valueEncoder(v)(e, v, opts)
|
||||
}
|
||||
|
||||
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
|
||||
|
||||
func valueEncoder(v reflect.Value) encoderFunc {
|
||||
if !v.IsValid() {
|
||||
panic(fmt.Errorf("Cannot encode an invalid value"))
|
||||
}
|
||||
return typeEncoder(v.Type())
|
||||
}
|
||||
|
||||
func typeEncoder(t reflect.Type) encoderFunc {
|
||||
// Note: Omits the caching / wait-group things that encoding/json uses
|
||||
return newTypeEncoder(t)
|
||||
}
|
||||
|
||||
func newTypeEncoder(t reflect.Type) encoderFunc {
|
||||
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return uintEncoder
|
||||
case reflect.Array:
|
||||
return newArrayEncoder(t)
|
||||
case reflect.Slice:
|
||||
return newSliceEncoder(t)
|
||||
case reflect.Struct:
|
||||
return newStructEncoder(t)
|
||||
default:
|
||||
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||
}
|
||||
}
|
||||
|
||||
///// Specific encoders below
|
||||
|
||||
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
u := v.Uint()
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Uint8:
|
||||
e.WriteByte(byte(u))
|
||||
case reflect.Uint16:
|
||||
e.Write([]byte{byte(u >> 8), byte(u)})
|
||||
case reflect.Uint32:
|
||||
e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
||||
case reflect.Uint64:
|
||||
e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32),
|
||||
byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
||||
}
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type arrayEncoder struct {
|
||||
elemEnc encoderFunc
|
||||
}
|
||||
|
||||
func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
n := v.Len()
|
||||
for i := 0; i < n; i += 1 {
|
||||
ae.elemEnc(e, v.Index(i), opts)
|
||||
}
|
||||
}
|
||||
|
||||
func newArrayEncoder(t reflect.Type) encoderFunc {
|
||||
enc := &arrayEncoder{typeEncoder(t.Elem())}
|
||||
return enc.encode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type sliceEncoder struct {
|
||||
ae *arrayEncoder
|
||||
}
|
||||
|
||||
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
if opts.head == 0 {
|
||||
panic(fmt.Errorf("Cannot encode a slice without a header length"))
|
||||
}
|
||||
|
||||
arrayState := &encodeState{}
|
||||
se.ae.encode(arrayState, v, opts)
|
||||
|
||||
n := uint(arrayState.Len())
|
||||
if opts.max > 0 && n > opts.max {
|
||||
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max))
|
||||
}
|
||||
if n>>(8*opts.head) > 0 {
|
||||
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
|
||||
}
|
||||
if n < opts.min {
|
||||
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min))
|
||||
}
|
||||
|
||||
for i := int(opts.head - 1); i >= 0; i -= 1 {
|
||||
e.WriteByte(byte(n >> (8 * uint(i))))
|
||||
}
|
||||
e.Write(arrayState.Bytes())
|
||||
}
|
||||
|
||||
func newSliceEncoder(t reflect.Type) encoderFunc {
|
||||
enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}}
|
||||
return enc.encode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type structEncoder struct {
|
||||
fieldOpts []encOpts
|
||||
fieldEncs []encoderFunc
|
||||
}
|
||||
|
||||
func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
for i := range se.fieldEncs {
|
||||
se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i])
|
||||
}
|
||||
}
|
||||
|
||||
func newStructEncoder(t reflect.Type) encoderFunc {
|
||||
n := t.NumField()
|
||||
se := structEncoder{
|
||||
fieldOpts: make([]encOpts, n),
|
||||
fieldEncs: make([]encoderFunc, n),
|
||||
}
|
||||
|
||||
for i := 0; i < n; i += 1 {
|
||||
f := t.Field(i)
|
||||
tag := f.Tag.Get("tls")
|
||||
tagOpts := parseTag(tag)
|
||||
|
||||
se.fieldOpts[i] = encOpts{
|
||||
head: tagOpts["head"],
|
||||
max: tagOpts["max"],
|
||||
min: tagOpts["min"],
|
||||
}
|
||||
se.fieldEncs[i] = typeEncoder(f.Type)
|
||||
}
|
||||
|
||||
return se.encode
|
||||
}
|
30
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
Normal file
30
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
Normal file
|
@ -0,0 +1,30 @@
|
|||
package syntax
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// `tls:"head=2,min=2,max=255"`
|
||||
|
||||
type tagOptions map[string]uint
|
||||
|
||||
// parseTag parses a struct field's "tls" tag as a comma-separated list of
|
||||
// name=value pairs, where the values MUST be unsigned integers
|
||||
func parseTag(tag string) tagOptions {
|
||||
opts := tagOptions{}
|
||||
for _, token := range strings.Split(tag, ",") {
|
||||
if strings.Index(token, "=") == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Split(token, "=")
|
||||
if len(parts[0]) == 0 {
|
||||
continue
|
||||
}
|
||||
if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
|
||||
opts[parts[0]] = uint(val)
|
||||
}
|
||||
}
|
||||
return opts
|
||||
}
|
168
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
Normal file
168
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
Normal file
|
@ -0,0 +1,168 @@
|
|||
package mint
|
||||
|
||||
// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Server returns a new TLS server side connection
|
||||
// using conn as the underlying transport.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func Server(conn net.Conn, config *Config) *Conn {
|
||||
return NewConn(conn, config, false)
|
||||
}
|
||||
|
||||
// Client returns a new TLS client side connection
|
||||
// using conn as the underlying transport.
|
||||
// The config cannot be nil: users must set either ServerName or
|
||||
// InsecureSkipVerify in the config.
|
||||
func Client(conn net.Conn, config *Config) *Conn {
|
||||
return NewConn(conn, config, true)
|
||||
}
|
||||
|
||||
// A listener implements a network listener (net.Listener) for TLS connections.
|
||||
type Listener struct {
|
||||
net.Listener
|
||||
config *Config
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next incoming TLS connection.
|
||||
// The returned connection c is a *tls.Conn.
|
||||
func (l *Listener) Accept() (c net.Conn, err error) {
|
||||
c, err = l.Listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
server := Server(c, l.config)
|
||||
err = server.Handshake()
|
||||
if err == AlertNoAlert {
|
||||
err = nil
|
||||
}
|
||||
c = server
|
||||
return
|
||||
}
|
||||
|
||||
// NewListener creates a Listener which accepts connections from an inner
|
||||
// Listener and wraps each connection with Server.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func NewListener(inner net.Listener, config *Config) net.Listener {
|
||||
l := new(Listener)
|
||||
l.Listener = inner
|
||||
l.config = config
|
||||
return l
|
||||
}
|
||||
|
||||
// Listen creates a TLS listener accepting connections on the
|
||||
// given network address using net.Listen.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func Listen(network, laddr string, config *Config) (net.Listener, error) {
|
||||
if config == nil || !config.ValidForServer() {
|
||||
return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config")
|
||||
}
|
||||
l, err := net.Listen(network, laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewListener(l, config), nil
|
||||
}
|
||||
|
||||
type TimeoutError struct{}
|
||||
|
||||
func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" }
|
||||
func (TimeoutError) Timeout() bool { return true }
|
||||
func (TimeoutError) Temporary() bool { return true }
|
||||
|
||||
// DialWithDialer connects to the given network address using dialer.Dial and
|
||||
// then initiates a TLS handshake, returning the resulting TLS connection. Any
|
||||
// timeout or deadline given in the dialer apply to connection and TLS
|
||||
// handshake as a whole.
|
||||
//
|
||||
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
||||
// configuration; see the documentation of Config for the defaults.
|
||||
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||
// We want the Timeout and Deadline values from dialer to cover the
|
||||
// whole process: TCP connection and TLS handshake. This means that we
|
||||
// also need to start our own timers now.
|
||||
timeout := dialer.Timeout
|
||||
|
||||
if !dialer.Deadline.IsZero() {
|
||||
deadlineTimeout := dialer.Deadline.Sub(time.Now())
|
||||
if timeout == 0 || deadlineTimeout < timeout {
|
||||
timeout = deadlineTimeout
|
||||
}
|
||||
}
|
||||
|
||||
var errChannel chan error
|
||||
|
||||
if timeout != 0 {
|
||||
errChannel = make(chan error, 2)
|
||||
time.AfterFunc(timeout, func() {
|
||||
errChannel <- TimeoutError{}
|
||||
})
|
||||
}
|
||||
|
||||
rawConn, err := dialer.Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
colonPos := strings.LastIndex(addr, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(addr)
|
||||
}
|
||||
hostname := addr[:colonPos]
|
||||
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
// If no ServerName is set, infer the ServerName
|
||||
// from the hostname we're connecting to.
|
||||
if config.ServerName == "" {
|
||||
// Make a copy to avoid polluting argument or default.
|
||||
c := config.Clone()
|
||||
c.ServerName = hostname
|
||||
config = c
|
||||
}
|
||||
|
||||
conn := Client(rawConn, config)
|
||||
|
||||
if timeout == 0 {
|
||||
err = conn.Handshake()
|
||||
if err == AlertNoAlert {
|
||||
err = nil
|
||||
}
|
||||
} else {
|
||||
go func() {
|
||||
errChannel <- conn.Handshake()
|
||||
}()
|
||||
|
||||
err = <-errChannel
|
||||
if err == AlertNoAlert {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Dial connects to the given network address using net.Dial
|
||||
// and then initiates a TLS handshake, returning the resulting
|
||||
// TLS connection.
|
||||
// Dial interprets a nil configuration as equivalent to
|
||||
// the zero configuration; see the documentation of Config
|
||||
// for the defaults.
|
||||
func Dial(network, addr string, config *Config) (*Conn, error) {
|
||||
return DialWithDialer(new(net.Dialer), network, addr, config)
|
||||
}
|
32
vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go
generated
vendored
32
vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go
generated
vendored
|
@ -1,32 +0,0 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
// SentPacketHandler handles ACKs received for outgoing packets
|
||||
type SentPacketHandler interface {
|
||||
// SentPacket may modify the packet
|
||||
SentPacket(packet *Packet) error
|
||||
ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error
|
||||
|
||||
SendingAllowed() bool
|
||||
GetStopWaitingFrame(force bool) *frames.StopWaitingFrame
|
||||
DequeuePacketForRetransmission() (packet *Packet)
|
||||
GetLeastUnacked() protocol.PacketNumber
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
OnAlarm()
|
||||
}
|
||||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
type ReceivedPacketHandler interface {
|
||||
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
|
||||
SetLowerLimit(protocol.PacketNumber)
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
GetAckFrame() *frames.AckFrame
|
||||
}
|
2
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
|
@ -3,7 +3,7 @@ package quic
|
|||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var bufferPool sync.Pool
|
||||
|
|
360
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
360
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
|
@ -10,32 +10,39 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
mutex sync.Mutex
|
||||
listenErr error
|
||||
mutex sync.Mutex
|
||||
|
||||
conn connection
|
||||
hostname string
|
||||
|
||||
errorChan chan struct{}
|
||||
handshakeChan <-chan handshakeEvent
|
||||
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
|
||||
versionNegotiated bool // has the server accepted our version
|
||||
receivedVersionNegotiationPacket bool
|
||||
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
versionNegotiated bool // has version negotiation completed yet
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
tls handshake.MintTLS // only used when using TLS
|
||||
|
||||
connectionID protocol.ConnectionID
|
||||
version protocol.VersionNumber
|
||||
|
||||
initialVersion protocol.VersionNumber
|
||||
version protocol.VersionNumber
|
||||
|
||||
session packetHandler
|
||||
}
|
||||
|
||||
var (
|
||||
// make it possible to mock connection ID generation in the tests
|
||||
generateConnectionID = utils.GenerateConnectionID
|
||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||
)
|
||||
|
||||
|
@ -53,71 +60,6 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
|
|||
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddrNonFWSecure(
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (NonFWSession, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func DialNonFWSecure(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (NonFWSession, error) {
|
||||
connID, err := utils.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var hostname string
|
||||
if tlsConf != nil {
|
||||
hostname = tlsConf.ServerName
|
||||
}
|
||||
|
||||
if hostname == "" {
|
||||
hostname, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
clientConfig := populateClientConfig(config)
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
connectionID: connID,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: clientConfig.Versions[0],
|
||||
errorChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
err = c.createNewSession(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
|
||||
return c.session.(NonFWSession), c.establishSecureConnection()
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func Dial(
|
||||
|
@ -127,15 +69,39 @@ func Dial(
|
|||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
|
||||
connID, err := generateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = sess.WaitUntilHandshakeComplete()
|
||||
if err != nil {
|
||||
|
||||
var hostname string
|
||||
if tlsConf != nil {
|
||||
hostname = tlsConf.ServerName
|
||||
}
|
||||
if hostname == "" {
|
||||
hostname, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
clientConfig := populateClientConfig(config)
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
connectionID: connID,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: clientConfig.Versions[0],
|
||||
versionNegotiationChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
|
||||
if err := c.dial(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sess, nil
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
|
@ -153,6 +119,10 @@ func populateClientConfig(config *Config) *Config {
|
|||
if config.HandshakeTimeout != 0 {
|
||||
handshakeTimeout = config.HandshakeTimeout
|
||||
}
|
||||
idleTimeout := protocol.DefaultIdleTimeout
|
||||
if config.IdleTimeout != 0 {
|
||||
idleTimeout = config.IdleTimeout
|
||||
}
|
||||
|
||||
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
|
||||
if maxReceiveStreamFlowControlWindow == 0 {
|
||||
|
@ -166,32 +136,109 @@ func populateClientConfig(config *Config) *Config {
|
|||
return &Config{
|
||||
Versions: versions,
|
||||
HandshakeTimeout: handshakeTimeout,
|
||||
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
|
||||
IdleTimeout: idleTimeout,
|
||||
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
||||
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||
KeepAlive: config.KeepAlive,
|
||||
}
|
||||
}
|
||||
|
||||
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
|
||||
func (c *client) establishSecureConnection() error {
|
||||
func (c *client) dial() error {
|
||||
var err error
|
||||
if c.version.UsesTLS() {
|
||||
err = c.dialTLS()
|
||||
} else {
|
||||
err = c.dialGQUIC()
|
||||
}
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return c.dial()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) dialGQUIC() error {
|
||||
if err := c.createNewGQUICSession(); err != nil {
|
||||
return err
|
||||
}
|
||||
go c.listen()
|
||||
return c.establishSecureConnection()
|
||||
}
|
||||
|
||||
func (c *client) dialTLS() error {
|
||||
params := &handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
IdleTimeout: c.config.IdleTimeout,
|
||||
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
||||
// TODO(#523): make these values configurable
|
||||
MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
|
||||
MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
|
||||
}
|
||||
csc := handshake.NewCryptoStreamConn(nil)
|
||||
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
|
||||
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mintConf.ExtensionHandler = extHandler
|
||||
mintConf.ServerName = c.hostname
|
||||
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
|
||||
|
||||
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
go c.listen()
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
if err != handshake.ErrCloseSessionForRetry {
|
||||
return err
|
||||
}
|
||||
utils.Infof("Received a Retry packet. Recreating session.")
|
||||
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||
// It returns:
|
||||
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
||||
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
||||
// - any other error that might occur
|
||||
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||
func (c *client) establishSecureConnection() error {
|
||||
var runErr error
|
||||
errorChan := make(chan struct{})
|
||||
go func() {
|
||||
runErr = c.session.run() // returns as soon as the session is closed
|
||||
close(errorChan)
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
|
||||
c.conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// wait until the server accepts the QUIC version (or an error occurs)
|
||||
select {
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case <-c.versionNegotiationChan:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.errorChan:
|
||||
return c.listenErr
|
||||
case ev := <-c.handshakeChan:
|
||||
if ev.err != nil {
|
||||
return ev.err
|
||||
}
|
||||
if ev.encLevel != protocol.EncryptionSecure {
|
||||
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
|
||||
}
|
||||
return nil
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case err := <-c.session.handshakeStatus():
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Listen listens
|
||||
// Listen listens on the underlying connection and passes packets on for handling.
|
||||
// It returns when the connection is closed.
|
||||
func (c *client) listen() {
|
||||
var err error
|
||||
|
||||
|
@ -205,13 +252,15 @@ func (c *client) listen() {
|
|||
n, addr, err = c.conn.Read(data)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
c.session.Close(err)
|
||||
c.mutex.Lock()
|
||||
if c.session != nil {
|
||||
c.session.Close(err)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
break
|
||||
}
|
||||
data = data[:n]
|
||||
|
||||
c.handlePacket(addr, data)
|
||||
c.handlePacket(addr, data[:n])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -219,10 +268,14 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||
rcvTime := time.Now()
|
||||
|
||||
r := bytes.NewReader(packet)
|
||||
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer)
|
||||
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
|
||||
if err != nil {
|
||||
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
||||
// drop this packet if we can't parse the Public Header
|
||||
// drop this packet if we can't parse the header
|
||||
return
|
||||
}
|
||||
// reject packets with truncated connection id if we didn't request truncation
|
||||
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
|
||||
return
|
||||
}
|
||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||
|
@ -230,6 +283,11 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// reject packets with the wrong connection ID
|
||||
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.ResetFlag {
|
||||
cr := c.conn.RemoteAddr()
|
||||
// check if the remote address and the connection ID match
|
||||
|
@ -238,44 +296,48 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||
utils.Infof("Received a spoofed Public Reset. Ignoring.")
|
||||
return
|
||||
}
|
||||
pr, err := parsePublicReset(r)
|
||||
pr, err := wire.ParsePublicReset(r)
|
||||
if err != nil {
|
||||
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.")
|
||||
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||
return
|
||||
}
|
||||
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber)
|
||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber)))
|
||||
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
|
||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||
return
|
||||
}
|
||||
|
||||
// ignore delayed / duplicated version negotiation packets
|
||||
if c.versionNegotiated && hdr.VersionFlag {
|
||||
return
|
||||
}
|
||||
// handle Version Negotiation Packets
|
||||
if hdr.IsVersionNegotiation {
|
||||
// ignore delayed / duplicated version negotiation packets
|
||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||
return
|
||||
}
|
||||
|
||||
// this is the first packet after the client sent a packet with the VersionFlag set
|
||||
// if the server doesn't send a version negotiation packet, it supports the suggested version
|
||||
if !hdr.VersionFlag && !c.versionNegotiated {
|
||||
c.versionNegotiated = true
|
||||
}
|
||||
|
||||
if hdr.VersionFlag {
|
||||
// version negotiation packets have no payload
|
||||
if err := c.handlePacketWithVersionFlag(hdr); err != nil {
|
||||
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
|
||||
c.session.Close(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// this is the first packet we are receiving
|
||||
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
|
||||
if !c.versionNegotiated {
|
||||
c.versionNegotiated = true
|
||||
close(c.versionNegotiationChan)
|
||||
}
|
||||
|
||||
// TODO: validate packet number and connection ID on Retry packets (for IETF QUIC)
|
||||
|
||||
c.session.handlePacket(&receivedPacket{
|
||||
remoteAddr: remoteAddr,
|
||||
publicHeader: hdr,
|
||||
data: packet[len(packet)-r.Len():],
|
||||
rcvTime: rcvTime,
|
||||
remoteAddr: remoteAddr,
|
||||
header: hdr,
|
||||
data: packet[len(packet)-r.Len():],
|
||||
rcvTime: rcvTime,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
|
||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||
for _, v := range hdr.SupportedVersions {
|
||||
if v == c.version {
|
||||
// the version negotiation packet contains the version that we offered
|
||||
|
@ -285,51 +347,57 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
|
|||
}
|
||||
}
|
||||
|
||||
newVersion := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||
if newVersion == protocol.VersionUnsupported {
|
||||
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||
if !ok {
|
||||
return qerr.InvalidVersion
|
||||
}
|
||||
c.receivedVersionNegotiationPacket = true
|
||||
c.negotiatedVersions = hdr.SupportedVersions
|
||||
|
||||
// switch to negotiated version
|
||||
c.initialVersion = c.version
|
||||
c.version = newVersion
|
||||
c.versionNegotiated = true
|
||||
var err error
|
||||
c.connectionID, err = utils.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID)
|
||||
|
||||
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
|
||||
c.session.Close(errCloseSessionForNewVersion)
|
||||
return c.createNewSession(hdr.SupportedVersions)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
|
||||
var err error
|
||||
c.session, c.handshakeChan, err = newClientSession(
|
||||
func (c *client) createNewGQUICSession() (err error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.session, err = newClientSession(
|
||||
c.conn,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.connectionID,
|
||||
c.tlsConf,
|
||||
c.config,
|
||||
negotiatedVersions,
|
||||
c.initialVersion,
|
||||
c.negotiatedVersions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
// session.run() returns as soon as the session is closed
|
||||
err := c.session.run()
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return
|
||||
}
|
||||
c.listenErr = err
|
||||
close(c.errorChan)
|
||||
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
c.conn.Close()
|
||||
}()
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) createNewTLSSession(
|
||||
paramsChan <-chan handshake.TransportParameters,
|
||||
version protocol.VersionNumber,
|
||||
) (err error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.session, err = newTLSClientSession(
|
||||
c.conn,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.connectionID,
|
||||
c.config,
|
||||
c.tls,
|
||||
paramsChan,
|
||||
1,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
|
58
vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go
generated
vendored
58
vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go
generated
vendored
|
@ -1,58 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"errors"
|
||||
|
||||
"github.com/lucas-clemente/aes12"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
type aeadAESGCM struct {
|
||||
otherIV []byte
|
||||
myIV []byte
|
||||
encrypter cipher.AEAD
|
||||
decrypter cipher.AEAD
|
||||
}
|
||||
|
||||
// NewAEADAESGCM creates a AEAD using AES-GCM with 12 bytes tag size
|
||||
//
|
||||
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
|
||||
// tag size, and couples the cipher and aes packages closely.
|
||||
// See https://github.com/lucas-clemente/aes12.
|
||||
func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
||||
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
|
||||
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
|
||||
}
|
||||
encrypterCipher, err := aes12.NewCipher(myKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encrypter, err := aes12.NewGCM(encrypterCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypterCipher, err := aes12.NewCipher(otherKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypter, err := aes12.NewGCM(decrypterCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &aeadAESGCM{
|
||||
otherIV: otherIV,
|
||||
myIV: myIV,
|
||||
encrypter: encrypter,
|
||||
decrypter: decrypter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||
}
|
14
vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go
generated
vendored
|
@ -1,14 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
func makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
||||
res := make([]byte, 12)
|
||||
copy(res[0:4], iv)
|
||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
||||
return res
|
||||
}
|
76
vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go
generated
vendored
76
vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go
generated
vendored
|
@ -1,76 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// StkSource is used to create and verify source address tokens
|
||||
type StkSource interface {
|
||||
// NewToken creates a new token
|
||||
NewToken([]byte) ([]byte, error)
|
||||
// DecodeToken decodes a token
|
||||
DecodeToken([]byte) ([]byte, error)
|
||||
}
|
||||
|
||||
type stkSource struct {
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
const stkKeySize = 16
|
||||
|
||||
// Chrome currently sets this to 12, but discusses changing it to 16. We start
|
||||
// at 16 :)
|
||||
const stkNonceSize = 16
|
||||
|
||||
// NewStkSource creates a source for source address tokens
|
||||
func NewStkSource() (StkSource, error) {
|
||||
secret := make([]byte, 32)
|
||||
if _, err := rand.Read(secret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, err := deriveKey(secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &stkSource{aead: aead}, nil
|
||||
}
|
||||
|
||||
func (s *stkSource) NewToken(data []byte) ([]byte, error) {
|
||||
nonce := make([]byte, stkNonceSize)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.aead.Seal(nonce, nonce, data, nil), nil
|
||||
}
|
||||
|
||||
func (s *stkSource) DecodeToken(p []byte) ([]byte, error) {
|
||||
if len(p) < stkNonceSize {
|
||||
return nil, fmt.Errorf("STK too short: %d", len(p))
|
||||
}
|
||||
nonce := p[:stkNonceSize]
|
||||
return s.aead.Open(nil, nonce, p[stkNonceSize:], nil)
|
||||
}
|
||||
|
||||
func deriveKey(secret []byte) ([]byte, error) {
|
||||
r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key"))
|
||||
key := make([]byte, stkKeySize)
|
||||
if _, err := io.ReadFull(r, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
41
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
Normal file
41
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
Normal file
|
@ -0,0 +1,41 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoStreamI interface {
|
||||
StreamID() protocol.StreamID
|
||||
io.Reader
|
||||
io.Writer
|
||||
handleStreamFrame(*wire.StreamFrame) error
|
||||
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
|
||||
closeForShutdown(error)
|
||||
setReadOffset(protocol.ByteCount)
|
||||
// methods needed for flow control
|
||||
getWindowUpdate() protocol.ByteCount
|
||||
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
|
||||
}
|
||||
|
||||
type cryptoStream struct {
|
||||
*stream
|
||||
}
|
||||
|
||||
var _ cryptoStreamI = &cryptoStream{}
|
||||
|
||||
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
|
||||
str := newStream(version.CryptoStreamID(), sender, flowController, version)
|
||||
return &cryptoStream{str}
|
||||
}
|
||||
|
||||
// SetReadOffset sets the read offset.
|
||||
// It is only needed for the crypto stream.
|
||||
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
||||
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
|
||||
s.receiveStream.readOffset = offset
|
||||
s.receiveStream.frameQueue.readPosition = offset
|
||||
}
|
14
vendor/github.com/lucas-clemente/quic-go/example/client/main.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/example/client/main.go
generated
vendored
|
@ -7,12 +7,15 @@ import (
|
|||
"net/http"
|
||||
"sync"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
func main() {
|
||||
verbose := flag.Bool("v", false, "verbose")
|
||||
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
||||
flag.Parse()
|
||||
urls := flag.Args()
|
||||
|
||||
|
@ -23,8 +26,17 @@ func main() {
|
|||
}
|
||||
utils.SetLogTimeFormat("")
|
||||
|
||||
versions := protocol.SupportedVersions
|
||||
if *tls {
|
||||
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
||||
}
|
||||
|
||||
roundTripper := &h2quic.RoundTripper{
|
||||
QuicConfig: &quic.Config{Versions: versions},
|
||||
}
|
||||
defer roundTripper.Close()
|
||||
hclient := &http.Client{
|
||||
Transport: &h2quic.RoundTripper{},
|
||||
Transport: roundTripper,
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
|
14
vendor/github.com/lucas-clemente/quic-go/example/main.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/example/main.go
generated
vendored
|
@ -17,7 +17,9 @@ import (
|
|||
|
||||
_ "net/http/pprof"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -121,6 +123,7 @@ func main() {
|
|||
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
|
||||
www := flag.String("www", "/var/www", "www data")
|
||||
tcp := flag.Bool("tcp", false, "also listen on TCP")
|
||||
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
||||
flag.Parse()
|
||||
|
||||
if *verbose {
|
||||
|
@ -130,6 +133,11 @@ func main() {
|
|||
}
|
||||
utils.SetLogTimeFormat("")
|
||||
|
||||
versions := protocol.SupportedVersions
|
||||
if *tls {
|
||||
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
||||
}
|
||||
|
||||
certFile := *certPath + "/fullchain.pem"
|
||||
keyFile := *certPath + "/privkey.pem"
|
||||
|
||||
|
@ -148,7 +156,11 @@ func main() {
|
|||
if *tcp {
|
||||
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
|
||||
} else {
|
||||
err = h2quic.ListenAndServeQUIC(bCap, certFile, keyFile, nil)
|
||||
server := h2quic.Server{
|
||||
Server: &http.Server{Addr: bCap},
|
||||
QuicConfig: &quic.Config{Versions: versions},
|
||||
}
|
||||
err = server.ListenAndServeTLS(certFile, keyFile)
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
|
|
240
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go
generated
vendored
240
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go
generated
vendored
|
@ -1,240 +0,0 @@
|
|||
package flowcontrol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type flowControlManager struct {
|
||||
connectionParameters handshake.ConnectionParametersManager
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
streamFlowController map[protocol.StreamID]*flowController
|
||||
connFlowController *flowController
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
var _ FlowControlManager = &flowControlManager{}
|
||||
|
||||
var errMapAccess = errors.New("Error accessing the flowController map.")
|
||||
|
||||
// NewFlowControlManager creates a new flow control manager
|
||||
func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager {
|
||||
return &flowControlManager{
|
||||
connectionParameters: connectionParameters,
|
||||
rttStats: rttStats,
|
||||
streamFlowController: make(map[protocol.StreamID]*flowController),
|
||||
connFlowController: newFlowController(0, false, connectionParameters, rttStats),
|
||||
}
|
||||
}
|
||||
|
||||
// NewStream creates new flow controllers for a stream
|
||||
// it does nothing if the stream already exists
|
||||
func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesToConnection bool) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if _, ok := f.streamFlowController[streamID]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connectionParameters, f.rttStats)
|
||||
}
|
||||
|
||||
// RemoveStream removes a closed stream from flow control
|
||||
func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) {
|
||||
f.mutex.Lock()
|
||||
delete(f.streamFlowController, streamID)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// ResetStream should be called when receiving a RstStreamFrame
|
||||
// it updates the byte offset to the value in the RstStreamFrame
|
||||
// streamID must not be 0 here
|
||||
func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
streamFlowController, err := f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increment, err := streamFlowController.UpdateHighestReceived(byteOffset)
|
||||
if err != nil {
|
||||
return qerr.StreamDataAfterTermination
|
||||
}
|
||||
|
||||
if streamFlowController.CheckFlowControlViolation() {
|
||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
|
||||
}
|
||||
|
||||
if streamFlowController.ContributesToConnection() {
|
||||
f.connFlowController.IncrementHighestReceived(increment)
|
||||
if f.connFlowController.CheckFlowControlViolation() {
|
||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateHighestReceived updates the highest received byte offset for a stream
|
||||
// it adds the number of additional bytes to connection level flow control
|
||||
// streamID must not be 0 here
|
||||
func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
streamFlowController, err := f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// UpdateHighestReceived returns an ErrReceivedSmallerByteOffset when StreamFrames got reordered
|
||||
// this error can be ignored here
|
||||
increment, _ := streamFlowController.UpdateHighestReceived(byteOffset)
|
||||
|
||||
if streamFlowController.CheckFlowControlViolation() {
|
||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
|
||||
}
|
||||
|
||||
if streamFlowController.ContributesToConnection() {
|
||||
f.connFlowController.IncrementHighestReceived(increment)
|
||||
if f.connFlowController.CheckFlowControlViolation() {
|
||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamID must not be 0 here
|
||||
func (f *flowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
fc, err := f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fc.AddBytesRead(n)
|
||||
if fc.ContributesToConnection() {
|
||||
f.connFlowController.AddBytesRead(n)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
// get WindowUpdates for streams
|
||||
for id, fc := range f.streamFlowController {
|
||||
if necessary, newIncrement, offset := fc.MaybeUpdateWindow(); necessary {
|
||||
res = append(res, WindowUpdate{StreamID: id, Offset: offset})
|
||||
if fc.ContributesToConnection() && newIncrement != 0 {
|
||||
f.connFlowController.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(newIncrement) * protocol.ConnectionFlowControlMultiplier))
|
||||
}
|
||||
}
|
||||
}
|
||||
// get a WindowUpdate for the connection
|
||||
if necessary, _, offset := f.connFlowController.MaybeUpdateWindow(); necessary {
|
||||
res = append(res, WindowUpdate{StreamID: 0, Offset: offset})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) {
|
||||
f.mutex.RLock()
|
||||
defer f.mutex.RUnlock()
|
||||
|
||||
// StreamID can be 0 when retransmitting
|
||||
if streamID == 0 {
|
||||
return f.connFlowController.receiveWindow, nil
|
||||
}
|
||||
|
||||
flowController, err := f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return flowController.receiveWindow, nil
|
||||
}
|
||||
|
||||
// streamID must not be 0 here
|
||||
func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
fc, err := f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fc.AddBytesSent(n)
|
||||
if fc.ContributesToConnection() {
|
||||
f.connFlowController.AddBytesSent(n)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// must not be called with StreamID 0
|
||||
func (f *flowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) {
|
||||
f.mutex.RLock()
|
||||
defer f.mutex.RUnlock()
|
||||
|
||||
fc, err := f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
res := fc.SendWindowSize()
|
||||
|
||||
if fc.ContributesToConnection() {
|
||||
res = utils.MinByteCount(res, f.connFlowController.SendWindowSize())
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (f *flowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
|
||||
f.mutex.RLock()
|
||||
defer f.mutex.RUnlock()
|
||||
|
||||
return f.connFlowController.SendWindowSize()
|
||||
}
|
||||
|
||||
// streamID may be 0 here
|
||||
func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
var fc *flowController
|
||||
if streamID == 0 {
|
||||
fc = f.connFlowController
|
||||
} else {
|
||||
var err error
|
||||
fc, err = f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
return fc.UpdateSendWindow(offset), nil
|
||||
}
|
||||
|
||||
func (f *flowControlManager) getFlowController(streamID protocol.StreamID) (*flowController, error) {
|
||||
streamFlowController, ok := f.streamFlowController[streamID]
|
||||
if !ok {
|
||||
return nil, errMapAccess
|
||||
}
|
||||
return streamFlowController, nil
|
||||
}
|
198
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go
generated
vendored
198
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go
generated
vendored
|
@ -1,198 +0,0 @@
|
|||
package flowcontrol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
type flowController struct {
|
||||
streamID protocol.StreamID
|
||||
contributesToConnection bool // does the stream contribute to connection level flow control
|
||||
|
||||
connectionParameters handshake.ConnectionParametersManager
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
bytesSent protocol.ByteCount
|
||||
sendWindow protocol.ByteCount
|
||||
|
||||
lastWindowUpdateTime time.Time
|
||||
|
||||
bytesRead protocol.ByteCount
|
||||
highestReceived protocol.ByteCount
|
||||
receiveWindow protocol.ByteCount
|
||||
receiveWindowIncrement protocol.ByteCount
|
||||
maxReceiveWindowIncrement protocol.ByteCount
|
||||
}
|
||||
|
||||
// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously
|
||||
var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset")
|
||||
|
||||
// newFlowController gets a new flow controller
|
||||
func newFlowController(streamID protocol.StreamID, contributesToConnection bool, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController {
|
||||
fc := flowController{
|
||||
streamID: streamID,
|
||||
contributesToConnection: contributesToConnection,
|
||||
connectionParameters: connectionParameters,
|
||||
rttStats: rttStats,
|
||||
}
|
||||
|
||||
if streamID == 0 {
|
||||
fc.receiveWindow = connectionParameters.GetReceiveConnectionFlowControlWindow()
|
||||
fc.receiveWindowIncrement = fc.receiveWindow
|
||||
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow()
|
||||
} else {
|
||||
fc.receiveWindow = connectionParameters.GetReceiveStreamFlowControlWindow()
|
||||
fc.receiveWindowIncrement = fc.receiveWindow
|
||||
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow()
|
||||
}
|
||||
|
||||
return &fc
|
||||
}
|
||||
|
||||
func (c *flowController) ContributesToConnection() bool {
|
||||
return c.contributesToConnection
|
||||
}
|
||||
|
||||
func (c *flowController) getSendWindow() protocol.ByteCount {
|
||||
if c.sendWindow == 0 {
|
||||
if c.streamID == 0 {
|
||||
return c.connectionParameters.GetSendConnectionFlowControlWindow()
|
||||
}
|
||||
return c.connectionParameters.GetSendStreamFlowControlWindow()
|
||||
}
|
||||
return c.sendWindow
|
||||
}
|
||||
|
||||
func (c *flowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.bytesSent += n
|
||||
}
|
||||
|
||||
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
|
||||
// it returns true if the window was actually updated
|
||||
func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool {
|
||||
if newOffset > c.sendWindow {
|
||||
c.sendWindow = newOffset
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *flowController) SendWindowSize() protocol.ByteCount {
|
||||
sendWindow := c.getSendWindow()
|
||||
|
||||
if c.bytesSent > sendWindow { // should never happen, but make sure we don't do an underflow here
|
||||
return 0
|
||||
}
|
||||
return sendWindow - c.bytesSent
|
||||
}
|
||||
|
||||
func (c *flowController) SendWindowOffset() protocol.ByteCount {
|
||||
return c.getSendWindow()
|
||||
}
|
||||
|
||||
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
|
||||
// Should **only** be used for the stream-level FlowController
|
||||
// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before
|
||||
// This error occurs every time StreamFrames get reordered and has to be ignored in that case
|
||||
// It should only be treated as an error when resetting a stream
|
||||
func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) (protocol.ByteCount, error) {
|
||||
if byteOffset == c.highestReceived {
|
||||
return 0, nil
|
||||
}
|
||||
if byteOffset > c.highestReceived {
|
||||
increment := byteOffset - c.highestReceived
|
||||
c.highestReceived = byteOffset
|
||||
return increment, nil
|
||||
}
|
||||
return 0, ErrReceivedSmallerByteOffset
|
||||
}
|
||||
|
||||
// IncrementHighestReceived adds an increment to the highestReceived value
|
||||
// Should **only** be used for the connection-level FlowController
|
||||
func (c *flowController) IncrementHighestReceived(increment protocol.ByteCount) {
|
||||
c.highestReceived += increment
|
||||
}
|
||||
|
||||
func (c *flowController) AddBytesRead(n protocol.ByteCount) {
|
||||
// pretend we sent a WindowUpdate when reading the first byte
|
||||
// this way auto-tuning of the window increment already works for the first WindowUpdate
|
||||
if c.bytesRead == 0 {
|
||||
c.lastWindowUpdateTime = time.Now()
|
||||
}
|
||||
c.bytesRead += n
|
||||
}
|
||||
|
||||
// MaybeUpdateWindow updates the receive window, if necessary
|
||||
// if the receive window increment is changed, the new value is returned, otherwise a 0
|
||||
// the last return value is the new offset of the receive window
|
||||
func (c *flowController) MaybeUpdateWindow() (bool, protocol.ByteCount /* new increment */, protocol.ByteCount /* new offset */) {
|
||||
diff := c.receiveWindow - c.bytesRead
|
||||
|
||||
// Chromium implements the same threshold
|
||||
if diff < (c.receiveWindowIncrement / 2) {
|
||||
var newWindowIncrement protocol.ByteCount
|
||||
oldWindowIncrement := c.receiveWindowIncrement
|
||||
|
||||
c.maybeAdjustWindowIncrement()
|
||||
if c.receiveWindowIncrement != oldWindowIncrement {
|
||||
newWindowIncrement = c.receiveWindowIncrement
|
||||
}
|
||||
|
||||
c.lastWindowUpdateTime = time.Now()
|
||||
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement
|
||||
return true, newWindowIncrement, c.receiveWindow
|
||||
}
|
||||
|
||||
return false, 0, 0
|
||||
}
|
||||
|
||||
// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
|
||||
func (c *flowController) maybeAdjustWindowIncrement() {
|
||||
if c.lastWindowUpdateTime.IsZero() {
|
||||
return
|
||||
}
|
||||
|
||||
rtt := c.rttStats.SmoothedRTT()
|
||||
if rtt == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime)
|
||||
|
||||
// interval between the window updates is sufficiently large, no need to increase the increment
|
||||
if timeSinceLastWindowUpdate >= 2*rtt {
|
||||
return
|
||||
}
|
||||
|
||||
oldWindowSize := c.receiveWindowIncrement
|
||||
c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement)
|
||||
|
||||
// debug log, if the window size was actually increased
|
||||
if oldWindowSize < c.receiveWindowIncrement {
|
||||
newWindowSize := c.receiveWindowIncrement / (1 << 10)
|
||||
if c.streamID == 0 {
|
||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize)
|
||||
} else {
|
||||
utils.Debugf("Increasing receive flow control window increment for stream %d to %d kB", c.streamID, newWindowSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureMinimumWindowIncrement sets a minimum window increment
|
||||
// it is intended be used for the connection-level flow controller
|
||||
// it should make sure that the connection-level window is increased when a stream-level window grows
|
||||
func (c *flowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) {
|
||||
if inc > c.receiveWindowIncrement {
|
||||
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement)
|
||||
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
|
||||
}
|
||||
}
|
||||
|
||||
func (c *flowController) CheckFlowControlViolation() bool {
|
||||
return c.highestReceived > c.receiveWindow
|
||||
}
|
26
vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go
generated
vendored
26
vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go
generated
vendored
|
@ -1,26 +0,0 @@
|
|||
package flowcontrol
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/protocol"
|
||||
|
||||
// WindowUpdate provides the data for WindowUpdateFrames.
|
||||
type WindowUpdate struct {
|
||||
StreamID protocol.StreamID
|
||||
Offset protocol.ByteCount
|
||||
}
|
||||
|
||||
// A FlowControlManager manages the flow control
|
||||
type FlowControlManager interface {
|
||||
NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool)
|
||||
RemoveStream(streamID protocol.StreamID)
|
||||
// methods needed for receiving data
|
||||
ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
|
||||
UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
|
||||
AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error
|
||||
GetWindowUpdates() []WindowUpdate
|
||||
GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error)
|
||||
// methods needed for sending data
|
||||
AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error
|
||||
SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error)
|
||||
RemainingConnectionWindowSize() protocol.ByteCount
|
||||
UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error)
|
||||
}
|
9
vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go
generated
vendored
9
vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go
generated
vendored
|
@ -1,9 +0,0 @@
|
|||
package frames
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/protocol"
|
||||
|
||||
// AckRange is an ACK range
|
||||
type AckRange struct {
|
||||
FirstPacketNumber protocol.PacketNumber
|
||||
LastPacketNumber protocol.PacketNumber
|
||||
}
|
44
vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go
generated
vendored
44
vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go
generated
vendored
|
@ -1,44 +0,0 @@
|
|||
package frames
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
// A BlockedFrame in QUIC
|
||||
type BlockedFrame struct {
|
||||
StreamID protocol.StreamID
|
||||
}
|
||||
|
||||
//Write writes a BlockedFrame frame
|
||||
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
b.WriteByte(0x05)
|
||||
utils.WriteUint32(b, uint32(f.StreamID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
return 1 + 4, nil
|
||||
}
|
||||
|
||||
// ParseBlockedFrame parses a BLOCKED frame
|
||||
func ParseBlockedFrame(r *bytes.Reader) (*BlockedFrame, error) {
|
||||
frame := &BlockedFrame{}
|
||||
|
||||
// read the TypeByte
|
||||
_, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sid, err := utils.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.StreamID = protocol.StreamID(sid)
|
||||
|
||||
return frame, nil
|
||||
}
|
73
vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go
generated
vendored
73
vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go
generated
vendored
|
@ -1,73 +0,0 @@
|
|||
package frames
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// A ConnectionCloseFrame in QUIC
|
||||
type ConnectionCloseFrame struct {
|
||||
ErrorCode qerr.ErrorCode
|
||||
ReasonPhrase string
|
||||
}
|
||||
|
||||
// ParseConnectionCloseFrame reads a CONNECTION_CLOSE frame
|
||||
func ParseConnectionCloseFrame(r *bytes.Reader) (*ConnectionCloseFrame, error) {
|
||||
frame := &ConnectionCloseFrame{}
|
||||
|
||||
// read the TypeByte
|
||||
_, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errorCode, err := utils.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.ErrorCode = qerr.ErrorCode(errorCode)
|
||||
|
||||
reasonPhraseLen, err := utils.ReadUint16(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if reasonPhraseLen > uint16(protocol.MaxPacketSize) {
|
||||
return nil, qerr.Error(qerr.InvalidConnectionCloseData, "reason phrase too long")
|
||||
}
|
||||
|
||||
reasonPhrase := make([]byte, reasonPhraseLen)
|
||||
if _, err := io.ReadFull(r, reasonPhrase); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.ReasonPhrase = string(reasonPhrase)
|
||||
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)), nil
|
||||
}
|
||||
|
||||
// Write writes an CONNECTION_CLOSE frame.
|
||||
func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
b.WriteByte(0x02)
|
||||
utils.WriteUint32(b, uint32(f.ErrorCode))
|
||||
|
||||
if len(f.ReasonPhrase) > math.MaxUint16 {
|
||||
return errors.New("ConnectionFrame: ReasonPhrase too long")
|
||||
}
|
||||
|
||||
reasonPhraseLen := uint16(len(f.ReasonPhrase))
|
||||
utils.WriteUint16(b, reasonPhraseLen)
|
||||
b.WriteString(f.ReasonPhrase)
|
||||
|
||||
return nil
|
||||
}
|
13
vendor/github.com/lucas-clemente/quic-go/frames/frame.go
generated
vendored
13
vendor/github.com/lucas-clemente/quic-go/frames/frame.go
generated
vendored
|
@ -1,13 +0,0 @@
|
|||
package frames
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
// A Frame in QUIC
|
||||
type Frame interface {
|
||||
Write(b *bytes.Buffer, version protocol.VersionNumber) error
|
||||
MinLength(version protocol.VersionNumber) (protocol.ByteCount, error)
|
||||
}
|
28
vendor/github.com/lucas-clemente/quic-go/frames/log.go
generated
vendored
28
vendor/github.com/lucas-clemente/quic-go/frames/log.go
generated
vendored
|
@ -1,28 +0,0 @@
|
|||
package frames
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
// LogFrame logs a frame, either sent or received
|
||||
func LogFrame(frame Frame, sent bool) {
|
||||
if !utils.Debug() {
|
||||
return
|
||||
}
|
||||
dir := "<-"
|
||||
if sent {
|
||||
dir = "->"
|
||||
}
|
||||
switch f := frame.(type) {
|
||||
case *StreamFrame:
|
||||
utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen())
|
||||
case *StopWaitingFrame:
|
||||
if sent {
|
||||
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen)
|
||||
} else {
|
||||
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked)
|
||||
}
|
||||
case *AckFrame:
|
||||
utils.Debugf("\t%s &frames.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String())
|
||||
default:
|
||||
utils.Debugf("\t%s %#v", dir, frame)
|
||||
}
|
||||
}
|
59
vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go
generated
vendored
59
vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go
generated
vendored
|
@ -1,59 +0,0 @@
|
|||
package frames
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
// A RstStreamFrame in QUIC
|
||||
type RstStreamFrame struct {
|
||||
StreamID protocol.StreamID
|
||||
ErrorCode uint32
|
||||
ByteOffset protocol.ByteCount
|
||||
}
|
||||
|
||||
//Write writes a RST_STREAM frame
|
||||
func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
b.WriteByte(0x01)
|
||||
utils.WriteUint32(b, uint32(f.StreamID))
|
||||
utils.WriteUint64(b, uint64(f.ByteOffset))
|
||||
utils.WriteUint32(b, f.ErrorCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
return 1 + 4 + 8 + 4, nil
|
||||
}
|
||||
|
||||
// ParseRstStreamFrame parses a RST_STREAM frame
|
||||
func ParseRstStreamFrame(r *bytes.Reader) (*RstStreamFrame, error) {
|
||||
frame := &RstStreamFrame{}
|
||||
|
||||
// read the TypeByte
|
||||
_, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sid, err := utils.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.StreamID = protocol.StreamID(sid)
|
||||
|
||||
byteOffset, err := utils.ReadUint64(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.ByteOffset = protocol.ByteCount(byteOffset)
|
||||
|
||||
frame.ErrorCode, err = utils.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return frame, nil
|
||||
}
|
54
vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go
generated
vendored
54
vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go
generated
vendored
|
@ -1,54 +0,0 @@
|
|||
package frames
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
// A WindowUpdateFrame in QUIC
|
||||
type WindowUpdateFrame struct {
|
||||
StreamID protocol.StreamID
|
||||
ByteOffset protocol.ByteCount
|
||||
}
|
||||
|
||||
//Write writes a RST_STREAM frame
|
||||
func (f *WindowUpdateFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
typeByte := uint8(0x04)
|
||||
b.WriteByte(typeByte)
|
||||
|
||||
utils.WriteUint32(b, uint32(f.StreamID))
|
||||
utils.WriteUint64(b, uint64(f.ByteOffset))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *WindowUpdateFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
return 1 + 4 + 8, nil
|
||||
}
|
||||
|
||||
// ParseWindowUpdateFrame parses a RST_STREAM frame
|
||||
func ParseWindowUpdateFrame(r *bytes.Reader) (*WindowUpdateFrame, error) {
|
||||
frame := &WindowUpdateFrame{}
|
||||
|
||||
// read the TypeByte
|
||||
_, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sid, err := utils.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.StreamID = protocol.StreamID(sid)
|
||||
|
||||
byteOffset, err := utils.ReadUint64(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.ByteOffset = protocol.ByteCount(byteOffset)
|
||||
|
||||
return frame, nil
|
||||
}
|
109
vendor/github.com/lucas-clemente/quic-go/h2quic/client.go
generated
vendored
109
vendor/github.com/lucas-clemente/quic-go/h2quic/client.go
generated
vendored
|
@ -15,8 +15,8 @@ import (
|
|||
"golang.org/x/net/idna"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
|
@ -34,10 +34,10 @@ type client struct {
|
|||
config *quic.Config
|
||||
opts *roundTripperOpts
|
||||
|
||||
hostname string
|
||||
encryptionLevel protocol.EncryptionLevel
|
||||
handshakeErr error
|
||||
dialOnce sync.Once
|
||||
hostname string
|
||||
handshakeErr error
|
||||
dialOnce sync.Once
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||
|
||||
session quic.Session
|
||||
headerStream quic.Stream
|
||||
|
@ -51,8 +51,8 @@ type client struct {
|
|||
var _ http.RoundTripper = &client{}
|
||||
|
||||
var defaultQuicConfig = &quic.Config{
|
||||
RequestConnectionIDTruncation: true,
|
||||
KeepAlive: true,
|
||||
RequestConnectionIDOmission: true,
|
||||
KeepAlive: true,
|
||||
}
|
||||
|
||||
// newClient creates a new client
|
||||
|
@ -61,26 +61,31 @@ func newClient(
|
|||
tlsConfig *tls.Config,
|
||||
opts *roundTripperOpts,
|
||||
quicConfig *quic.Config,
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
||||
) *client {
|
||||
config := defaultQuicConfig
|
||||
if quicConfig != nil {
|
||||
config = quicConfig
|
||||
}
|
||||
return &client{
|
||||
hostname: authorityAddr("https", hostname),
|
||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||
encryptionLevel: protocol.EncryptionUnencrypted,
|
||||
tlsConf: tlsConfig,
|
||||
config: config,
|
||||
opts: opts,
|
||||
headerErrored: make(chan struct{}),
|
||||
hostname: authorityAddr("https", hostname),
|
||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||
tlsConf: tlsConfig,
|
||||
config: config,
|
||||
opts: opts,
|
||||
headerErrored: make(chan struct{}),
|
||||
dialer: dialer,
|
||||
}
|
||||
}
|
||||
|
||||
// dial dials the connection
|
||||
func (c *client) dial() error {
|
||||
var err error
|
||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
||||
if c.dialer != nil {
|
||||
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
||||
} else {
|
||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -90,9 +95,6 @@ func (c *client) dial() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.headerStream.StreamID() != 3 {
|
||||
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3")
|
||||
}
|
||||
c.requestWriter = newRequestWriter(c.headerStream)
|
||||
go c.handleHeaderStream()
|
||||
return nil
|
||||
|
@ -102,45 +104,44 @@ func (c *client) handleHeaderStream() {
|
|||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
||||
h2framer := http2.NewFramer(nil, c.headerStream)
|
||||
|
||||
var lastStream protocol.StreamID
|
||||
var err error
|
||||
for err == nil {
|
||||
err = c.readResponse(h2framer, decoder)
|
||||
}
|
||||
utils.Debugf("Error handling header stream: %s", err)
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
|
||||
// stop all running request
|
||||
close(c.headerErrored)
|
||||
}
|
||||
|
||||
for {
|
||||
frame, err := h2framer.ReadFrame()
|
||||
if err != nil {
|
||||
c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
|
||||
break
|
||||
}
|
||||
lastStream = protocol.StreamID(frame.Header().StreamID)
|
||||
hframe, ok := frame.(*http2.HeadersFrame)
|
||||
if !ok {
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")
|
||||
break
|
||||
}
|
||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
||||
if err != nil {
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")
|
||||
break
|
||||
}
|
||||
|
||||
c.mutex.RLock()
|
||||
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
||||
c.mutex.RUnlock()
|
||||
if !ok {
|
||||
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
|
||||
break
|
||||
}
|
||||
|
||||
rsp, err := responseFromHeaders(mhframe)
|
||||
if err != nil {
|
||||
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
|
||||
}
|
||||
responseChan <- rsp
|
||||
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
|
||||
frame, err := h2framer.ReadFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hframe, ok := frame.(*http2.HeadersFrame)
|
||||
if !ok {
|
||||
return errors.New("not a headers frame")
|
||||
}
|
||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read header fields: %s", err.Error())
|
||||
}
|
||||
|
||||
// stop all running request
|
||||
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
|
||||
close(c.headerErrored)
|
||||
c.mutex.RLock()
|
||||
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
||||
c.mutex.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
|
||||
}
|
||||
|
||||
rsp, err := responseFromHeaders(mhframe)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
responseChan <- rsp
|
||||
return nil
|
||||
}
|
||||
|
||||
// Roundtrip executes a request and returns a response
|
||||
|
|
2
vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go
generated
vendored
|
@ -13,8 +13,8 @@ import (
|
|||
"golang.org/x/net/lex/httplex"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
type requestWriter struct {
|
||||
|
|
4
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go
generated
vendored
4
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go
generated
vendored
|
@ -8,8 +8,8 @@ import (
|
|||
"sync"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
@ -83,7 +83,7 @@ func (w *responseWriter) Write(p []byte) (int, error) {
|
|||
|
||||
func (w *responseWriter) Flush() {}
|
||||
|
||||
// TODO: Implement a functional CloseNotify method.
|
||||
// This is a NOP. Use http.Request.Context
|
||||
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
|
||||
|
||||
// test that we implement http.Flusher
|
||||
|
|
13
vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go
generated
vendored
13
vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go
generated
vendored
|
@ -41,6 +41,11 @@ type RoundTripper struct {
|
|||
// If nil, reasonable default values will be used.
|
||||
QuicConfig *quic.Config
|
||||
|
||||
// Dial specifies an optional dial function for creating QUIC
|
||||
// connections for requests.
|
||||
// If Dial is nil, quic.DialAddr will be used.
|
||||
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||
|
||||
clients map[string]roundTripCloser
|
||||
}
|
||||
|
||||
|
@ -120,7 +125,13 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
|
|||
if onlyCached {
|
||||
return nil, ErrNoCachedConn
|
||||
}
|
||||
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
|
||||
client = newClient(
|
||||
hostname,
|
||||
r.TLSClientConfig,
|
||||
&roundTripperOpts{DisableCompression: r.DisableCompression},
|
||||
r.QuicConfig,
|
||||
r.Dial,
|
||||
)
|
||||
r.clients[hostname] = client
|
||||
}
|
||||
return client, nil
|
||||
|
|
87
vendor/github.com/lucas-clemente/quic-go/h2quic/server.go
generated
vendored
87
vendor/github.com/lucas-clemente/quic-go/h2quic/server.go
generated
vendored
|
@ -7,14 +7,14 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
@ -50,6 +50,7 @@ type Server struct {
|
|||
|
||||
listenerMutex sync.Mutex
|
||||
listener quic.Listener
|
||||
closed bool
|
||||
|
||||
supportedVersionsAsString string
|
||||
}
|
||||
|
@ -88,6 +89,10 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
|||
return errors.New("use of h2quic.Server without http.Server")
|
||||
}
|
||||
s.listenerMutex.Lock()
|
||||
if s.closed {
|
||||
s.listenerMutex.Unlock()
|
||||
return errors.New("Server is already closed")
|
||||
}
|
||||
if s.listener != nil {
|
||||
s.listenerMutex.Unlock()
|
||||
return errors.New("ListenAndServe may only be called once")
|
||||
|
@ -122,29 +127,23 @@ func (s *Server) handleHeaderStream(session streamCreator) {
|
|||
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
|
||||
return
|
||||
}
|
||||
if stream.StreamID() != 3 {
|
||||
session.Close(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3"))
|
||||
return
|
||||
}
|
||||
|
||||
hpackDecoder := hpack.NewDecoder(4096, nil)
|
||||
h2framer := http2.NewFramer(nil, stream)
|
||||
|
||||
go func() {
|
||||
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
|
||||
for {
|
||||
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
|
||||
// QuicErrors must originate from stream.Read() returning an error.
|
||||
// In this case, the session has already logged the error, so we don't
|
||||
// need to log it again.
|
||||
if _, ok := err.(*qerr.QuicError); !ok {
|
||||
utils.Errorf("error handling h2 request: %s", err.Error())
|
||||
}
|
||||
session.Close(err)
|
||||
return
|
||||
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
|
||||
for {
|
||||
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
|
||||
// QuicErrors must originate from stream.Read() returning an error.
|
||||
// In this case, the session has already logged the error, so we don't
|
||||
// need to log it again.
|
||||
if _, ok := err.(*qerr.QuicError); !ok {
|
||||
utils.Errorf("error handling h2 request: %s", err.Error())
|
||||
}
|
||||
session.Close(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
|
||||
|
@ -170,8 +169,6 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
return err
|
||||
}
|
||||
|
||||
req.RemoteAddr = session.RemoteAddr().String()
|
||||
|
||||
if utils.Debug() {
|
||||
utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
|
||||
} else {
|
||||
|
@ -187,19 +184,25 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
return nil
|
||||
}
|
||||
|
||||
var streamEnded bool
|
||||
if h2headersFrame.StreamEnded() {
|
||||
dataStream.(remoteCloser).CloseRemote(0)
|
||||
streamEnded = true
|
||||
_, _ = dataStream.Read([]byte{0}) // read the eof
|
||||
}
|
||||
|
||||
reqBody := newRequestBody(dataStream)
|
||||
req.Body = reqBody
|
||||
|
||||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
|
||||
|
||||
// handleRequest should be as non-blocking as possible to minimize
|
||||
// head-of-line blocking. Potentially blocking code is run in a separate
|
||||
// goroutine, enabling handleRequest to return before the code is executed.
|
||||
go func() {
|
||||
streamEnded := h2headersFrame.StreamEnded()
|
||||
if streamEnded {
|
||||
dataStream.(remoteCloser).CloseRemote(0)
|
||||
streamEnded = true
|
||||
_, _ = dataStream.Read([]byte{0}) // read the eof
|
||||
}
|
||||
|
||||
req = req.WithContext(dataStream.Context())
|
||||
reqBody := newRequestBody(dataStream)
|
||||
req.Body = reqBody
|
||||
|
||||
req.RemoteAddr = session.RemoteAddr().String()
|
||||
|
||||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
|
||||
|
||||
handler := s.Handler
|
||||
if handler == nil {
|
||||
handler = http.DefaultServeMux
|
||||
|
@ -225,7 +228,8 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
}
|
||||
if responseWriter.dataStream != nil {
|
||||
if !streamEnded && !reqBody.requestRead {
|
||||
responseWriter.dataStream.Reset(nil)
|
||||
// in gQUIC, the error code doesn't matter, so just use 0 here
|
||||
responseWriter.dataStream.CancelRead(0)
|
||||
}
|
||||
responseWriter.dataStream.Close()
|
||||
}
|
||||
|
@ -243,6 +247,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
func (s *Server) Close() error {
|
||||
s.listenerMutex.Lock()
|
||||
defer s.listenerMutex.Unlock()
|
||||
s.closed = true
|
||||
if s.listener != nil {
|
||||
err := s.listener.Close()
|
||||
s.listener = nil
|
||||
|
@ -279,12 +284,11 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
|
|||
}
|
||||
|
||||
if s.supportedVersionsAsString == "" {
|
||||
for i, v := range protocol.SupportedVersions {
|
||||
s.supportedVersionsAsString += strconv.Itoa(int(v))
|
||||
if i != len(protocol.SupportedVersions)-1 {
|
||||
s.supportedVersionsAsString += ","
|
||||
}
|
||||
var versions []string
|
||||
for _, v := range protocol.SupportedVersions {
|
||||
versions = append(versions, v.ToAltSvc())
|
||||
}
|
||||
s.supportedVersionsAsString = strings.Join(versions, ",")
|
||||
}
|
||||
|
||||
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
|
||||
|
@ -344,6 +348,9 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
|
|||
}
|
||||
defer tcpConn.Close()
|
||||
|
||||
tlsConn := tls.NewListener(tcpConn, config)
|
||||
defer tlsConn.Close()
|
||||
|
||||
// Start the servers
|
||||
httpServer := &http.Server{
|
||||
Addr: addr,
|
||||
|
@ -365,7 +372,7 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
|
|||
hErr := make(chan error)
|
||||
qErr := make(chan error)
|
||||
go func() {
|
||||
hErr <- httpServer.Serve(tcpConn)
|
||||
hErr <- httpServer.Serve(tlsConn)
|
||||
}()
|
||||
go func() {
|
||||
qErr <- quicServer.Serve(udpConn)
|
||||
|
|
265
vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go
generated
vendored
265
vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go
generated
vendored
|
@ -1,265 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// ConnectionParametersManager negotiates and stores the connection parameters
|
||||
// A ConnectionParametersManager can be used for a server as well as a client
|
||||
// For the server:
|
||||
// 1. call SetFromMap with the values received in the CHLO. This sets the corresponding values here, subject to negotiation
|
||||
// 2. call GetHelloMap to get the values to send in the SHLO
|
||||
// For the client:
|
||||
// 1. call GetHelloMap to get the values to send in a CHLO
|
||||
// 2. call SetFromMap with the values received in the SHLO
|
||||
type ConnectionParametersManager interface {
|
||||
SetFromMap(map[Tag][]byte) error
|
||||
GetHelloMap() (map[Tag][]byte, error)
|
||||
|
||||
GetSendStreamFlowControlWindow() protocol.ByteCount
|
||||
GetSendConnectionFlowControlWindow() protocol.ByteCount
|
||||
GetReceiveStreamFlowControlWindow() protocol.ByteCount
|
||||
GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount
|
||||
GetReceiveConnectionFlowControlWindow() protocol.ByteCount
|
||||
GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount
|
||||
GetMaxOutgoingStreams() uint32
|
||||
GetMaxIncomingStreams() uint32
|
||||
GetIdleConnectionStateLifetime() time.Duration
|
||||
TruncateConnectionID() bool
|
||||
}
|
||||
|
||||
type connectionParametersManager struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
version protocol.VersionNumber
|
||||
perspective protocol.Perspective
|
||||
|
||||
flowControlNegotiated bool
|
||||
|
||||
truncateConnectionID bool
|
||||
maxStreamsPerConnection uint32
|
||||
maxIncomingDynamicStreamsPerConnection uint32
|
||||
idleConnectionStateLifetime time.Duration
|
||||
sendStreamFlowControlWindow protocol.ByteCount
|
||||
sendConnectionFlowControlWindow protocol.ByteCount
|
||||
receiveStreamFlowControlWindow protocol.ByteCount
|
||||
receiveConnectionFlowControlWindow protocol.ByteCount
|
||||
maxReceiveStreamFlowControlWindow protocol.ByteCount
|
||||
maxReceiveConnectionFlowControlWindow protocol.ByteCount
|
||||
}
|
||||
|
||||
var _ ConnectionParametersManager = &connectionParametersManager{}
|
||||
|
||||
// ErrMalformedTag is returned when the tag value cannot be read
|
||||
var (
|
||||
ErrMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
|
||||
ErrFlowControlRenegotiationNotSupported = qerr.Error(qerr.InvalidCryptoMessageParameter, "renegotiation of flow control parameters not supported")
|
||||
)
|
||||
|
||||
// NewConnectionParamatersManager creates a new connection parameters manager
|
||||
func NewConnectionParamatersManager(
|
||||
pers protocol.Perspective, v protocol.VersionNumber,
|
||||
maxReceiveStreamFlowControlWindow protocol.ByteCount, maxReceiveConnectionFlowControlWindow protocol.ByteCount,
|
||||
) ConnectionParametersManager {
|
||||
h := &connectionParametersManager{
|
||||
perspective: pers,
|
||||
version: v,
|
||||
sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client
|
||||
sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client
|
||||
receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||
maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||
}
|
||||
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
h.idleConnectionStateLifetime = protocol.DefaultIdleTimeout
|
||||
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
|
||||
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective
|
||||
} else {
|
||||
h.idleConnectionStateLifetime = protocol.MaxIdleTimeoutClient
|
||||
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
|
||||
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the server's perspective
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// SetFromMap reads all params
|
||||
func (h *connectionParametersManager) SetFromMap(params map[Tag][]byte) error {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if value, ok := params[TagTCID]; ok && h.perspective == protocol.PerspectiveServer {
|
||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return ErrMalformedTag
|
||||
}
|
||||
h.truncateConnectionID = (clientValue == 0)
|
||||
}
|
||||
if value, ok := params[TagMSPC]; ok {
|
||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return ErrMalformedTag
|
||||
}
|
||||
h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue)
|
||||
}
|
||||
if value, ok := params[TagMIDS]; ok {
|
||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return ErrMalformedTag
|
||||
}
|
||||
h.maxIncomingDynamicStreamsPerConnection = h.negotiateMaxIncomingDynamicStreamsPerConnection(clientValue)
|
||||
}
|
||||
if value, ok := params[TagICSL]; ok {
|
||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return ErrMalformedTag
|
||||
}
|
||||
h.idleConnectionStateLifetime = h.negotiateIdleConnectionStateLifetime(time.Duration(clientValue) * time.Second)
|
||||
}
|
||||
if value, ok := params[TagSFCW]; ok {
|
||||
if h.flowControlNegotiated {
|
||||
return ErrFlowControlRenegotiationNotSupported
|
||||
}
|
||||
sendStreamFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return ErrMalformedTag
|
||||
}
|
||||
h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow)
|
||||
}
|
||||
if value, ok := params[TagCFCW]; ok {
|
||||
if h.flowControlNegotiated {
|
||||
return ErrFlowControlRenegotiationNotSupported
|
||||
}
|
||||
sendConnectionFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return ErrMalformedTag
|
||||
}
|
||||
h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow)
|
||||
}
|
||||
|
||||
_, containsSFCW := params[TagSFCW]
|
||||
_, containsCFCW := params[TagCFCW]
|
||||
if containsCFCW || containsSFCW {
|
||||
h.flowControlNegotiated = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *connectionParametersManager) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 {
|
||||
return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection)
|
||||
}
|
||||
|
||||
func (h *connectionParametersManager) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 {
|
||||
return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection)
|
||||
}
|
||||
|
||||
func (h *connectionParametersManager) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration {
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutServer)
|
||||
}
|
||||
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutClient)
|
||||
}
|
||||
|
||||
// GetHelloMap gets all parameters needed for the Hello message
|
||||
func (h *connectionParametersManager) GetHelloMap() (map[Tag][]byte, error) {
|
||||
sfcw := bytes.NewBuffer([]byte{})
|
||||
utils.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow()))
|
||||
cfcw := bytes.NewBuffer([]byte{})
|
||||
utils.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow()))
|
||||
mspc := bytes.NewBuffer([]byte{})
|
||||
utils.WriteUint32(mspc, h.maxStreamsPerConnection)
|
||||
mids := bytes.NewBuffer([]byte{})
|
||||
utils.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection)
|
||||
icsl := bytes.NewBuffer([]byte{})
|
||||
utils.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second))
|
||||
|
||||
return map[Tag][]byte{
|
||||
TagICSL: icsl.Bytes(),
|
||||
TagMSPC: mspc.Bytes(),
|
||||
TagMIDS: mids.Bytes(),
|
||||
TagCFCW: cfcw.Bytes(),
|
||||
TagSFCW: sfcw.Bytes(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data
|
||||
func (h *connectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.sendStreamFlowControlWindow
|
||||
}
|
||||
|
||||
// GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data
|
||||
func (h *connectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.sendConnectionFlowControlWindow
|
||||
}
|
||||
|
||||
// GetReceiveStreamFlowControlWindow gets the size of the stream-level flow control window for receiving data
|
||||
func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.receiveStreamFlowControlWindow
|
||||
}
|
||||
|
||||
// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
|
||||
func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
|
||||
return h.maxReceiveStreamFlowControlWindow
|
||||
}
|
||||
|
||||
// GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data
|
||||
func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.receiveConnectionFlowControlWindow
|
||||
}
|
||||
|
||||
// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
|
||||
func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
|
||||
return h.maxReceiveConnectionFlowControlWindow
|
||||
}
|
||||
|
||||
// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection
|
||||
func (h *connectionParametersManager) GetMaxOutgoingStreams() uint32 {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
return h.maxIncomingDynamicStreamsPerConnection
|
||||
}
|
||||
|
||||
// GetMaxIncomingStreams get the maximum number of incoming streams per connection
|
||||
func (h *connectionParametersManager) GetMaxIncomingStreams() uint32 {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
maxStreams := protocol.MaxIncomingDynamicStreamsPerConnection
|
||||
return utils.MaxUint32(uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, uint32(float64(maxStreams)*protocol.MaxStreamsMultiplier))
|
||||
}
|
||||
|
||||
// GetIdleConnectionStateLifetime gets the idle timeout
|
||||
func (h *connectionParametersManager) GetIdleConnectionStateLifetime() time.Duration {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.idleConnectionStateLifetime
|
||||
}
|
||||
|
||||
// TruncateConnectionID determines if the client requests truncated ConnectionIDs
|
||||
func (h *connectionParametersManager) TruncateConnectionID() bool {
|
||||
if h.perspective == protocol.PerspectiveClient {
|
||||
return false
|
||||
}
|
||||
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.truncateConnectionID
|
||||
}
|
24
vendor/github.com/lucas-clemente/quic-go/handshake/interface.go
generated
vendored
24
vendor/github.com/lucas-clemente/quic-go/handshake/interface.go
generated
vendored
|
@ -1,24 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/protocol"
|
||||
|
||||
// Sealer seals a packet
|
||||
type Sealer func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
||||
|
||||
// CryptoSetup is a crypto setup
|
||||
type CryptoSetup interface {
|
||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
||||
HandleCryptoStream() error
|
||||
// TODO: clean up this interface
|
||||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
||||
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
||||
|
||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
||||
}
|
||||
|
||||
// TransportParameters are parameters sent to the peer during the handshake
|
||||
type TransportParameters struct {
|
||||
RequestConnectionIDTruncation bool
|
||||
}
|
100
vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go
generated
vendored
100
vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go
generated
vendored
|
@ -1,100 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
stkPrefixIP byte = iota
|
||||
stkPrefixString
|
||||
)
|
||||
|
||||
// An STK is a source address token
|
||||
type STK struct {
|
||||
RemoteAddr string
|
||||
SentTime time.Time
|
||||
}
|
||||
|
||||
// token is the struct that is used for ASN1 serialization and deserialization
|
||||
type token struct {
|
||||
Data []byte
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
// An STKGenerator generates STKs
|
||||
type STKGenerator struct {
|
||||
stkSource crypto.StkSource
|
||||
}
|
||||
|
||||
// NewSTKGenerator initializes a new STKGenerator
|
||||
func NewSTKGenerator() (*STKGenerator, error) {
|
||||
stkSource, err := crypto.NewStkSource()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &STKGenerator{
|
||||
stkSource: stkSource,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewToken generates a new STK token for a given source address
|
||||
func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
||||
data, err := asn1.Marshal(token{
|
||||
Data: encodeRemoteAddr(raddr),
|
||||
Timestamp: time.Now().Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.stkSource.NewToken(data)
|
||||
}
|
||||
|
||||
// DecodeToken decodes an STK token
|
||||
func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) {
|
||||
// if the client didn't send any STK, DecodeToken will be called with a nil-slice
|
||||
if len(encrypted) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := g.stkSource.DecodeToken(encrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := &token{}
|
||||
rest, err := asn1.Unmarshal(data, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rest) != 0 {
|
||||
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
|
||||
}
|
||||
return &STK{
|
||||
RemoteAddr: decodeRemoteAddr(t.Data),
|
||||
SentTime: time.Unix(t.Timestamp, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
|
||||
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
|
||||
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
|
||||
return append([]byte{stkPrefixIP}, udpAddr.IP...)
|
||||
}
|
||||
return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...)
|
||||
}
|
||||
|
||||
// decodeRemoteAddr decodes the remote address saved in the STK
|
||||
func decodeRemoteAddr(data []byte) string {
|
||||
// data will never be empty for an STK that we generated. Check it to be on the safe side
|
||||
if len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
if data[0] == stkPrefixIP {
|
||||
return net.IP(data[1:]).String()
|
||||
}
|
||||
return string(data[1:])
|
||||
}
|
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/chrome/chrome.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/chrome/chrome.go
generated
vendored
|
@ -1 +0,0 @@
|
|||
package chrome
|
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/gquic/gquic.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/gquic/gquic.go
generated
vendored
|
@ -1 +0,0 @@
|
|||
package gquic
|
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/self/self.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/self/self.go
generated
vendored
|
@ -1 +0,0 @@
|
|||
package self
|
73
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go
generated
vendored
73
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go
generated
vendored
|
@ -1,14 +1,12 @@
|
|||
package quicproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// Connection is a UDP connection
|
||||
|
@ -28,21 +26,43 @@ const (
|
|||
DirectionIncoming Direction = iota
|
||||
// DirectionOutgoing is the direction from the server to the client.
|
||||
DirectionOutgoing
|
||||
// DirectionBoth is both incoming and outgoing
|
||||
DirectionBoth
|
||||
)
|
||||
|
||||
func (d Direction) String() string {
|
||||
switch d {
|
||||
case DirectionIncoming:
|
||||
return "incoming"
|
||||
case DirectionOutgoing:
|
||||
return "outgoing"
|
||||
case DirectionBoth:
|
||||
return "both"
|
||||
default:
|
||||
panic("unknown direction")
|
||||
}
|
||||
}
|
||||
|
||||
func (d Direction) Is(dir Direction) bool {
|
||||
if d == DirectionBoth || dir == DirectionBoth {
|
||||
return true
|
||||
}
|
||||
return d == dir
|
||||
}
|
||||
|
||||
// DropCallback is a callback that determines which packet gets dropped.
|
||||
type DropCallback func(Direction, protocol.PacketNumber) bool
|
||||
type DropCallback func(dir Direction, packetCount uint64) bool
|
||||
|
||||
// NoDropper doesn't drop packets.
|
||||
var NoDropper DropCallback = func(Direction, protocol.PacketNumber) bool {
|
||||
var NoDropper DropCallback = func(Direction, uint64) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// DelayCallback is a callback that determines how much delay to apply to a packet.
|
||||
type DelayCallback func(Direction, protocol.PacketNumber) time.Duration
|
||||
type DelayCallback func(dir Direction, packetCount uint64) time.Duration
|
||||
|
||||
// NoDelay doesn't apply a delay.
|
||||
var NoDelay DelayCallback = func(Direction, protocol.PacketNumber) time.Duration {
|
||||
var NoDelay DelayCallback = func(Direction, uint64) time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
|
@ -62,6 +82,8 @@ type Opts struct {
|
|||
type QuicProxy struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
version protocol.VersionNumber
|
||||
|
||||
conn *net.UDPConn
|
||||
serverAddr *net.UDPAddr
|
||||
|
||||
|
@ -73,7 +95,10 @@ type QuicProxy struct {
|
|||
}
|
||||
|
||||
// NewQuicProxy creates a new UDP proxy
|
||||
func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) {
|
||||
func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) {
|
||||
if opts == nil {
|
||||
opts = &Opts{}
|
||||
}
|
||||
laddr, err := net.ResolveUDPAddr("udp", local)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -103,6 +128,7 @@ func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) {
|
|||
serverAddr: raddr,
|
||||
dropPacket: packetDropper,
|
||||
delayPacket: packetDelayer,
|
||||
version: version,
|
||||
}
|
||||
|
||||
go p.runProxy()
|
||||
|
@ -119,6 +145,7 @@ func (p *QuicProxy) LocalAddr() net.Addr {
|
|||
return p.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// LocalPort is the UDP port number the proxy is listening on.
|
||||
func (p *QuicProxy) LocalPort() int {
|
||||
return p.conn.LocalAddr().(*net.UDPAddr).Port
|
||||
}
|
||||
|
@ -137,7 +164,7 @@ func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
|||
// runProxy listens on the proxy address and handles incoming packets.
|
||||
func (p *QuicProxy) runProxy() error {
|
||||
for {
|
||||
buffer := make([]byte, protocol.MaxPacketSize)
|
||||
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -159,20 +186,14 @@ func (p *QuicProxy) runProxy() error {
|
|||
}
|
||||
p.mutex.Unlock()
|
||||
|
||||
atomic.AddUint64(&conn.incomingPacketCounter, 1)
|
||||
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
|
||||
|
||||
r := bytes.NewReader(raw)
|
||||
hdr, err := quic.ParsePublicHeader(r, protocol.PerspectiveClient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.dropPacket(DirectionIncoming, hdr.PacketNumber) {
|
||||
if p.dropPacket(DirectionIncoming, packetCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet to the server
|
||||
delay := p.delayPacket(DirectionIncoming, hdr.PacketNumber)
|
||||
delay := p.delayPacket(DirectionIncoming, packetCount)
|
||||
if delay != 0 {
|
||||
time.AfterFunc(delay, func() {
|
||||
// TODO: handle error
|
||||
|
@ -190,28 +211,20 @@ func (p *QuicProxy) runProxy() error {
|
|||
// runConnection handles packets from server to a single client
|
||||
func (p *QuicProxy) runConnection(conn *connection) error {
|
||||
for {
|
||||
buffer := make([]byte, protocol.MaxPacketSize)
|
||||
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||
n, err := conn.ServerConn.Read(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw := buffer[0:n]
|
||||
|
||||
// TODO: Switch back to using the public header once Chrome properly sets the type byte.
|
||||
// r := bytes.NewReader(raw)
|
||||
// , err := quic.ParsePublicHeader(r, protocol.PerspectiveServer)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
|
||||
|
||||
v := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
|
||||
|
||||
packetNumber := protocol.PacketNumber(v)
|
||||
if p.dropPacket(DirectionOutgoing, packetNumber) {
|
||||
if p.dropPacket(DirectionOutgoing, packetCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
delay := p.delayPacket(DirectionOutgoing, packetNumber)
|
||||
delay := p.delayPacket(DirectionOutgoing, packetCount)
|
||||
if delay != 0 {
|
||||
time.AfterFunc(delay, func() {
|
||||
// TODO: handle error
|
||||
|
|
2
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go
generated
vendored
|
@ -27,7 +27,7 @@ var _ = BeforeEach(func() {
|
|||
|
||||
if len(logFileName) > 0 {
|
||||
var err error
|
||||
logFile, err = os.Create("./log.txt")
|
||||
logFile, err = os.Create(logFileName)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
log.SetOutput(logFile)
|
||||
utils.SetLogLevel(utils.LogLevelDebug)
|
||||
|
|
18
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go
generated
vendored
18
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go
generated
vendored
|
@ -7,7 +7,9 @@ import (
|
|||
"net/http"
|
||||
"strconv"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
|
@ -23,8 +25,9 @@ var (
|
|||
PRData = GeneratePRData(dataLen)
|
||||
PRDataLong = GeneratePRData(dataLenLong)
|
||||
|
||||
server *h2quic.Server
|
||||
port string
|
||||
server *h2quic.Server
|
||||
stoppedServing chan struct{}
|
||||
port string
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -75,11 +78,16 @@ func GeneratePRData(l int) []byte {
|
|||
return res
|
||||
}
|
||||
|
||||
func StartQuicServer() {
|
||||
// StartQuicServer starts a h2quic.Server.
|
||||
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
|
||||
func StartQuicServer(versions []protocol.VersionNumber) {
|
||||
server = &h2quic.Server{
|
||||
Server: &http.Server{
|
||||
TLSConfig: testdata.GetTLSConfig(),
|
||||
},
|
||||
QuicConfig: &quic.Config{
|
||||
Versions: versions,
|
||||
},
|
||||
}
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
|
||||
|
@ -88,14 +96,18 @@ func StartQuicServer() {
|
|||
Expect(err).NotTo(HaveOccurred())
|
||||
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
|
||||
|
||||
stoppedServing = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
server.Serve(conn)
|
||||
close(stoppedServing)
|
||||
}()
|
||||
}
|
||||
|
||||
func StopQuicServer() {
|
||||
Expect(server.Close()).NotTo(HaveOccurred())
|
||||
Eventually(stoppedServing).Should(BeClosed())
|
||||
}
|
||||
|
||||
func Port() string {
|
||||
|
|
121
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
121
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
|
@ -6,23 +6,55 @@ import (
|
|||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// The StreamID is the ID of a QUIC stream.
|
||||
type StreamID = protocol.StreamID
|
||||
|
||||
// A VersionNumber is a QUIC version number.
|
||||
type VersionNumber = protocol.VersionNumber
|
||||
|
||||
// A Cookie can be used to verify the ownership of the client address.
|
||||
type Cookie = handshake.Cookie
|
||||
|
||||
// ConnectionState records basic details about the QUIC connection.
|
||||
type ConnectionState = handshake.ConnectionState
|
||||
|
||||
// An ErrorCode is an application-defined error code.
|
||||
type ErrorCode = protocol.ApplicationErrorCode
|
||||
|
||||
// Stream is the interface implemented by QUIC streams
|
||||
type Stream interface {
|
||||
// StreamID returns the stream ID.
|
||||
StreamID() StreamID
|
||||
// Read reads data from the stream.
|
||||
// Read can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
io.Reader
|
||||
// Write writes data to the stream.
|
||||
// Write can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
io.Writer
|
||||
// Close closes the write-direction of the stream.
|
||||
// Future calls to Write are not permitted after calling Close.
|
||||
// It must not be called concurrently with Write.
|
||||
// It must not be called after calling CancelWrite.
|
||||
io.Closer
|
||||
StreamID() protocol.StreamID
|
||||
// Reset closes the stream with an error.
|
||||
Reset(error)
|
||||
// CancelWrite aborts sending on this stream.
|
||||
// It must not be called after Close.
|
||||
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
|
||||
// Write will unblock immediately, and future calls to Write will fail.
|
||||
CancelWrite(ErrorCode) error
|
||||
// CancelRead aborts receiving on this stream.
|
||||
// It will ask the peer to stop transmitting stream data.
|
||||
// Read will unblock immediately, and future Read calls will fail.
|
||||
CancelRead(ErrorCode) error
|
||||
// The context is canceled as soon as the write-side of the stream is closed.
|
||||
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
|
@ -43,6 +75,41 @@ type Stream interface {
|
|||
SetDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A ReceiveStream is a unidirectional Receive Stream.
|
||||
type ReceiveStream interface {
|
||||
// see Stream.StreamID
|
||||
StreamID() StreamID
|
||||
// see Stream.Read
|
||||
io.Reader
|
||||
// see Stream.CancelRead
|
||||
CancelRead(ErrorCode) error
|
||||
// see Stream.SetReadDealine
|
||||
SetReadDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A SendStream is a unidirectional Send Stream.
|
||||
type SendStream interface {
|
||||
// see Stream.StreamID
|
||||
StreamID() StreamID
|
||||
// see Stream.Write
|
||||
io.Writer
|
||||
// see Stream.Close
|
||||
io.Closer
|
||||
// see Stream.CancelWrite
|
||||
CancelWrite(ErrorCode) error
|
||||
// see Stream.Context
|
||||
Context() context.Context
|
||||
// see Stream.SetWriteDeadline
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// StreamError is returned by Read and Write when the peer cancels the stream.
|
||||
type StreamError interface {
|
||||
error
|
||||
Canceled() bool
|
||||
ErrorCode() ErrorCode
|
||||
}
|
||||
|
||||
// A Session is a QUIC connection between two peers.
|
||||
type Session interface {
|
||||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||
|
@ -64,53 +131,41 @@ type Session interface {
|
|||
// The context is cancelled when the session is closed.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
Context() context.Context
|
||||
}
|
||||
|
||||
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
|
||||
// The communication is encrypted, but not yet forward secure.
|
||||
type NonFWSession interface {
|
||||
Session
|
||||
WaitUntilHandshakeComplete() error
|
||||
}
|
||||
|
||||
// An STK is a Source Address token.
|
||||
// It is issued by the server and sent to the client. For the client, it is an opaque blob.
|
||||
// The client can send the STK in subsequent handshakes to prove ownership of its IP address.
|
||||
type STK struct {
|
||||
// The remote address this token was issued for.
|
||||
// If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String())
|
||||
// Otherwise, this is the string representation of the net.Addr (net.Addr.String())
|
||||
remoteAddr string
|
||||
// The time that the STK was issued (resolution 1 second)
|
||||
sentTime time.Time
|
||||
// ConnectionState returns basic details about the QUIC connection.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
ConnectionState() ConnectionState
|
||||
}
|
||||
|
||||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
|
||||
type Config struct {
|
||||
// The QUIC versions that can be negotiated.
|
||||
// If not set, it uses all versions available.
|
||||
// Warning: This API should not be considered stable and will change soon.
|
||||
Versions []protocol.VersionNumber
|
||||
// Ask the server to truncate the connection ID sent in the Public Header.
|
||||
Versions []VersionNumber
|
||||
// Ask the server to omit the connection ID sent in the Public Header.
|
||||
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
||||
// Currently only valid for the client.
|
||||
RequestConnectionIDTruncation bool
|
||||
RequestConnectionIDOmission bool
|
||||
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
|
||||
// If the timeout is exceeded, the connection is closed.
|
||||
// If this value is zero, the timeout is set to 10 seconds.
|
||||
HandshakeTimeout time.Duration
|
||||
// AcceptSTK determines if an STK is accepted.
|
||||
// It is called with stk = nil if the client didn't send an STK.
|
||||
// If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours.
|
||||
// IdleTimeout is the maximum duration that may pass without any incoming network activity.
|
||||
// This value only applies after the handshake has completed.
|
||||
// If the timeout is exceeded, the connection is closed.
|
||||
// If this value is zero, the timeout is set to 30 seconds.
|
||||
IdleTimeout time.Duration
|
||||
// AcceptCookie determines if a Cookie is accepted.
|
||||
// It is called with cookie = nil if the client didn't send an Cookie.
|
||||
// If not set, it verifies that the address matches, and that the Cookie was issued within the last 24 hours.
|
||||
// This option is only valid for the server.
|
||||
AcceptSTK func(clientAddr net.Addr, stk *STK) bool
|
||||
AcceptCookie func(clientAddr net.Addr, cookie *Cookie) bool
|
||||
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
|
||||
MaxReceiveStreamFlowControlWindow protocol.ByteCount
|
||||
MaxReceiveStreamFlowControlWindow uint64
|
||||
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
|
||||
MaxReceiveConnectionFlowControlWindow protocol.ByteCount
|
||||
MaxReceiveConnectionFlowControlWindow uint64
|
||||
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
||||
KeepAlive bool
|
||||
}
|
||||
|
|
48
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
48
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
|
@ -0,0 +1,48 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// SentPacketHandler handles ACKs received for outgoing packets
|
||||
type SentPacketHandler interface {
|
||||
// SentPacket may modify the packet
|
||||
SentPacket(packet *Packet) error
|
||||
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
|
||||
SetHandshakeComplete()
|
||||
|
||||
// SendingAllowed says if a packet can be sent.
|
||||
// Sending packets might not be possible because:
|
||||
// * we're congestion limited
|
||||
// * we're tracking the maximum number of sent packets
|
||||
SendingAllowed() bool
|
||||
// TimeUntilSend is the time when the next packet should be sent.
|
||||
// It is used for pacing packets.
|
||||
TimeUntilSend() time.Time
|
||||
// ShouldSendNumPackets returns the number of packets that should be sent immediately.
|
||||
// It always returns a number greater or equal than 1.
|
||||
// A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay.
|
||||
// Note that the number of packets is only calculated based on the pacing algorithm.
|
||||
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
|
||||
ShouldSendNumPackets() int
|
||||
|
||||
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
||||
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||
DequeuePacketForRetransmission() (packet *Packet)
|
||||
GetLeastUnacked() protocol.PacketNumber
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
OnAlarm()
|
||||
}
|
||||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
type ReceivedPacketHandler interface {
|
||||
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
|
||||
IgnoreBelow(protocol.PacketNumber)
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
GetAckFrame() *wire.AckFrame
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue