From e2997ac974d194c0fef32a2c98d07a84237876fd Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Fri, 2 Feb 2018 19:59:28 -0700 Subject: [PATCH 01/19] request_id: Allow reusing ID from header (closes #2012) --- caddyhttp/requestid/requestid.go | 24 +++++++++-- caddyhttp/requestid/requestid_test.go | 59 ++++++++++++++++++--------- caddyhttp/requestid/setup.go | 9 +++- caddyhttp/requestid/setup_test.go | 10 ++++- 4 files changed, 76 insertions(+), 26 deletions(-) diff --git a/caddyhttp/requestid/requestid.go b/caddyhttp/requestid/requestid.go index c3f69267f..b03c449f6 100644 --- a/caddyhttp/requestid/requestid.go +++ b/caddyhttp/requestid/requestid.go @@ -16,6 +16,7 @@ package requestid import ( "context" + "log" "net/http" "github.com/google/uuid" @@ -24,12 +25,29 @@ import ( // Handler is a middleware handler type Handler struct { - Next httpserver.Handler + Next httpserver.Handler + HeaderName string // (optional) header from which to read an existing ID } func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - reqid := uuid.New().String() - c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid) + var reqid uuid.UUID + + uuidFromHeader := r.Header.Get(h.HeaderName) + if h.HeaderName != "" && uuidFromHeader != "" { + // use the ID in the header field if it exists + var err error + reqid, err = uuid.Parse(uuidFromHeader) + if err != nil { + log.Printf("[NOTICE] Parsing request ID from %s header: %v", h.HeaderName, err) + reqid = uuid.New() + } + } else { + // otherwise, create a new one + reqid = uuid.New() + } + + // set the request ID on the context + c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid.String()) r = r.WithContext(c) return h.Next.ServeHTTP(w, r) diff --git a/caddyhttp/requestid/requestid_test.go b/caddyhttp/requestid/requestid_test.go index 80968221f..e68c8d2c0 100644 --- a/caddyhttp/requestid/requestid_test.go +++ b/caddyhttp/requestid/requestid_test.go @@ -15,34 +15,53 @@ package requestid import ( - "context" "net/http" + "net/http/httptest" "testing" - "github.com/google/uuid" "github.com/mholt/caddy/caddyhttp/httpserver" ) -func TestRequestID(t *testing.T) { - request, err := http.NewRequest("GET", "http://localhost/", nil) +func TestRequestIDHandler(t *testing.T) { + handler := Handler{ + Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string) + if value == "" { + t.Error("Request ID should not be empty") + } + return 0, nil + }), + } + + req, err := http.NewRequest("GET", "http://localhost/", nil) if err != nil { t.Fatal("Could not create HTTP request:", err) } + rec := httptest.NewRecorder() - reqid := uuid.New().String() - - c := context.WithValue(request.Context(), httpserver.RequestIDCtxKey, reqid) - - request = request.WithContext(c) - - // See caddyhttp/replacer.go - value, _ := request.Context().Value(httpserver.RequestIDCtxKey).(string) - - if value == "" { - t.Fatal("Request ID should not be empty") - } - - if value != reqid { - t.Fatal("Request ID does not match") - } + handler.ServeHTTP(rec, req) +} + +func TestRequestIDFromHeader(t *testing.T) { + headerName := "X-Request-ID" + headerValue := "71a75329-d9f9-4d25-957e-e689a7b68d78" + handler := Handler{ + Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string) + if value != headerValue { + t.Errorf("Request ID should be '%s' but got '%s'", headerValue, value) + } + return 0, nil + }), + HeaderName: headerName, + } + + req, err := http.NewRequest("GET", "http://localhost/", nil) + if err != nil { + t.Fatal("Could not create HTTP request:", err) + } + req.Header.Set(headerName, headerValue) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) } diff --git a/caddyhttp/requestid/setup.go b/caddyhttp/requestid/setup.go index 4da5a3683..689f99e33 100644 --- a/caddyhttp/requestid/setup.go +++ b/caddyhttp/requestid/setup.go @@ -27,14 +27,19 @@ func init() { } func setup(c *caddy.Controller) error { + var headerName string + for c.Next() { if c.NextArg() { - return c.ArgErr() //no arg expected. + headerName = c.Val() + } + if c.NextArg() { + return c.ArgErr() } } httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { - return Handler{Next: next} + return Handler{Next: next, HeaderName: headerName} }) return nil diff --git a/caddyhttp/requestid/setup_test.go b/caddyhttp/requestid/setup_test.go index aea123694..9c420787b 100644 --- a/caddyhttp/requestid/setup_test.go +++ b/caddyhttp/requestid/setup_test.go @@ -45,7 +45,15 @@ func TestSetup(t *testing.T) { } func TestSetupWithArg(t *testing.T) { - c := caddy.NewTestController("http", `requestid abc`) + c := caddy.NewTestController("http", `requestid X-Request-ID`) + err := setup(c) + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } +} + +func TestSetupWithTooManyArgs(t *testing.T) { + c := caddy.NewTestController("http", `requestid foo bar`) err := setup(c) if err == nil { t.Errorf("Expected an error, got: %v", err) From fc6d62286ea3d6c3e40a5e1942c40eaa0ee64330 Mon Sep 17 00:00:00 2001 From: Tw Date: Sat, 3 Feb 2018 14:52:53 +0800 Subject: [PATCH 02/19] make eventHooks thread safe (Go 1.9) (#2009) Signed-off-by: Tw --- plugins.go | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/plugins.go b/plugins.go index f5372184e..d95177816 100644 --- a/plugins.go +++ b/plugins.go @@ -19,6 +19,7 @@ import ( "log" "net" "sort" + "sync" "github.com/mholt/caddy/caddyfile" ) @@ -38,7 +39,7 @@ var ( // eventHooks is a map of hook name to Hook. All hooks plugins // must have a name. - eventHooks = make(map[string]EventHook) + eventHooks = sync.Map{} // parsingCallbacks maps server type to map of directive // to list of callback functions. These aren't really @@ -67,12 +68,15 @@ func DescribePlugins() string { str += " " + defaultCaddyfileLoader.name + "\n" } - if len(eventHooks) > 0 { - // List the event hook plugins + // List the event hook plugins + hooks := "" + eventHooks.Range(func(k, _ interface{}) bool { + hooks += " hook." + k.(string) + "\n" + return true + }) + if hooks != "" { str += "\nEvent hook plugins:\n" - for hookPlugin := range eventHooks { - str += " hook." + hookPlugin + "\n" - } + str += hooks } // Let's alphabetize the rest of these... @@ -248,23 +252,23 @@ func RegisterEventHook(name string, hook EventHook) { if name == "" { panic("event hook must have a name") } - if _, dup := eventHooks[name]; dup { + _, dup := eventHooks.LoadOrStore(name, hook) + if dup { panic("hook named " + name + " already registered") } - eventHooks[name] = hook } // EmitEvent executes the different hooks passing the EventType as an // argument. This is a blocking function. Hook developers should // use 'go' keyword if they don't want to block Caddy. func EmitEvent(event EventName, info interface{}) { - for name, hook := range eventHooks { - err := hook(event, info) - + eventHooks.Range(func(k, v interface{}) bool { + err := v.(EventHook)(event, info) if err != nil { - log.Printf("error on '%s' hook: %v", name, err) + log.Printf("error on '%s' hook: %v", k.(string), err) } - } + return true + }) } // ParsingCallback is a function that is called after From e20779e40514b14c7e2471eaab22b8b9503fd4a5 Mon Sep 17 00:00:00 2001 From: Phillipp Engelke Date: Sat, 3 Feb 2018 07:53:40 +0100 Subject: [PATCH 03/19] Update README.md (#2004) Adding the bash command for downloading the caddy.service file from the reposetory. Because it was easy to forget where you find it. --- dist/init/linux-systemd/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/dist/init/linux-systemd/README.md b/dist/init/linux-systemd/README.md index aa3dd75a3..be548ae4d 100644 --- a/dist/init/linux-systemd/README.md +++ b/dist/init/linux-systemd/README.md @@ -91,6 +91,7 @@ Install the systemd service unit configuration file, reload the systemd daemon, and start caddy: ```bash +wget https://raw.githubusercontent.com/mholt/caddy/master/dist/init/linux-systemd/caddy.service sudo cp caddy.service /etc/systemd/system/ sudo chown root:root /etc/systemd/system/caddy.service sudo chmod 644 /etc/systemd/system/caddy.service From fd3fafa50caf0dcbe695d28b48198a1e2bf810bd Mon Sep 17 00:00:00 2001 From: magikstm Date: Sat, 3 Feb 2018 13:13:23 -0500 Subject: [PATCH 04/19] Disable PrivateDevices in systemd as it doesn't work for some devices (#1990) --- dist/init/linux-systemd/caddy.service | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dist/init/linux-systemd/caddy.service b/dist/init/linux-systemd/caddy.service index 649ec9556..61b70b1f3 100644 --- a/dist/init/linux-systemd/caddy.service +++ b/dist/init/linux-systemd/caddy.service @@ -30,8 +30,8 @@ LimitNPROC=512 ; Use private /tmp and /var/tmp, which are discarded after caddy stops. PrivateTmp=true -; Use a minimal /dev -PrivateDevices=true +; Use a minimal /dev (May bring additional security if switched to 'true', but it may not work on Raspberry Pi's or other devices, so it has been disabled in this dist.) +PrivateDevices=false ; Hide /home, /root, and /run/user. Nobody will steal your SSH-keys. ProtectHome=true ; Make /usr, /boot, /etc and possibly some more folders read-only. From a50f3a4cfe0da94801f9c2561a812025806a22eb Mon Sep 17 00:00:00 2001 From: Toby Allen Date: Sat, 3 Feb 2018 21:48:02 +0000 Subject: [PATCH 05/19] gitignore: Ignore .bat files (#2013) --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 4f3845ed4..425a29cf3 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,6 @@ Caddyfile og_static/ -.vscode/ \ No newline at end of file +.vscode/ + +*.bat \ No newline at end of file From fc2ff9155cc53393ac29885f6de83ed87093b274 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Sun, 4 Feb 2018 00:58:27 -0700 Subject: [PATCH 06/19] tls: Restructure and improve certificate management - Expose the list of Caddy instances through caddy.Instances() - Added arbitrary storage to caddy.Instance - The cache of loaded certificates is no longer global; now scoped per-instance, meaning upon reload (like SIGUSR1) the old cert cache will be discarded entirely, whereas before, aggressively reloading config that added and removed lots of sites would cause unnecessary build-up in the cache over time. - Key certificates in the cache by their SHA-256 hash instead of by their names. This means certificates will not be duplicated in memory (within each instance), making Caddy much more memory-efficient for large-scale deployments with thousands of sites sharing certs. - Perform name-to-certificate lookups scoped per caddytls.Config instead of a single global lookup. This prevents certificates from stepping on each other when they overlap in their names. - Do not allow TLS configurations keyed by the same hostname to be different; this now throws an error. - Updated relevant tests, with a stark awareness that more tests are needed. - Change the NewContext function signature to include an *Instance. - Strongly recommend (basically require) use of caddytls.NewConfig() to create a new *caddytls.Config, to ensure pointers to the instance certificate cache are initialized properly. - Update the TLS-SNI challenge solver (even though TLS-SNI is disabled currently on the CA side). Store temporary challenge cert in instance cache, but do so directly by the ACME challenge name, not the hash. Modified the getCertificate function to check the cache directly for a name match if one isn't found otherwise. This will allow any caddytls.Config to be able to help solve a TLS-SNI challenge, with one extra side-effect that might actually be kind of interesting (and useless): clients could send a certificate's hash as the SNI and Caddy would be able to serve that certificate for the handshake. - Do not attempt to match a "default" (random) certificate when SNI is present but unrecognized; return no certificate so a TLS alert happens instead. - Store an Instance in the list of instances even while the instance is still starting up (this allows access to the cert cache for performing renewals at startup, etc). Will be removed from list again if instance startup fails. - Laid groundwork for ACMEv2 and Let's Encrypt wildcard support. Server type plugins will need to be updated slightly to accommodate minor adjustments to their API (like passing in an Instance). This commit includes the changes for the HTTP server. Certain Caddyfile configurations might error out with this change, if they configured different TLS settings for the same hostname. This change trades some complexity for other complexity, but ultimately this new complexity is more correct and robust than earlier logic. Fixes #1991 Fixes #1994 Fixes #1303 --- caddy.go | 56 ++++- caddyhttp/httpserver/https.go | 7 +- caddyhttp/httpserver/plugin.go | 24 +- caddyhttp/httpserver/plugin_test.go | 8 +- caddytls/certificates.go | 323 +++++++++++++++----------- caddytls/certificates_test.go | 70 +++--- caddytls/client.go | 2 +- caddytls/config.go | 116 +++++++++- caddytls/crypto.go | 8 +- caddytls/handshake.go | 127 +++++++--- caddytls/handshake_test.go | 38 +-- caddytls/maintain.go | 346 +++++++++++++++------------- caddytls/setup.go | 17 +- caddytls/setup_test.go | 44 +++- caddytls/tls.go | 22 +- controller.go | 20 +- plugins.go | 2 +- 17 files changed, 801 insertions(+), 429 deletions(-) diff --git a/caddy.go b/caddy.go index 8da6d4db8..dd2d473a9 100644 --- a/caddy.go +++ b/caddy.go @@ -79,6 +79,8 @@ var ( // Instance contains the state of servers created as a result of // calling Start and can be used to access or control those servers. +// It is literally an instance of a server type. Instance values +// should NOT be copied. Use *Instance for safety. type Instance struct { // serverType is the name of the instance's server type serverType string @@ -89,10 +91,11 @@ type Instance struct { // wg is used to wait for all servers to shut down wg *sync.WaitGroup - // context is the context created for this instance. + // context is the context created for this instance, + // used to coordinate the setting up of the server type context Context - // servers is the list of servers with their listeners. + // servers is the list of servers with their listeners servers []ServerListener // these callbacks execute when certain events occur @@ -101,6 +104,18 @@ type Instance struct { onRestart []func() error // before restart commences onShutdown []func() error // stopping, even as part of a restart onFinalShutdown []func() error // stopping, not as part of a restart + + // storing values on an instance is preferable to + // global state because these will get garbage- + // collected after in-process reloads when the + // old instances are destroyed; use StorageMu + // to access this value safely + Storage map[interface{}]interface{} + StorageMu sync.RWMutex +} + +func Instances() []*Instance { + return instances } // Servers returns the ServerListeners in i. @@ -196,7 +211,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) { } // create new instance; if the restart fails, it is simply discarded - newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg} + newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg, Storage: make(map[interface{}]interface{})} // attempt to start new instance err := startWithListenerFds(newCaddyfile, newInst, restartFds) @@ -455,7 +470,7 @@ func (i *Instance) Caddyfile() Input { // // This function blocks until all the servers are listening. func Start(cdyfile Input) (*Instance, error) { - inst := &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup)} + inst := &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup), Storage: make(map[interface{}]interface{})} err := startWithListenerFds(cdyfile, inst, nil) if err != nil { return inst, err @@ -468,11 +483,34 @@ func Start(cdyfile Input) (*Instance, error) { } func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]restartTriple) error { + // save this instance in the list now so that + // plugins can access it if need be, for example + // the caddytls package, so it can perform cert + // renewals while starting up; we just have to + // remove the instance from the list later if + // it fails + instancesMu.Lock() + instances = append(instances, inst) + instancesMu.Unlock() + var err error + defer func() { + if err != nil { + instancesMu.Lock() + for i, otherInst := range instances { + if otherInst == inst { + instances = append(instances[:i], instances[i+1:]...) + break + } + } + instancesMu.Unlock() + } + }() + if cdyfile == nil { cdyfile = CaddyfileInput{} } - err := ValidateAndExecuteDirectives(cdyfile, inst, false) + err = ValidateAndExecuteDirectives(cdyfile, inst, false) if err != nil { return err } @@ -504,10 +542,6 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r return err } - instancesMu.Lock() - instances = append(instances, inst) - instancesMu.Unlock() - // run any AfterStartup callbacks if this is not // part of a restart; then show file descriptor notice if restartFds == nil { @@ -546,7 +580,7 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bool) error { // If parsing only inst will be nil, create an instance for this function call only. if justValidate { - inst = &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup)} + inst = &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup), Storage: make(map[interface{}]interface{})} } stypeName := cdyfile.ServerType() @@ -563,7 +597,7 @@ func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bo return err } - inst.context = stype.NewContext() + inst.context = stype.NewContext(inst) if inst.context == nil { return fmt.Errorf("server type %s produced a nil Context", stypeName) } diff --git a/caddyhttp/httpserver/https.go b/caddyhttp/httpserver/https.go index a1c84f11b..a12d9982c 100644 --- a/caddyhttp/httpserver/https.go +++ b/caddyhttp/httpserver/https.go @@ -27,7 +27,7 @@ func activateHTTPS(cctx caddy.Context) error { operatorPresent := !caddy.Started() if !caddy.Quiet && operatorPresent { - fmt.Print("Activating privacy features...") + fmt.Print("Activating privacy features... ") } ctx := cctx.(*httpContext) @@ -69,7 +69,7 @@ func activateHTTPS(cctx caddy.Context) error { } if !caddy.Quiet && operatorPresent { - fmt.Println(" done.") + fmt.Println("done.") } return nil @@ -163,6 +163,7 @@ func redirPlaintextHost(cfg *SiteConfig) *SiteConfig { if redirPort == DefaultHTTPSPort { redirPort = "" // default port is redundant } + redirMiddleware := func(next Handler) Handler { return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { // Construct the URL to which to redirect. Note that the Host in a request might @@ -184,9 +185,11 @@ func redirPlaintextHost(cfg *SiteConfig) *SiteConfig { return 0, nil }) } + host := cfg.Addr.Host port := HTTPPort addr := net.JoinHostPort(host, port) + return &SiteConfig{ Addr: Address{Original: addr, Host: host, Port: port}, ListenHost: cfg.ListenHost, diff --git a/caddyhttp/httpserver/plugin.go b/caddyhttp/httpserver/plugin.go index 643eea7f7..ea31a58d8 100644 --- a/caddyhttp/httpserver/plugin.go +++ b/caddyhttp/httpserver/plugin.go @@ -91,11 +91,13 @@ func hideCaddyfile(cctx caddy.Context) error { return nil } -func newContext() caddy.Context { - return &httpContext{keysToSiteConfigs: make(map[string]*SiteConfig)} +func newContext(inst *caddy.Instance) caddy.Context { + return &httpContext{instance: inst, keysToSiteConfigs: make(map[string]*SiteConfig)} } type httpContext struct { + instance *caddy.Instance + // keysToSiteConfigs maps an address at the top of a // server block (a "key") to its SiteConfig. Not all // SiteConfigs will be represented here, only ones @@ -146,15 +148,19 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd altTLSSNIPort = HTTPSPort } + // Make our caddytls.Config, which has a pointer to the + // instance's certificate cache and enough information + // to use automatic HTTPS when the time comes + caddytlsConfig := caddytls.NewConfig(h.instance) + caddytlsConfig.Hostname = addr.Host + caddytlsConfig.AltHTTPPort = altHTTPPort + caddytlsConfig.AltTLSSNIPort = altTLSSNIPort + // Save the config to our master list, and key it for lookups cfg := &SiteConfig{ - Addr: addr, - Root: Root, - TLS: &caddytls.Config{ - Hostname: addr.Host, - AltHTTPPort: altHTTPPort, - AltTLSSNIPort: altTLSSNIPort, - }, + Addr: addr, + Root: Root, + TLS: caddytlsConfig, originCaddyfile: sourceFile, IndexPages: staticfiles.DefaultIndexPages, } diff --git a/caddyhttp/httpserver/plugin_test.go b/caddyhttp/httpserver/plugin_test.go index 31eafd8f2..5a60f2e83 100644 --- a/caddyhttp/httpserver/plugin_test.go +++ b/caddyhttp/httpserver/plugin_test.go @@ -137,7 +137,7 @@ func TestAddressString(t *testing.T) { func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) { Port = "9999" filename := "Testfile" - ctx := newContext().(*httpContext) + ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext) input := strings.NewReader(`localhost`) sblocks, err := caddyfile.Parse(filename, input, nil) if err != nil { @@ -155,7 +155,7 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) { func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) { filename := "Testfile" - ctx := newContext().(*httpContext) + ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext) input := strings.NewReader("localhost {\n}\nLOCALHOST {\n}") sblocks, err := caddyfile.Parse(filename, input, nil) if err != nil { @@ -207,7 +207,7 @@ func TestDirectivesList(t *testing.T) { } func TestContextSaveConfig(t *testing.T) { - ctx := newContext().(*httpContext) + ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext) ctx.saveConfig("foo", new(SiteConfig)) if _, ok := ctx.keysToSiteConfigs["foo"]; !ok { t.Error("Expected config to be saved, but it wasn't") @@ -226,7 +226,7 @@ func TestContextSaveConfig(t *testing.T) { // Test to make sure we are correctly hiding the Caddyfile func TestHideCaddyfile(t *testing.T) { - ctx := newContext().(*httpContext) + ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext) ctx.saveConfig("test", &SiteConfig{ Root: Root, originCaddyfile: "Testfile", diff --git a/caddytls/certificates.go b/caddytls/certificates.go index 05af914fe..2df576ff3 100644 --- a/caddytls/certificates.go +++ b/caddytls/certificates.go @@ -15,9 +15,11 @@ package caddytls import ( + "crypto/sha256" "crypto/tls" "crypto/x509" "errors" + "fmt" "io/ioutil" "log" "strings" @@ -27,24 +29,14 @@ import ( "golang.org/x/crypto/ocsp" ) -// certCache stores certificates in memory, -// keying certificates by name. Certificates -// should not overlap in the names they serve, -// because a name only maps to one certificate. -var certCache = make(map[string]Certificate) -var certCacheMu sync.RWMutex - // Certificate is a tls.Certificate with associated metadata tacked on. // Even if the metadata can be obtained by parsing the certificate, -// we can be more efficient by extracting the metadata once so it's -// just there, ready to use. +// we are more efficient by extracting the metadata onto this struct. type Certificate struct { tls.Certificate // Names is the list of names this certificate is written for. // The first is the CommonName (if any), the rest are SAN. - // This should be the exact list of keys by which this cert - // is accessed in the cache, careful to avoid overlap. Names []string // NotAfter is when the certificate expires. @@ -53,59 +45,91 @@ type Certificate struct { // OCSP contains the certificate's parsed OCSP response. OCSP *ocsp.Response - // Config is the configuration with which the certificate was - // loaded or obtained and with which it should be maintained. - Config *Config + // The hex-encoded hash of this cert's chain's bytes. + Hash string + + // configs is the list of configs that use or refer to + // The first one is assumed to be the config that is + // "in charge" of this certificate (i.e. determines + // whether it is managed, how it is managed, etc). + // This field will be populated by cacheCertificate. + // Only meddle with it if you know what you're doing! + configs []*Config } -// getCertificate gets a certificate that matches name (a server name) -// from the in-memory cache. If there is no exact match for name, it -// will be checked against names of the form '*.example.com' (wildcard -// certificates) according to RFC 6125. If a match is found, matched will -// be true. If no matches are found, matched will be false and a default -// certificate will be returned with defaulted set to true. If no default -// certificate is set, defaulted will be set to false. +// certificateCache is to be an instance-wide cache of certs +// that site-specific TLS configs can refer to. Using a +// central map like this avoids duplication of certs in +// memory when the cert is used by multiple sites, and makes +// maintenance easier. Because these are not to be global, +// the cache will get garbage collected after a config reload +// (a new instance will take its place). +type certificateCache struct { + sync.RWMutex + cache map[string]Certificate // keyed by certificate hash +} + +// replaceCertificate replaces oldCert with newCert in the cache, and +// updates all configs that are pointing to the old certificate to +// point to the new one instead. newCert must already be loaded into +// the cache (this method does NOT load it into the cache). // -// The logic in this function is adapted from the Go standard library, -// which is by the Go Authors. +// Note that all the names on the old certificate will be deleted +// from the name lookup maps of each config, then all the names on +// the new certificate will be added to the lookup maps as long as +// they do not overwrite any entries. // -// This function is safe for concurrent use. -func getCertificate(name string) (cert Certificate, matched, defaulted bool) { - var ok bool +// The newCert may be modified and its cache entry updated. +// +// This method is safe for concurrent use. +func (certCache *certificateCache) replaceCertificate(oldCert, newCert Certificate) error { + certCache.Lock() + defer certCache.Unlock() - // Not going to trim trailing dots here since RFC 3546 says, - // "The hostname is represented ... without a trailing dot." - // Just normalize to lowercase. - name = strings.ToLower(name) + // have all the configs that are pointing to the old + // certificate point to the new certificate instead + for _, cfg := range oldCert.configs { + // first delete all the name lookup entries that + // pointed to the old certificate + for name, certKey := range cfg.Certificates { + if certKey == oldCert.Hash { + delete(cfg.Certificates, name) + } + } - certCacheMu.RLock() - defer certCacheMu.RUnlock() - - // exact match? great, let's use it - if cert, ok = certCache[name]; ok { - matched = true - return - } - - // try replacing labels in the name with wildcards until we get a match - labels := strings.Split(name, ".") - for i := range labels { - labels[i] = "*" - candidate := strings.Join(labels, ".") - if cert, ok = certCache[candidate]; ok { - matched = true - return + // then add name lookup entries for the names + // on the new certificate, but don't overwrite + // entries that may already exist, not only as + // a courtesy, but importantly: because if we + // overwrote a value here, and this config no + // longer pointed to a certain certificate in + // the cache, that certificate's list of configs + // referring to it would be incorrect; so just + // insert entries, don't overwrite any + for _, name := range newCert.Names { + if _, ok := cfg.Certificates[name]; !ok { + cfg.Certificates[name] = newCert.Hash + } } } - // if nothing matches, use the default certificate or bust - cert, defaulted = certCache[""] - return + // since caching a new certificate attaches only the config + // that loaded it, the new certificate needs to be given the + // list of all the configs that use it, so copy the list + // over from the old certificate to the new certificate + // in the cache + newCert.configs = oldCert.configs + certCache.cache[newCert.Hash] = newCert + + // finally, delete the old certificate from the cache + delete(certCache.cache, oldCert.Hash) + + return nil } // CacheManagedCertificate loads the certificate for domain into the -// cache, flagging it as Managed and, if onDemand is true, as "OnDemand" -// (meaning that it was obtained or loaded during a TLS handshake). +// cache, from the TLS storage for managed certificates. It returns a +// copy of the Certificate that was put into the cache. // // This method is safe for concurrent use. func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) { @@ -117,39 +141,24 @@ func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) { if err != nil { return Certificate{}, err } - cert, err := makeCertificate(siteData.Cert, siteData.Key) + cert, err := makeCertificateWithOCSP(siteData.Cert, siteData.Key) if err != nil { return cert, err } - cert.Config = cfg - cacheCertificate(cert) - return cert, nil + return cfg.cacheCertificate(cert), nil } // cacheUnmanagedCertificatePEMFile loads a certificate for host using certFile // and keyFile, which must be in PEM format. It stores the certificate in -// memory after evicting any other entries in the cache keyed by the names -// on this certificate. In other words, it replaces existing certificates keyed -// by the names on this certificate. The Managed and OnDemand flags of the -// certificate will be set to false. +// the in-memory cache. // // This function is safe for concurrent use. -func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error { - cert, err := makeCertificateFromDisk(certFile, keyFile) +func (cfg *Config) cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error { + cert, err := makeCertificateFromDiskWithOCSP(certFile, keyFile) if err != nil { return err } - - // since this is manually managed, this call might be part of a reload after - // the owner renewed a certificate; so clear cache of any previous cert first, - // otherwise the renewed certificate may never be loaded - certCacheMu.Lock() - for _, name := range cert.Names { - delete(certCache, name) - } - certCacheMu.Unlock() - - cacheCertificate(cert) + cfg.cacheCertificate(cert) return nil } @@ -157,20 +166,20 @@ func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error { // of the certificate and key, then caches it in memory. // // This function is safe for concurrent use. -func cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error { - cert, err := makeCertificate(certBytes, keyBytes) +func (cfg *Config) cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error { + cert, err := makeCertificateWithOCSP(certBytes, keyBytes) if err != nil { return err } - cacheCertificate(cert) + cfg.cacheCertificate(cert) return nil } -// makeCertificateFromDisk makes a Certificate by loading the +// makeCertificateFromDiskWithOCSP makes a Certificate by loading the // certificate and key files. It fills out all the fields in // the certificate except for the Managed and OnDemand flags. -// (It is up to the caller to set those.) -func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) { +// (It is up to the caller to set those.) It staples OCSP. +func makeCertificateFromDiskWithOCSP(certFile, keyFile string) (Certificate, error) { certPEMBlock, err := ioutil.ReadFile(certFile) if err != nil { return Certificate{}, err @@ -179,13 +188,14 @@ func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) { if err != nil { return Certificate{}, err } - return makeCertificate(certPEMBlock, keyPEMBlock) + return makeCertificateWithOCSP(certPEMBlock, keyPEMBlock) } // makeCertificate turns a certificate PEM bundle and a key PEM block into -// a Certificate, with OCSP and other relevant metadata tagged with it, -// except for the OnDemand and Managed flags. It is up to the caller to -// set those properties. +// a Certificate with necessary metadata from parsing its bytes filled into +// its struct fields for convenience (except for the OnDemand and Managed +// flags; it is up to the caller to set those properties!). This function +// does NOT staple OCSP. func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { var cert Certificate @@ -195,16 +205,26 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { return cert, err } - // Extract relevant metadata and staple OCSP + // Extract necessary metadata err = fillCertFromLeaf(&cert, tlsCert) if err != nil { return cert, err } + + return cert, nil +} + +// makeCertificateWithOCSP is the same as makeCertificate except that it also +// staples OCSP to the certificate. +func makeCertificateWithOCSP(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { + cert, err := makeCertificate(certPEMBlock, keyPEMBlock) + if err != nil { + return cert, err + } err = stapleOCSP(&cert, certPEMBlock) if err != nil { log.Printf("[WARNING] Stapling OCSP: %v", err) } - return cert, nil } @@ -243,65 +263,104 @@ func fillCertFromLeaf(cert *Certificate, tlsCert tls.Certificate) error { return errors.New("certificate has no names") } + // save the hash of this certificate (chain) and + // expiration date, for necessity and efficiency + cert.Hash = hashCertificateChain(cert.Certificate.Certificate) cert.NotAfter = leaf.NotAfter return nil } -// cacheCertificate adds cert to the in-memory cache. If the cache is -// empty, cert will be used as the default certificate. If the cache is -// full, random entries are deleted until there is room to map all the -// names on the certificate. +// hashCertificateChain computes the unique hash of certChain, +// which is the chain of DER-encoded bytes. It returns the +// hex encoding of the hash. +func hashCertificateChain(certChain [][]byte) string { + h := sha256.New() + for _, certInChain := range certChain { + h.Write(certInChain) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} + +// managedCertInStorageExpiresSoon returns true if cert (being a +// managed certificate) is expiring within RenewDurationBefore. +// It returns false if there was an error checking the expiration +// of the certificate as found in storage, or if the certificate +// in storage is NOT expiring soon. A certificate that is expiring +// soon in our cache but is not expiring soon in storage probably +// means that another instance renewed the certificate in the +// meantime, and it would be a good idea to simply load the cert +// into our cache rather than repeating the renewal process again. +func managedCertInStorageExpiresSoon(cert Certificate) (bool, error) { + if len(cert.configs) == 0 { + return false, fmt.Errorf("no configs for certificate") + } + storage, err := cert.configs[0].StorageFor(cert.configs[0].CAUrl) + if err != nil { + return false, err + } + siteData, err := storage.LoadSite(cert.Names[0]) + if err != nil { + return false, err + } + tlsCert, err := tls.X509KeyPair(siteData.Cert, siteData.Key) + if err != nil { + return false, err + } + leaf, err := x509.ParseCertificate(tlsCert.Certificate[0]) + if err != nil { + return false, err + } + timeLeft := leaf.NotAfter.Sub(time.Now().UTC()) + return timeLeft < RenewDurationBefore, nil +} + +// cacheCertificate adds cert to the in-memory cache. If a certificate +// with the same hash is already cached, it is NOT overwritten; instead, +// cfg is added to the existing certificate's list of configs if not +// already in the list. Then all the names on cert are used to add +// entries to cfg.Certificates (the config's name lookup map). +// Then the certificate is stored/updated in the cache. It returns +// a copy of the certificate that ends up being stored in the cache. // -// This certificate will be keyed to the names in cert.Names. Any names -// already used as a cache key will NOT be replaced by this cert; in -// other words, no overlap is allowed, and this certificate will not -// service those pre-existing names. +// It is VERY important, even for some test cases, that the Hash field +// of the cert be set properly. // // This function is safe for concurrent use. -func cacheCertificate(cert Certificate) { - if cert.Config == nil { - cert.Config = new(Config) +func (cfg *Config) cacheCertificate(cert Certificate) Certificate { + cfg.certCache.Lock() + defer cfg.certCache.Unlock() + + // if this certificate already exists in the cache, + // use it instead of overwriting it -- very important! + if existingCert, ok := cfg.certCache.cache[cert.Hash]; ok { + cert = existingCert } - certCacheMu.Lock() - if _, ok := certCache[""]; !ok { - // use as default - must be *appended* to end of list, or bad things happen! - cert.Names = append(cert.Names, "") - } - for len(certCache)+len(cert.Names) > 10000 { - // for simplicity, just remove random elements - for key := range certCache { - if key == "" { // ... but not the default cert - continue - } - delete(certCache, key) + + // attach this config to the certificate so we know which + // configs are referencing/using the certificate, but don't + // duplicate entries + var found bool + for _, c := range cert.configs { + if c == cfg { + found = true break } } - for i := 0; i < len(cert.Names); i++ { - name := cert.Names[i] - if _, ok := certCache[name]; ok { - // do not allow certificates to overlap in the names they serve; - // this ambiguity causes problems because it is confusing while - // maintaining certificates; see OCSP maintenance code and - // https://caddy.community/t/random-ocsp-response-errors-for-random-clients/2473?u=matt. - log.Printf("[NOTICE] There is already a certificate loaded for %s, "+ - "so certificate for %v will not service that name", - name, cert.Names) - cert.Names = append(cert.Names[:i], cert.Names[i+1:]...) - i-- - continue - } - certCache[name] = cert + if !found { + cert.configs = append(cert.configs, cfg) } - certCacheMu.Unlock() -} -// uncacheCertificate deletes name's certificate from the -// cache. If name is not a key in the certificate cache, -// this function does nothing. -func uncacheCertificate(name string) { - certCacheMu.Lock() - delete(certCache, name) - certCacheMu.Unlock() + // key the certificate by all its names for this config only, + // this is how we find the certificate during handshakes + // (yes, if certs overlap in the names they serve, one will + // overwrite another here, but that's just how it goes) + for _, name := range cert.Names { + cfg.Certificates[name] = cert.Hash + } + + // store the certificate + cfg.certCache.cache[cert.Hash] = cert + + return cert } diff --git a/caddytls/certificates_test.go b/caddytls/certificates_test.go index ce848d10b..817d16496 100644 --- a/caddytls/certificates_test.go +++ b/caddytls/certificates_test.go @@ -17,57 +17,71 @@ package caddytls import "testing" func TestUnexportedGetCertificate(t *testing.T) { - defer func() { certCache = make(map[string]Certificate) }() + certCache := &certificateCache{cache: make(map[string]Certificate)} + cfg := &Config{Certificates: make(map[string]string), certCache: certCache} // When cache is empty - if _, matched, defaulted := getCertificate("example.com"); matched || defaulted { + if _, matched, defaulted := cfg.getCertificate("example.com"); matched || defaulted { t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted) } - // When cache has one certificate in it (also is default) - defaultCert := Certificate{Names: []string{"example.com", ""}} - certCache[""] = defaultCert - certCache["example.com"] = defaultCert - if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" { + // When cache has one certificate in it + firstCert := Certificate{Names: []string{"example.com"}} + certCache.cache["0xdeadbeef"] = firstCert + cfg.Certificates["example.com"] = "0xdeadbeef" + if cert, matched, defaulted := cfg.getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" { t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) } - if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" { - t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) + if cert, matched, defaulted := cfg.getCertificate("example.com"); !matched || defaulted || cert.Names[0] != "example.com" { + t.Errorf("Didn't get a cert for 'example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) } // When retrieving wildcard certificate - certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}} - if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" { + certCache.cache["0xb01dface"] = Certificate{Names: []string{"*.example.com"}} + cfg.Certificates["*.example.com"] = "0xb01dface" + if cert, matched, defaulted := cfg.getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" { t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) } - // When no certificate matches, the default is returned - if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted { + // When no certificate matches and SNI is provided, return no certificate (should be TLS alert) + if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted { + t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert) + } + + // When no certificate matches and SNI is NOT provided, a random is returned + if cert, matched, defaulted := cfg.getCertificate(""); matched || !defaulted { t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert) - } else if cert.Names[0] != "example.com" { - t.Errorf("Expected default cert, got: %v", cert) } } func TestCacheCertificate(t *testing.T) { - defer func() { certCache = make(map[string]Certificate) }() + certCache := &certificateCache{cache: make(map[string]Certificate)} + cfg := &Config{Certificates: make(map[string]string), certCache: certCache} - cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}}) - if _, ok := certCache["example.com"]; !ok { - t.Error("Expected first cert to be cached by key 'example.com', but it wasn't") + cfg.cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}, Hash: "foobar"}) + if len(certCache.cache) != 1 { + t.Errorf("Expected length of certificate cache to be 1") } - if _, ok := certCache["sub.example.com"]; !ok { - t.Error("Expected first cert to be cached by key 'sub.example.com', but it wasn't") + if _, ok := certCache.cache["foobar"]; !ok { + t.Error("Expected first cert to be cached by key 'foobar', but it wasn't") } - if cert, ok := certCache[""]; !ok || cert.Names[2] != "" { - t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't") + if _, ok := cfg.Certificates["example.com"]; !ok { + t.Error("Expected first cert to be keyed by 'example.com', but it wasn't") + } + if _, ok := cfg.Certificates["sub.example.com"]; !ok { + t.Error("Expected first cert to be keyed by 'sub.example.com', but it wasn't") } - cacheCertificate(Certificate{Names: []string{"example2.com"}}) - if _, ok := certCache["example2.com"]; !ok { - t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't") + // different config, but using same cache; and has cert with overlapping name, + // but different hash + cfg2 := &Config{Certificates: make(map[string]string), certCache: certCache} + cfg2.cacheCertificate(Certificate{Names: []string{"example.com"}, Hash: "barbaz"}) + if _, ok := certCache.cache["barbaz"]; !ok { + t.Error("Expected second cert to be cached by key 'barbaz.com', but it wasn't") } - if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" { - t.Error("Expected second cert to NOT be cached as default, but it was") + if hash, ok := cfg2.Certificates["example.com"]; !ok { + t.Error("Expected second cert to be keyed by 'example.com', but it wasn't") + } else if hash != "barbaz" { + t.Errorf("Expected second cert to map to 'barbaz' but it was %s instead", hash) } } diff --git a/caddytls/client.go b/caddytls/client.go index 26ef6a3c5..4775a2d18 100644 --- a/caddytls/client.go +++ b/caddytls/client.go @@ -160,7 +160,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) // See if TLS challenge needs to be handled by our own facilities if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) { - c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSniSolver{}) + c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache}) } // Disable any challenges that should not be used diff --git a/caddytls/config.go b/caddytls/config.go index d3468e348..0b64f3575 100644 --- a/caddytls/config.go +++ b/caddytls/config.go @@ -134,7 +134,12 @@ type Config struct { // Protocol Negotiation (ALPN). ALPN []string - tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig() + // The map of hostname to certificate hash. This is used to complete + // handshakes and serve the right certificate given the SNI. + Certificates map[string]string + + certCache *certificateCache // pointer to the Instance's certificate store + tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig() } // OnDemandState contains some state relevant for providing @@ -155,6 +160,25 @@ type OnDemandState struct { AskURL *url.URL } +// NewConfig returns a new Config with a pointer to the instance's +// certificate cache. You will usually need to set Other fields on +// the returned Config for successful practical use. +func NewConfig(inst *caddy.Instance) *Config { + inst.StorageMu.RLock() + certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache) + inst.StorageMu.RUnlock() + if !ok || certCache == nil { + certCache = &certificateCache{cache: make(map[string]Certificate)} + inst.StorageMu.Lock() + inst.Storage[CertCacheInstStorageKey] = certCache + inst.StorageMu.Unlock() + } + cfg := new(Config) + cfg.Certificates = make(map[string]string) + cfg.certCache = certCache + return cfg +} + // ObtainCert obtains a certificate for name using c, as long // as a certificate does not already exist in storage for that // name. The name must qualify and c must be flagged as Managed. @@ -330,7 +354,9 @@ func (c *Config) buildStandardTLSConfig() error { // MakeTLSConfig makes a tls.Config from configs. The returned // tls.Config is programmed to load the matching caddytls.Config -// based on the hostname in SNI, but that's all. +// based on the hostname in SNI, but that's all. This is used +// to create a single TLS configuration for a listener (a group +// of sites). func MakeTLSConfig(configs []*Config) (*tls.Config, error) { if len(configs) == 0 { return nil, nil @@ -358,15 +384,28 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) { configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto) } - // convert each caddytls.Config into a tls.Config + // convert this caddytls.Config into a tls.Config if err := cfg.buildStandardTLSConfig(); err != nil { return nil, err } - // Key this config by its hostname (overwriting - // configs with the same hostname pattern); during - // TLS handshakes, configs are loaded based on - // the hostname pattern, according to client's SNI. + // if an existing config with this hostname was already + // configured, then they must be identical (or at least + // compatible), otherwise that is a configuration error + if otherConfig, ok := configMap[cfg.Hostname]; ok { + if err := assertConfigsCompatible(cfg, otherConfig); err != nil { + return nil, fmt.Errorf("incompabile TLS configurations for the same SNI "+ + "name (%s) on the same listener: %v", + cfg.Hostname, err) + } + } + + // key this config by its hostname (overwrites + // configs with the same hostname pattern; should + // be OK since we already asserted they are roughly + // the same); during TLS handshakes, configs are + // loaded based on the hostname pattern, according + // to client's SNI configMap[cfg.Hostname] = cfg } @@ -383,6 +422,63 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) { }, nil } +// assertConfigsCompatible returns an error if the two Configs +// do not have the same (or roughly compatible) configurations. +// If one of the tlsConfig pointers on either Config is nil, +// an error will be returned. If both are nil, no error. +func assertConfigsCompatible(cfg1, cfg2 *Config) error { + c1, c2 := cfg1.tlsConfig, cfg2.tlsConfig + + if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) { + return fmt.Errorf("one config is not made") + } + if c1 == nil && c2 == nil { + return nil + } + + if len(c1.CipherSuites) != len(c2.CipherSuites) { + return fmt.Errorf("different number of allowed cipher suites") + } + for i, ciph := range c1.CipherSuites { + if c2.CipherSuites[i] != ciph { + return fmt.Errorf("different cipher suites or different order") + } + } + + if len(c1.CurvePreferences) != len(c2.CurvePreferences) { + return fmt.Errorf("different number of allowed cipher suites") + } + for i, curve := range c1.CurvePreferences { + if c2.CurvePreferences[i] != curve { + return fmt.Errorf("different curve preferences or different order") + } + } + + if len(c1.NextProtos) != len(c2.NextProtos) { + return fmt.Errorf("different number of ALPN (NextProtos) values") + } + for i, proto := range c1.NextProtos { + if c2.NextProtos[i] != proto { + return fmt.Errorf("different ALPN (NextProtos) values or different order") + } + } + + if c1.PreferServerCipherSuites != c2.PreferServerCipherSuites { + return fmt.Errorf("one prefers server cipher suites, the other does not") + } + if c1.MinVersion != c2.MinVersion { + return fmt.Errorf("minimum TLS version mismatch") + } + if c1.MaxVersion != c2.MaxVersion { + return fmt.Errorf("maximum TLS version mismatch") + } + if c1.ClientAuth != c2.ClientAuth { + return fmt.Errorf("client authentication policy mismatch") + } + + return nil +} + // ConfigGetter gets a Config keyed by key. type ConfigGetter func(c *caddy.Controller) *Config @@ -522,7 +618,7 @@ var supportedCurvesMap = map[string]tls.CurveID{ "P521": tls.CurveP521, } -// List of all the curves we want to use by default +// List of all the curves we want to use by default. // // This list should only include curves which are fast by design (e.g. X25519) // and those for which an optimized assembly implementation exists (e.g. P256). @@ -548,4 +644,8 @@ const ( // be capable of proxying or forwarding the request to this // alternate port. DefaultHTTPAlternatePort = "5033" + + // CertCacheInstStorageKey is the name of the key for + // accessing the certificate storage on the *caddy.Instance. + CertCacheInstStorageKey = "tls_cert_cache" ) diff --git a/caddytls/crypto.go b/caddytls/crypto.go index 3036834c4..b2107f152 100644 --- a/caddytls/crypto.go +++ b/caddytls/crypto.go @@ -237,15 +237,17 @@ func makeSelfSignedCert(config *Config) error { return fmt.Errorf("could not create certificate: %v", err) } - cacheCertificate(Certificate{ + chain := [][]byte{derBytes} + + config.cacheCertificate(Certificate{ Certificate: tls.Certificate{ - Certificate: [][]byte{derBytes}, + Certificate: chain, PrivateKey: privKey, Leaf: cert, }, Names: cert.DNSNames, NotAfter: cert.NotAfter, - Config: config, + Hash: hashCertificateChain(chain), }) return nil diff --git a/caddytls/handshake.go b/caddytls/handshake.go index c50e8ab63..2f3f34af3 100644 --- a/caddytls/handshake.go +++ b/caddytls/handshake.go @@ -59,15 +59,7 @@ func (cg configGroup) getConfig(name string) *Config { } } - // as a fallback, try a config that serves all names - if config, ok := cg[""]; ok { - return config - } - - // as a last resort, use a random config - // (even if the config isn't for that hostname, - // it should help us serve clients without SNI - // or at least defer TLS alerts to the cert) + // no matches, so just serve up a random config for _, config := range cg { return config } @@ -102,6 +94,86 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif return &cert.Certificate, err } +// getCertificate gets a certificate that matches name (a server name) +// from the in-memory cache, according to the lookup table associated with +// cfg. The lookup then points to a certificate in the Instance certificate +// cache. +// +// If there is no exact match for name, it will be checked against names of +// the form '*.example.com' (wildcard certificates) according to RFC 6125. +// If a match is found, matched will be true. If no matches are found, matched +// will be false and a "default" certificate will be returned with defaulted +// set to true. If defaulted is false, then no certificates were available. +// +// The logic in this function is adapted from the Go standard library, +// which is by the Go Authors. +// +// This function is safe for concurrent use. +func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defaulted bool) { + var certKey string + var ok bool + + // Not going to trim trailing dots here since RFC 3546 says, + // "The hostname is represented ... without a trailing dot." + // Just normalize to lowercase. + name = strings.ToLower(name) + + cfg.certCache.RLock() + defer cfg.certCache.RUnlock() + + // exact match? great, let's use it + if certKey, ok = cfg.Certificates[name]; ok { + cert = cfg.certCache.cache[certKey] + matched = true + return + } + + // try replacing labels in the name with wildcards until we get a match + labels := strings.Split(name, ".") + for i := range labels { + labels[i] = "*" + candidate := strings.Join(labels, ".") + if certKey, ok = cfg.Certificates[candidate]; ok { + cert = cfg.certCache.cache[certKey] + matched = true + return + } + } + + // check the certCache directly to see if the SNI name is + // already the key of the certificate it wants! this is vital + // for supporting the TLS-SNI challenge, since the tlsSNISolver + // just puts the temporary certificate in the instance cache, + // with no regard for configs; this also means that the SNI + // can contain the hash of a specific cert (chain) it wants + // and we will still be able to serve it up + // (this behavior, by the way, could be controversial as to + // whether it complies with RFC 6066 about SNI, but I think + // it does soooo...) + // NOTE/TODO: TLS-SNI challenge is changing, as of Jan. 2018 + // but what will be different, if it ever returns, is unclear + if directCert, ok := cfg.certCache.cache[name]; ok { + cert = directCert + matched = true + return + } + + // if nothing matches and SNI was not provided, use a random + // certificate; at least there's a chance this older client + // can connect, and in the future we won't need this provision + // (if SNI is present, it's probably best to just raise a TLS + // alert by not serving a certificate) + if name == "" { + for _, certKey := range cfg.Certificates { + defaulted = true + cert = cfg.certCache.cache[certKey] + return + } + } + + return +} + // getCertDuringHandshake will get a certificate for name. It first tries // the in-memory cache. If no certificate for name is in the cache, the // config most closely corresponding to name will be loaded. If that config @@ -115,7 +187,7 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif // This function is safe for concurrent use. func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { // First check our in-memory cache to see if we've already loaded it - cert, matched, defaulted := getCertificate(name) + cert, matched, defaulted := cfg.getCertificate(name) if matched { return cert, nil } @@ -258,7 +330,7 @@ func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) { obtainCertWaitChans[name] = wait obtainCertWaitChansMu.Unlock() - // do the obtain + // obtain the certificate log.Printf("[INFO] Obtaining new certificate for %s", name) err := cfg.ObtainCert(name, false) @@ -317,9 +389,9 @@ func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certific // quite common considering not all certs have issuer URLs that support it. log.Printf("[ERROR] Getting OCSP for %s: %v", name, err) } - certCacheMu.Lock() - certCache[name] = cert - certCacheMu.Unlock() + cfg.certCache.Lock() + cfg.certCache.cache[cert.Hash] = cert + cfg.certCache.Unlock() } } @@ -348,29 +420,22 @@ func (cfg *Config) renewDynamicCertificate(name string, currentCert Certificate) obtainCertWaitChans[name] = wait obtainCertWaitChansMu.Unlock() - // do the renew and reload the certificate + // renew and reload the certificate log.Printf("[INFO] Renewing certificate for %s", name) err := cfg.RenewCert(name, false) if err == nil { - // immediately flush this certificate from the cache so - // the name doesn't overlap when we try to replace it, - // which would fail, because overlapping existing cert - // names isn't allowed - certCacheMu.Lock() - for _, certName := range currentCert.Names { - delete(certCache, certName) - } - certCacheMu.Unlock() - // even though the recursive nature of the dynamic cert loading // would just call this function anyway, we do it here to - // make the replacement as atomic as possible. (TODO: similar - // to the note in maintain.go, it'd be nice if the clearing of - // the cache entries above and this load function were truly - // atomic...) - _, err := currentCert.Config.CacheManagedCertificate(name) + // make the replacement as atomic as possible. + newCert, err := currentCert.configs[0].CacheManagedCertificate(name) if err != nil { - log.Printf("[ERROR] loading renewed certificate: %v", err) + log.Printf("[ERROR] loading renewed certificate for %s: %v", name, err) + } else { + // replace the old certificate with the new one + err = cfg.certCache.replaceCertificate(currentCert, newCert) + if err != nil { + log.Printf("[ERROR] Replacing certificate for %s: %v", name, err) + } } } diff --git a/caddytls/handshake_test.go b/caddytls/handshake_test.go index 63a6c1dba..f0b8f7be2 100644 --- a/caddytls/handshake_test.go +++ b/caddytls/handshake_test.go @@ -21,9 +21,8 @@ import ( ) func TestGetCertificate(t *testing.T) { - defer func() { certCache = make(map[string]Certificate) }() - - cfg := new(Config) + certCache := &certificateCache{cache: make(map[string]Certificate)} + cfg := &Config{Certificates: make(map[string]string), certCache: certCache} hello := &tls.ClientHelloInfo{ServerName: "example.com"} helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"} @@ -38,33 +37,40 @@ func TestGetCertificate(t *testing.T) { t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert) } - // When cache has one certificate in it (also is default) - defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}} - certCache[""] = defaultCert - certCache["example.com"] = defaultCert + // When cache has one certificate in it + firstCert := Certificate{Names: []string{"example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}} + cfg.cacheCertificate(firstCert) if cert, err := cfg.GetCertificate(hello); err != nil { t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err) } else if cert.Leaf.DNSNames[0] != "example.com" { t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert) } - if cert, err := cfg.GetCertificate(helloNoSNI); err != nil { + if _, err := cfg.GetCertificate(helloNoSNI); err != nil { t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err) - } else if cert.Leaf.DNSNames[0] != "example.com" { - t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert) } // When retrieving wildcard certificate - certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}} + wildcardCert := Certificate{ + Names: []string{"*.example.com"}, + Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}, + Hash: "(don't overwrite the first one)", + } + cfg.cacheCertificate(wildcardCert) if cert, err := cfg.GetCertificate(helloSub); err != nil { t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err) } else if cert.Leaf.DNSNames[0] != "*.example.com" { t.Errorf("Got wrong certificate, expected wildcard: %v", cert) } - // When no certificate matches, the default is returned - if cert, err := cfg.GetCertificate(helloNoMatch); err != nil { - t.Errorf("Expected default certificate with no error when no matches, got err: %v", err) - } else if cert.Leaf.DNSNames[0] != "example.com" { - t.Errorf("Expected default cert with no matches, got: %v", cert) + // When cache is NOT empty but there's no SNI + if cert, err := cfg.GetCertificate(helloNoSNI); err != nil { + t.Errorf("Expected random certificate with no error when no SNI, got err: %v", err) + } else if cert == nil || len(cert.Leaf.DNSNames) == 0 { + t.Errorf("Expected random cert with no matches, got: %v", cert) + } + + // When no certificate matches, raise an alert + if _, err := cfg.GetCertificate(helloNoMatch); err == nil { + t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err) } } diff --git a/caddytls/maintain.go b/caddytls/maintain.go index 9e42fc87c..7ce6c5e26 100644 --- a/caddytls/maintain.go +++ b/caddytls/maintain.go @@ -87,119 +87,163 @@ func maintainAssets(stopChan chan struct{}) { // RenewManagedCertificates renews managed certificates, // including ones loaded on-demand. func RenewManagedCertificates(allowPrompts bool) (err error) { - var renewQueue, deleteQueue []Certificate - visitedNames := make(map[string]struct{}) - - certCacheMu.RLock() - for name, cert := range certCache { - if !cert.Config.Managed || cert.Config.SelfSigned { + for _, inst := range caddy.Instances() { + inst.StorageMu.RLock() + certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache) + inst.StorageMu.RUnlock() + if !ok || certCache == nil { continue } - // the list of names on this cert should never be empty... - if cert.Names == nil || len(cert.Names) == 0 { - log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v - removing from cache", name, cert.Names) - deleteQueue = append(deleteQueue, cert) - continue - } + // we use the queues for a very important reason: to do any and all + // operations that could require an exclusive write lock outside + // of the read lock! otherwise we get a deadlock, yikes. in other + // words, our first iteration through the certificate cache does NOT + // perform any operations--only queues them--so that more fine-grained + // write locks may be obtained during the actual operations. + var renewQueue, reloadQueue, deleteQueue []Certificate - // skip names whose certificate we've already renewed - if _, ok := visitedNames[name]; ok { - continue - } - for _, name := range cert.Names { - visitedNames[name] = struct{}{} - } - - // if its time is up or ending soon, we need to try to renew it - timeLeft := cert.NotAfter.Sub(time.Now().UTC()) - if timeLeft < RenewDurationBefore { - log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft) - - if cert.Config == nil { - log.Printf("[ERROR] %s: No associated TLS config; unable to renew", name) + certCache.RLock() + for certKey, cert := range certCache.cache { + if len(cert.configs) == 0 { + // this is bad if this happens, probably a programmer error (oops) + log.Printf("[ERROR] No associated TLS config for certificate with names %v; unable to manage", cert.Names) + continue + } + if !cert.configs[0].Managed || cert.configs[0].SelfSigned { continue } - // queue for renewal when we aren't in a read lock anymore - // (the TLS-SNI challenge will need a write lock in order to - // present the certificate, so we renew outside of read lock) - renewQueue = append(renewQueue, cert) - } - } - certCacheMu.RUnlock() - - // Perform renewals that are queued - for _, cert := range renewQueue { - // Get the name which we should use to renew this certificate; - // we only support managing certificates with one name per cert, - // so this should be easy. We can't rely on cert.Config.Hostname - // because it may be a wildcard value from the Caddyfile (e.g. - // *.something.com) which, as of Jan. 2017, is not supported by ACME. - var renewName string - for _, name := range cert.Names { - if name != "" { - renewName = name - break - } - } - - // perform renewal - err := cert.Config.RenewCert(renewName, allowPrompts) - if err != nil { - if allowPrompts { - // Certificate renewal failed and the operator is present. See a discussion - // about this in issue 642. For a while, we only stopped if the certificate - // was expired, but in reality, there is no difference between reporting - // it now versus later, except that there's somebody present to deal with - // it right now. - timeLeft := cert.NotAfter.Sub(time.Now().UTC()) - if timeLeft < RenewDurationBeforeAtStartup { - // See issue 1680. Only fail at startup if the certificate is dangerously - // close to expiration. - return err - } - } - log.Printf("[ERROR] %v", err) - if cert.Config.OnDemand { - // loaded dynamically, removed dynamically + // the list of names on this cert should never be empty... programmer error? + if cert.Names == nil || len(cert.Names) == 0 { + log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v - removing from cache", certKey, cert.Names) deleteQueue = append(deleteQueue, cert) + continue } - } else { + + // if time is up or expires soon, we need to try to renew it + timeLeft := cert.NotAfter.Sub(time.Now().UTC()) + if timeLeft < RenewDurationBefore { + // see if the certificate in storage has already been renewed, possibly by another + // instance of Caddy that didn't coordinate with this one; if so, just load it (this + // might happen if another instance already renewed it - kinda sloppy but checking disk + // first is a simple way to possibly drastically reduce rate limit problems) + storedCertExpiring, err := managedCertInStorageExpiresSoon(cert) + if err != nil { + // hmm, weird, but not a big deal, maybe it was deleted or something + log.Printf("[NOTICE] Error while checking if certificate for %v in storage is also expiring soon: %v", + cert.Names, err) + } else if !storedCertExpiring { + // if the certificate is NOT expiring soon and there was no error, then we + // are good to just reload the certificate from storage instead of repeating + // a likely-unnecessary renewal procedure + reloadQueue = append(reloadQueue, cert) + continue + } + + // the certificate in storage has not been renewed yet, so we will do it + // NOTE 1: This is not correct 100% of the time, if multiple Caddy instances + // happen to run their maintenance checks at approximately the same times; + // both might start renewal at about the same time and do two renewals and one + // will overwrite the other. Hence TLS storage plugins. This is sort of a TODO. + // NOTE 2: It is super-important to note that the TLS-SNI challenge requires + // a write lock on the cache in order to complete its challenge, so it is extra + // vital that this renew operation does not happen inside our read lock! + renewQueue = append(renewQueue, cert) + } + } + certCache.RUnlock() + + // Reload certificates that merely need to be updated in memory + for _, oldCert := range reloadQueue { + timeLeft := oldCert.NotAfter.Sub(time.Now().UTC()) + log.Printf("[INFO] Certificate for %v expires in %v, but is already renewed in storage; reloading stored certificate", + oldCert.Names, timeLeft) + + // get the certificate from storage and cache it + newCert, err := oldCert.configs[0].CacheManagedCertificate(oldCert.Names[0]) + if err != nil { + log.Printf("[ERROR] Unable to reload certificate for %v into cache: %v", oldCert.Names, err) + continue + } + + // and replace the old certificate with the new one + err = certCache.replaceCertificate(oldCert, newCert) + if err != nil { + log.Printf("[ERROR] Replacing certificate: %v", err) + } + } + + // Renewal queue + for _, oldCert := range renewQueue { + timeLeft := oldCert.NotAfter.Sub(time.Now().UTC()) + log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", oldCert.Names, timeLeft) + + // Get the name which we should use to renew this certificate; + // we only support managing certificates with one name per cert, + // so this should be easy. We can't rely on cert.Config.Hostname + // because it may be a wildcard value from the Caddyfile (e.g. + // *.something.com) which, as of Jan. 2017, is not supported by ACME. + // TODO: ^ ^ ^ (wildcards) + renewName := oldCert.Names[0] + + // perform renewal + err := oldCert.configs[0].RenewCert(renewName, allowPrompts) + if err != nil { + if allowPrompts { + // Certificate renewal failed and the operator is present. See a discussion + // about this in issue 642. For a while, we only stopped if the certificate + // was expired, but in reality, there is no difference between reporting + // it now versus later, except that there's somebody present to deal with + // it right now. Follow-up: See issue 1680. Only fail in this case if the + // certificate is dangerously close to expiration. + timeLeft := oldCert.NotAfter.Sub(time.Now().UTC()) + if timeLeft < RenewDurationBeforeAtStartup { + return err + } + } + log.Printf("[ERROR] %v", err) + if oldCert.configs[0].OnDemand { + // loaded dynamically, remove dynamically + deleteQueue = append(deleteQueue, oldCert) + } + continue + } + // successful renewal, so update in-memory cache by loading // renewed certificate so it will be used with handshakes - // we must delete all the names this cert services from the cache - // so that we can replace the certificate, because replacing names - // already in the cache is not allowed, to avoid later conflicts - // with renewals. - // TODO: It would be nice if this whole operation were idempotent; - // i.e. a thread-safe function to replace a certificate in the cache, - // see also handshake.go for on-demand maintenance. - certCacheMu.Lock() - for _, name := range cert.Names { - delete(certCache, name) - } - certCacheMu.Unlock() - // put the certificate in the cache - _, err := cert.Config.CacheManagedCertificate(cert.Names[0]) + newCert, err := oldCert.configs[0].CacheManagedCertificate(renewName) if err != nil { if allowPrompts { return err // operator is present, so report error immediately } log.Printf("[ERROR] %v", err) } - } - } - // Apply queued deletion changes to the cache - for _, cert := range deleteQueue { - certCacheMu.Lock() - for _, name := range cert.Names { - delete(certCache, name) + // replace the old certificate with the new one + err = certCache.replaceCertificate(oldCert, newCert) + if err != nil { + log.Printf("[ERROR] Replacing certificate: %v", err) + } + } + + // Deletion queue + for _, cert := range deleteQueue { + certCache.Lock() + // remove any pointers to this certificate from Configs + for _, cfg := range cert.configs { + for name, certKey := range cfg.Certificates { + if certKey == cert.Hash { + delete(cfg.Certificates, name) + } + } + } + // then delete the certificate from the cache + delete(certCache.cache, cert.Hash) + certCache.Unlock() } - certCacheMu.Unlock() } return nil @@ -212,91 +256,75 @@ func RenewManagedCertificates(allowPrompts bool) (err error) { // Ryan Sleevi's recommendations for good OCSP support: // https://gist.github.com/sleevi/5efe9ef98961ecfb4da8 func UpdateOCSPStaples() { - // Create a temporary place to store updates - // until we release the potentially long-lived - // read lock and use a short-lived write lock. - type ocspUpdate struct { - rawBytes []byte - parsed *ocsp.Response - } - updated := make(map[string]ocspUpdate) - - // A single SAN certificate maps to multiple names, so we use this - // set to make sure we don't waste cycles checking OCSP for the same - // certificate multiple times. - visited := make(map[string]struct{}) - - certCacheMu.RLock() - for name, cert := range certCache { - // skip this certificate if we've already visited it, - // and if not, mark all the names as visited - if _, ok := visited[name]; ok { - continue - } - for _, n := range cert.Names { - visited[n] = struct{}{} - } - - // no point in updating OCSP for expired certificates - if time.Now().After(cert.NotAfter) { + for _, inst := range caddy.Instances() { + inst.StorageMu.RLock() + certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache) + inst.StorageMu.RUnlock() + if !ok || certCache == nil { continue } - var lastNextUpdate time.Time - if cert.OCSP != nil { - lastNextUpdate = cert.OCSP.NextUpdate - if freshOCSP(cert.OCSP) { - // no need to update staple if ours is still fresh + // Create a temporary place to store updates + // until we release the potentially long-lived + // read lock and use a short-lived write lock + // on the certificate cache. + type ocspUpdate struct { + rawBytes []byte + parsed *ocsp.Response + } + updated := make(map[string]ocspUpdate) + + certCache.RLock() + for certHash, cert := range certCache.cache { + // no point in updating OCSP for expired certificates + if time.Now().After(cert.NotAfter) { continue } - } - err := stapleOCSP(&cert, nil) - if err != nil { + var lastNextUpdate time.Time if cert.OCSP != nil { - // if there was no staple before, that's fine; otherwise we should log the error - log.Printf("[ERROR] Checking OCSP: %v", err) + lastNextUpdate = cert.OCSP.NextUpdate + if freshOCSP(cert.OCSP) { + continue // no need to update staple if ours is still fresh + } } - continue - } - // By this point, we've obtained the latest OCSP response. - // If there was no staple before, or if the response is updated, make - // sure we apply the update to all names on the certificate. - if cert.OCSP != nil && (lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate) { - log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s", - cert.Names, lastNextUpdate, cert.OCSP.NextUpdate) - for _, n := range cert.Names { - // BUG: If this certificate has names on it that appear on another - // certificate in the cache, AND the other certificate is keyed by - // that name in the cache, then this method of 'queueing' the staple - // update will cause this certificate's new OCSP to be stapled to - // a different certificate! See: - // https://caddy.community/t/random-ocsp-response-errors-for-random-clients/2473?u=matt - // This problem should be avoided if names on certificates in the - // cache don't overlap with regards to the cache keys. - // (This is isn't a bug anymore, since we're careful when we add - // certificates to the cache by skipping keying when key already exists.) - updated[n] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP} + err := stapleOCSP(&cert, nil) + if err != nil { + if cert.OCSP != nil { + // if there was no staple before, that's fine; otherwise we should log the error + log.Printf("[ERROR] Checking OCSP: %v", err) + } + continue + } + + // By this point, we've obtained the latest OCSP response. + // If there was no staple before, or if the response is updated, make + // sure we apply the update to all names on the certificate. + if cert.OCSP != nil && (lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate) { + log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s", + cert.Names, lastNextUpdate, cert.OCSP.NextUpdate) + updated[certHash] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP} } } - } - certCacheMu.RUnlock() + certCache.RUnlock() - // This write lock should be brief since we have all the info we need now. - certCacheMu.Lock() - for name, update := range updated { - cert := certCache[name] - cert.OCSP = update.parsed - cert.Certificate.OCSPStaple = update.rawBytes - certCache[name] = cert + // These write locks should be brief since we have all the info we need now. + for certKey, update := range updated { + certCache.Lock() + cert := certCache.cache[certKey] + cert.OCSP = update.parsed + cert.Certificate.OCSPStaple = update.rawBytes + certCache.cache[certKey] = cert + certCache.Unlock() + } } - certCacheMu.Unlock() } // DeleteOldStapleFiles deletes cached OCSP staples that have expired. // TODO: Should we do this for certificates too? func DeleteOldStapleFiles() { + // TODO: Upgrade caddytls.Storage to support OCSP operations too files, err := ioutil.ReadDir(ocspFolder) if err != nil { // maybe just hasn't been created yet; no big deal diff --git a/caddytls/setup.go b/caddytls/setup.go index cbc2baca1..63c2a9e6d 100644 --- a/caddytls/setup.go +++ b/caddytls/setup.go @@ -38,6 +38,7 @@ func init() { // are specified by the user in the config file. All the automatic HTTPS // stuff comes later outside of this function. func setupTLS(c *caddy.Controller) error { + // obtain the configGetter, which loads the config we're, uh, configuring configGetter, ok := configGetters[c.ServerType()] if !ok { return fmt.Errorf("no caddytls.ConfigGetter for %s server type; must call RegisterConfigGetter", c.ServerType()) @@ -47,6 +48,14 @@ func setupTLS(c *caddy.Controller) error { return fmt.Errorf("no caddytls.Config to set up for %s", c.Key) } + // the certificate cache is tied to the current caddy.Instance; get a pointer to it + certCache, ok := c.Get(CertCacheInstStorageKey).(*certificateCache) + if !ok || certCache == nil { + certCache = &certificateCache{cache: make(map[string]Certificate)} + c.Set(CertCacheInstStorageKey, certCache) + } + config.certCache = certCache + config.Enabled = true for c.Next() { @@ -237,7 +246,7 @@ func setupTLS(c *caddy.Controller) error { // load a single certificate and key, if specified if certificateFile != "" && keyFile != "" { - err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile) + err := config.cacheUnmanagedCertificatePEMFile(certificateFile, keyFile) if err != nil { return c.Errf("Unable to load certificate and key files for '%s': %v", c.Key, err) } @@ -246,7 +255,7 @@ func setupTLS(c *caddy.Controller) error { // load a directory of certificates, if specified if loadDir != "" { - err := loadCertsInDir(c, loadDir) + err := loadCertsInDir(config, c, loadDir) if err != nil { return err } @@ -273,7 +282,7 @@ func setupTLS(c *caddy.Controller) error { // https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt // // This function may write to the log as it walks the directory tree. -func loadCertsInDir(c *caddy.Controller, dir string) error { +func loadCertsInDir(cfg *Config, c *caddy.Controller, dir string) error { return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { if err != nil { log.Printf("[WARNING] Unable to traverse into %s; skipping", path) @@ -336,7 +345,7 @@ func loadCertsInDir(c *caddy.Controller, dir string) error { return c.Errf("%s: no private key block found", path) } - err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes) + err = cfg.cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes) if err != nil { return c.Errf("%s: failed to load cert and key for '%s': %v", path, c.Key, err) } diff --git a/caddytls/setup_test.go b/caddytls/setup_test.go index ee8a709bd..b93b1fc5f 100644 --- a/caddytls/setup_test.go +++ b/caddytls/setup_test.go @@ -46,9 +46,12 @@ func TestMain(m *testing.M) { } func TestSetupParseBasic(t *testing.T) { - cfg := new(Config) + certCache := &certificateCache{cache: make(map[string]Certificate)} + cfg := &Config{Certificates: make(map[string]string), certCache: certCache} + RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) c := caddy.NewTestController("", `tls `+certFile+` `+keyFile+``) + c.Set(CertCacheInstStorageKey, certCache) err := setupTLS(c) if err != nil { @@ -124,9 +127,12 @@ func TestSetupParseWithOptionalParams(t *testing.T) { must_staple alpn http/1.1 }` - cfg := new(Config) + certCache := &certificateCache{cache: make(map[string]Certificate)} + cfg := &Config{Certificates: make(map[string]string), certCache: certCache} + RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) c := caddy.NewTestController("", params) + c.Set(CertCacheInstStorageKey, certCache) err := setupTLS(c) if err != nil { @@ -158,9 +164,11 @@ func TestSetupDefaultWithOptionalParams(t *testing.T) { params := `tls { ciphers RSA-3DES-EDE-CBC-SHA }` - cfg := new(Config) + certCache := &certificateCache{cache: make(map[string]Certificate)} + cfg := &Config{Certificates: make(map[string]string), certCache: certCache} RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) c := caddy.NewTestController("", params) + c.Set(CertCacheInstStorageKey, certCache) err := setupTLS(c) if err != nil { @@ -176,9 +184,12 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { params := `tls ` + certFile + ` ` + keyFile + ` { protocols ssl tls }` - cfg := new(Config) + certCache := &certificateCache{cache: make(map[string]Certificate)} + cfg := &Config{Certificates: make(map[string]string), certCache: certCache} RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) c := caddy.NewTestController("", params) + c.Set(CertCacheInstStorageKey, certCache) + err := setupTLS(c) if err == nil { t.Errorf("Expected errors, but no error returned") @@ -191,6 +202,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { cfg = new(Config) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) c = caddy.NewTestController("", params) + c.Set(CertCacheInstStorageKey, certCache) err = setupTLS(c) if err == nil { t.Error("Expected errors, but no error returned") @@ -215,6 +227,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { cfg = new(Config) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) c = caddy.NewTestController("", params) + c.Set(CertCacheInstStorageKey, certCache) err = setupTLS(c) if err == nil { t.Error("Expected errors, but no error returned") @@ -226,7 +239,8 @@ func TestSetupParseWithClientAuth(t *testing.T) { params := `tls ` + certFile + ` ` + keyFile + ` { clients }` - cfg := new(Config) + certCache := &certificateCache{cache: make(map[string]Certificate)} + cfg := &Config{Certificates: make(map[string]string), certCache: certCache} RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) c := caddy.NewTestController("", params) err := setupTLS(c) @@ -259,9 +273,11 @@ func TestSetupParseWithClientAuth(t *testing.T) { clients verify_if_given }`, tls.VerifyClientCertIfGiven, true, noCAs}, } { - cfg := new(Config) + certCache := &certificateCache{cache: make(map[string]Certificate)} + cfg := &Config{Certificates: make(map[string]string), certCache: certCache} RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) c := caddy.NewTestController("", caseData.params) + c.Set(CertCacheInstStorageKey, certCache) err := setupTLS(c) if caseData.expectedErr { if err == nil { @@ -311,9 +327,11 @@ func TestSetupParseWithCAUrl(t *testing.T) { ca 1 2 }`, true, ""}, } { - cfg := new(Config) + certCache := &certificateCache{cache: make(map[string]Certificate)} + cfg := &Config{Certificates: make(map[string]string), certCache: certCache} RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) c := caddy.NewTestController("", caseData.params) + c.Set(CertCacheInstStorageKey, certCache) err := setupTLS(c) if caseData.expectedErr { if err == nil { @@ -335,9 +353,11 @@ func TestSetupParseWithKeyType(t *testing.T) { params := `tls { key_type p384 }` - cfg := new(Config) + certCache := &certificateCache{cache: make(map[string]Certificate)} + cfg := &Config{Certificates: make(map[string]string), certCache: certCache} RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) c := caddy.NewTestController("", params) + c.Set(CertCacheInstStorageKey, certCache) err := setupTLS(c) if err != nil { @@ -353,9 +373,11 @@ func TestSetupParseWithCurves(t *testing.T) { params := `tls { curves x25519 p256 p384 p521 }` - cfg := new(Config) + certCache := &certificateCache{cache: make(map[string]Certificate)} + cfg := &Config{Certificates: make(map[string]string), certCache: certCache} RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) c := caddy.NewTestController("", params) + c.Set(CertCacheInstStorageKey, certCache) err := setupTLS(c) if err != nil { @@ -380,9 +402,11 @@ func TestSetupParseWithOneTLSProtocol(t *testing.T) { params := `tls { protocols tls1.2 }` - cfg := new(Config) + certCache := &certificateCache{cache: make(map[string]Certificate)} + cfg := &Config{Certificates: make(map[string]string), certCache: certCache} RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) c := caddy.NewTestController("", params) + c.Set(CertCacheInstStorageKey, certCache) err := setupTLS(c) if err != nil { diff --git a/caddytls/tls.go b/caddytls/tls.go index 9a17ddd3d..bf1a8301e 100644 --- a/caddytls/tls.go +++ b/caddytls/tls.go @@ -88,30 +88,38 @@ func Revoke(host string) error { return client.Revoke(host) } -// tlsSniSolver is a type that can solve tls-sni challenges using +// tlsSNISolver is a type that can solve TLS-SNI challenges using // an existing listener and our custom, in-memory certificate cache. -type tlsSniSolver struct{} +type tlsSNISolver struct { + certCache *certificateCache +} // Present adds the challenge certificate to the cache. -func (s tlsSniSolver) Present(domain, token, keyAuth string) error { +func (s tlsSNISolver) Present(domain, token, keyAuth string) error { cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth) if err != nil { return err } - cacheCertificate(Certificate{ + certHash := hashCertificateChain(cert.Certificate) + s.certCache.Lock() + s.certCache.cache[acmeDomain] = Certificate{ Certificate: cert, Names: []string{acmeDomain}, - }) + Hash: certHash, // perhaps not necesssary + } + s.certCache.Unlock() return nil } // CleanUp removes the challenge certificate from the cache. -func (s tlsSniSolver) CleanUp(domain, token, keyAuth string) error { +func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error { _, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth) if err != nil { return err } - uncacheCertificate(acmeDomain) + s.certCache.Lock() + delete(s.certCache.cache, acmeDomain) + s.certCache.Unlock() return nil } diff --git a/controller.go b/controller.go index e162280c0..6015d210f 100644 --- a/controller.go +++ b/controller.go @@ -103,6 +103,20 @@ func (c *Controller) Context() Context { return c.instance.context } +// Get safely gets a value from the Instance's storage. +func (c *Controller) Get(key interface{}) interface{} { + c.instance.StorageMu.RLock() + defer c.instance.StorageMu.RUnlock() + return c.instance.Storage[key] +} + +// Set safely sets a value on the Instance's storage. +func (c *Controller) Set(key, val interface{}) { + c.instance.StorageMu.Lock() + c.instance.Storage[key] = val + c.instance.StorageMu.Unlock() +} + // NewTestController creates a new Controller for // the server type and input specified. The filename // is "Testfile". If the server type is not empty and @@ -113,12 +127,12 @@ func (c *Controller) Context() Context { // Used only for testing, but exported so plugins can // use this for convenience. func NewTestController(serverType, input string) *Controller { - var ctx Context + testInst := &Instance{serverType: serverType, Storage: make(map[interface{}]interface{})} if stype, err := getServerType(serverType); err == nil { - ctx = stype.NewContext() + testInst.context = stype.NewContext(testInst) } return &Controller{ - instance: &Instance{serverType: serverType, context: ctx}, + instance: testInst, Dispenser: caddyfile.NewDispenser("Testfile", strings.NewReader(input)), OncePerServerBlock: func(f func() error) error { return f() }, } diff --git a/plugins.go b/plugins.go index f5372184e..f7d14f86b 100644 --- a/plugins.go +++ b/plugins.go @@ -191,7 +191,7 @@ type ServerType struct { // startup phases before this one. It's a way to keep // each set of server instances separate and to reduce // the amount of global state you need. - NewContext func() Context + NewContext func(inst *Instance) Context } // Plugin is a type which holds information about a plugin. From 592d1993150f9cede58e4c5013bc79c0ba0cdbbe Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Sun, 11 Feb 2018 13:30:01 -0700 Subject: [PATCH 07/19] staticfiles: Prevent path-based open redirects Not a huge issue, but has security implications if OAuth tokens leaked --- caddyhttp/staticfiles/fileserver.go | 8 ++++++ caddyhttp/staticfiles/fileserver_test.go | 32 ++++++++++++++++++++---- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/caddyhttp/staticfiles/fileserver.go b/caddyhttp/staticfiles/fileserver.go index 2b38212ea..91fb1a7f5 100644 --- a/caddyhttp/staticfiles/fileserver.go +++ b/caddyhttp/staticfiles/fileserver.go @@ -107,6 +107,10 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request) (int, err if d.IsDir() { // ensure there is a trailing slash if urlCopy.Path[len(urlCopy.Path)-1] != '/' { + for strings.HasPrefix(urlCopy.Path, "//") { + // prevent path-based open redirects + urlCopy.Path = strings.TrimPrefix(urlCopy.Path, "/") + } urlCopy.Path += "/" http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently) return http.StatusMovedPermanently, nil @@ -131,6 +135,10 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request) (int, err } if redir { + for strings.HasPrefix(urlCopy.Path, "//") { + // prevent path-based open redirects + urlCopy.Path = strings.TrimPrefix(urlCopy.Path, "/") + } http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently) return http.StatusMovedPermanently, nil } diff --git a/caddyhttp/staticfiles/fileserver_test.go b/caddyhttp/staticfiles/fileserver_test.go index 9cce77057..80d8f1a40 100644 --- a/caddyhttp/staticfiles/fileserver_test.go +++ b/caddyhttp/staticfiles/fileserver_test.go @@ -77,9 +77,9 @@ func TestServeHTTP(t *testing.T) { { url: "https://foo/dirwithindex/", expectedStatus: http.StatusOK, - expectedBodyContent: testFiles[webrootDirwithindexIndeHTML], + expectedBodyContent: testFiles[webrootDirwithindexIndexHTML], expectedEtag: `"2n9cw"`, - expectedContentLength: strconv.Itoa(len(testFiles[webrootDirwithindexIndeHTML])), + expectedContentLength: strconv.Itoa(len(testFiles[webrootDirwithindexIndexHTML])), }, // Test 4 - access folder with index file without trailing slash { @@ -235,16 +235,38 @@ func TestServeHTTP(t *testing.T) { expectedBodyContent: movedPermanently, }, { + // Test 27 - Check etag url: "https://foo/notindex.html", expectedStatus: http.StatusOK, expectedBodyContent: testFiles[webrootNotIndexHTML], expectedEtag: `"2n9cm"`, expectedContentLength: strconv.Itoa(len(testFiles[webrootNotIndexHTML])), }, + { + // Test 28 - Prevent path-based open redirects (directory) + url: "https://foo//example.com%2f..", + expectedStatus: http.StatusMovedPermanently, + expectedLocation: "https://foo/example.com/../", + expectedBodyContent: movedPermanently, + }, + { + // Test 29 - Prevent path-based open redirects (file) + url: "https://foo//example.com%2f../dirwithindex/index.html", + expectedStatus: http.StatusMovedPermanently, + expectedLocation: "https://foo/example.com/../dirwithindex/", + expectedBodyContent: movedPermanently, + }, + { + // Test 29 - Prevent path-based open redirects (extra leading slashes) + url: "https://foo///example.com%2f..", + expectedStatus: http.StatusMovedPermanently, + expectedLocation: "https://foo/example.com/../", + expectedBodyContent: movedPermanently, + }, } for i, test := range tests { - // set up response writer and rewuest + // set up response writer and request responseRecorder := httptest.NewRecorder() request, err := http.NewRequest("GET", test.url, nil) if err != nil { @@ -518,7 +540,7 @@ var ( webrootNotIndexHTML = filepath.Join(webrootName, "notindex.html") webrootDirFile2HTML = filepath.Join(webrootName, "dir", "file2.html") webrootDirHiddenHTML = filepath.Join(webrootName, "dir", "hidden.html") - webrootDirwithindexIndeHTML = filepath.Join(webrootName, "dirwithindex", "index.html") + webrootDirwithindexIndexHTML = filepath.Join(webrootName, "dirwithindex", "index.html") webrootSubGzippedHTML = filepath.Join(webrootName, "sub", "gzipped.html") webrootSubGzippedHTMLGz = filepath.Join(webrootName, "sub", "gzipped.html.gz") webrootSubGzippedHTMLBr = filepath.Join(webrootName, "sub", "gzipped.html.br") @@ -544,7 +566,7 @@ var testFiles = map[string]string{ webrootFile1HTML: "

file1.html

", webrootNotIndexHTML: "

notindex.html

", webrootDirFile2HTML: "

dir/file2.html

", - webrootDirwithindexIndeHTML: "

dirwithindex/index.html

", + webrootDirwithindexIndexHTML: "

dirwithindex/index.html

", webrootDirHiddenHTML: "

dir/hidden.html

", webrootSubGzippedHTML: "

gzipped.html

", webrootSubGzippedHTMLGz: "1.gzipped.html.gz", From 6a9aea04b1a939c8a3d7aa4774839ce4426562db Mon Sep 17 00:00:00 2001 From: Etienne Bruines Date: Sun, 11 Feb 2018 22:45:45 +0100 Subject: [PATCH 08/19] fastcig: GET requests send along the body (#1975) Fixes #1961 According to RFC 7231 and RFC 7230, there's no reason a GET-Request can't have a body (other than it possibly not being supported by existing software). It's use is simply not defined, and is left to the application. --- caddyhttp/fastcgi/fastcgi.go | 2 +- caddyhttp/fastcgi/fcgiclient.go | 6 +++--- caddyhttp/fastcgi/fcgiclient_test.go | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/caddyhttp/fastcgi/fastcgi.go b/caddyhttp/fastcgi/fastcgi.go index ee466a3e8..28ea55f9f 100644 --- a/caddyhttp/fastcgi/fastcgi.go +++ b/caddyhttp/fastcgi/fastcgi.go @@ -148,7 +148,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) case "HEAD": resp, err = fcgiBackend.Head(env) case "GET": - resp, err = fcgiBackend.Get(env) + resp, err = fcgiBackend.Get(env, r.Body, contentLength) case "OPTIONS": resp, err = fcgiBackend.Options(env) default: diff --git a/caddyhttp/fastcgi/fcgiclient.go b/caddyhttp/fastcgi/fcgiclient.go index adf37d09a..b5fd1d9ea 100644 --- a/caddyhttp/fastcgi/fcgiclient.go +++ b/caddyhttp/fastcgi/fcgiclient.go @@ -460,12 +460,12 @@ func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Res } // Get issues a GET request to the fcgi responder. -func (c *FCGIClient) Get(p map[string]string) (resp *http.Response, err error) { +func (c *FCGIClient) Get(p map[string]string, body io.Reader, l int64) (resp *http.Response, err error) { p["REQUEST_METHOD"] = "GET" - p["CONTENT_LENGTH"] = "0" + p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10) - return c.Request(p, nil) + return c.Request(p, body) } // Head issues a HEAD request to the fcgi responder. diff --git a/caddyhttp/fastcgi/fcgiclient_test.go b/caddyhttp/fastcgi/fcgiclient_test.go index ef4981d48..9c5237f20 100644 --- a/caddyhttp/fastcgi/fcgiclient_test.go +++ b/caddyhttp/fastcgi/fcgiclient_test.go @@ -140,7 +140,8 @@ func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[ } resp, err = fcgi.PostForm(fcgiParams, values) } else { - resp, err = fcgi.Get(fcgiParams) + rd := bytes.NewReader(data) + resp, err = fcgi.Get(fcgiParams, rd, int64(rd.Len())) } default: From d29640699eca5e3f64ba15f0dd5a8d452c495058 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Tue, 13 Feb 2018 09:30:26 -0700 Subject: [PATCH 09/19] readme: Update logo image --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2d42d2e6c..bdbffd006 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@

- Caddy + Caddy

Every Site on HTTPS

Caddy is a general-purpose HTTP/2 web server that serves HTTPS by default.

From 08028714b57540a46209b6ab094c0a9cc103830f Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Tue, 13 Feb 2018 13:23:09 -0700 Subject: [PATCH 10/19] tls: Synchronize renewals between Caddy instances sharing file storage Also introduce caddy.OnProcessExit which is a list of functions that run before exiting the process cleanly; these do not count as shutdown callbacks, so they do not return errors and must execute quickly. --- caddy.go | 8 +++ caddytls/certificates.go | 76 ++++++++++++++-------- caddytls/client.go | 54 +++++----------- caddytls/filestorage.go | 7 +- caddytls/filestoragesync.go | 123 ++++++++++++++++++++++++++++++++++++ caddytls/maintain.go | 25 ++------ caddytls/storage.go | 4 ++ caddytls/sync_locker.go | 57 ----------------- plugins.go | 8 +++ sigtrap.go | 9 +-- sigtrap_posix.go | 8 +-- 11 files changed, 227 insertions(+), 152 deletions(-) create mode 100644 caddytls/filestoragesync.go delete mode 100644 caddytls/sync_locker.go diff --git a/caddy.go b/caddy.go index dd2d473a9..628673a73 100644 --- a/caddy.go +++ b/caddy.go @@ -77,6 +77,14 @@ var ( mu sync.Mutex ) +func init() { + OnProcessExit = append(OnProcessExit, func() { + if PidFile != "" { + os.Remove(PidFile) + } + }) +} + // Instance contains the state of servers created as a result of // calling Start and can be used to access or control those servers. // It is literally an instance of a server type. Instance values diff --git a/caddytls/certificates.go b/caddytls/certificates.go index 2df576ff3..29c0c8c21 100644 --- a/caddytls/certificates.go +++ b/caddytls/certificates.go @@ -29,34 +29,6 @@ import ( "golang.org/x/crypto/ocsp" ) -// Certificate is a tls.Certificate with associated metadata tacked on. -// Even if the metadata can be obtained by parsing the certificate, -// we are more efficient by extracting the metadata onto this struct. -type Certificate struct { - tls.Certificate - - // Names is the list of names this certificate is written for. - // The first is the CommonName (if any), the rest are SAN. - Names []string - - // NotAfter is when the certificate expires. - NotAfter time.Time - - // OCSP contains the certificate's parsed OCSP response. - OCSP *ocsp.Response - - // The hex-encoded hash of this cert's chain's bytes. - Hash string - - // configs is the list of configs that use or refer to - // The first one is assumed to be the config that is - // "in charge" of this certificate (i.e. determines - // whether it is managed, how it is managed, etc). - // This field will be populated by cacheCertificate. - // Only meddle with it if you know what you're doing! - configs []*Config -} - // certificateCache is to be an instance-wide cache of certs // that site-specific TLS configs can refer to. Using a // central map like this avoids duplication of certs in @@ -127,6 +99,54 @@ func (certCache *certificateCache) replaceCertificate(oldCert, newCert Certifica return nil } +// reloadManagedCertificate reloads the certificate corresponding to the name(s) +// on oldCert into the cache, from storage. This also replaces the old certificate +// with the new one, so that all configurations that used the old cert now point +// to the new cert. +func (certCache *certificateCache) reloadManagedCertificate(oldCert Certificate) error { + // get the certificate from storage and cache it + newCert, err := oldCert.configs[0].CacheManagedCertificate(oldCert.Names[0]) + if err != nil { + return fmt.Errorf("unable to reload certificate for %v into cache: %v", oldCert.Names, err) + } + + // and replace the old certificate with the new one + err = certCache.replaceCertificate(oldCert, newCert) + if err != nil { + return fmt.Errorf("replacing certificate %v: %v", oldCert.Names, err) + } + + return nil +} + +// Certificate is a tls.Certificate with associated metadata tacked on. +// Even if the metadata can be obtained by parsing the certificate, +// we are more efficient by extracting the metadata onto this struct. +type Certificate struct { + tls.Certificate + + // Names is the list of names this certificate is written for. + // The first is the CommonName (if any), the rest are SAN. + Names []string + + // NotAfter is when the certificate expires. + NotAfter time.Time + + // OCSP contains the certificate's parsed OCSP response. + OCSP *ocsp.Response + + // The hex-encoded hash of this cert's chain's bytes. + Hash string + + // configs is the list of configs that use or refer to + // The first one is assumed to be the config that is + // "in charge" of this certificate (i.e. determines + // whether it is managed, how it is managed, etc). + // This field will be populated by cacheCertificate. + // Only meddle with it if you know what you're doing! + configs []*Config +} + // CacheManagedCertificate loads the certificate for domain into the // cache, from the TLS storage for managed certificates. It returns a // copy of the Certificate that was put into the cache. diff --git a/caddytls/client.go b/caddytls/client.go index 4775a2d18..cdd715b97 100644 --- a/caddytls/client.go +++ b/caddytls/client.go @@ -39,7 +39,7 @@ type ACMEClient struct { AllowPrompts bool config *Config acmeClient *acme.Client - locker Locker + storage Storage } // newACMEClient creates a new ACMEClient given an email and whether @@ -121,10 +121,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) AllowPrompts: allowPrompts, config: config, acmeClient: client, - locker: &syncLock{ - nameLocks: make(map[string]*sync.WaitGroup), - nameLocksMu: sync.Mutex{}, - }, + storage: storage, } if config.DNSProvider == "" { @@ -209,13 +206,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) // Callers who have access to a Config value should use the ObtainCert // method on that instead of this lower-level method. func (c *ACMEClient) Obtain(name string) error { - // Get access to ACME storage - storage, err := c.config.StorageFor(c.config.CAUrl) - if err != nil { - return err - } - - waiter, err := c.locker.TryLock(name) + waiter, err := c.storage.TryLock(name) if err != nil { return err } @@ -225,7 +216,7 @@ func (c *ACMEClient) Obtain(name string) error { return nil // we assume the process with the lock succeeded, rather than hammering this execution path again } defer func() { - if err := c.locker.Unlock(name); err != nil { + if err := c.storage.Unlock(name); err != nil { log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err) } }() @@ -268,7 +259,7 @@ Attempts: } // Success - immediately save the certificate resource - err = saveCertResource(storage, certificate) + err = saveCertResource(c.storage, certificate) if err != nil { return fmt.Errorf("error saving assets for %v: %v", name, err) } @@ -279,35 +270,30 @@ Attempts: return nil } -// Renew renews the managed certificate for name. This function is -// safe for concurrent use. +// Renew renews the managed certificate for name. It puts the renewed +// certificate into storage (not the cache). This function is safe for +// concurrent use. // // Callers who have access to a Config value should use the RenewCert // method on that instead of this lower-level method. func (c *ACMEClient) Renew(name string) error { - // Get access to ACME storage - storage, err := c.config.StorageFor(c.config.CAUrl) - if err != nil { - return err - } - - waiter, err := c.locker.TryLock(name) + waiter, err := c.storage.TryLock(name) if err != nil { return err } if waiter != nil { log.Printf("[INFO] Certificate for %s is already being renewed elsewhere and stored; waiting", name) waiter.Wait() - return nil // we assume the process with the lock succeeded, rather than hammering this execution path again + return nil // assume that the worker that renewed the cert succeeded; avoid hammering this path over and over } defer func() { - if err := c.locker.Unlock(name); err != nil { + if err := c.storage.Unlock(name); err != nil { log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err) } }() // Prepare for renewal (load PEM cert, key, and meta) - siteData, err := storage.LoadSite(name) + siteData, err := c.storage.LoadSite(name) if err != nil { return err } @@ -350,21 +336,15 @@ func (c *ACMEClient) Renew(name string) error { return errors.New("too many renewal attempts; last error: " + err.Error()) } - // Executes Cert renew events caddy.EmitEvent(caddy.CertRenewEvent, name) - return saveCertResource(storage, newCertMeta) + return saveCertResource(c.storage, newCertMeta) } -// Revoke revokes the certificate for name and deltes +// Revoke revokes the certificate for name and deletes // it from storage. func (c *ACMEClient) Revoke(name string) error { - storage, err := c.config.StorageFor(c.config.CAUrl) - if err != nil { - return err - } - - siteExists, err := storage.SiteExists(name) + siteExists, err := c.storage.SiteExists(name) if err != nil { return err } @@ -373,7 +353,7 @@ func (c *ACMEClient) Revoke(name string) error { return errors.New("no certificate and key for " + name) } - siteData, err := storage.LoadSite(name) + siteData, err := c.storage.LoadSite(name) if err != nil { return err } @@ -383,7 +363,7 @@ func (c *ACMEClient) Revoke(name string) error { return err } - err = storage.DeleteSite(name) + err = c.storage.DeleteSite(name) if err != nil { return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error()) } diff --git a/caddytls/filestorage.go b/caddytls/filestorage.go index 67084ef45..9dd2d8494 100644 --- a/caddytls/filestorage.go +++ b/caddytls/filestorage.go @@ -38,9 +38,9 @@ var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme") // Storage instance backed by the local disk. The resulting Storage // instance is guaranteed to be non-nil if there is no error. func NewFileStorage(caURL *url.URL) (Storage, error) { - return &FileStorage{ - Path: filepath.Join(storageBasePath, caURL.Host), - }, nil + storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)} + storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage} + return storage, nil } // FileStorage facilitates forming file paths derived from a root @@ -48,6 +48,7 @@ func NewFileStorage(caURL *url.URL) (Storage, error) { // cross-platform way or persisting ACME assets on the file system. type FileStorage struct { Path string + Locker } // sites gets the directory that stores site certificate and keys. diff --git a/caddytls/filestoragesync.go b/caddytls/filestoragesync.go new file mode 100644 index 000000000..4c81ca02e --- /dev/null +++ b/caddytls/filestoragesync.go @@ -0,0 +1,123 @@ +// Copyright 2015 Light Code Labs, LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package caddytls + +import ( + "fmt" + "os" + "sync" + "time" + + "github.com/mholt/caddy" +) + +func init() { + // be sure to remove lock files when exiting the process! + caddy.OnProcessExit = append(caddy.OnProcessExit, func() { + fileStorageNameLocksMu.Lock() + defer fileStorageNameLocksMu.Unlock() + for key, fw := range fileStorageNameLocks { + os.Remove(fw.filename) + delete(fileStorageNameLocks, key) + } + }) +} + +// fileStorageLock facilitates ACME-related locking by using +// the associated FileStorage, so multiple processes can coordinate +// renewals on the certificates on a shared file system. +type fileStorageLock struct { + caURL string + storage *FileStorage +} + +// TryLock attempts to get a lock for name, otherwise it returns +// a Waiter value to wait until the other process is finished. +func (s *fileStorageLock) TryLock(name string) (Waiter, error) { + fileStorageNameLocksMu.Lock() + defer fileStorageNameLocksMu.Unlock() + + // see if lock already exists within this process + fw, ok := fileStorageNameLocks[s.caURL+name] + if ok { + // lock already created within process, let caller wait on it + return fw, nil + } + + // attempt to persist lock to disk by creating lock file + fw = &fileWaiter{ + filename: s.storage.siteCertFile(name) + ".lock", + wg: new(sync.WaitGroup), + } + lf, err := os.OpenFile(fw.filename, os.O_CREATE|os.O_EXCL, 0644) + if err != nil { + if os.IsExist(err) { + // another process has the lock; use it to wait + return fw, nil + } + // otherwise, this was some unexpected error + return nil, err + } + lf.Close() + + // looks like we get the lock + fw.wg.Add(1) + fileStorageNameLocks[s.caURL+name] = fw + + return nil, nil +} + +// Unlock unlocks name. +func (s *fileStorageLock) Unlock(name string) error { + fileStorageNameLocksMu.Lock() + defer fileStorageNameLocksMu.Unlock() + fw, ok := fileStorageNameLocks[s.caURL+name] + if !ok { + return fmt.Errorf("FileStorage: no lock to release for %s", name) + } + os.Remove(fw.filename) + fw.wg.Done() + delete(fileStorageNameLocks, s.caURL+name) + return nil +} + +// fileWaiter waits for a file to disappear; it polls +// the file system to check for the existence of a file. +// It also has a WaitGroup which will be faster than +// polling, for when locking need only happen within this +// process. +type fileWaiter struct { + filename string + wg *sync.WaitGroup +} + +// Wait waits until the lock is released. +func (fw *fileWaiter) Wait() { + start := time.Now() + fw.wg.Wait() + for time.Since(start) < 1*time.Hour { + _, err := os.Stat(fw.filename) + if os.IsNotExist(err) { + return + } + time.Sleep(1 * time.Second) + } +} + +var fileStorageNameLocks = make(map[string]*fileWaiter) // keyed by CA + name +var fileStorageNameLocksMu sync.Mutex + +var _ Locker = &fileStorageLock{} +var _ Waiter = &fileWaiter{} diff --git a/caddytls/maintain.go b/caddytls/maintain.go index 7ce6c5e26..5e867d4b8 100644 --- a/caddytls/maintain.go +++ b/caddytls/maintain.go @@ -160,17 +160,12 @@ func RenewManagedCertificates(allowPrompts bool) (err error) { log.Printf("[INFO] Certificate for %v expires in %v, but is already renewed in storage; reloading stored certificate", oldCert.Names, timeLeft) - // get the certificate from storage and cache it - newCert, err := oldCert.configs[0].CacheManagedCertificate(oldCert.Names[0]) + err = certCache.reloadManagedCertificate(oldCert) if err != nil { - log.Printf("[ERROR] Unable to reload certificate for %v into cache: %v", oldCert.Names, err) - continue - } - - // and replace the old certificate with the new one - err = certCache.replaceCertificate(oldCert, newCert) - if err != nil { - log.Printf("[ERROR] Replacing certificate: %v", err) + if allowPrompts { + return err // operator is present, so report error immediately + } + log.Printf("[ERROR] Loading renewed certificate: %v", err) } } @@ -212,21 +207,13 @@ func RenewManagedCertificates(allowPrompts bool) (err error) { // successful renewal, so update in-memory cache by loading // renewed certificate so it will be used with handshakes - - // put the certificate in the cache - newCert, err := oldCert.configs[0].CacheManagedCertificate(renewName) + err = certCache.reloadManagedCertificate(oldCert) if err != nil { if allowPrompts { return err // operator is present, so report error immediately } log.Printf("[ERROR] %v", err) } - - // replace the old certificate with the new one - err = certCache.replaceCertificate(oldCert, newCert) - if err != nil { - log.Printf("[ERROR] Replacing certificate: %v", err) - } } // Deletion queue diff --git a/caddytls/storage.go b/caddytls/storage.go index 8587dd026..05606ed92 100644 --- a/caddytls/storage.go +++ b/caddytls/storage.go @@ -107,6 +107,10 @@ type Storage interface { // in StoreUser. The result is an empty string if there are no // persisted users in storage. MostRecentUserEmail() string + + // Locker is necessary because synchronizing certificate maintenance + // depends on how storage is implemented. + Locker } // ErrNotExist is returned by Storage implementations when diff --git a/caddytls/sync_locker.go b/caddytls/sync_locker.go deleted file mode 100644 index 693f3b875..000000000 --- a/caddytls/sync_locker.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2015 Light Code Labs, LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package caddytls - -import ( - "fmt" - "sync" -) - -var _ Locker = &syncLock{} - -type syncLock struct { - nameLocks map[string]*sync.WaitGroup - nameLocksMu sync.Mutex -} - -// TryLock attempts to get a lock for name, otherwise it returns -// a Waiter value to wait until the other process is finished. -func (s *syncLock) TryLock(name string) (Waiter, error) { - s.nameLocksMu.Lock() - defer s.nameLocksMu.Unlock() - wg, ok := s.nameLocks[name] - if ok { - // lock already obtained, let caller wait on it - return wg, nil - } - // caller gets lock - wg = new(sync.WaitGroup) - wg.Add(1) - s.nameLocks[name] = wg - return nil, nil -} - -// Unlock unlocks name. -func (s *syncLock) Unlock(name string) error { - s.nameLocksMu.Lock() - defer s.nameLocksMu.Unlock() - wg, ok := s.nameLocks[name] - if !ok { - return fmt.Errorf("FileStorage: no lock to release for %s", name) - } - wg.Done() - delete(s.nameLocks, name) - return nil -} diff --git a/plugins.go b/plugins.go index f7d14f86b..ba1114034 100644 --- a/plugins.go +++ b/plugins.go @@ -383,6 +383,14 @@ func loadCaddyfileInput(serverType string) (Input, error) { return caddyfileToUse, nil } +// OnProcessExit is a list of functions to run when the process +// exits -- they are ONLY for cleanup and should not block, +// return errors, or do anything fancy. They will be run with +// every signal, even if "shutdown callbacks" are not executed. +// This variable must only be modified in the main goroutine +// from init() functions. +var OnProcessExit []func() + // caddyfileLoader pairs the name of a loader to the loader. type caddyfileLoader struct { name string diff --git a/sigtrap.go b/sigtrap.go index a10cf0f09..ac61c59c0 100644 --- a/sigtrap.go +++ b/sigtrap.go @@ -44,16 +44,17 @@ func trapSignalsCrossPlatform() { if i > 0 { log.Println("[INFO] SIGINT: Force quit") - if PidFile != "" { - os.Remove(PidFile) + for _, f := range OnProcessExit { + f() // important cleanup actions only } os.Exit(2) } log.Println("[INFO] SIGINT: Shutting down") - if PidFile != "" { - os.Remove(PidFile) + // important cleanup actions before shutdown callbacks + for _, f := range OnProcessExit { + f() } go func() { diff --git a/sigtrap_posix.go b/sigtrap_posix.go index 71b6969af..38aaa774c 100644 --- a/sigtrap_posix.go +++ b/sigtrap_posix.go @@ -33,8 +33,8 @@ func trapSignalsPosix() { switch sig { case syscall.SIGTERM: log.Println("[INFO] SIGTERM: Terminating process") - if PidFile != "" { - os.Remove(PidFile) + for _, f := range OnProcessExit { + f() // only perform important cleanup actions } os.Exit(0) @@ -46,8 +46,8 @@ func trapSignalsPosix() { log.Printf("[ERROR] SIGQUIT stop: %v", err) exitCode = 3 } - if PidFile != "" { - os.Remove(PidFile) + for _, f := range OnProcessExit { + f() // only perform important cleanup actions } os.Exit(exitCode) From 4b2e22289d9c7423eea2f9b44430bfefed074d8f Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Tue, 13 Feb 2018 13:27:08 -0700 Subject: [PATCH 11/19] sigtrap: Ensure cleanup actions happen before too many things go wrong --- sigtrap_posix.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sigtrap_posix.go b/sigtrap_posix.go index 2a0a0de57..cc65ccb46 100644 --- a/sigtrap_posix.go +++ b/sigtrap_posix.go @@ -41,14 +41,14 @@ func trapSignalsPosix() { case syscall.SIGTERM: log.Println("[INFO] SIGTERM: Shutting down servers then terminating") exitCode := executeShutdownCallbacks("SIGTERM") + for _, f := range OnProcessExit { + f() // only perform important cleanup actions + } err := Stop() if err != nil { log.Printf("[ERROR] SIGTERM stop: %v", err) exitCode = 3 } - for _, f := range OnProcessExit { - f() // only perform important cleanup actions - } os.Exit(exitCode) case syscall.SIGUSR1: From ef585ed810a6913e7174c19ed383ddba9c3949ed Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Wed, 14 Feb 2018 13:32:16 -0700 Subject: [PATCH 12/19] tls: Ensure parent dir exists before creating lock file --- caddytls/filestoragesync.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/caddytls/filestoragesync.go b/caddytls/filestoragesync.go index 4c81ca02e..a8e7b9291 100644 --- a/caddytls/filestoragesync.go +++ b/caddytls/filestoragesync.go @@ -61,6 +61,10 @@ func (s *fileStorageLock) TryLock(name string) (Waiter, error) { filename: s.storage.siteCertFile(name) + ".lock", wg: new(sync.WaitGroup), } + // parent dir must exist + if err := os.MkdirAll(s.storage.site(name), 0700); err != nil { + return nil, err + } lf, err := os.OpenFile(fw.filename, os.O_CREATE|os.O_EXCL, 0644) if err != nil { if os.IsExist(err) { From be96cc0e656de2cac5913508f1dfab79872bfd7e Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Thu, 15 Feb 2018 00:04:31 -0700 Subject: [PATCH 13/19] httpserver: Raise error when adjusted site addresses clash at startup See discussion on #2015 for how this situation was discovered. For a Caddyfile like this: localhost { ... } :2015 { ... } Running Caddy like this: caddy -host localhost Produces two sites both defined as `localhost:2015` because the flag changes the default host value to be `localhost`. This should be an error since the sites are not distinct and it is confusing. It can also cause issues with TLS handshakes loading the wrong cert, as the linked discussion shows. --- caddy.go | 2 +- caddyhttp/httpserver/plugin.go | 21 ++++++++++++++++++++- caddyhttp/httpserver/plugin_test.go | 17 +++++++++++++++++ 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/caddy.go b/caddy.go index 628673a73..917a5d069 100644 --- a/caddy.go +++ b/caddy.go @@ -612,7 +612,7 @@ func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bo sblocks, err = inst.context.InspectServerBlocks(cdyfile.Path(), sblocks) if err != nil { - return err + return fmt.Errorf("error inspecting server blocks: %v", err) } return executeDirectives(inst, cdyfile.Path(), stype.Directives(), sblocks, justValidate) diff --git a/caddyhttp/httpserver/plugin.go b/caddyhttp/httpserver/plugin.go index ea31a58d8..93811abcb 100644 --- a/caddyhttp/httpserver/plugin.go +++ b/caddyhttp/httpserver/plugin.go @@ -117,12 +117,14 @@ func (h *httpContext) saveConfig(key string, cfg *SiteConfig) { // executing directives and otherwise prepares the directives to // be parsed and executed. func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) { + siteAddrs := make(map[string]string) + // For each address in each server block, make a new config for _, sb := range serverBlocks { for _, key := range sb.Keys { key = strings.ToLower(key) if _, dup := h.keysToSiteConfigs[key]; dup { - return serverBlocks, fmt.Errorf("duplicate site address: %s", key) + return serverBlocks, fmt.Errorf("duplicate site key: %s", key) } addr, err := standardizeAddress(key) if err != nil { @@ -138,6 +140,23 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd addr.Port = Port } + // Make sure the adjusted site address is distinct + addrCopy := addr // make copy so we don't disturb the original, carefully-parsed address struct + if addrCopy.Port == "" && Port == DefaultPort { + addrCopy.Port = Port + } + addrStr := strings.ToLower(addrCopy.String()) + if otherSiteKey, dup := siteAddrs[addrStr]; dup { + err := fmt.Errorf("duplicate site address: %s", addrStr) + if (addrCopy.Host == Host && Host != DefaultHost) || + (addrCopy.Port == Port && Port != DefaultPort) { + err = fmt.Errorf("site defined as %s is a duplicate of %s because of modified "+ + "default host and/or port values (usually via -host or -port flags)", key, otherSiteKey) + } + return serverBlocks, err + } + siteAddrs[addrStr] = key + // If default HTTP or HTTPS ports have been customized, // make sure the ACME challenge ports match var altHTTPPort, altTLSSNIPort string diff --git a/caddyhttp/httpserver/plugin_test.go b/caddyhttp/httpserver/plugin_test.go index 5a60f2e83..f7b9cfc00 100644 --- a/caddyhttp/httpserver/plugin_test.go +++ b/caddyhttp/httpserver/plugin_test.go @@ -153,6 +153,23 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) { } } +// See discussion on PR #2015 +func TestInspectServerBlocksWithAdjustedAddress(t *testing.T) { + Port = DefaultPort + Host = "example.com" + filename := "Testfile" + ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext) + input := strings.NewReader("example.com {\n}\n:2015 {\n}") + sblocks, err := caddyfile.Parse(filename, input, nil) + if err != nil { + t.Fatalf("Expected no error setting up test, got: %v", err) + } + _, err = ctx.InspectServerBlocks(filename, sblocks) + if err == nil { + t.Fatalf("Expected an error because site definitions should overlap, got: %v", err) + } +} + func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) { filename := "Testfile" ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext) From 6f4cf7eec70d8d2ea8660d7212caa25695811447 Mon Sep 17 00:00:00 2001 From: Jason Daly Date: Thu, 15 Feb 2018 10:05:58 -0500 Subject: [PATCH 14/19] readme: Update minimum version to build from source (#2024) Re: #2009, 1.9 or newer is needed because of the introduction of `sync.Map` --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index bdbffd006..d8e05ec30 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ customize your build in the browser pre-built, vanilla binaries ## Build -To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.8 or newer). Follow these instruction for fast building: +To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.9 or newer). Follow these instruction for fast building: - Get source `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds` - Now `cd` to `$GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go` From 896dc6bc690bce51884bd43da91180882ce90cab Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Thu, 15 Feb 2018 08:48:05 -0700 Subject: [PATCH 15/19] tls: Try empty name if no matches for getting config during handshake See discussion on #2015; the initial change had removed this check, and I can't remember why I removed it or if it was accidental. Anyway, it's back now. --- caddytls/handshake.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/caddytls/handshake.go b/caddytls/handshake.go index 2f3f34af3..841d06cd0 100644 --- a/caddytls/handshake.go +++ b/caddytls/handshake.go @@ -59,6 +59,14 @@ func (cg configGroup) getConfig(name string) *Config { } } + // try a config that serves all names (this + // is basically the same as a config defined + // for "*" -- I think -- but the above loop + // doesn't try an empty string) + if config, ok := cg[""]; ok { + return config + } + // no matches, so just serve up a random config for _, config := range cg { return config From 8db80c4a88e53c9b710dc7753fd2561fa974ffb2 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Fri, 16 Feb 2018 12:05:34 -0700 Subject: [PATCH 16/19] tls: Fix HTTP->HTTPS redirects and HTTP challenge when using custom port --- caddyhttp/httpserver/https.go | 30 +++++++++++++++++------------- caddyhttp/httpserver/server.go | 22 ++-------------------- caddytls/config.go | 13 +++++++------ caddytls/httphandler.go | 15 ++++++++++----- caddytls/httphandler_test.go | 4 ++-- 5 files changed, 38 insertions(+), 46 deletions(-) diff --git a/caddyhttp/httpserver/https.go b/caddyhttp/httpserver/https.go index a12d9982c..3d1de7499 100644 --- a/caddyhttp/httpserver/https.go +++ b/caddyhttp/httpserver/https.go @@ -159,25 +159,29 @@ func hostHasOtherPort(allConfigs []*SiteConfig, thisConfigIdx int, otherPort str // be the HTTPS configuration. The returned configuration is set // to listen on HTTPPort. The TLS field of cfg must not be nil. func redirPlaintextHost(cfg *SiteConfig) *SiteConfig { - redirPort := cfg.Addr.Port - if redirPort == DefaultHTTPSPort { - redirPort = "" // default port is redundant - } - redirMiddleware := func(next Handler) Handler { return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - // Construct the URL to which to redirect. Note that the Host in a request might - // contain a port, but we just need the hostname; we'll set the port if needed. + // Construct the URL to which to redirect. Note that the Host in a + // request might contain a port, but we just need the hostname. toURL := "https://" requestHost, _, err := net.SplitHostPort(r.Host) if err != nil { - requestHost = r.Host // Host did not contain a port; great - } - if redirPort == "" { - toURL += requestHost - } else { - toURL += net.JoinHostPort(requestHost, redirPort) + requestHost = r.Host // host did not contain a port; okay } + + // The rest of the URL will consist of the hostname and the URI. + // We do not append a port because if the HTTPSPort is changed + // from the default value, it is probably because there is port + // forwarding going on; and we do not need to specify the default + // HTTPS port in the redirect. Serving HTTPS on a port other than + // 443 is unusual, and is considered an advanced use case. If port + // forwarding IS happening, then redirecting the external client to + // this internal port will cause the connection to fail; and it + // definitely causes ACME HTTP-01 challenges to fail, because it + // only allows redirecting to port 80 or 443 (as of Feb 2018). + // If a user wants to redirect HTTP to HTTPS on an external port + // other than 443, they can easily configure that themselves. + toURL += requestHost toURL += r.URL.RequestURI() w.Header().Set("Connection", "close") diff --git a/caddyhttp/httpserver/server.go b/caddyhttp/httpserver/server.go index 92f2b6fd7..9d3ae0389 100644 --- a/caddyhttp/httpserver/server.go +++ b/caddyhttp/httpserver/server.go @@ -389,7 +389,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) if vhost == nil { // check for ACME challenge even if vhost is nil; // could be a new host coming online soon - if caddytls.HTTPChallengeHandler(w, r, "localhost", caddytls.DefaultHTTPAlternatePort) { + if caddytls.HTTPChallengeHandler(w, r, "localhost") { return 0, nil } // otherwise, log the error and write a message to the client @@ -405,7 +405,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) // we still check for ACME challenge if the vhost exists, // because we must apply its HTTP challenge config settings - if s.proxyHTTPChallenge(vhost, w, r) { + if caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost) { return 0, nil } @@ -422,24 +422,6 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) return vhost.middlewareChain.ServeHTTP(w, r) } -// proxyHTTPChallenge solves the ACME HTTP challenge if r is the HTTP -// request for the challenge. If it is, and if the request has been -// fulfilled (response written), true is returned; false otherwise. -// If you don't have a vhost, just call the challenge handler directly. -func (s *Server) proxyHTTPChallenge(vhost *SiteConfig, w http.ResponseWriter, r *http.Request) bool { - if vhost.Addr.Port != caddytls.HTTPChallengePort { - return false - } - if vhost.TLS != nil && vhost.TLS.Manual { - return false - } - altPort := caddytls.DefaultHTTPAlternatePort - if vhost.TLS != nil && vhost.TLS.AltHTTPPort != "" { - altPort = vhost.TLS.AltHTTPPort - } - return caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost, altPort) -} - // Address returns the address s was assigned to listen on. func (s *Server) Address() string { return s.Server.Addr diff --git a/caddytls/config.go b/caddytls/config.go index 0b64f3575..938cb08ca 100644 --- a/caddytls/config.go +++ b/caddytls/config.go @@ -93,16 +93,17 @@ type Config struct { // an ACME challenge ListenHost string - // The alternate port (ONLY port, not host) - // to use for the ACME HTTP challenge; this - // port will be used if we proxy challenges - // coming in on port 80 to this alternate port + // The alternate port (ONLY port, not host) to + // use for the ACME HTTP challenge; if non-empty, + // this port will be used instead of + // HTTPChallengePort to spin up a listener for + // the HTTP challenge AltHTTPPort string // The alternate port (ONLY port, not host) // to use for the ACME TLS-SNI challenge. - // The system must forward the standard port - // for the TLS-SNI challenge to this port. + // The system must forward TLSSNIChallengePort + // to this port for challenge to succeed AltTLSSNIPort string // The string identifier of the DNS provider diff --git a/caddytls/httphandler.go b/caddytls/httphandler.go index ca356cd43..663e2eb02 100644 --- a/caddytls/httphandler.go +++ b/caddytls/httphandler.go @@ -27,10 +27,11 @@ import ( const challengeBasePath = "/.well-known/acme-challenge" // HTTPChallengeHandler proxies challenge requests to ACME client if the -// request path starts with challengeBasePath. It returns true if it -// handled the request and no more needs to be done; it returns false -// if this call was a no-op and the request still needs handling. -func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, altPort string) bool { +// request path starts with challengeBasePath, if the HTTP challenge is not +// disabled, and if we are known to be obtaining a certificate for the name. +// It returns true if it handled the request and no more needs to be done; +// it returns false if this call was a no-op and the request still needs handling. +func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost string) bool { if !strings.HasPrefix(r.URL.Path, challengeBasePath) { return false } @@ -50,7 +51,11 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, al listenHost = "localhost" } - upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, altPort)) + // always proxy to the DefaultHTTPAlternatePort because obviously the + // ACME challenge request already got into one of our HTTP handlers, so + // it means we must have started a HTTP listener on the alternate + // port instead; which is only accessible via listenHost + upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, DefaultHTTPAlternatePort)) if err != nil { w.WriteHeader(http.StatusInternalServerError) log.Printf("[ERROR] ACME proxy handler: %v", err) diff --git a/caddytls/httphandler_test.go b/caddytls/httphandler_test.go index 451c0cf4f..cae65ac8c 100644 --- a/caddytls/httphandler_test.go +++ b/caddytls/httphandler_test.go @@ -39,7 +39,7 @@ func TestHTTPChallengeHandlerNoOp(t *testing.T) { t.Fatalf("Could not craft request, got error: %v", err) } rw := httptest.NewRecorder() - if HTTPChallengeHandler(rw, req, "", DefaultHTTPAlternatePort) { + if HTTPChallengeHandler(rw, req, "") { t.Errorf("Got true with this URL, but shouldn't have: %s", url) } } @@ -76,7 +76,7 @@ func TestHTTPChallengeHandlerSuccess(t *testing.T) { } rw := httptest.NewRecorder() - HTTPChallengeHandler(rw, req, "", DefaultHTTPAlternatePort) + HTTPChallengeHandler(rw, req, "") if !proxySuccess { t.Fatal("Expected request to be proxied, but it wasn't") From a03eba6fbc2770ef6bbe7defa184ec4a6817f0c2 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Fri, 16 Feb 2018 12:36:28 -0700 Subject: [PATCH 17/19] tls: In HTTP->HTTPS redirects, preserve redir port in some circumstances Only strip the port from the Location URL value if the port is NOT the HTTPSPort (before, we compared against DefaultHTTPSPort instead of HTTPSPort). The HTTPSPort can be changed, but is done so for port forwarding, since in reality you can't 'change' the standard HTTPS port, you can only forward it. --- caddyhttp/httpserver/https.go | 39 ++++++++++++++++++------------ caddyhttp/httpserver/https_test.go | 2 +- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/caddyhttp/httpserver/https.go b/caddyhttp/httpserver/https.go index 3d1de7499..ae3c4e902 100644 --- a/caddyhttp/httpserver/https.go +++ b/caddyhttp/httpserver/https.go @@ -159,29 +159,38 @@ func hostHasOtherPort(allConfigs []*SiteConfig, thisConfigIdx int, otherPort str // be the HTTPS configuration. The returned configuration is set // to listen on HTTPPort. The TLS field of cfg must not be nil. func redirPlaintextHost(cfg *SiteConfig) *SiteConfig { + redirPort := cfg.Addr.Port + if redirPort == HTTPSPort { + // By default, HTTPSPort should be DefaultHTTPSPort, + // which of course doesn't need to be explicitly stated + // in the Location header. Even if HTTPSPort is changed + // so that it is no longer DefaultHTTPSPort, we shouldn't + // append it to the URL in the Location because changing + // the HTTPS port is assumed to be an internal-only change + // (in other words, we assume port forwarding is going on); + // but redirects go back to a presumably-external client. + // (If redirect clients are also internal, that is more + // advanced, and the user should configure HTTP->HTTPS + // redirects themselves.) + redirPort = "" + } + redirMiddleware := func(next Handler) Handler { return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { // Construct the URL to which to redirect. Note that the Host in a - // request might contain a port, but we just need the hostname. + // request might contain a port, but we just need the hostname from + // it; and we'll set the port if needed. toURL := "https://" requestHost, _, err := net.SplitHostPort(r.Host) if err != nil { - requestHost = r.Host // host did not contain a port; okay + requestHost = r.Host // Host did not contain a port, so use the whole value + } + if redirPort == "" { + toURL += requestHost + } else { + toURL += net.JoinHostPort(requestHost, redirPort) } - // The rest of the URL will consist of the hostname and the URI. - // We do not append a port because if the HTTPSPort is changed - // from the default value, it is probably because there is port - // forwarding going on; and we do not need to specify the default - // HTTPS port in the redirect. Serving HTTPS on a port other than - // 443 is unusual, and is considered an advanced use case. If port - // forwarding IS happening, then redirecting the external client to - // this internal port will cause the connection to fail; and it - // definitely causes ACME HTTP-01 challenges to fail, because it - // only allows redirecting to port 80 or 443 (as of Feb 2018). - // If a user wants to redirect HTTP to HTTPS on an external port - // other than 443, they can easily configure that themselves. - toURL += requestHost toURL += r.URL.RequestURI() w.Header().Set("Connection", "close") diff --git a/caddyhttp/httpserver/https_test.go b/caddyhttp/httpserver/https_test.go index 3e9fe915a..043249445 100644 --- a/caddyhttp/httpserver/https_test.go +++ b/caddyhttp/httpserver/https_test.go @@ -53,7 +53,7 @@ func TestRedirPlaintextHost(t *testing.T) { }, { Host: "foohost", - Port: "443", // since this is the default HTTPS port, should not be included in Location value + Port: HTTPSPort, // since this is the 'default' HTTPS port, should not be included in Location value }, { Host: "*.example.com", From faa5248d1f1c73a1b5ed61a5a321319adae04f5f Mon Sep 17 00:00:00 2001 From: Toby Allen Date: Fri, 16 Feb 2018 21:18:02 +0000 Subject: [PATCH 18/19] httpserver: Leave %2f encoded when trimming path in site address Fix #1927 (#2014) * Trim path prefix using EscapedPath() * clarify comments * Added Tests for trimPathPrefix * Ensure path with trailing slash is properly trimmed * Updated tests to match prepatch behaviour * Updated tests to match prepatch behaviour * call parse on url rather than instance * add additional tests * return unmodified url if error. Additional tests --- caddyhttp/httpserver/server.go | 20 +++++-- caddyhttp/httpserver/server_test.go | 89 +++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 4 deletions(-) diff --git a/caddyhttp/httpserver/server.go b/caddyhttp/httpserver/server.go index 9d3ae0389..ab65d955f 100644 --- a/caddyhttp/httpserver/server.go +++ b/caddyhttp/httpserver/server.go @@ -413,15 +413,27 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) // the URL path, so a request to example.com/foo/blog on the site // defined as example.com/foo appears as /blog instead of /foo/blog. if pathPrefix != "/" { - r.URL.Path = strings.TrimPrefix(r.URL.Path, pathPrefix) - if !strings.HasPrefix(r.URL.Path, "/") { - r.URL.Path = "/" + r.URL.Path - } + r.URL = trimPathPrefix(r.URL, pathPrefix) } return vhost.middlewareChain.ServeHTTP(w, r) } +func trimPathPrefix(u *url.URL, prefix string) *url.URL { + // We need to use URL.EscapedPath() when trimming the pathPrefix as + // URL.Path is ambiguous about / or %2f - see docs. See #1927 + trimmed := strings.TrimPrefix(u.EscapedPath(), prefix) + if !strings.HasPrefix(trimmed, "/") { + trimmed = "/" + trimmed + } + trimmedURL, err := url.Parse(trimmed) + if err != nil { + log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmed, err) + return u + } + return trimmedURL +} + // Address returns the address s was assigned to listen on. func (s *Server) Address() string { return s.Server.Addr diff --git a/caddyhttp/httpserver/server_test.go b/caddyhttp/httpserver/server_test.go index a781a80a4..82926851d 100644 --- a/caddyhttp/httpserver/server_test.go +++ b/caddyhttp/httpserver/server_test.go @@ -16,6 +16,7 @@ package httpserver import ( "net/http" + "net/url" "testing" "time" ) @@ -126,6 +127,94 @@ func TestMakeHTTPServerWithTimeouts(t *testing.T) { } } +func TestTrimPathPrefix(t *testing.T) { + for i, pt := range []struct { + path string + prefix string + expected string + shouldFail bool + }{ + { + path: "/my/path", + prefix: "/my", + expected: "/path", + shouldFail: false, + }, + { + path: "/my/%2f/path", + prefix: "/my", + expected: "/%2f/path", + shouldFail: false, + }, + { + path: "/my/path", + prefix: "/my/", + expected: "/path", + shouldFail: false, + }, + { + path: "/my///path", + prefix: "/my", + expected: "/path", + shouldFail: true, + }, + { + path: "/my///path", + prefix: "/my", + expected: "///path", + shouldFail: false, + }, + { + path: "/my/path///slash", + prefix: "/my", + expected: "/path///slash", + shouldFail: false, + }, + { + path: "/my/%2f/path/%2f", + prefix: "/my", + expected: "/%2f/path/%2f", + shouldFail: false, + }, { + path: "/my/%20/path", + prefix: "/my", + expected: "/%20/path", + shouldFail: false, + }, { + path: "/path", + prefix: "", + expected: "/path", + shouldFail: false, + }, { + path: "/path/my/", + prefix: "/my", + expected: "/path/my/", + shouldFail: false, + }, { + path: "", + prefix: "/my", + expected: "/", + shouldFail: false, + }, { + path: "/apath", + prefix: "", + expected: "/apath", + shouldFail: false, + }, + } { + + u, _ := url.Parse(pt.path) + if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.EscapedPath() != want { + if !pt.shouldFail { + + t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.EscapedPath()) + } + } else if pt.shouldFail { + t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.EscapedPath()) + } + } +} + func TestMakeHTTPServerWithHeaderLimit(t *testing.T) { for name, c := range map[string]struct { group []*SiteConfig From 120149222293541d9b3b03bc5fcf2cc11594a5c3 Mon Sep 17 00:00:00 2001 From: Amos Ng Date: Sat, 17 Feb 2018 13:29:53 +0800 Subject: [PATCH 19/19] vendor: Updated quic-go for QUIC 39+ (#1968) * Updated lucas-clemente/quic-go for QUIC 39+ support * Update quic-go to latest --- vendor/github.com/bifurcation/mint/LICENSE.md | 21 + vendor/github.com/bifurcation/mint/alert.go | 99 ++ .../mint/bin/mint-client-https/main.go | 42 + .../bifurcation/mint/bin/mint-client/main.go | 37 + .../mint/bin/mint-server-https/main.go | 226 +++ .../bifurcation/mint/bin/mint-server/main.go | 65 + .../bifurcation/mint/client-state-machine.go | 942 +++++++++++ vendor/github.com/bifurcation/mint/common.go | 152 ++ vendor/github.com/bifurcation/mint/conn.go | 819 ++++++++++ vendor/github.com/bifurcation/mint/crypto.go | 654 ++++++++ .../github.com/bifurcation/mint/extensions.go | 586 +++++++ vendor/github.com/bifurcation/mint/ffdhe.go | 147 ++ .../bifurcation/mint/frame-reader.go | 98 ++ .../bifurcation/mint/handshake-layer.go | 253 +++ .../bifurcation/mint/handshake-messages.go | 450 ++++++ vendor/github.com/bifurcation/mint/log.go | 55 + .../bifurcation/mint/negotiation.go | 217 +++ .../bifurcation/mint/record-layer.go | 296 ++++ .../bifurcation/mint/server-state-machine.go | 898 +++++++++++ .../bifurcation/mint/state-machine.go | 230 +++ .../bifurcation/mint/syntax/decode.go | 243 +++ .../bifurcation/mint/syntax/encode.go | 187 +++ .../bifurcation/mint/syntax/tags.go | 30 + vendor/github.com/bifurcation/mint/tls.go | 168 ++ .../quic-go/ackhandler/interfaces.go | 32 - .../lucas-clemente/quic-go/buffer_pool.go | 2 +- .../lucas-clemente/quic-go/client.go | 360 +++-- .../quic-go/crypto/aesgcm_aead.go | 58 - .../lucas-clemente/quic-go/crypto/nonce.go | 14 - .../quic-go/crypto/source_address_token.go | 76 - .../lucas-clemente/quic-go/crypto_stream.go | 41 + .../quic-go/example/client/main.go | 14 +- .../lucas-clemente/quic-go/example/main.go | 14 +- .../flowcontrol/flow_control_manager.go | 240 --- .../quic-go/flowcontrol/flow_controller.go | 198 --- .../quic-go/flowcontrol/interface.go | 26 - .../quic-go/frames/ack_range.go | 9 - .../quic-go/frames/blocked_frame.go | 44 - .../quic-go/frames/connection_close_frame.go | 73 - .../lucas-clemente/quic-go/frames/frame.go | 13 - .../lucas-clemente/quic-go/frames/log.go | 28 - .../quic-go/frames/rst_stream_frame.go | 59 - .../quic-go/frames/window_update_frame.go | 54 - .../lucas-clemente/quic-go/h2quic/client.go | 109 +- .../quic-go/h2quic/request_writer.go | 2 +- .../quic-go/h2quic/response_writer.go | 4 +- .../quic-go/h2quic/roundtrip.go | 13 +- .../lucas-clemente/quic-go/h2quic/server.go | 87 +- .../connection_parameters_manager.go | 265 ---- .../quic-go/handshake/interface.go | 24 - .../quic-go/handshake/stk_generator.go | 100 -- .../quic-go/integrationtests/chrome/chrome.go | 1 - .../quic-go/integrationtests/gquic/gquic.go | 1 - .../quic-go/integrationtests/self/self.go | 1 - .../integrationtests/tools/proxy/proxy.go | 73 +- .../integrationtests/tools/testlog/testlog.go | 2 +- .../tools/testserver/server.go | 18 +- .../lucas-clemente/quic-go/interface.go | 121 +- .../quic-go/internal/ackhandler/interfaces.go | 48 + .../{ => internal}/ackhandler/packet.go | 17 +- .../ackhandler/packet_linkedlist.go | 0 .../ackhandler/received_packet_handler.go | 56 +- .../ackhandler/received_packet_history.go | 70 +- .../ackhandler/retransmittable.go | 16 +- .../ackhandler/sent_packet_handler.go | 198 ++- .../ackhandler/stop_waiting_manager.go | 12 +- .../{ => internal}/congestion/bandwidth.go | 2 +- .../{ => internal}/congestion/clock.go | 0 .../{ => internal}/congestion/cubic.go | 2 +- .../{ => internal}/congestion/cubic_sender.go | 16 +- .../congestion/hybrid_slow_start.go | 2 +- .../{ => internal}/congestion/interface.go | 4 +- .../{ => internal}/congestion/prr_sender.go | 2 +- .../{ => internal}/congestion/rtt_stats.go | 7 +- .../{ => internal}/congestion/stats.go | 2 +- .../quic-go/{ => internal}/crypto/AEAD.go | 3 +- .../quic-go/internal/crypto/aesgcm12_aead.go | 72 + .../quic-go/internal/crypto/aesgcm_aead.go | 74 + .../{ => internal}/crypto/cert_cache.go | 2 +- .../{ => internal}/crypto/cert_chain.go | 0 .../{ => internal}/crypto/cert_compression.go | 14 +- .../{ => internal}/crypto/cert_dict.go | 0 .../{ => internal}/crypto/cert_manager.go | 5 + .../{ => internal}/crypto/cert_sets.go | 0 .../crypto/chacha20poly1305_aead.go | 14 +- .../{ => internal}/crypto/curve_25519.go | 0 .../quic-go/internal/crypto/key_derivation.go | 49 + .../crypto/key_derivation_quic_crypto.go} | 10 +- .../{ => internal}/crypto/key_exchange.go | 0 .../quic-go/internal/crypto/null_aead.go | 11 + .../internal/crypto/null_aead_aesgcm.go | 44 + .../crypto/null_aead_fnv128a.go} | 43 +- .../{ => internal}/crypto/server_proof.go | 0 .../flowcontrol/base_flow_controller.go | 108 ++ .../flowcontrol/connection_flow_controller.go | 83 + .../quic-go/internal/flowcontrol/interface.go | 42 + .../flowcontrol/stream_flow_controller.go | 147 ++ .../internal/handshake/cookie_generator.go | 101 ++ .../internal/handshake/cookie_handler.go | 43 + .../handshake/crypto_setup_client.go | 161 +- .../handshake/crypto_setup_server.go | 151 +- .../internal/handshake/crypto_setup_tls.go | 177 +++ .../internal/handshake/crypto_stream_conn.go | 101 ++ .../handshake/ephermal_cache.go | 4 +- .../handshake/handshake_message.go | 8 +- .../quic-go/internal/handshake/interface.go | 58 + .../{ => internal}/handshake/server_config.go | 10 +- .../handshake/server_config_client.go | 2 +- .../quic-go/{ => internal}/handshake/tags.go | 0 .../internal/handshake/tls_extension.go | 55 + .../handshake/tls_extension_handler_client.go | 134 ++ .../handshake/tls_extension_handler_server.go | 113 ++ .../handshake/transport_parameters.go | 176 +++ .../ackhandler/received_packet_handler.go | 83 + .../mocks/ackhandler/sent_packet_handler.go | 178 +++ .../quic-go/internal/mocks/congestion.go | 154 ++ .../mocks/connection_flow_controller.go | 102 ++ .../quic-go/internal/mocks/cpm.go | 153 -- .../quic-go/internal/mocks/crypto/aead.go | 72 + .../quic-go/internal/mocks/gen.go | 11 +- .../internal/mocks/handshake/mint_tls.go | 107 ++ .../mocks/mocks_fc/flow_control_manager.go | 140 -- .../internal/mocks/stream_flow_controller.go | 126 ++ .../internal/mocks/tls_extension_handler.go | 72 + .../protocol/encryption_level.go | 0 .../{ => internal}/protocol/packet_number.go | 12 +- .../{ => internal}/protocol/perspective.go | 11 + .../{ => internal}/protocol/protocol.go | 52 +- .../protocol/server_parameters.go | 67 +- .../quic-go/internal/protocol/stream_id.go | 36 + .../quic-go/internal/protocol/version.go | 135 ++ .../quic-go/internal/utils/byteorder.go | 25 + .../internal/utils/byteorder_big_endian.go | 157 ++ .../{utils.go => byteorder_little_endian.go} | 50 +- .../quic-go/internal/utils/connection_id.go | 2 +- .../quic-go/internal/utils/float16.go | 12 +- .../quic-go/internal/utils/log.go | 9 +- .../quic-go/internal/utils/minmax.go | 10 +- .../quic-go/internal/utils/packet_interval.go | 2 +- .../internal/utils/streamframe_interval.go | 2 +- .../quic-go/internal/utils/timer.go | 2 +- .../quic-go/internal/utils/varint.go | 101 ++ .../quic-go/internal/wire/ack_frame.go | 239 +++ .../wire/ack_frame_legacy.go} | 181 +-- .../quic-go/internal/wire/ack_range.go | 9 + .../quic-go/internal/wire/blocked_frame.go | 45 + .../internal/wire/blocked_frame_legacy.go | 37 + .../internal/wire/connection_close_frame.go | 96 ++ .../quic-go/internal/wire/frame.go | 13 + .../{frames => internal/wire}/goaway_frame.go | 35 +- .../quic-go/internal/wire/header.go | 110 ++ .../quic-go/internal/wire/ietf_header.go | 172 ++ .../quic-go/internal/wire/log.go | 28 + .../quic-go/internal/wire/max_data_frame.go | 51 + .../internal/wire/max_stream_data_frame.go | 60 + .../internal/wire/max_stream_id_frame.go | 37 + .../{frames => internal/wire}/ping_frame.go | 12 +- .../{ => internal/wire}/public_header.go | 156 +- .../quic-go/internal/wire/public_reset.go | 65 + .../quic-go/internal/wire/rst_stream_frame.go | 89 ++ .../internal/wire/stop_sending_frame.go | 47 + .../wire}/stop_waiting_frame.go | 54 +- .../internal/wire/stream_blocked_frame.go | 52 + .../quic-go/internal/wire/stream_frame.go | 182 +++ .../wire/stream_frame_legacy.go} | 88 +- .../internal/wire/stream_id_blocked_frame.go | 37 + .../internal/wire/version_negotiation.go | 59 + .../internal/wire/window_update_frame.go | 45 + .../lucas-clemente/quic-go/mint_utils.go | 160 ++ .../lucas-clemente/quic-go/mockgen.go | 12 + .../quic-go/packet_number_generator.go | 6 +- .../lucas-clemente/quic-go/packet_packer.go | 290 ++-- .../lucas-clemente/quic-go/packet_unpacker.go | 207 ++- .../quic-go/protocol/version.go | 59 - .../lucas-clemente/quic-go/public_reset.go | 62 - .../quic-go/qerr/errorcode_string.go | 7 +- .../lucas-clemente/quic-go/receive_stream.go | 286 ++++ .../lucas-clemente/quic-go/send_stream.go | 313 ++++ .../lucas-clemente/quic-go/server.go | 267 ++-- .../lucas-clemente/quic-go/server_tls.go | 220 +++ .../lucas-clemente/quic-go/session.go | 843 ++++++---- .../lucas-clemente/quic-go/stream.go | 507 ++---- .../quic-go/stream_frame_sorter.go | 14 +- .../lucas-clemente/quic-go/stream_framer.go | 221 +-- .../lucas-clemente/quic-go/streams_map.go | 461 +++--- .../quic-go/streams_map_incoming_bidi.go | 123 ++ .../quic-go/streams_map_incoming_generic.go | 121 ++ .../quic-go/streams_map_incoming_uni.go | 123 ++ .../quic-go/streams_map_legacy.go | 263 ++++ .../quic-go/streams_map_outgoing_bidi.go | 122 ++ .../quic-go/streams_map_outgoing_generic.go | 123 ++ .../quic-go/streams_map_outgoing_uni.go | 122 ++ .../github.com/bifurcation/mint/alert.go | 101 ++ .../bifurcation/mint/client-state-machine.go | 1062 +++++++++++++ .../github.com/bifurcation/mint/common.go | 252 +++ .../github.com/bifurcation/mint/conn.go | 884 +++++++++++ .../bifurcation/mint/cookie-protector.go | 86 + .../github.com/bifurcation/mint/crypto.go | 618 ++++++++ .../github.com/bifurcation/mint/dtls.go | 28 + .../github.com/bifurcation/mint/extensions.go | 626 ++++++++ .../github.com/bifurcation/mint/ffdhe.go | 147 ++ .../bifurcation/mint/frame-reader.go | 98 ++ .../bifurcation/mint/handshake-layer.go | 495 ++++++ .../bifurcation/mint/handshake-messages.go | 481 ++++++ .../vendor/github.com/bifurcation/mint/log.go | 55 + .../bifurcation/mint/negotiation.go | 217 +++ .../bifurcation/mint/record-layer.go | 393 +++++ .../bifurcation/mint/server-state-machine.go | 1102 +++++++++++++ .../bifurcation/mint/state-machine.go | 241 +++ .../bifurcation/mint/syntax/decode.go | 310 ++++ .../bifurcation/mint/syntax/encode.go | 266 ++++ .../bifurcation/mint/syntax/tags.go | 40 + .../vendor/github.com/bifurcation/mint/tls.go | 179 +++ .../cheekybits/genny/generic/doc.go | 2 + .../cheekybits/genny/generic/generic.go | 13 + .../x/crypto/curve25519/const_amd64.h | 8 + .../x/crypto/curve25519/const_amd64.s | 20 + .../x/crypto/curve25519/cswap_amd64.s | 65 + .../x/crypto/curve25519/curve25519.go | 834 ++++++++++ .../golang.org/x/crypto/curve25519/doc.go | 23 + .../x/crypto/curve25519/freeze_amd64.s | 73 + .../x/crypto/curve25519/ladderstep_amd64.s | 1377 +++++++++++++++++ .../x/crypto/curve25519/mont25519_amd64.go | 240 +++ .../x/crypto/curve25519/mul_amd64.s | 169 ++ .../x/crypto/curve25519/square_amd64.s | 132 ++ .../vendor/golang.org/x/crypto/hkdf/hkdf.go | 75 + .../quic-go/window_update_queue.go | 57 + .../x/crypto/curve25519/curve25519.go | 2 +- vendor/manifest | 12 +- 229 files changed, 26903 insertions(+), 4254 deletions(-) create mode 100644 vendor/github.com/bifurcation/mint/LICENSE.md create mode 100644 vendor/github.com/bifurcation/mint/alert.go create mode 100644 vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go create mode 100644 vendor/github.com/bifurcation/mint/bin/mint-client/main.go create mode 100644 vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go create mode 100644 vendor/github.com/bifurcation/mint/bin/mint-server/main.go create mode 100644 vendor/github.com/bifurcation/mint/client-state-machine.go create mode 100644 vendor/github.com/bifurcation/mint/common.go create mode 100644 vendor/github.com/bifurcation/mint/conn.go create mode 100644 vendor/github.com/bifurcation/mint/crypto.go create mode 100644 vendor/github.com/bifurcation/mint/extensions.go create mode 100644 vendor/github.com/bifurcation/mint/ffdhe.go create mode 100644 vendor/github.com/bifurcation/mint/frame-reader.go create mode 100644 vendor/github.com/bifurcation/mint/handshake-layer.go create mode 100644 vendor/github.com/bifurcation/mint/handshake-messages.go create mode 100644 vendor/github.com/bifurcation/mint/log.go create mode 100644 vendor/github.com/bifurcation/mint/negotiation.go create mode 100644 vendor/github.com/bifurcation/mint/record-layer.go create mode 100644 vendor/github.com/bifurcation/mint/server-state-machine.go create mode 100644 vendor/github.com/bifurcation/mint/state-machine.go create mode 100644 vendor/github.com/bifurcation/mint/syntax/decode.go create mode 100644 vendor/github.com/bifurcation/mint/syntax/encode.go create mode 100644 vendor/github.com/bifurcation/mint/syntax/tags.go create mode 100644 vendor/github.com/bifurcation/mint/tls.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/crypto_stream.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/frames/frame.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/frames/log.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/handshake/interface.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/integrationtests/chrome/chrome.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/integrationtests/gquic/gquic.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/integrationtests/self/self.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/ackhandler/packet.go (51%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/ackhandler/packet_linkedlist.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/ackhandler/received_packet_handler.go (63%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/ackhandler/received_packet_history.go (53%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/ackhandler/retransmittable.go (61%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/ackhandler/sent_packet_handler.go (60%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/ackhandler/stop_waiting_manager.go (72%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/congestion/bandwidth.go (90%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/congestion/clock.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/congestion/cubic.go (99%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/congestion/cubic_sender.go (95%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/congestion/hybrid_slow_start.go (98%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/congestion/interface.go (90%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/congestion/prr_sender.go (97%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/congestion/rtt_stats.go (97%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/congestion/stats.go (69%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/AEAD.go (79%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm12_aead.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm_aead.go rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/cert_cache.go (94%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/cert_chain.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/cert_compression.go (94%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/cert_dict.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/cert_manager.go (96%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/cert_sets.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/chacha20poly1305_aead.go (72%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/curve_25519.go (100%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go rename vendor/github.com/lucas-clemente/quic-go/{crypto/key_derivation.go => internal/crypto/key_derivation_quic_crypto.go} (84%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/key_exchange.go (100%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go rename vendor/github.com/lucas-clemente/quic-go/{crypto/null_aead.go => internal/crypto/null_aead_fnv128a.go} (55%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/server_proof.go (100%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/crypto_setup_client.go (76%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/crypto_setup_server.go (78%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/ephermal_cache.go (92%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/handshake_message.go (93%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/server_config.go (88%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/server_config_client.go (98%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/tags.go (100%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/mocks/ackhandler/received_packet_handler.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/mocks/ackhandler/sent_packet_handler.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/mocks/congestion.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/mocks/connection_flow_controller.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/mocks/cpm.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/mocks/crypto/aead.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/mocks/handshake/mint_tls.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc/flow_control_manager.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/mocks/stream_flow_controller.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/mocks/tls_extension_handler.go rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/protocol/encryption_level.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/protocol/packet_number.go (74%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/protocol/perspective.go (52%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/protocol/protocol.go (58%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/protocol/server_parameters.go (71%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/protocol/stream_id.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go rename vendor/github.com/lucas-clemente/quic-go/internal/utils/{utils.go => byteorder_little_endian.go} (64%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/varint.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go rename vendor/github.com/lucas-clemente/quic-go/{frames/ack_frame.go => internal/wire/ack_frame_legacy.go} (58%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go rename vendor/github.com/lucas-clemente/quic-go/{frames => internal/wire}/goaway_frame.go (54%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame.go rename vendor/github.com/lucas-clemente/quic-go/{frames => internal/wire}/ping_frame.go (57%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal/wire}/public_header.go (59%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go rename vendor/github.com/lucas-clemente/quic-go/{frames => internal/wire}/stop_waiting_frame.go (54%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go rename vendor/github.com/lucas-clemente/quic-go/{frames/stream_frame.go => internal/wire/stream_frame_legacy.go} (55%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/window_update_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/mint_utils.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/mockgen.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/protocol/version.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/public_reset.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/receive_stream.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/send_stream.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/server_tls.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_legacy.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/alert.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/client-state-machine.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/common.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/conn.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/cookie-protector.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/crypto.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/dtls.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/extensions.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/ffdhe.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/frame-reader.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/handshake-layer.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/handshake-messages.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/log.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/negotiation.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/record-layer.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/server-state-machine.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/state-machine.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/syntax/decode.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/syntax/encode.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/syntax/tags.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/tls.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/cheekybits/genny/generic/doc.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/github.com/cheekybits/genny/generic/generic.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/const_amd64.h create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/const_amd64.s create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/cswap_amd64.s create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/curve25519.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/doc.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/freeze_amd64.s create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/ladderstep_amd64.s create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/mont25519_amd64.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/mul_amd64.s create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/square_amd64.s create mode 100644 vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/hkdf/hkdf.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/window_update_queue.go diff --git a/vendor/github.com/bifurcation/mint/LICENSE.md b/vendor/github.com/bifurcation/mint/LICENSE.md new file mode 100644 index 000000000..63858124d --- /dev/null +++ b/vendor/github.com/bifurcation/mint/LICENSE.md @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Richard Barnes + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/github.com/bifurcation/mint/alert.go b/vendor/github.com/bifurcation/mint/alert.go new file mode 100644 index 000000000..5e31035af --- /dev/null +++ b/vendor/github.com/bifurcation/mint/alert.go @@ -0,0 +1,99 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mint + +import "strconv" + +type Alert uint8 + +const ( + // alert level + AlertLevelWarning = 1 + AlertLevelError = 2 +) + +const ( + AlertCloseNotify Alert = 0 + AlertUnexpectedMessage Alert = 10 + AlertBadRecordMAC Alert = 20 + AlertDecryptionFailed Alert = 21 + AlertRecordOverflow Alert = 22 + AlertDecompressionFailure Alert = 30 + AlertHandshakeFailure Alert = 40 + AlertBadCertificate Alert = 42 + AlertUnsupportedCertificate Alert = 43 + AlertCertificateRevoked Alert = 44 + AlertCertificateExpired Alert = 45 + AlertCertificateUnknown Alert = 46 + AlertIllegalParameter Alert = 47 + AlertUnknownCA Alert = 48 + AlertAccessDenied Alert = 49 + AlertDecodeError Alert = 50 + AlertDecryptError Alert = 51 + AlertProtocolVersion Alert = 70 + AlertInsufficientSecurity Alert = 71 + AlertInternalError Alert = 80 + AlertInappropriateFallback Alert = 86 + AlertUserCanceled Alert = 90 + AlertNoRenegotiation Alert = 100 + AlertMissingExtension Alert = 109 + AlertUnsupportedExtension Alert = 110 + AlertCertificateUnobtainable Alert = 111 + AlertUnrecognizedName Alert = 112 + AlertBadCertificateStatsResponse Alert = 113 + AlertBadCertificateHashValue Alert = 114 + AlertUnknownPSKIdentity Alert = 115 + AlertNoApplicationProtocol Alert = 120 + AlertWouldBlock Alert = 254 + AlertNoAlert Alert = 255 +) + +var alertText = map[Alert]string{ + AlertCloseNotify: "close notify", + AlertUnexpectedMessage: "unexpected message", + AlertBadRecordMAC: "bad record MAC", + AlertDecryptionFailed: "decryption failed", + AlertRecordOverflow: "record overflow", + AlertDecompressionFailure: "decompression failure", + AlertHandshakeFailure: "handshake failure", + AlertBadCertificate: "bad certificate", + AlertUnsupportedCertificate: "unsupported certificate", + AlertCertificateRevoked: "revoked certificate", + AlertCertificateExpired: "expired certificate", + AlertCertificateUnknown: "unknown certificate", + AlertIllegalParameter: "illegal parameter", + AlertUnknownCA: "unknown certificate authority", + AlertAccessDenied: "access denied", + AlertDecodeError: "error decoding message", + AlertDecryptError: "error decrypting message", + AlertProtocolVersion: "protocol version not supported", + AlertInsufficientSecurity: "insufficient security level", + AlertInternalError: "internal error", + AlertInappropriateFallback: "inappropriate fallback", + AlertUserCanceled: "user canceled", + AlertMissingExtension: "missing extension", + AlertUnsupportedExtension: "unsupported extension", + AlertCertificateUnobtainable: "certificate unobtainable", + AlertUnrecognizedName: "unrecognized name", + AlertBadCertificateStatsResponse: "bad certificate status response", + AlertBadCertificateHashValue: "bad certificate hash value", + AlertUnknownPSKIdentity: "unknown PSK identity", + AlertNoApplicationProtocol: "no application protocol", + AlertNoRenegotiation: "no renegotiation", + AlertWouldBlock: "would have blocked", + AlertNoAlert: "no alert", +} + +func (e Alert) String() string { + s, ok := alertText[e] + if ok { + return s + } + return "alert(" + strconv.Itoa(int(e)) + ")" +} + +func (e Alert) Error() string { + return e.String() +} diff --git a/vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go b/vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go new file mode 100644 index 000000000..4efe2f556 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go @@ -0,0 +1,42 @@ +package main + +import ( + "flag" + "fmt" + "io/ioutil" + "net" + "net/http" + "os" + + "github.com/bifurcation/mint" +) + +var url string + +func main() { + url := flag.String("url", "https://localhost:4430", "URL to send request") + flag.Parse() + mintdial := func(network, addr string) (net.Conn, error) { + return mint.Dial(network, addr, nil) + } + + tr := &http.Transport{ + DialTLS: mintdial, + DisableCompression: true, + } + client := &http.Client{Transport: tr} + + response, err := client.Get(*url) + if err != nil { + fmt.Println("err:", err) + return + } + defer response.Body.Close() + + contents, err := ioutil.ReadAll(response.Body) + if err != nil { + fmt.Printf("%s", err) + os.Exit(1) + } + fmt.Printf("%s\n", string(contents)) +} diff --git a/vendor/github.com/bifurcation/mint/bin/mint-client/main.go b/vendor/github.com/bifurcation/mint/bin/mint-client/main.go new file mode 100644 index 000000000..27b0f2539 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/bin/mint-client/main.go @@ -0,0 +1,37 @@ +package main + +import ( + "flag" + "fmt" + + "github.com/bifurcation/mint" +) + +var addr string + +func main() { + flag.StringVar(&addr, "addr", "localhost:4430", "port") + flag.Parse() + + conn, err := mint.Dial("tcp", addr, nil) + + if err != nil { + fmt.Println("TLS handshake failed:", err) + return + } + + request := "GET / HTTP/1.0\r\n\r\n" + conn.Write([]byte(request)) + + response := "" + buffer := make([]byte, 1024) + var read int + for err == nil { + read, err = conn.Read(buffer) + fmt.Println(" ~~ read: ", read) + response += string(buffer) + } + fmt.Println("err:", err) + fmt.Println("Received from server:") + fmt.Println(response) +} diff --git a/vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go b/vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go new file mode 100644 index 000000000..7ac0e60ee --- /dev/null +++ b/vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go @@ -0,0 +1,226 @@ +package main + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "flag" + "fmt" + "io/ioutil" + "log" + "net/http" + + "github.com/bifurcation/mint" + "golang.org/x/net/http2" +) + +var ( + port string + serverName string + certFile string + keyFile string + responseFile string + h2 bool + sendTickets bool +) + +type responder []byte + +func (rsp responder) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Write(rsp) +} + +// ParsePrivateKeyDER parses a PKCS #1, PKCS #8, or elliptic curve +// PEM-encoded private key. +// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module +func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) { + keyDER, _ := pem.Decode(keyPEM) + if keyDER == nil { + return nil, err + } + + generalKey, err := x509.ParsePKCS8PrivateKey(keyDER.Bytes) + if err != nil { + generalKey, err = x509.ParsePKCS1PrivateKey(keyDER.Bytes) + if err != nil { + generalKey, err = x509.ParseECPrivateKey(keyDER.Bytes) + if err != nil { + // We don't include the actual error into + // the final error. The reason might be + // we don't want to leak any info about + // the private key. + return nil, fmt.Errorf("No successful private key decoder") + } + } + } + + switch generalKey.(type) { + case *rsa.PrivateKey: + return generalKey.(*rsa.PrivateKey), nil + case *ecdsa.PrivateKey: + return generalKey.(*ecdsa.PrivateKey), nil + } + + // should never reach here + return nil, fmt.Errorf("Should be unreachable") +} + +// ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object, +// either a raw x509 certificate or a PKCS #7 structure possibly containing +// multiple certificates, from the top of certsPEM, which itself may +// contain multiple PEM encoded certificate objects. +// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module +func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) { + block, rest := pem.Decode(certsPEM) + if block == nil { + return nil, rest, nil + } + + cert, err := x509.ParseCertificate(block.Bytes) + var certs = []*x509.Certificate{cert} + return certs, rest, err +} + +// ParseCertificatesPEM parses a sequence of PEM-encoded certificate and returns them, +// can handle PEM encoded PKCS #7 structures. +// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module +func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) { + var certs []*x509.Certificate + var err error + certsPEM = bytes.TrimSpace(certsPEM) + for len(certsPEM) > 0 { + var cert []*x509.Certificate + cert, certsPEM, err = ParseOneCertificateFromPEM(certsPEM) + if err != nil { + return nil, err + } else if cert == nil { + break + } + + certs = append(certs, cert...) + } + if len(certsPEM) > 0 { + return nil, fmt.Errorf("Trailing PEM data") + } + return certs, nil +} + +func main() { + flag.StringVar(&port, "port", "4430", "port") + flag.StringVar(&serverName, "host", "example.com", "hostname") + flag.StringVar(&certFile, "cert", "", "certificate chain in PEM or DER") + flag.StringVar(&keyFile, "key", "", "private key in PEM format") + flag.StringVar(&responseFile, "response", "", "file to serve") + flag.BoolVar(&h2, "h2", false, "whether to use HTTP/2 (exclusively)") + flag.BoolVar(&sendTickets, "tickets", true, "whether to send session tickets") + flag.Parse() + + var certChain []*x509.Certificate + var priv crypto.Signer + var response []byte + var err error + + // Load the key and certificate chain + if certFile != "" { + certs, err := ioutil.ReadFile(certFile) + if err != nil { + log.Fatalf("Error: %v", err) + } else { + certChain, err = ParseCertificatesPEM(certs) + if err != nil { + certChain, err = x509.ParseCertificates(certs) + if err != nil { + log.Fatalf("Error parsing certificates: %v", err) + } + } + } + } + if keyFile != "" { + keyPEM, err := ioutil.ReadFile(keyFile) + if err != nil { + log.Fatalf("Error: %v", err) + } else { + priv, err = ParsePrivateKeyPEM(keyPEM) + if priv == nil || err != nil { + log.Fatalf("Error parsing private key: %v", err) + } + } + } + if err != nil { + log.Fatalf("Error: %v", err) + } + + // Load response file + if responseFile != "" { + log.Printf("Loading response file: %v", responseFile) + response, err = ioutil.ReadFile(responseFile) + if err != nil { + log.Fatalf("Error: %v", err) + } + } else { + response = []byte("Welcome to the TLS 1.3 zone!") + } + handler := responder(response) + + config := mint.Config{ + SendSessionTickets: true, + ServerName: serverName, + NextProtos: []string{"http/1.1"}, + } + + if h2 { + config.NextProtos = []string{"h2"} + } + + config.SendSessionTickets = sendTickets + + if certChain != nil && priv != nil { + log.Printf("Loading cert: %v key: %v", certFile, keyFile) + config.Certificates = []*mint.Certificate{ + { + Chain: certChain, + PrivateKey: priv, + }, + } + } + config.Init(false) + + service := "0.0.0.0:" + port + srv := &http.Server{Handler: handler} + + log.Printf("Listening on port %v", port) + // Need the inner loop here because the h1 server errors on a dropped connection + // Need the outer loop here because the h2 server is per-connection + for { + listener, err := mint.Listen("tcp", service, &config) + if err != nil { + log.Printf("Listen Error: %v", err) + continue + } + + if !h2 { + alert := srv.Serve(listener) + if alert != mint.AlertNoAlert { + log.Printf("Serve Error: %v", err) + } + } else { + srv2 := new(http2.Server) + opts := &http2.ServeConnOpts{ + Handler: handler, + BaseConfig: srv, + } + + for { + conn, err := listener.Accept() + if err != nil { + log.Printf("Accept error: %v", err) + continue + } + go srv2.ServeConn(conn, opts) + } + } + } +} diff --git a/vendor/github.com/bifurcation/mint/bin/mint-server/main.go b/vendor/github.com/bifurcation/mint/bin/mint-server/main.go new file mode 100644 index 000000000..216f8acba --- /dev/null +++ b/vendor/github.com/bifurcation/mint/bin/mint-server/main.go @@ -0,0 +1,65 @@ +package main + +import ( + "flag" + "log" + "net" + + "github.com/bifurcation/mint" +) + +var port string + +func main() { + var config mint.Config + config.SendSessionTickets = true + config.ServerName = "localhost" + config.Init(false) + + flag.StringVar(&port, "port", "4430", "port") + flag.Parse() + + service := "0.0.0.0:" + port + listener, err := mint.Listen("tcp", service, &config) + + if err != nil { + log.Fatalf("server: listen: %s", err) + } + log.Print("server: listening") + + for { + conn, err := listener.Accept() + if err != nil { + log.Printf("server: accept: %s", err) + break + } + defer conn.Close() + log.Printf("server: accepted from %s", conn.RemoteAddr()) + go handleClient(conn) + } +} + +func handleClient(conn net.Conn) { + defer conn.Close() + buf := make([]byte, 10) + for { + log.Print("server: conn: waiting") + n, err := conn.Read(buf) + if err != nil { + if err != nil { + log.Printf("server: conn: read: %s", err) + } + break + } + + n, err = conn.Write([]byte("hello world")) + log.Printf("server: conn: wrote %d bytes", n) + + if err != nil { + log.Printf("server: write: %s", err) + break + } + break + } + log.Println("server: conn: closed") +} diff --git a/vendor/github.com/bifurcation/mint/client-state-machine.go b/vendor/github.com/bifurcation/mint/client-state-machine.go new file mode 100644 index 000000000..290a93032 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/client-state-machine.go @@ -0,0 +1,942 @@ +package mint + +import ( + "bytes" + "crypto" + "hash" + "time" +) + +// Client State Machine +// +// START <----+ +// Send ClientHello | | Recv HelloRetryRequest +// / v | +// | WAIT_SH ---+ +// Can | | Recv ServerHello +// send | V +// early | WAIT_EE +// data | | Recv EncryptedExtensions +// | +--------+--------+ +// | Using | | Using certificate +// | PSK | v +// | | WAIT_CERT_CR +// | | Recv | | Recv CertificateRequest +// | | Certificate | v +// | | | WAIT_CERT +// | | | | Recv Certificate +// | | v v +// | | WAIT_CV +// | | | Recv CertificateVerify +// | +> WAIT_FINISHED <+ +// | | Recv Finished +// \ | +// | [Send EndOfEarlyData] +// | [Send Certificate [+ CertificateVerify]] +// | Send Finished +// Can send v +// app data --> CONNECTED +// after +// here +// +// State Instructions +// START Send(CH); [RekeyOut; SendEarlyData] +// WAIT_SH Send(CH) || RekeyIn +// WAIT_EE {} +// WAIT_CERT_CR {} +// WAIT_CERT {} +// WAIT_CV {} +// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut; +// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) + +type ClientStateStart struct { + Caps Capabilities + Opts ConnectionOptions + Params ConnectionParameters + + cookie []byte + firstClientHello *HandshakeMessage + helloRetryRequest *HandshakeMessage +} + +func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm != nil { + logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message") + return nil, nil, AlertUnexpectedMessage + } + + // key_shares + offeredDH := map[NamedGroup][]byte{} + ks := KeyShareExtension{ + HandshakeType: HandshakeTypeClientHello, + Shares: make([]KeyShareEntry, len(state.Caps.Groups)), + } + for i, group := range state.Caps.Groups { + pub, priv, err := newKeyShare(group) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err) + return nil, nil, AlertInternalError + } + + ks.Shares[i].Group = group + ks.Shares[i].KeyExchange = pub + offeredDH[group] = priv + } + + logf(logTypeHandshake, "opts: %+v", state.Opts) + + // supported_versions, supported_groups, signature_algorithms, server_name + sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}} + sni := ServerNameExtension(state.Opts.ServerName) + sg := SupportedGroupsExtension{Groups: state.Caps.Groups} + sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes} + + state.Params.ServerName = state.Opts.ServerName + + // Application Layer Protocol Negotiation + var alpn *ALPNExtension + if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) { + alpn = &ALPNExtension{Protocols: state.Opts.NextProtos} + } + + // Construct base ClientHello + ch := &ClientHelloBody{ + CipherSuites: state.Caps.CipherSuites, + } + _, err := prng.Read(ch.Random[:]) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err) + return nil, nil, AlertInternalError + } + for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} { + err := ch.Extensions.Add(ext) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err) + return nil, nil, AlertInternalError + } + } + // XXX: These optional extensions can't be folded into the above because Go + // interface-typed values are never reported as nil + if alpn != nil { + err := ch.Extensions.Add(alpn) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err) + return nil, nil, AlertInternalError + } + } + if state.cookie != nil { + err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie}) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Run the external extension handler. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Handle PSK and EarlyData just before transmitting, so that we can + // calculate the PSK binder value + var psk *PreSharedKeyExtension + var ed *EarlyDataExtension + var offeredPSK PreSharedKey + var earlyHash crypto.Hash + var earlySecret []byte + var clientEarlyTrafficKeys keySet + var clientHello *HandshakeMessage + if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok { + offeredPSK = key + + // Narrow ciphersuites to ones that match PSK hash + params, ok := cipherSuiteMap[key.CipherSuite] + if !ok { + logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite") + return nil, nil, AlertInternalError + } + + compatibleSuites := []CipherSuite{} + for _, suite := range ch.CipherSuites { + if cipherSuiteMap[suite].Hash == params.Hash { + compatibleSuites = append(compatibleSuites, suite) + } + } + ch.CipherSuites = compatibleSuites + + // Signal early data if we're going to do it + if len(state.Opts.EarlyData) > 0 { + state.Params.ClientSendingEarlyData = true + ed = &EarlyDataExtension{} + err = ch.Extensions.Add(ed) + if err != nil { + logf(logTypeHandshake, "Error adding early data extension: %v", err) + return nil, nil, AlertInternalError + } + } + + // Signal supported PSK key exchange modes + if len(state.Caps.PSKModes) == 0 { + logf(logTypeHandshake, "PSK selected, but no PSKModes") + return nil, nil, AlertInternalError + } + kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes} + err = ch.Extensions.Add(kem) + if err != nil { + logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err) + return nil, nil, AlertInternalError + } + + // Add the shim PSK extension to the ClientHello + logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity) + psk = &PreSharedKeyExtension{ + HandshakeType: HandshakeTypeClientHello, + Identities: []PSKIdentity{ + { + Identity: key.Identity, + ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd, + }, + }, + Binders: []PSKBinderEntry{ + // Note: Stub to get the length fields right + {Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())}, + }, + } + ch.Extensions.Add(psk) + + // Compute the binder key + h0 := params.Hash.New().Sum(nil) + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + + earlyHash = params.Hash + earlySecret = HkdfExtract(params.Hash, zero, key.Key) + logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) + + binderLabel := labelExternalBinder + if key.IsResumption { + binderLabel = labelResumptionBinder + } + binderKey := deriveSecret(params, earlySecret, binderLabel, h0) + logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey) + + // Compute the binder value + trunc, err := ch.Truncated() + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err) + return nil, nil, AlertInternalError + } + + truncHash := params.Hash.New() + truncHash.Write(trunc) + + binder := computeFinishedData(params, binderKey, truncHash.Sum(nil)) + + // Replace the PSK extension + psk.Binders[0].Binder = binder + ch.Extensions.Add(psk) + + // If we got here, the earlier marshal succeeded (in ch.Truncated()), so + // this one should too. + clientHello, _ = HandshakeMessageFromBody(ch) + + // Compute early traffic keys + h := params.Hash.New() + h.Write(clientHello.Marshal()) + chHash := h.Sum(nil) + + earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) + logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret) + clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret) + } else if len(state.Opts.EarlyData) > 0 { + logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK") + return nil, nil, AlertInternalError + } else { + clientHello, err = HandshakeMessageFromBody(ch) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err) + return nil, nil, AlertInternalError + } + } + + logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]") + nextState := ClientStateWaitSH{ + Caps: state.Caps, + Opts: state.Opts, + Params: state.Params, + OfferedDH: offeredDH, + OfferedPSK: offeredPSK, + + earlySecret: earlySecret, + earlyHash: earlyHash, + + firstClientHello: state.firstClientHello, + helloRetryRequest: state.helloRetryRequest, + clientHello: clientHello, + } + + toSend := []HandshakeAction{ + SendHandshakeMessage{clientHello}, + } + if state.Params.ClientSendingEarlyData { + toSend = append(toSend, []HandshakeAction{ + RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys}, + SendEarlyData{}, + }...) + } + + return nextState, toSend, AlertNoAlert +} + +type ClientStateWaitSH struct { + Caps Capabilities + Opts ConnectionOptions + Params ConnectionParameters + OfferedDH map[NamedGroup][]byte + OfferedPSK PreSharedKey + PSK []byte + + earlySecret []byte + earlyHash crypto.Hash + + firstClientHello *HandshakeMessage + helloRetryRequest *HandshakeMessage + clientHello *HandshakeMessage +} + +func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil { + logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message") + return nil, nil, AlertUnexpectedMessage + } + + bodyGeneric, err := hm.ToBody() + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + switch body := bodyGeneric.(type) { + case *HelloRetryRequestBody: + hrr := body + + if state.helloRetryRequest != nil { + logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest") + return nil, nil, AlertUnexpectedMessage + } + + // Check that the version sent by the server is the one we support + if hrr.Version != supportedVersion { + logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version) + return nil, nil, AlertProtocolVersion + } + + // Check that the server provided a supported ciphersuite + supportedCipherSuite := false + for _, suite := range state.Caps.CipherSuites { + supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite) + } + if !supportedCipherSuite { + logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite) + return nil, nil, AlertHandshakeFailure + } + + // Narrow the supported ciphersuites to the server-provided one + state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite} + + // Handle external extensions. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + // The only thing we know how to respond to in an HRR is the Cookie + // extension, so if there is either no Cookie extension or anything other + // than a Cookie extension, we have to fail. + serverCookie := new(CookieExtension) + foundCookie := hrr.Extensions.Find(serverCookie) + if !foundCookie || len(hrr.Extensions) != 1 { + logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions)) + return nil, nil, AlertIllegalParameter + } + + // Hash the body into a pseudo-message + // XXX: Ignoring some errors here + params := cipherSuiteMap[hrr.CipherSuite] + h := params.Hash.New() + h.Write(state.clientHello.Marshal()) + firstClientHello := &HandshakeMessage{ + msgType: HandshakeTypeMessageHash, + body: h.Sum(nil), + } + + logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]") + return ClientStateStart{ + Caps: state.Caps, + Opts: state.Opts, + cookie: serverCookie.Cookie, + firstClientHello: firstClientHello, + helloRetryRequest: hm, + }.Next(nil) + + case *ServerHelloBody: + sh := body + + // Check that the version sent by the server is the one we support + if sh.Version != supportedVersion { + logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version) + return nil, nil, AlertProtocolVersion + } + + // Check that the server provided a supported ciphersuite + supportedCipherSuite := false + for _, suite := range state.Caps.CipherSuites { + supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite) + } + if !supportedCipherSuite { + logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite) + return nil, nil, AlertHandshakeFailure + } + + // Handle external extensions. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Do PSK or key agreement depending on extensions + serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello} + serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello} + + foundPSK := sh.Extensions.Find(&serverPSK) + foundKeyShare := sh.Extensions.Find(&serverKeyShare) + + if foundPSK && (serverPSK.SelectedIdentity == 0) { + state.Params.UsingPSK = true + } + + var dhSecret []byte + if foundKeyShare { + sks := serverKeyShare.Shares[0] + priv, ok := state.OfferedDH[sks.Group] + if !ok { + logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group") + return nil, nil, AlertIllegalParameter + } + + state.Params.UsingDH = true + dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv) + } + + suite := sh.CipherSuite + state.Params.CipherSuite = suite + + params, ok := cipherSuiteMap[suite] + if !ok { + logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite) + return nil, nil, AlertHandshakeFailure + } + + // Start up the handshake hash + handshakeHash := params.Hash.New() + handshakeHash.Write(state.firstClientHello.Marshal()) + handshakeHash.Write(state.helloRetryRequest.Marshal()) + handshakeHash.Write(state.clientHello.Marshal()) + handshakeHash.Write(hm.Marshal()) + + // Compute handshake secrets + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + + var earlySecret []byte + if state.Params.UsingPSK { + if params.Hash != state.earlyHash { + logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]", + state.earlyHash, suite, params.Hash) + } + + earlySecret = state.earlySecret + } else { + earlySecret = HkdfExtract(params.Hash, zero, zero) + } + + if dhSecret == nil { + dhSecret = zero + } + + h0 := params.Hash.New().Sum(nil) + h2 := handshakeHash.Sum(nil) + preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) + handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret) + clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) + serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) + preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) + masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) + + logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) + logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) + logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) + logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) + logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) + + serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) + + logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") + nextState := ClientStateWaitEE{ + Caps: state.Caps, + Params: state.Params, + cryptoParams: params, + handshakeHash: handshakeHash, + certificates: state.Caps.Certificates, + masterSecret: masterSecret, + clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: serverHandshakeTrafficSecret, + } + toSend := []HandshakeAction{ + RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys}, + } + return nextState, toSend, AlertNoAlert + } + + logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType) + return nil, nil, AlertUnexpectedMessage +} + +type ClientStateWaitEE struct { + Caps Capabilities + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + certificates []*Certificate + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions { + logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + ee := EncryptedExtensionsBody{} + _, err := ee.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + // Handle external extensions. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + serverALPN := ALPNExtension{} + serverEarlyData := EarlyDataExtension{} + + gotALPN := ee.Extensions.Find(&serverALPN) + state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData) + + if gotALPN && len(serverALPN.Protocols) > 0 { + state.Params.NextProto = serverALPN.Protocols[0] + } + + state.handshakeHash.Write(hm.Marshal()) + + if state.Params.UsingPSK { + logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]") + nextState := ClientStateWaitFinished{ + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.certificates, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert + } + + logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]") + nextState := ClientStateWaitCertCR{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.certificates, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert +} + +type ClientStateWaitCertCR struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + certificates []*Certificate + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil { + logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + bodyGeneric, err := hm.ToBody() + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + state.handshakeHash.Write(hm.Marshal()) + + switch body := bodyGeneric.(type) { + case *CertificateBody: + logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]") + nextState := ClientStateWaitCV{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.certificates, + serverCertificate: body, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert + + case *CertificateRequestBody: + // A certificate request in the handshake should have a zero-length context + if len(body.CertificateRequestContext) > 0 { + logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err) + return nil, nil, AlertIllegalParameter + } + + state.Params.UsingClientAuth = true + + logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]") + nextState := ClientStateWaitCert{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.certificates, + serverCertificateRequest: body, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert + } + + return nil, nil, AlertUnexpectedMessage +} + +type ClientStateWaitCert struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + + certificates []*Certificate + serverCertificateRequest *CertificateRequestBody + + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeCertificate { + logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + cert := &CertificateBody{} + _, err := cert.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + state.handshakeHash.Write(hm.Marshal()) + + logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]") + nextState := ClientStateWaitCV{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.certificates, + serverCertificate: cert, + serverCertificateRequest: state.serverCertificateRequest, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert +} + +type ClientStateWaitCV struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + + certificates []*Certificate + serverCertificate *CertificateBody + serverCertificateRequest *CertificateRequestBody + + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { + logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + certVerify := CertificateVerifyBody{} + _, err := certVerify.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + hcv := state.handshakeHash.Sum(nil) + logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) + + serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey + if err := certVerify.Verify(serverPublicKey, hcv); err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify") + return nil, nil, AlertHandshakeFailure + } + + if state.AuthCertificate != nil { + err := state.AuthCertificate(state.serverCertificate.CertificateList) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate") + return nil, nil, AlertBadCertificate + } + } else { + logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate") + } + + state.handshakeHash.Write(hm.Marshal()) + + logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]") + nextState := ClientStateWaitFinished{ + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.certificates, + serverCertificateRequest: state.serverCertificateRequest, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert +} + +type ClientStateWaitFinished struct { + Params ConnectionParameters + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + + certificates []*Certificate + serverCertificateRequest *CertificateRequestBody + + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeFinished { + logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + // Verify server's Finished + h3 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3) + logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3) + + serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3) + logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) + + fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)} + _, err := fin.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + if !bytes.Equal(fin.VerifyData, serverFinishedData) { + logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]", + fin.VerifyData, serverFinishedData) + return nil, nil, AlertHandshakeFailure + } + + // Update the handshake hash with the Finished + state.handshakeHash.Write(hm.Marshal()) + logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal()) + h4 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4) + + // Compute traffic secrets and keys + clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4) + serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4) + logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret) + logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret) + + clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret) + serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret) + + exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4) + logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret) + + // Assemble client's second flight + toSend := []HandshakeAction{} + + if state.Params.UsingEarlyData { + // Note: We only send EOED if the server is actually going to use the early + // data. Otherwise, it will never see it, and the transcripts will + // mismatch. + // EOED marshal is infallible + eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{}) + toSend = append(toSend, SendHandshakeMessage{eoedm}) + state.handshakeHash.Write(eoedm.Marshal()) + logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal()) + } + + clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) + toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys}) + + if state.Params.UsingClientAuth { + // Extract constraints from certicateRequest + schemes := SignatureAlgorithmsExtension{} + gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes) + if !gotSchemes { + logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) + return nil, nil, AlertIllegalParameter + } + + // Select a certificate + cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates) + if err != nil { + // XXX: Signal this to the application layer? + logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) + + certificate := &CertificateBody{} + certm, err := HandshakeMessageFromBody(certificate) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, SendHandshakeMessage{certm}) + state.handshakeHash.Write(certm.Marshal()) + } else { + // Create and send Certificate, CertificateVerify + certificate := &CertificateBody{ + CertificateList: make([]CertificateEntry, len(cert.Chain)), + } + for i, entry := range cert.Chain { + certificate.CertificateList[i] = CertificateEntry{CertData: entry} + } + certm, err := HandshakeMessageFromBody(certificate) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, SendHandshakeMessage{certm}) + state.handshakeHash.Write(certm.Marshal()) + + hcv := state.handshakeHash.Sum(nil) + logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) + + certificateVerify := &CertificateVerifyBody{Algorithm: certScheme} + logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash) + + err = certificateVerify.Sign(cert.PrivateKey, hcv) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err) + return nil, nil, AlertInternalError + } + certvm, err := HandshakeMessageFromBody(certificateVerify) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, SendHandshakeMessage{certvm}) + state.handshakeHash.Write(certvm.Marshal()) + } + } + + // Compute the client's Finished message + h5 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5) + + clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5) + logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData) + + fin = &FinishedBody{ + VerifyDataLen: len(clientFinishedData), + VerifyData: clientFinishedData, + } + finm, err := HandshakeMessageFromBody(fin) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err) + return nil, nil, AlertInternalError + } + + // Compute the resumption secret + state.handshakeHash.Write(finm.Marshal()) + h6 := state.handshakeHash.Sum(nil) + + resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6) + logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) + + toSend = append(toSend, []HandshakeAction{ + SendHandshakeMessage{finm}, + RekeyIn{Label: "application", KeySet: serverTrafficKeys}, + RekeyOut{Label: "application", KeySet: clientTrafficKeys}, + }...) + + logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]") + nextState := StateConnected{ + Params: state.Params, + isClient: true, + cryptoParams: state.cryptoParams, + resumptionSecret: resumptionSecret, + clientTrafficSecret: clientTrafficSecret, + serverTrafficSecret: serverTrafficSecret, + exporterSecret: exporterSecret, + } + return nextState, toSend, AlertNoAlert +} diff --git a/vendor/github.com/bifurcation/mint/common.go b/vendor/github.com/bifurcation/mint/common.go new file mode 100644 index 000000000..dfda7c3ef --- /dev/null +++ b/vendor/github.com/bifurcation/mint/common.go @@ -0,0 +1,152 @@ +package mint + +import ( + "fmt" + "strconv" +) + +var ( + supportedVersion uint16 = 0x7f15 // draft-21 + + // Flags for some minor compat issues + allowWrongVersionNumber = true + allowPKCS1 = true +) + +// enum {...} ContentType; +type RecordType byte + +const ( + RecordTypeAlert RecordType = 21 + RecordTypeHandshake RecordType = 22 + RecordTypeApplicationData RecordType = 23 +) + +// enum {...} HandshakeType; +type HandshakeType byte + +const ( + // Omitted: *_RESERVED + HandshakeTypeClientHello HandshakeType = 1 + HandshakeTypeServerHello HandshakeType = 2 + HandshakeTypeNewSessionTicket HandshakeType = 4 + HandshakeTypeEndOfEarlyData HandshakeType = 5 + HandshakeTypeHelloRetryRequest HandshakeType = 6 + HandshakeTypeEncryptedExtensions HandshakeType = 8 + HandshakeTypeCertificate HandshakeType = 11 + HandshakeTypeCertificateRequest HandshakeType = 13 + HandshakeTypeCertificateVerify HandshakeType = 15 + HandshakeTypeServerConfiguration HandshakeType = 17 + HandshakeTypeFinished HandshakeType = 20 + HandshakeTypeKeyUpdate HandshakeType = 24 + HandshakeTypeMessageHash HandshakeType = 254 +) + +// uint8 CipherSuite[2]; +type CipherSuite uint16 + +const ( + // XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero + // value for this type so that we can detect when a field is set. + CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000 + TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301 + TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302 + TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303 + TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304 + TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305 +) + +func (c CipherSuite) String() string { + switch c { + case CIPHER_SUITE_UNKNOWN: + return "unknown" + case TLS_AES_128_GCM_SHA256: + return "TLS_AES_128_GCM_SHA256" + case TLS_AES_256_GCM_SHA384: + return "TLS_AES_256_GCM_SHA384" + case TLS_CHACHA20_POLY1305_SHA256: + return "TLS_CHACHA20_POLY1305_SHA256" + case TLS_AES_128_CCM_SHA256: + return "TLS_AES_128_CCM_SHA256" + case TLS_AES_256_CCM_8_SHA256: + return "TLS_AES_256_CCM_8_SHA256" + } + // cannot use %x here, since it calls String(), leading to infinite recursion + return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16)) +} + +// enum {...} SignatureScheme +type SignatureScheme uint16 + +const ( + // RSASSA-PKCS1-v1_5 algorithms + RSA_PKCS1_SHA1 SignatureScheme = 0x0201 + RSA_PKCS1_SHA256 SignatureScheme = 0x0401 + RSA_PKCS1_SHA384 SignatureScheme = 0x0501 + RSA_PKCS1_SHA512 SignatureScheme = 0x0601 + // ECDSA algorithms + ECDSA_P256_SHA256 SignatureScheme = 0x0403 + ECDSA_P384_SHA384 SignatureScheme = 0x0503 + ECDSA_P521_SHA512 SignatureScheme = 0x0603 + // RSASSA-PSS algorithms + RSA_PSS_SHA256 SignatureScheme = 0x0804 + RSA_PSS_SHA384 SignatureScheme = 0x0805 + RSA_PSS_SHA512 SignatureScheme = 0x0806 + // EdDSA algorithms + Ed25519 SignatureScheme = 0x0807 + Ed448 SignatureScheme = 0x0808 +) + +// enum {...} ExtensionType +type ExtensionType uint16 + +const ( + ExtensionTypeServerName ExtensionType = 0 + ExtensionTypeSupportedGroups ExtensionType = 10 + ExtensionTypeSignatureAlgorithms ExtensionType = 13 + ExtensionTypeALPN ExtensionType = 16 + ExtensionTypeKeyShare ExtensionType = 40 + ExtensionTypePreSharedKey ExtensionType = 41 + ExtensionTypeEarlyData ExtensionType = 42 + ExtensionTypeSupportedVersions ExtensionType = 43 + ExtensionTypeCookie ExtensionType = 44 + ExtensionTypePSKKeyExchangeModes ExtensionType = 45 + ExtensionTypeTicketEarlyDataInfo ExtensionType = 46 +) + +// enum {...} NamedGroup +type NamedGroup uint16 + +const ( + // Elliptic Curve Groups. + P256 NamedGroup = 23 + P384 NamedGroup = 24 + P521 NamedGroup = 25 + // ECDH functions. + X25519 NamedGroup = 29 + X448 NamedGroup = 30 + // Finite field groups. + FFDHE2048 NamedGroup = 256 + FFDHE3072 NamedGroup = 257 + FFDHE4096 NamedGroup = 258 + FFDHE6144 NamedGroup = 259 + FFDHE8192 NamedGroup = 260 +) + +// enum {...} PskKeyExchangeMode; +type PSKKeyExchangeMode uint8 + +const ( + PSKModeKE PSKKeyExchangeMode = 0 + PSKModeDHEKE PSKKeyExchangeMode = 1 +) + +// enum { +// update_not_requested(0), update_requested(1), (255) +// } KeyUpdateRequest; +type KeyUpdateRequest uint8 + +const ( + KeyUpdateNotRequested KeyUpdateRequest = 0 + KeyUpdateRequested KeyUpdateRequest = 1 +) diff --git a/vendor/github.com/bifurcation/mint/conn.go b/vendor/github.com/bifurcation/mint/conn.go new file mode 100644 index 000000000..08eb58df9 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/conn.go @@ -0,0 +1,819 @@ +package mint + +import ( + "crypto" + "crypto/x509" + "encoding/hex" + "fmt" + "io" + "net" + "reflect" + "sync" + "time" +) + +var WouldBlock = fmt.Errorf("Would have blocked") + +type Certificate struct { + Chain []*x509.Certificate + PrivateKey crypto.Signer +} + +type PreSharedKey struct { + CipherSuite CipherSuite + IsResumption bool + Identity []byte + Key []byte + NextProto string + ReceivedAt time.Time + ExpiresAt time.Time + TicketAgeAdd uint32 +} + +type PreSharedKeyCache interface { + Get(string) (PreSharedKey, bool) + Put(string, PreSharedKey) + Size() int +} + +type PSKMapCache map[string]PreSharedKey + +// A CookieHandler does two things: +// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest +// - validates this byte string echoed by the client in the ClientHello +type CookieHandler interface { + Generate(*Conn) ([]byte, error) + Validate(*Conn, []byte) bool +} + +func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) { + psk, ok = cache[key] + return +} + +func (cache *PSKMapCache) Put(key string, psk PreSharedKey) { + (*cache)[key] = psk +} + +func (cache PSKMapCache) Size() int { + return len(cache) +} + +// Config is the struct used to pass configuration settings to a TLS client or +// server instance. The settings for client and server are pretty different, +// but we just throw them all in here. +type Config struct { + // Client fields + ServerName string + + // Server fields + SendSessionTickets bool + TicketLifetime uint32 + TicketLen int + EarlyDataLifetime uint32 + AllowEarlyData bool + // Require the client to echo a cookie. + RequireCookie bool + // If cookies are required and no CookieHandler is set, a default cookie handler is used. + // The default cookie handler uses 32 random bytes as a cookie. + CookieHandler CookieHandler + RequireClientAuth bool + + // Shared fields + Certificates []*Certificate + AuthCertificate func(chain []CertificateEntry) error + CipherSuites []CipherSuite + Groups []NamedGroup + SignatureSchemes []SignatureScheme + NextProtos []string + PSKs PreSharedKeyCache + PSKModes []PSKKeyExchangeMode + NonBlocking bool + + // The same config object can be shared among different connections, so it + // needs its own mutex + mutex sync.RWMutex +} + +// Clone returns a shallow clone of c. It is safe to clone a Config that is +// being used concurrently by a TLS client or server. +func (c *Config) Clone() *Config { + c.mutex.Lock() + defer c.mutex.Unlock() + + return &Config{ + ServerName: c.ServerName, + + SendSessionTickets: c.SendSessionTickets, + TicketLifetime: c.TicketLifetime, + TicketLen: c.TicketLen, + EarlyDataLifetime: c.EarlyDataLifetime, + AllowEarlyData: c.AllowEarlyData, + RequireCookie: c.RequireCookie, + RequireClientAuth: c.RequireClientAuth, + + Certificates: c.Certificates, + AuthCertificate: c.AuthCertificate, + CipherSuites: c.CipherSuites, + Groups: c.Groups, + SignatureSchemes: c.SignatureSchemes, + NextProtos: c.NextProtos, + PSKs: c.PSKs, + PSKModes: c.PSKModes, + NonBlocking: c.NonBlocking, + } +} + +func (c *Config) Init(isClient bool) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + // Set defaults + if len(c.CipherSuites) == 0 { + c.CipherSuites = defaultSupportedCipherSuites + } + if len(c.Groups) == 0 { + c.Groups = defaultSupportedGroups + } + if len(c.SignatureSchemes) == 0 { + c.SignatureSchemes = defaultSignatureSchemes + } + if c.TicketLen == 0 { + c.TicketLen = defaultTicketLen + } + if !reflect.ValueOf(c.PSKs).IsValid() { + c.PSKs = &PSKMapCache{} + } + if len(c.PSKModes) == 0 { + c.PSKModes = defaultPSKModes + } + + // If there is no certificate, generate one + if !isClient && len(c.Certificates) == 0 { + logf(logTypeHandshake, "Generating key name=%v", c.ServerName) + priv, err := newSigningKey(RSA_PSS_SHA256) + if err != nil { + return err + } + + cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv) + if err != nil { + return err + } + + c.Certificates = []*Certificate{ + { + Chain: []*x509.Certificate{cert}, + PrivateKey: priv, + }, + } + } + + return nil +} + +func (c *Config) ValidForServer() bool { + return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) || + (len(c.Certificates) > 0 && + len(c.Certificates[0].Chain) > 0 && + c.Certificates[0].PrivateKey != nil) +} + +func (c *Config) ValidForClient() bool { + return len(c.ServerName) > 0 +} + +var ( + defaultSupportedCipherSuites = []CipherSuite{ + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + } + + defaultSupportedGroups = []NamedGroup{ + P256, + P384, + FFDHE2048, + X25519, + } + + defaultSignatureSchemes = []SignatureScheme{ + RSA_PSS_SHA256, + RSA_PSS_SHA384, + RSA_PSS_SHA512, + ECDSA_P256_SHA256, + ECDSA_P384_SHA384, + ECDSA_P521_SHA512, + } + + defaultTicketLen = 16 + + defaultPSKModes = []PSKKeyExchangeMode{ + PSKModeKE, + PSKModeDHEKE, + } +) + +type ConnectionState struct { + HandshakeState string // string representation of the handshake state. + CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...) + PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement + NextProto string // Selected ALPN proto +} + +// Conn implements the net.Conn interface, as with "crypto/tls" +// * Read, Write, and Close are provided locally +// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn +type Conn struct { + config *Config + conn net.Conn + isClient bool + + EarlyData []byte + + state StateConnected + hState HandshakeState + handshakeMutex sync.Mutex + handshakeAlert Alert + handshakeComplete bool + + readBuffer []byte + in, out *RecordLayer + hIn, hOut *HandshakeLayer + + extHandler AppExtensionHandler +} + +func NewConn(conn net.Conn, config *Config, isClient bool) *Conn { + c := &Conn{conn: conn, config: config, isClient: isClient} + c.in = NewRecordLayer(c.conn) + c.out = NewRecordLayer(c.conn) + c.hIn = NewHandshakeLayer(c.in) + c.hIn.nonblocking = c.config.NonBlocking + c.hOut = NewHandshakeLayer(c.out) + return c +} + +// Read up +func (c *Conn) consumeRecord() error { + pt, err := c.in.ReadRecord() + if pt == nil { + logf(logTypeIO, "extendBuffer returns error %v", err) + return err + } + + switch pt.contentType { + case RecordTypeHandshake: + logf(logTypeHandshake, "Received post-handshake message") + // We do not support fragmentation of post-handshake handshake messages. + // TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage() + start := 0 + for start < len(pt.fragment) { + if len(pt.fragment[start:]) < handshakeHeaderLen { + return fmt.Errorf("Post-handshake handshake message too short for header") + } + + hm := &HandshakeMessage{} + hm.msgType = HandshakeType(pt.fragment[start]) + hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3]) + + if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen { + return fmt.Errorf("Post-handshake handshake message too short for body") + } + hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen] + + // Advance state machine + state, actions, alert := c.state.Next(hm) + + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error in state transition: %v", alert) + c.sendAlert(alert) + return io.EOF + } + + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + c.sendAlert(alert) + return io.EOF + } + } + + // XXX: If we want to support more advanced cases, e.g., post-handshake + // authentication, we'll need to allow transitions other than + // Connected -> Connected + var connected bool + c.state, connected = state.(StateConnected) + if !connected { + logf(logTypeHandshake, "Disconnected after state transition: %v", alert) + c.sendAlert(alert) + return io.EOF + } + + start += handshakeHeaderLen + hmLen + } + case RecordTypeAlert: + logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer) + if len(pt.fragment) != 2 { + c.sendAlert(AlertUnexpectedMessage) + return io.EOF + } + if Alert(pt.fragment[1]) == AlertCloseNotify { + return io.EOF + } + + switch pt.fragment[0] { + case AlertLevelWarning: + // drop on the floor + case AlertLevelError: + return Alert(pt.fragment[1]) + default: + c.sendAlert(AlertUnexpectedMessage) + return io.EOF + } + + case RecordTypeApplicationData: + c.readBuffer = append(c.readBuffer, pt.fragment...) + logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer) + } + + return err +} + +// Read application data up to the size of buffer. Handshake and alert records +// are consumed by the Conn object directly. +func (c *Conn) Read(buffer []byte) (int, error) { + logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer)) + if alert := c.Handshake(); alert != AlertNoAlert { + return 0, alert + } + + if len(buffer) == 0 { + return 0, nil + } + + // Lock the input channel + c.in.Lock() + defer c.in.Unlock() + for len(c.readBuffer) == 0 { + err := c.consumeRecord() + + // err can be nil if consumeRecord processed a non app-data + // record. + if err != nil { + if c.config.NonBlocking || err != WouldBlock { + logf(logTypeIO, "conn.Read returns err=%v", err) + return 0, err + } + } + } + + var read int + n := len(buffer) + logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer)) + if len(c.readBuffer) <= n { + buffer = buffer[:len(c.readBuffer)] + copy(buffer, c.readBuffer) + read = len(c.readBuffer) + c.readBuffer = c.readBuffer[:0] + } else { + logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n) + copy(buffer[:n], c.readBuffer[:n]) + c.readBuffer = c.readBuffer[n:] + read = n + } + + logf(logTypeVerbose, "Returning %v", string(buffer)) + return read, nil +} + +// Write application data +func (c *Conn) Write(buffer []byte) (int, error) { + // Lock the output channel + c.out.Lock() + defer c.out.Unlock() + + // Send full-size fragments + var start int + sent := 0 + for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen { + err := c.out.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeApplicationData, + fragment: buffer[start : start+maxFragmentLen], + }) + + if err != nil { + return sent, err + } + sent += maxFragmentLen + } + + // Send a final partial fragment if necessary + if start < len(buffer) { + err := c.out.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeApplicationData, + fragment: buffer[start:], + }) + + if err != nil { + return sent, err + } + sent += len(buffer[start:]) + } + return sent, nil +} + +// sendAlert sends a TLS alert message. +// c.out.Mutex <= L. +func (c *Conn) sendAlert(err Alert) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + var level int + switch err { + case AlertNoRenegotiation, AlertCloseNotify: + level = AlertLevelWarning + default: + level = AlertLevelError + } + + buf := []byte{byte(err), byte(level)} + c.out.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeAlert, + fragment: buf, + }) + + // close_notify and end_of_early_data are not actually errors + if level == AlertLevelWarning { + return &net.OpError{Op: "local error", Err: err} + } + + return c.Close() +} + +// Close closes the connection. +func (c *Conn) Close() error { + // XXX crypto/tls has an interlock with Write here. Do we need that? + + return c.conn.Close() +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated with the connection. +// A zero value for t means Read and Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying connection. +// A zero value for t means Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { + label := "[server]" + if c.isClient { + label = "[client]" + } + + switch action := actionGeneric.(type) { + case SendHandshakeMessage: + err := c.hOut.WriteMessage(action.Message) + if err != nil { + logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) + return AlertInternalError + } + + case RekeyIn: + logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet) + err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) + if err != nil { + logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err) + return AlertInternalError + } + + case RekeyOut: + logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet) + err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) + if err != nil { + logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err) + return AlertInternalError + } + + case SendEarlyData: + logf(logTypeHandshake, "%s Sending early data...", label) + _, err := c.Write(c.EarlyData) + if err != nil { + logf(logTypeHandshake, "%s Error writing early data: %v", label, err) + return AlertInternalError + } + + case ReadPastEarlyData: + logf(logTypeHandshake, "%s Reading past early data...", label) + // Scan past all records that fail to decrypt + _, err := c.in.PeekRecordType(!c.config.NonBlocking) + if err == nil { + break + } + _, ok := err.(DecryptError) + + for ok { + _, err = c.in.PeekRecordType(!c.config.NonBlocking) + if err == nil { + break + } + _, ok = err.(DecryptError) + } + + case ReadEarlyData: + logf(logTypeHandshake, "%s Reading early data...", label) + t, err := c.in.PeekRecordType(!c.config.NonBlocking) + if err != nil { + logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err) + return AlertInternalError + } + logf(logTypeHandshake, "%s Got record type(1): %v", label, t) + + for t == RecordTypeApplicationData { + // Read a record into the buffer. Note that this is safe + // in blocking mode because we read the record in in + // PeekRecordType. + pt, err := c.in.ReadRecord() + if err != nil { + logf(logTypeHandshake, "%s Error reading early data record: %v", label, err) + return AlertInternalError + } + + logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment) + c.EarlyData = append(c.EarlyData, pt.fragment...) + + t, err = c.in.PeekRecordType(!c.config.NonBlocking) + if err != nil { + logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err) + return AlertInternalError + } + logf(logTypeHandshake, "%s Got record type (2): %v", label, t) + } + logf(logTypeHandshake, "%s Done reading early data", label) + + case StorePSK: + logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity) + if c.isClient { + // Clients look up PSKs based on server name + c.config.PSKs.Put(c.config.ServerName, action.PSK) + } else { + // Servers look them up based on the identity in the extension + c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK) + } + + default: + logf(logTypeHandshake, "%s Unknown actionuction type", label) + return AlertInternalError + } + + return AlertNoAlert +} + +func (c *Conn) HandshakeSetup() Alert { + var state HandshakeState + var actions []HandshakeAction + var alert Alert + + if err := c.config.Init(c.isClient); err != nil { + logf(logTypeHandshake, "Error initializing config: %v", err) + return AlertInternalError + } + + // Set things up + caps := Capabilities{ + CipherSuites: c.config.CipherSuites, + Groups: c.config.Groups, + SignatureSchemes: c.config.SignatureSchemes, + PSKs: c.config.PSKs, + PSKModes: c.config.PSKModes, + AllowEarlyData: c.config.AllowEarlyData, + RequireCookie: c.config.RequireCookie, + CookieHandler: c.config.CookieHandler, + RequireClientAuth: c.config.RequireClientAuth, + NextProtos: c.config.NextProtos, + Certificates: c.config.Certificates, + ExtensionHandler: c.extHandler, + } + opts := ConnectionOptions{ + ServerName: c.config.ServerName, + NextProtos: c.config.NextProtos, + EarlyData: c.EarlyData, + } + + if caps.RequireCookie && caps.CookieHandler == nil { + caps.CookieHandler = &defaultCookieHandler{} + } + + if c.isClient { + state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error initializing client state: %v", alert) + return alert + } + + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + return alert + } + } + } else { + state = ServerStateStart{Caps: caps, conn: c} + } + + c.hState = state + + return AlertNoAlert +} + +// Handshake causes a TLS handshake on the connection. The `isClient` member +// determines whether a client or server handshake is performed. If a +// handshake has already been performed, then its result will be returned. +func (c *Conn) Handshake() Alert { + label := "[server]" + if c.isClient { + label = "[client]" + } + + // TODO Lock handshakeMutex + // TODO Remove CloseNotify hack + if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify { + logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert) + return c.handshakeAlert + } + if c.handshakeComplete { + return AlertNoAlert + } + + var alert Alert + if c.hState == nil { + logf(logTypeHandshake, "%s First time through handshake, setting up", label) + alert = c.HandshakeSetup() + if alert != AlertNoAlert { + return alert + } + } else { + logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState) + } + + state := c.hState + _, connected := state.(StateConnected) + + var actions []HandshakeAction + + for !connected { + // Read a handshake message + hm, err := c.hIn.ReadMessage() + if err == WouldBlock { + logf(logTypeHandshake, "%s Would block reading message: %v", label, err) + return AlertWouldBlock + } + if err != nil { + logf(logTypeHandshake, "%s Error reading message: %v", label, err) + c.sendAlert(AlertCloseNotify) + return AlertCloseNotify + } + logf(logTypeHandshake, "Read message with type: %v", hm.msgType) + + // Advance the state machine + state, actions, alert = state.Next(hm) + + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error in state transition: %v", alert) + return alert + } + + for index, action := range actions { + logf(logTypeHandshake, "%s taking next action (%d)", label, index) + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + c.sendAlert(alert) + return alert + } + } + + c.hState = state + logf(logTypeHandshake, "state is now %s", c.GetHsState()) + + _, connected = state.(StateConnected) + } + + c.state = state.(StateConnected) + + // Send NewSessionTicket if acting as server + if !c.isClient && c.config.SendSessionTickets { + actions, alert := c.state.NewSessionTicket( + c.config.TicketLen, + c.config.TicketLifetime, + c.config.EarlyDataLifetime) + + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + c.sendAlert(alert) + return alert + } + } + } + + c.handshakeComplete = true + return AlertNoAlert +} + +func (c *Conn) SendKeyUpdate(requestUpdate bool) error { + if !c.handshakeComplete { + return fmt.Errorf("Cannot update keys until after handshake") + } + + request := KeyUpdateNotRequested + if requestUpdate { + request = KeyUpdateRequested + } + + // Create the key update and update state + actions, alert := c.state.KeyUpdate(request) + if alert != AlertNoAlert { + c.sendAlert(alert) + return fmt.Errorf("Alert while generating key update: %v", alert) + } + + // Take actions (send key update and rekey) + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + c.sendAlert(alert) + return fmt.Errorf("Alert during key update actions: %v", alert) + } + } + + return nil +} + +func (c *Conn) GetHsState() string { + return reflect.TypeOf(c.hState).Name() +} + +func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { + _, connected := c.hState.(StateConnected) + if !connected { + return nil, fmt.Errorf("Cannot compute exporter when state is not connected") + } + + if c.state.exporterSecret == nil { + return nil, fmt.Errorf("Internal error: no exporter secret") + } + + h0 := c.state.cryptoParams.Hash.New().Sum(nil) + tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0) + + hc := c.state.cryptoParams.Hash.New().Sum(context) + return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil +} + +func (c *Conn) State() ConnectionState { + state := ConnectionState{ + HandshakeState: c.GetHsState(), + } + + if c.handshakeComplete { + state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite] + state.NextProto = c.state.Params.NextProto + } + + return state +} + +func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error { + if c.hState != nil { + return fmt.Errorf("Can't set extension handler after setup") + } + + c.extHandler = h + return nil +} diff --git a/vendor/github.com/bifurcation/mint/crypto.go b/vendor/github.com/bifurcation/mint/crypto.go new file mode 100644 index 000000000..60d343774 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/crypto.go @@ -0,0 +1,654 @@ +package mint + +import ( + "bytes" + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "fmt" + "math/big" + "time" + + "golang.org/x/crypto/curve25519" + + // Blank includes to ensure hash support + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" +) + +var prng = rand.Reader + +type aeadFactory func(key []byte) (cipher.AEAD, error) + +type CipherSuiteParams struct { + Suite CipherSuite + Cipher aeadFactory // Cipher factory + Hash crypto.Hash // Hash function + KeyLen int // Key length in octets + IvLen int // IV length in octets +} + +type signatureAlgorithm uint8 + +const ( + signatureAlgorithmUnknown = iota + signatureAlgorithmRSA_PKCS1 + signatureAlgorithmRSA_PSS + signatureAlgorithmECDSA +) + +var ( + hashMap = map[SignatureScheme]crypto.Hash{ + RSA_PKCS1_SHA1: crypto.SHA1, + RSA_PKCS1_SHA256: crypto.SHA256, + RSA_PKCS1_SHA384: crypto.SHA384, + RSA_PKCS1_SHA512: crypto.SHA512, + ECDSA_P256_SHA256: crypto.SHA256, + ECDSA_P384_SHA384: crypto.SHA384, + ECDSA_P521_SHA512: crypto.SHA512, + RSA_PSS_SHA256: crypto.SHA256, + RSA_PSS_SHA384: crypto.SHA384, + RSA_PSS_SHA512: crypto.SHA512, + } + + sigMap = map[SignatureScheme]signatureAlgorithm{ + RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1, + RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1, + RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1, + RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1, + ECDSA_P256_SHA256: signatureAlgorithmECDSA, + ECDSA_P384_SHA384: signatureAlgorithmECDSA, + ECDSA_P521_SHA512: signatureAlgorithmECDSA, + RSA_PSS_SHA256: signatureAlgorithmRSA_PSS, + RSA_PSS_SHA384: signatureAlgorithmRSA_PSS, + RSA_PSS_SHA512: signatureAlgorithmRSA_PSS, + } + + curveMap = map[SignatureScheme]NamedGroup{ + ECDSA_P256_SHA256: P256, + ECDSA_P384_SHA384: P384, + ECDSA_P521_SHA512: P521, + } + + newAESGCM = func(key []byte) (cipher.AEAD, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + // TLS always uses 12-byte nonces + return cipher.NewGCMWithNonceSize(block, 12) + } + + cipherSuiteMap = map[CipherSuite]CipherSuiteParams{ + TLS_AES_128_GCM_SHA256: { + Suite: TLS_AES_128_GCM_SHA256, + Cipher: newAESGCM, + Hash: crypto.SHA256, + KeyLen: 16, + IvLen: 12, + }, + TLS_AES_256_GCM_SHA384: { + Suite: TLS_AES_256_GCM_SHA384, + Cipher: newAESGCM, + Hash: crypto.SHA384, + KeyLen: 32, + IvLen: 12, + }, + } + + x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{ + RSA_PKCS1_SHA1: x509.SHA1WithRSA, + RSA_PKCS1_SHA256: x509.SHA256WithRSA, + RSA_PKCS1_SHA384: x509.SHA384WithRSA, + RSA_PKCS1_SHA512: x509.SHA512WithRSA, + ECDSA_P256_SHA256: x509.ECDSAWithSHA256, + ECDSA_P384_SHA384: x509.ECDSAWithSHA384, + ECDSA_P521_SHA512: x509.ECDSAWithSHA512, + } + + defaultRSAKeySize = 2048 +) + +func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) { + switch group { + case P256: + crv = elliptic.P256() + case P384: + crv = elliptic.P384() + case P521: + crv = elliptic.P521() + } + return +} + +func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) { + switch key.Curve.Params().Name { + case elliptic.P256().Params().Name: + g = P256 + case elliptic.P384().Params().Name: + g = P384 + case elliptic.P521().Params().Name: + g = P521 + } + return +} + +func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) { + size = 0 + switch group { + case X25519: + size = 32 + case P256: + size = 65 + case P384: + size = 97 + case P521: + size = 133 + case FFDHE2048: + size = 256 + case FFDHE3072: + size = 384 + case FFDHE4096: + size = 512 + case FFDHE6144: + size = 768 + case FFDHE8192: + size = 1024 + } + return +} + +func primeFromNamedGroup(group NamedGroup) (p *big.Int) { + switch group { + case FFDHE2048: + p = finiteFieldPrime2048 + case FFDHE3072: + p = finiteFieldPrime3072 + case FFDHE4096: + p = finiteFieldPrime4096 + case FFDHE6144: + p = finiteFieldPrime6144 + case FFDHE8192: + p = finiteFieldPrime8192 + } + return +} + +func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool { + sigType := sigMap[alg] + switch key.(type) { + case *rsa.PrivateKey: + return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS + case *ecdsa.PrivateKey: + return sigType == signatureAlgorithmECDSA + default: + return false + } +} + +func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) { + primeLen := len(p.Bytes()) + for { + // g = 2 for all ffdhe groups + priv, err = rand.Int(prng, p) + if err != nil { + return + } + + pub = big.NewInt(0) + pub.Exp(big.NewInt(2), priv, p) + + if len(pub.Bytes()) == primeLen { + return + } + } +} + +func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) { + switch group { + case P256, P384, P521: + var x, y *big.Int + crv := curveFromNamedGroup(group) + priv, x, y, err = elliptic.GenerateKey(crv, prng) + if err != nil { + return + } + + pub = elliptic.Marshal(crv, x, y) + return + + case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192: + p := primeFromNamedGroup(group) + x, X, err2 := ffdheKeyShareFromPrime(p) + if err2 != nil { + err = err2 + return + } + + priv = x.Bytes() + pubBytes := X.Bytes() + + numBytes := keyExchangeSizeFromNamedGroup(group) + + pub = make([]byte, numBytes) + copy(pub[numBytes-len(pubBytes):], pubBytes) + + return + + case X25519: + var private, public [32]byte + _, err = prng.Read(private[:]) + if err != nil { + return + } + + curve25519.ScalarBaseMult(&public, &private) + priv = private[:] + pub = public[:] + return + + default: + return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group) + } +} + +func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) { + switch group { + case P256, P384, P521: + if len(pub) != keyExchangeSizeFromNamedGroup(group) { + return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") + } + + crv := curveFromNamedGroup(group) + pubX, pubY := elliptic.Unmarshal(crv, pub) + x, _ := crv.Params().ScalarMult(pubX, pubY, priv) + xBytes := x.Bytes() + + numBytes := len(crv.Params().P.Bytes()) + + ret := make([]byte, numBytes) + copy(ret[numBytes-len(xBytes):], xBytes) + + return ret, nil + + case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192: + numBytes := keyExchangeSizeFromNamedGroup(group) + if len(pub) != numBytes { + return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") + } + p := primeFromNamedGroup(group) + x := big.NewInt(0).SetBytes(priv) + Y := big.NewInt(0).SetBytes(pub) + ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes() + + ret := make([]byte, numBytes) + copy(ret[numBytes-len(ZBytes):], ZBytes) + + return ret, nil + + case X25519: + if len(pub) != keyExchangeSizeFromNamedGroup(group) { + return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") + } + + var private, public, ret [32]byte + copy(private[:], priv) + copy(public[:], pub) + curve25519.ScalarMult(&ret, &private, &public) + + return ret[:], nil + + default: + return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group) + } +} + +func newSigningKey(sig SignatureScheme) (crypto.Signer, error) { + switch sig { + case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256, + RSA_PKCS1_SHA384, RSA_PKCS1_SHA512, + RSA_PSS_SHA256, RSA_PSS_SHA384, + RSA_PSS_SHA512: + return rsa.GenerateKey(prng, defaultRSAKeySize) + case ECDSA_P256_SHA256: + return ecdsa.GenerateKey(elliptic.P256(), prng) + case ECDSA_P384_SHA384: + return ecdsa.GenerateKey(elliptic.P384(), prng) + case ECDSA_P521_SHA512: + return ecdsa.GenerateKey(elliptic.P521(), prng) + default: + return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig) + } +} + +func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) { + sigAlg, ok := x509AlgMap[alg] + if !ok { + return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg) + } + if len(name) == 0 { + return nil, fmt.Errorf("tls.selfsigned: No name provided") + } + + serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0)) + if err != nil { + return nil, err + } + + template := &x509.Certificate{ + SerialNumber: serial, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(0, 0, 1), + SignatureAlgorithm: sigAlg, + Subject: pkix.Name{CommonName: name}, + DNSNames: []string{name}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv) + if err != nil { + return nil, err + } + + // It is safe to ignore the error here because we're parsing known-good data + cert, _ := x509.ParseCertificate(der) + return cert, nil +} + +// XXX(rlb): Copied from crypto/x509 +type ecdsaSignature struct { + R, S *big.Int +} + +func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) { + var opts crypto.SignerOpts + + hash := hashMap[alg] + if hash == crypto.SHA1 { + return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden") + } + + sigType := sigMap[alg] + var realInput []byte + switch key := privateKey.(type) { + case *rsa.PrivateKey: + switch { + case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: + logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size()) + opts = hash + case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: + fallthrough + case sigType == signatureAlgorithmRSA_PSS: + logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size()) + opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash} + default: + return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key") + } + + h := hash.New() + h.Write(sigInput) + realInput = h.Sum(nil) + case *ecdsa.PrivateKey: + if sigType != signatureAlgorithmECDSA { + return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key") + } + + algGroup := curveMap[alg] + keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey)) + if algGroup != keyGroup { + return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination") + } + + h := hash.New() + h.Write(sigInput) + realInput = h.Sum(nil) + default: + return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type") + } + + sig, err := privateKey.Sign(prng, realInput, opts) + logf(logTypeCrypto, "signature: %x", sig) + return sig, err +} + +func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error { + hash := hashMap[alg] + + if hash == crypto.SHA1 { + return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden") + } + + sigType := sigMap[alg] + switch pub := publicKey.(type) { + case *rsa.PublicKey: + switch { + case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: + logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size()) + + h := hash.New() + h.Write(sigInput) + realInput := h.Sum(nil) + return rsa.VerifyPKCS1v15(pub, hash, realInput, sig) + case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: + fallthrough + case sigType == signatureAlgorithmRSA_PSS: + logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size()) + opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash} + + h := hash.New() + h.Write(sigInput) + realInput := h.Sum(nil) + return rsa.VerifyPSS(pub, hash, realInput, sig, opts) + default: + return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key") + } + + case *ecdsa.PublicKey: + if sigType != signatureAlgorithmECDSA { + return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key") + } + + if curveMap[alg] != namedGroupFromECDSAKey(pub) { + return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key") + } + + ecdsaSig := new(ecdsaSignature) + if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil { + return err + } else if len(rest) != 0 { + return fmt.Errorf("tls.verify: trailing data after ECDSA signature") + } + if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { + return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values") + } + + h := hash.New() + h.Write(sigInput) + realInput := h.Sum(nil) + if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) { + return fmt.Errorf("tls.verify: ECDSA verification failure") + } + return nil + default: + return fmt.Errorf("tls.verify: Unsupported key type") + } +} + +// 0 +// | +// v +// PSK -> HKDF-Extract = Early Secret +// | +// +-----> Derive-Secret(., +// | "ext binder" | +// | "res binder", +// | "") +// | = binder_key +// | +// +-----> Derive-Secret(., "c e traffic", +// | ClientHello) +// | = client_early_traffic_secret +// | +// +-----> Derive-Secret(., "e exp master", +// | ClientHello) +// | = early_exporter_master_secret +// v +// Derive-Secret(., "derived", "") +// | +// v +// (EC)DHE -> HKDF-Extract = Handshake Secret +// | +// +-----> Derive-Secret(., "c hs traffic", +// | ClientHello...ServerHello) +// | = client_handshake_traffic_secret +// | +// +-----> Derive-Secret(., "s hs traffic", +// | ClientHello...ServerHello) +// | = server_handshake_traffic_secret +// v +// Derive-Secret(., "derived", "") +// | +// v +// 0 -> HKDF-Extract = Master Secret +// | +// +-----> Derive-Secret(., "c ap traffic", +// | ClientHello...server Finished) +// | = client_application_traffic_secret_0 +// | +// +-----> Derive-Secret(., "s ap traffic", +// | ClientHello...server Finished) +// | = server_application_traffic_secret_0 +// | +// +-----> Derive-Secret(., "exp master", +// | ClientHello...server Finished) +// | = exporter_master_secret +// | +// +-----> Derive-Secret(., "res master", +// ClientHello...client Finished) +// = resumption_master_secret + +// From RFC 5869 +// PRK = HMAC-Hash(salt, IKM) +func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte { + salt := saltIn + + // if [salt is] not provided, it is set to a string of HashLen zeros + if salt == nil { + salt = bytes.Repeat([]byte{0}, hash.Size()) + } + + h := hmac.New(hash.New, salt) + h.Write(input) + out := h.Sum(nil) + + logf(logTypeCrypto, "HKDF Extract:\n") + logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt) + logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input) + logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out) + + return out +} + +const ( + labelExternalBinder = "ext binder" + labelResumptionBinder = "res binder" + labelEarlyTrafficSecret = "c e traffic" + labelEarlyExporterSecret = "e exp master" + labelClientHandshakeTrafficSecret = "c hs traffic" + labelServerHandshakeTrafficSecret = "s hs traffic" + labelClientApplicationTrafficSecret = "c ap traffic" + labelServerApplicationTrafficSecret = "s ap traffic" + labelExporterSecret = "exp master" + labelResumptionSecret = "res master" + labelDerived = "derived" + labelFinished = "finished" + labelResumption = "resumption" +) + +// struct HkdfLabel { +// uint16 length; +// opaque label<9..255>; +// opaque hash_value<0..255>; +// }; +func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte { + label := "tls13 " + labelIn + + labelLen := len(label) + hashLen := len(hashValue) + hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen) + hkdfLabel[0] = byte(outLen >> 8) + hkdfLabel[1] = byte(outLen) + hkdfLabel[2] = byte(labelLen) + copy(hkdfLabel[3:3+labelLen], []byte(label)) + hkdfLabel[3+labelLen] = byte(hashLen) + copy(hkdfLabel[3+labelLen+1:], hashValue) + + return hkdfLabel +} + +func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte { + out := []byte{} + T := []byte{} + i := byte(1) + for len(out) < outLen { + block := append(T, info...) + block = append(block, i) + + h := hmac.New(hash.New, prk) + h.Write(block) + + T = h.Sum(nil) + out = append(out, T...) + i++ + } + return out[:outLen] +} + +func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte { + info := hkdfEncodeLabel(label, hashValue, outLen) + derived := HkdfExpand(hash, secret, info, outLen) + + logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen) + logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret) + logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue) + logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info) + logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived) + + return derived +} + +func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte { + return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size()) +} + +func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte { + macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size()) + mac := hmac.New(params.Hash.New, macKey) + mac.Write(input) + return mac.Sum(nil) +} + +type keySet struct { + cipher aeadFactory + key []byte + iv []byte +} + +func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet { + logf(logTypeCrypto, "making traffic keys: secret=%x", secret) + return keySet{ + cipher: params.Cipher, + key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen), + iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen), + } +} diff --git a/vendor/github.com/bifurcation/mint/extensions.go b/vendor/github.com/bifurcation/mint/extensions.go new file mode 100644 index 000000000..1dbe7bd2f --- /dev/null +++ b/vendor/github.com/bifurcation/mint/extensions.go @@ -0,0 +1,586 @@ +package mint + +import ( + "bytes" + "fmt" + + "github.com/bifurcation/mint/syntax" +) + +type ExtensionBody interface { + Type() ExtensionType + Marshal() ([]byte, error) + Unmarshal(data []byte) (int, error) +} + +// struct { +// ExtensionType extension_type; +// opaque extension_data<0..2^16-1>; +// } Extension; +type Extension struct { + ExtensionType ExtensionType + ExtensionData []byte `tls:"head=2"` +} + +func (ext Extension) Marshal() ([]byte, error) { + return syntax.Marshal(ext) +} + +func (ext *Extension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, ext) +} + +type ExtensionList []Extension + +type extensionListInner struct { + List []Extension `tls:"head=2"` +} + +func (el ExtensionList) Marshal() ([]byte, error) { + return syntax.Marshal(extensionListInner{el}) +} + +func (el *ExtensionList) Unmarshal(data []byte) (int, error) { + var list extensionListInner + read, err := syntax.Unmarshal(data, &list) + if err != nil { + return 0, err + } + + *el = list.List + return read, nil +} + +func (el *ExtensionList) Add(src ExtensionBody) error { + data, err := src.Marshal() + if err != nil { + return err + } + + if el == nil { + el = new(ExtensionList) + } + + // If one already exists with this type, replace it + for i := range *el { + if (*el)[i].ExtensionType == src.Type() { + (*el)[i].ExtensionData = data + return nil + } + } + + // Otherwise append + *el = append(*el, Extension{ + ExtensionType: src.Type(), + ExtensionData: data, + }) + return nil +} + +func (el ExtensionList) Find(dst ExtensionBody) bool { + for _, ext := range el { + if ext.ExtensionType == dst.Type() { + _, err := dst.Unmarshal(ext.ExtensionData) + return err == nil + } + } + return false +} + +// struct { +// NameType name_type; +// select (name_type) { +// case host_name: HostName; +// } name; +// } ServerName; +// +// enum { +// host_name(0), (255) +// } NameType; +// +// opaque HostName<1..2^16-1>; +// +// struct { +// ServerName server_name_list<1..2^16-1> +// } ServerNameList; +// +// But we only care about the case where there's a single DNS hostname. We +// will never create anything else, and throw if we receive something else +// +// 2 1 2 +// | listLen | NameType | nameLen | name | +type ServerNameExtension string + +type serverNameInner struct { + NameType uint8 + HostName []byte `tls:"head=2,min=1"` +} + +type serverNameListInner struct { + ServerNameList []serverNameInner `tls:"head=2,min=1"` +} + +func (sni ServerNameExtension) Type() ExtensionType { + return ExtensionTypeServerName +} + +func (sni ServerNameExtension) Marshal() ([]byte, error) { + list := serverNameListInner{ + ServerNameList: []serverNameInner{{ + NameType: 0x00, // host_name + HostName: []byte(sni), + }}, + } + + return syntax.Marshal(list) +} + +func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) { + var list serverNameListInner + read, err := syntax.Unmarshal(data, &list) + if err != nil { + return 0, err + } + + // Syntax requires at least one entry + // Entries beyond the first are ignored + if nameType := list.ServerNameList[0].NameType; nameType != 0x00 { + return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType) + } + + *sni = ServerNameExtension(list.ServerNameList[0].HostName) + return read, nil +} + +// struct { +// NamedGroup group; +// opaque key_exchange<1..2^16-1>; +// } KeyShareEntry; +// +// struct { +// select (Handshake.msg_type) { +// case client_hello: +// KeyShareEntry client_shares<0..2^16-1>; +// +// case hello_retry_request: +// NamedGroup selected_group; +// +// case server_hello: +// KeyShareEntry server_share; +// }; +// } KeyShare; +type KeyShareEntry struct { + Group NamedGroup + KeyExchange []byte `tls:"head=2,min=1"` +} + +func (kse KeyShareEntry) SizeValid() bool { + return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group) +} + +type KeyShareExtension struct { + HandshakeType HandshakeType + SelectedGroup NamedGroup + Shares []KeyShareEntry +} + +type KeyShareClientHelloInner struct { + ClientShares []KeyShareEntry `tls:"head=2,min=0"` +} +type KeyShareHelloRetryInner struct { + SelectedGroup NamedGroup +} +type KeyShareServerHelloInner struct { + ServerShare KeyShareEntry +} + +func (ks KeyShareExtension) Type() ExtensionType { + return ExtensionTypeKeyShare +} + +func (ks KeyShareExtension) Marshal() ([]byte, error) { + switch ks.HandshakeType { + case HandshakeTypeClientHello: + for _, share := range ks.Shares { + if !share.SizeValid() { + return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") + } + } + return syntax.Marshal(KeyShareClientHelloInner{ks.Shares}) + + case HandshakeTypeHelloRetryRequest: + if len(ks.Shares) > 0 { + return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest") + } + + return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup}) + + case HandshakeTypeServerHello: + if len(ks.Shares) != 1 { + return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share") + } + + if !ks.Shares[0].SizeValid() { + return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") + } + + return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]}) + + default: + return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed") + } +} + +func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) { + switch ks.HandshakeType { + case HandshakeTypeClientHello: + var inner KeyShareClientHelloInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + for _, share := range inner.ClientShares { + if !share.SizeValid() { + return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") + } + } + + ks.Shares = inner.ClientShares + return read, nil + + case HandshakeTypeHelloRetryRequest: + var inner KeyShareHelloRetryInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + ks.SelectedGroup = inner.SelectedGroup + return read, nil + + case HandshakeTypeServerHello: + var inner KeyShareServerHelloInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + if !inner.ServerShare.SizeValid() { + return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") + } + + ks.Shares = []KeyShareEntry{inner.ServerShare} + return read, nil + + default: + return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed") + } +} + +// struct { +// NamedGroup named_group_list<2..2^16-1>; +// } NamedGroupList; +type SupportedGroupsExtension struct { + Groups []NamedGroup `tls:"head=2,min=2"` +} + +func (sg SupportedGroupsExtension) Type() ExtensionType { + return ExtensionTypeSupportedGroups +} + +func (sg SupportedGroupsExtension) Marshal() ([]byte, error) { + return syntax.Marshal(sg) +} + +func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, sg) +} + +// struct { +// SignatureScheme supported_signature_algorithms<2..2^16-2>; +// } SignatureSchemeList +type SignatureAlgorithmsExtension struct { + Algorithms []SignatureScheme `tls:"head=2,min=2"` +} + +func (sa SignatureAlgorithmsExtension) Type() ExtensionType { + return ExtensionTypeSignatureAlgorithms +} + +func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) { + return syntax.Marshal(sa) +} + +func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, sa) +} + +// struct { +// opaque identity<1..2^16-1>; +// uint32 obfuscated_ticket_age; +// } PskIdentity; +// +// opaque PskBinderEntry<32..255>; +// +// struct { +// select (Handshake.msg_type) { +// case client_hello: +// PskIdentity identities<7..2^16-1>; +// PskBinderEntry binders<33..2^16-1>; +// +// case server_hello: +// uint16 selected_identity; +// }; +// +// } PreSharedKeyExtension; +type PSKIdentity struct { + Identity []byte `tls:"head=2,min=1"` + ObfuscatedTicketAge uint32 +} + +type PSKBinderEntry struct { + Binder []byte `tls:"head=1,min=32"` +} + +type PreSharedKeyExtension struct { + HandshakeType HandshakeType + Identities []PSKIdentity + Binders []PSKBinderEntry + SelectedIdentity uint16 +} + +type preSharedKeyClientInner struct { + Identities []PSKIdentity `tls:"head=2,min=7"` + Binders []PSKBinderEntry `tls:"head=2,min=33"` +} + +type preSharedKeyServerInner struct { + SelectedIdentity uint16 +} + +func (psk PreSharedKeyExtension) Type() ExtensionType { + return ExtensionTypePreSharedKey +} + +func (psk PreSharedKeyExtension) Marshal() ([]byte, error) { + switch psk.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Marshal(preSharedKeyClientInner{ + Identities: psk.Identities, + Binders: psk.Binders, + }) + + case HandshakeTypeServerHello: + if len(psk.Identities) > 0 || len(psk.Binders) > 0 { + return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index") + } + return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity}) + + default: + return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported") + } +} + +func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) { + switch psk.HandshakeType { + case HandshakeTypeClientHello: + var inner preSharedKeyClientInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + if len(inner.Identities) != len(inner.Binders) { + return 0, fmt.Errorf("Lengths of identities and binders not equal") + } + + psk.Identities = inner.Identities + psk.Binders = inner.Binders + return read, nil + + case HandshakeTypeServerHello: + var inner preSharedKeyServerInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + psk.SelectedIdentity = inner.SelectedIdentity + return read, nil + + default: + return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported") + } +} + +func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) { + for i, localID := range psk.Identities { + if bytes.Equal(localID.Identity, id) { + return psk.Binders[i].Binder, true + } + } + return nil, false +} + +// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode; +// +// struct { +// PskKeyExchangeMode ke_modes<1..255>; +// } PskKeyExchangeModes; +type PSKKeyExchangeModesExtension struct { + KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"` +} + +func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType { + return ExtensionTypePSKKeyExchangeModes +} + +func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) { + return syntax.Marshal(pkem) +} + +func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, pkem) +} + +// struct { +// } EarlyDataIndication; + +type EarlyDataExtension struct{} + +func (ed EarlyDataExtension) Type() ExtensionType { + return ExtensionTypeEarlyData +} + +func (ed EarlyDataExtension) Marshal() ([]byte, error) { + return []byte{}, nil +} + +func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) { + return 0, nil +} + +// struct { +// uint32 max_early_data_size; +// } TicketEarlyDataInfo; + +type TicketEarlyDataInfoExtension struct { + MaxEarlyDataSize uint32 +} + +func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType { + return ExtensionTypeTicketEarlyDataInfo +} + +func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) { + return syntax.Marshal(tedi) +} + +func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, tedi) +} + +// opaque ProtocolName<1..2^8-1>; +// +// struct { +// ProtocolName protocol_name_list<2..2^16-1> +// } ProtocolNameList; +type ALPNExtension struct { + Protocols []string +} + +type protocolNameInner struct { + Name []byte `tls:"head=1,min=1"` +} + +type alpnExtensionInner struct { + Protocols []protocolNameInner `tls:"head=2,min=2"` +} + +func (alpn ALPNExtension) Type() ExtensionType { + return ExtensionTypeALPN +} + +func (alpn ALPNExtension) Marshal() ([]byte, error) { + protocols := make([]protocolNameInner, len(alpn.Protocols)) + for i, protocol := range alpn.Protocols { + protocols[i] = protocolNameInner{[]byte(protocol)} + } + return syntax.Marshal(alpnExtensionInner{protocols}) +} + +func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) { + var inner alpnExtensionInner + read, err := syntax.Unmarshal(data, &inner) + + if err != nil { + return 0, err + } + + alpn.Protocols = make([]string, len(inner.Protocols)) + for i, protocol := range inner.Protocols { + alpn.Protocols[i] = string(protocol.Name) + } + return read, nil +} + +// struct { +// ProtocolVersion versions<2..254>; +// } SupportedVersions; +type SupportedVersionsExtension struct { + Versions []uint16 `tls:"head=1,min=2,max=254"` +} + +func (sv SupportedVersionsExtension) Type() ExtensionType { + return ExtensionTypeSupportedVersions +} + +func (sv SupportedVersionsExtension) Marshal() ([]byte, error) { + return syntax.Marshal(sv) +} + +func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, sv) +} + +// struct { +// opaque cookie<1..2^16-1>; +// } Cookie; +type CookieExtension struct { + Cookie []byte `tls:"head=2,min=1"` +} + +func (c CookieExtension) Type() ExtensionType { + return ExtensionTypeCookie +} + +func (c CookieExtension) Marshal() ([]byte, error) { + return syntax.Marshal(c) +} + +func (c *CookieExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, c) +} + +// defaultCookieLength is the default length of a cookie +const defaultCookieLength = 32 + +type defaultCookieHandler struct { + data []byte +} + +var _ CookieHandler = &defaultCookieHandler{} + +// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data +func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) { + h.data = make([]byte, defaultCookieLength) + if _, err := prng.Read(h.data); err != nil { + return nil, err + } + return h.data, nil +} + +func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool { + return bytes.Equal(h.data, data) +} diff --git a/vendor/github.com/bifurcation/mint/ffdhe.go b/vendor/github.com/bifurcation/mint/ffdhe.go new file mode 100644 index 000000000..59d1f7f9d --- /dev/null +++ b/vendor/github.com/bifurcation/mint/ffdhe.go @@ -0,0 +1,147 @@ +package mint + +import ( + "encoding/hex" + "math/big" +) + +var ( + finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B423861285C97FFFFFFFFFFFFFFFF" + finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex) + finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes) + + finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + + "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + + "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + + "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + + "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + + "3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF" + finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex) + finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes) + + finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + + "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + + "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + + "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + + "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + + "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + + "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + + "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + + "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + + "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + + "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" + + "FFFFFFFFFFFFFFFF" + finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex) + finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes) + + finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + + "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + + "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + + "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + + "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + + "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + + "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + + "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + + "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + + "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + + "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" + + "0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" + + "3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" + + "CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" + + "A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" + + "0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" + + "763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" + + "B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" + + "D72B03746AE77F5E62292C311562A846505DC82DB854338A" + + "E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" + + "5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" + + "A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF" + finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex) + finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes) + + finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + + "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + + "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + + "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + + "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + + "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + + "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + + "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + + "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + + "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + + "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" + + "0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" + + "3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" + + "CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" + + "A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" + + "0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" + + "763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" + + "B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" + + "D72B03746AE77F5E62292C311562A846505DC82DB854338A" + + "E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" + + "5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" + + "A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" + + "1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" + + "0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" + + "CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" + + "2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" + + "BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" + + "51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" + + "D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" + + "1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" + + "FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" + + "97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" + + "D68C8BB7C5C6424CFFFFFFFFFFFFFFFF" + finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex) + finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes) +) diff --git a/vendor/github.com/bifurcation/mint/frame-reader.go b/vendor/github.com/bifurcation/mint/frame-reader.go new file mode 100644 index 000000000..99ea470df --- /dev/null +++ b/vendor/github.com/bifurcation/mint/frame-reader.go @@ -0,0 +1,98 @@ +// Read a generic "framed" packet consisting of a header and a +// This is used for both TLS Records and TLS Handshake Messages +package mint + +type framing interface { + headerLen() int + defaultReadLen() int + frameLen(hdr []byte) (int, error) +} + +const ( + kFrameReaderHdr = 0 + kFrameReaderBody = 1 +) + +type frameNextAction func(f *frameReader) error + +type frameReader struct { + details framing + state uint8 + header []byte + body []byte + working []byte + writeOffset int + remainder []byte +} + +func newFrameReader(d framing) *frameReader { + hdr := make([]byte, d.headerLen()) + return &frameReader{ + d, + kFrameReaderHdr, + hdr, + nil, + hdr, + 0, + nil, + } +} + +func dup(a []byte) []byte { + r := make([]byte, len(a)) + copy(r, a) + return r +} + +func (f *frameReader) needed() int { + tmp := (len(f.working) - f.writeOffset) - len(f.remainder) + if tmp < 0 { + return 0 + } + return tmp +} + +func (f *frameReader) addChunk(in []byte) { + // Append to the buffer. + logf(logTypeFrameReader, "Appending %v", len(in)) + f.remainder = append(f.remainder, in...) +} + +func (f *frameReader) process() (hdr []byte, body []byte, err error) { + for f.needed() == 0 { + logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset) + // Fill out our working block + copied := copy(f.working[f.writeOffset:], f.remainder) + f.remainder = f.remainder[copied:] + f.writeOffset += copied + if f.writeOffset < len(f.working) { + logf(logTypeFrameReader, "Read would have blocked 1") + return nil, nil, WouldBlock + } + // Reset the write offset, because we are now full. + f.writeOffset = 0 + + // We have read a full frame + if f.state == kFrameReaderBody { + logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder)) + f.state = kFrameReaderHdr + f.working = f.header + return dup(f.header), dup(f.body), nil + } + + // We have read the header + bodyLen, err := f.details.frameLen(f.header) + if err != nil { + return nil, nil, err + } + logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen) + + f.body = make([]byte, bodyLen) + f.working = f.body + f.writeOffset = 0 + f.state = kFrameReaderBody + } + + logf(logTypeFrameReader, "Read would have blocked 2") + return nil, nil, WouldBlock +} diff --git a/vendor/github.com/bifurcation/mint/handshake-layer.go b/vendor/github.com/bifurcation/mint/handshake-layer.go new file mode 100644 index 000000000..2b04ac5cb --- /dev/null +++ b/vendor/github.com/bifurcation/mint/handshake-layer.go @@ -0,0 +1,253 @@ +package mint + +import ( + "fmt" + "io" + "net" +) + +const ( + handshakeHeaderLen = 4 // handshake message header length + maxHandshakeMessageLen = 1 << 24 // max handshake message length +) + +// struct { +// HandshakeType msg_type; /* handshake type */ +// uint24 length; /* bytes in message */ +// select (HandshakeType) { +// ... +// } body; +// } Handshake; +// +// We do the select{...} part in a different layer, so we treat the +// actual message body as opaque: +// +// struct { +// HandshakeType msg_type; +// opaque msg<0..2^24-1> +// } Handshake; +// +// TODO: File a spec bug +type HandshakeMessage struct { + // Omitted: length + msgType HandshakeType + body []byte +} + +// Note: This could be done with the `syntax` module, using the simplified +// syntax as discussed above. However, since this is so simple, there's not +// much benefit to doing so. +func (hm *HandshakeMessage) Marshal() []byte { + if hm == nil { + return []byte{} + } + + msgLen := len(hm.body) + data := make([]byte, 4+len(hm.body)) + data[0] = byte(hm.msgType) + data[1] = byte(msgLen >> 16) + data[2] = byte(msgLen >> 8) + data[3] = byte(msgLen) + copy(data[4:], hm.body) + return data +} + +func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) { + logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body) + + var body HandshakeMessageBody + switch hm.msgType { + case HandshakeTypeClientHello: + body = new(ClientHelloBody) + case HandshakeTypeServerHello: + body = new(ServerHelloBody) + case HandshakeTypeHelloRetryRequest: + body = new(HelloRetryRequestBody) + case HandshakeTypeEncryptedExtensions: + body = new(EncryptedExtensionsBody) + case HandshakeTypeCertificate: + body = new(CertificateBody) + case HandshakeTypeCertificateRequest: + body = new(CertificateRequestBody) + case HandshakeTypeCertificateVerify: + body = new(CertificateVerifyBody) + case HandshakeTypeFinished: + body = &FinishedBody{VerifyDataLen: len(hm.body)} + case HandshakeTypeNewSessionTicket: + body = new(NewSessionTicketBody) + case HandshakeTypeKeyUpdate: + body = new(KeyUpdateBody) + case HandshakeTypeEndOfEarlyData: + body = new(EndOfEarlyDataBody) + default: + return body, fmt.Errorf("tls.handshakemessage: Unsupported body type") + } + + _, err := body.Unmarshal(hm.body) + return body, err +} + +func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) { + data, err := body.Marshal() + if err != nil { + return nil, err + } + + return &HandshakeMessage{ + msgType: body.Type(), + body: data, + }, nil +} + +type HandshakeLayer struct { + nonblocking bool // Should we operate in nonblocking mode + conn *RecordLayer // Used for reading/writing records + frame *frameReader // The buffered frame reader +} + +type handshakeLayerFrameDetails struct{} + +func (d handshakeLayerFrameDetails) headerLen() int { + return handshakeHeaderLen +} + +func (d handshakeLayerFrameDetails) defaultReadLen() int { + return handshakeHeaderLen + maxFragmentLen +} + +func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { + logf(logTypeIO, "Header=%x", hdr) + return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil +} + +func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer { + h := HandshakeLayer{} + h.conn = r + h.frame = newFrameReader(&handshakeLayerFrameDetails{}) + return &h +} + +func (h *HandshakeLayer) readRecord() error { + logf(logTypeIO, "Trying to read record") + pt, err := h.conn.ReadRecord() + if err != nil { + return err + } + + if pt.contentType != RecordTypeHandshake && + pt.contentType != RecordTypeAlert { + return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType) + } + + if pt.contentType == RecordTypeAlert { + logf(logTypeIO, "read alert %v", pt.fragment[1]) + if len(pt.fragment) < 2 { + h.sendAlert(AlertUnexpectedMessage) + return io.EOF + } + return Alert(pt.fragment[1]) + } + + logf(logTypeIO, "read handshake record of len %v", len(pt.fragment)) + h.frame.addChunk(pt.fragment) + + return nil +} + +// sendAlert sends a TLS alert message. +func (h *HandshakeLayer) sendAlert(err Alert) error { + tmp := make([]byte, 2) + tmp[0] = AlertLevelError + tmp[1] = byte(err) + h.conn.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeAlert, + fragment: tmp}, + ) + + // closeNotify is a special case in that it isn't an error: + if err != AlertCloseNotify { + return &net.OpError{Op: "local error", Err: err} + } + return nil +} + +func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { + var hdr, body []byte + var err error + + for { + logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder)) + if h.frame.needed() > 0 { + logf(logTypeHandshake, "Trying to read a new record") + err = h.readRecord() + } + if err != nil && (h.nonblocking || err != WouldBlock) { + return nil, err + } + + hdr, body, err = h.frame.process() + if err == nil { + break + } + if err != nil && (h.nonblocking || err != WouldBlock) { + return nil, err + } + } + + logf(logTypeHandshake, "read handshake message") + + hm := &HandshakeMessage{} + hm.msgType = HandshakeType(hdr[0]) + + hm.body = make([]byte, len(body)) + copy(hm.body, body) + + return hm, nil +} + +func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error { + return h.WriteMessages([]*HandshakeMessage{hm}) +} + +func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error { + for _, hm := range hms { + logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body) + } + + // Write out headers and bodies + buffer := []byte{} + for _, msg := range hms { + msgLen := len(msg.body) + if msgLen > maxHandshakeMessageLen { + return fmt.Errorf("tls.handshakelayer: Message too large to send") + } + + buffer = append(buffer, msg.Marshal()...) + } + + // Send full-size fragments + var start int + for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen { + err := h.conn.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeHandshake, + fragment: buffer[start : start+maxFragmentLen], + }) + + if err != nil { + return err + } + } + + // Send a final partial fragment if necessary + if start < len(buffer) { + err := h.conn.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeHandshake, + fragment: buffer[start:], + }) + + if err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/bifurcation/mint/handshake-messages.go b/vendor/github.com/bifurcation/mint/handshake-messages.go new file mode 100644 index 000000000..339bbcd09 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/handshake-messages.go @@ -0,0 +1,450 @@ +package mint + +import ( + "bytes" + "crypto" + "crypto/x509" + "encoding/binary" + "fmt" + + "github.com/bifurcation/mint/syntax" +) + +type HandshakeMessageBody interface { + Type() HandshakeType + Marshal() ([]byte, error) + Unmarshal(data []byte) (int, error) +} + +// struct { +// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */ +// Random random; +// opaque legacy_session_id<0..32>; +// CipherSuite cipher_suites<2..2^16-2>; +// opaque legacy_compression_methods<1..2^8-1>; +// Extension extensions<0..2^16-1>; +// } ClientHello; +type ClientHelloBody struct { + // Omitted: clientVersion + // Omitted: legacySessionID + // Omitted: legacyCompressionMethods + Random [32]byte + CipherSuites []CipherSuite + Extensions ExtensionList +} + +type clientHelloBodyInner struct { + LegacyVersion uint16 + Random [32]byte + LegacySessionID []byte `tls:"head=1,max=32"` + CipherSuites []CipherSuite `tls:"head=2,min=2"` + LegacyCompressionMethods []byte `tls:"head=1,min=1"` + Extensions []Extension `tls:"head=2"` +} + +func (ch ClientHelloBody) Type() HandshakeType { + return HandshakeTypeClientHello +} + +func (ch ClientHelloBody) Marshal() ([]byte, error) { + return syntax.Marshal(clientHelloBodyInner{ + LegacyVersion: 0x0303, + Random: ch.Random, + LegacySessionID: []byte{}, + CipherSuites: ch.CipherSuites, + LegacyCompressionMethods: []byte{0}, + Extensions: ch.Extensions, + }) +} + +func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) { + var inner clientHelloBodyInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + // We are strict about these things because we only support 1.3 + if inner.LegacyVersion != 0x0303 { + return 0, fmt.Errorf("tls.clienthello: Incorrect version number") + } + + if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 { + return 0, fmt.Errorf("tls.clienthello: Invalid compression method") + } + + ch.Random = inner.Random + ch.CipherSuites = inner.CipherSuites + ch.Extensions = inner.Extensions + return read, nil +} + +// TODO: File a spec bug to clarify this +func (ch ClientHelloBody) Truncated() ([]byte, error) { + if len(ch.Extensions) == 0 { + return nil, fmt.Errorf("tls.clienthello.truncate: No extensions") + } + + pskExt := ch.Extensions[len(ch.Extensions)-1] + if pskExt.ExtensionType != ExtensionTypePreSharedKey { + return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK") + } + + chm, err := HandshakeMessageFromBody(&ch) + if err != nil { + return nil, err + } + chData := chm.Marshal() + + psk := PreSharedKeyExtension{ + HandshakeType: HandshakeTypeClientHello, + } + _, err = psk.Unmarshal(pskExt.ExtensionData) + if err != nil { + return nil, err + } + + // Marshal just the binders so that we know how much to truncate + binders := struct { + Binders []PSKBinderEntry `tls:"head=2,min=33"` + }{Binders: psk.Binders} + binderData, _ := syntax.Marshal(binders) + binderLen := len(binderData) + + chLen := len(chData) + return chData[:chLen-binderLen], nil +} + +// struct { +// ProtocolVersion server_version; +// CipherSuite cipher_suite; +// Extension extensions<2..2^16-1>; +// } HelloRetryRequest; +type HelloRetryRequestBody struct { + Version uint16 + CipherSuite CipherSuite + Extensions ExtensionList `tls:"head=2,min=2"` +} + +func (hrr HelloRetryRequestBody) Type() HandshakeType { + return HandshakeTypeHelloRetryRequest +} + +func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) { + return syntax.Marshal(hrr) +} + +func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, hrr) +} + +// struct { +// ProtocolVersion version; +// Random random; +// CipherSuite cipher_suite; +// Extension extensions<0..2^16-1>; +// } ServerHello; +type ServerHelloBody struct { + Version uint16 + Random [32]byte + CipherSuite CipherSuite + Extensions ExtensionList `tls:"head=2"` +} + +func (sh ServerHelloBody) Type() HandshakeType { + return HandshakeTypeServerHello +} + +func (sh ServerHelloBody) Marshal() ([]byte, error) { + return syntax.Marshal(sh) +} + +func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, sh) +} + +// struct { +// opaque verify_data[verify_data_length]; +// } Finished; +// +// verifyDataLen is not a field in the TLS struct, but we add it here so +// that calling code can tell us how much data to expect when we marshal / +// unmarshal. (We could add this to the marshal/unmarshal methods, but let's +// try to keep the signature consistent for now.) +// +// For similar reasons, we don't use the `syntax` module here, because this +// struct doesn't map well to standard TLS presentation language concepts. +// +// TODO: File a spec bug +type FinishedBody struct { + VerifyDataLen int + VerifyData []byte +} + +func (fin FinishedBody) Type() HandshakeType { + return HandshakeTypeFinished +} + +func (fin FinishedBody) Marshal() ([]byte, error) { + if len(fin.VerifyData) != fin.VerifyDataLen { + return nil, fmt.Errorf("tls.finished: data length mismatch") + } + + body := make([]byte, len(fin.VerifyData)) + copy(body, fin.VerifyData) + return body, nil +} + +func (fin *FinishedBody) Unmarshal(data []byte) (int, error) { + if len(data) < fin.VerifyDataLen { + return 0, fmt.Errorf("tls.finished: Malformed finished; too short") + } + + fin.VerifyData = make([]byte, fin.VerifyDataLen) + copy(fin.VerifyData, data[:fin.VerifyDataLen]) + return fin.VerifyDataLen, nil +} + +// struct { +// Extension extensions<0..2^16-1>; +// } EncryptedExtensions; +// +// Marshal() and Unmarshal() are handled by ExtensionList +type EncryptedExtensionsBody struct { + Extensions ExtensionList `tls:"head=2"` +} + +func (ee EncryptedExtensionsBody) Type() HandshakeType { + return HandshakeTypeEncryptedExtensions +} + +func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) { + return syntax.Marshal(ee) +} + +func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, ee) +} + +// opaque ASN1Cert<1..2^24-1>; +// +// struct { +// ASN1Cert cert_data; +// Extension extensions<0..2^16-1> +// } CertificateEntry; +// +// struct { +// opaque certificate_request_context<0..2^8-1>; +// CertificateEntry certificate_list<0..2^24-1>; +// } Certificate; +type CertificateEntry struct { + CertData *x509.Certificate + Extensions ExtensionList +} + +type CertificateBody struct { + CertificateRequestContext []byte + CertificateList []CertificateEntry +} + +type certificateEntryInner struct { + CertData []byte `tls:"head=3,min=1"` + Extensions ExtensionList `tls:"head=2"` +} + +type certificateBodyInner struct { + CertificateRequestContext []byte `tls:"head=1"` + CertificateList []certificateEntryInner `tls:"head=3"` +} + +func (c CertificateBody) Type() HandshakeType { + return HandshakeTypeCertificate +} + +func (c CertificateBody) Marshal() ([]byte, error) { + inner := certificateBodyInner{ + CertificateRequestContext: c.CertificateRequestContext, + CertificateList: make([]certificateEntryInner, len(c.CertificateList)), + } + + for i, entry := range c.CertificateList { + inner.CertificateList[i] = certificateEntryInner{ + CertData: entry.CertData.Raw, + Extensions: entry.Extensions, + } + } + + return syntax.Marshal(inner) +} + +func (c *CertificateBody) Unmarshal(data []byte) (int, error) { + inner := certificateBodyInner{} + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return read, err + } + + c.CertificateRequestContext = inner.CertificateRequestContext + c.CertificateList = make([]CertificateEntry, len(inner.CertificateList)) + + for i, entry := range inner.CertificateList { + c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData) + if err != nil { + return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err) + } + + c.CertificateList[i].Extensions = entry.Extensions + } + + return read, nil +} + +// struct { +// SignatureScheme algorithm; +// opaque signature<0..2^16-1>; +// } CertificateVerify; +type CertificateVerifyBody struct { + Algorithm SignatureScheme + Signature []byte `tls:"head=2"` +} + +func (cv CertificateVerifyBody) Type() HandshakeType { + return HandshakeTypeCertificateVerify +} + +func (cv CertificateVerifyBody) Marshal() ([]byte, error) { + return syntax.Marshal(cv) +} + +func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, cv) +} + +func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte { + // TODO: Change context for client auth + // TODO: Put this in a const + const context = "TLS 1.3, server CertificateVerify" + sigInput := bytes.Repeat([]byte{0x20}, 64) + sigInput = append(sigInput, []byte(context)...) + sigInput = append(sigInput, []byte{0}...) + sigInput = append(sigInput, data...) + return sigInput +} + +func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) { + sigInput := cv.EncodeSignatureInput(handshakeHash) + cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput) + logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature) + return +} + +func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error { + sigInput := cv.EncodeSignatureInput(handshakeHash) + logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature) + return verify(cv.Algorithm, publicKey, sigInput, cv.Signature) +} + +// struct { +// opaque certificate_request_context<0..2^8-1>; +// Extension extensions<2..2^16-1>; +// } CertificateRequest; +type CertificateRequestBody struct { + CertificateRequestContext []byte `tls:"head=1"` + Extensions ExtensionList `tls:"head=2"` +} + +func (cr CertificateRequestBody) Type() HandshakeType { + return HandshakeTypeCertificateRequest +} + +func (cr CertificateRequestBody) Marshal() ([]byte, error) { + return syntax.Marshal(cr) +} + +func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, cr) +} + +// struct { +// uint32 ticket_lifetime; +// uint32 ticket_age_add; +// opaque ticket_nonce<1..255>; +// opaque ticket<1..2^16-1>; +// Extension extensions<0..2^16-2>; +// } NewSessionTicket; +type NewSessionTicketBody struct { + TicketLifetime uint32 + TicketAgeAdd uint32 + TicketNonce []byte `tls:"head=1,min=1"` + Ticket []byte `tls:"head=2,min=1"` + Extensions ExtensionList `tls:"head=2"` +} + +const ticketNonceLen = 16 + +func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) { + buf := make([]byte, 4+ticketNonceLen+ticketLen) + _, err := prng.Read(buf) + if err != nil { + return nil, err + } + + tkt := &NewSessionTicketBody{ + TicketLifetime: ticketLifetime, + TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]), + TicketNonce: buf[4 : 4+ticketNonceLen], + Ticket: buf[4+ticketNonceLen:], + } + + return tkt, err +} + +func (tkt NewSessionTicketBody) Type() HandshakeType { + return HandshakeTypeNewSessionTicket +} + +func (tkt NewSessionTicketBody) Marshal() ([]byte, error) { + return syntax.Marshal(tkt) +} + +func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, tkt) +} + +// enum { +// update_not_requested(0), update_requested(1), (255) +// } KeyUpdateRequest; +// +// struct { +// KeyUpdateRequest request_update; +// } KeyUpdate; +type KeyUpdateBody struct { + KeyUpdateRequest KeyUpdateRequest +} + +func (ku KeyUpdateBody) Type() HandshakeType { + return HandshakeTypeKeyUpdate +} + +func (ku KeyUpdateBody) Marshal() ([]byte, error) { + return syntax.Marshal(ku) +} + +func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, ku) +} + +// struct {} EndOfEarlyData; +type EndOfEarlyDataBody struct{} + +func (eoed EndOfEarlyDataBody) Type() HandshakeType { + return HandshakeTypeEndOfEarlyData +} + +func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) { + return []byte{}, nil +} + +func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) { + return 0, nil +} diff --git a/vendor/github.com/bifurcation/mint/log.go b/vendor/github.com/bifurcation/mint/log.go new file mode 100644 index 000000000..2fba90de7 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/log.go @@ -0,0 +1,55 @@ +package mint + +import ( + "fmt" + "log" + "os" + "strings" +) + +// We use this environment variable to control logging. It should be a +// comma-separated list of log tags (see below) or "*" to enable all logging. +const logConfigVar = "MINT_LOG" + +// Pre-defined log types +const ( + logTypeCrypto = "crypto" + logTypeHandshake = "handshake" + logTypeNegotiation = "negotiation" + logTypeIO = "io" + logTypeFrameReader = "frame" + logTypeVerbose = "verbose" +) + +var ( + logFunction = log.Printf + logAll = false + logSettings = map[string]bool{} +) + +func init() { + parseLogEnv(os.Environ()) +} + +func parseLogEnv(env []string) { + for _, stmt := range env { + if strings.HasPrefix(stmt, logConfigVar+"=") { + val := stmt[len(logConfigVar)+1:] + + if val == "*" { + logAll = true + } else { + for _, t := range strings.Split(val, ",") { + logSettings[t] = true + } + } + } + } +} + +func logf(tag string, format string, args ...interface{}) { + if logAll || logSettings[tag] { + fullFormat := fmt.Sprintf("[%s] %s", tag, format) + logFunction(fullFormat, args...) + } +} diff --git a/vendor/github.com/bifurcation/mint/negotiation.go b/vendor/github.com/bifurcation/mint/negotiation.go new file mode 100644 index 000000000..f4ead72e8 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/negotiation.go @@ -0,0 +1,217 @@ +package mint + +import ( + "bytes" + "encoding/hex" + "fmt" + "time" +) + +func VersionNegotiation(offered, supported []uint16) (bool, uint16) { + for _, offeredVersion := range offered { + for _, supportedVersion := range supported { + logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion) + if offeredVersion == supportedVersion { + // XXX: Should probably be highest supported version, but for now, we + // only support one version, so it doesn't really matter. + return true, offeredVersion + } + } + } + + return false, 0 +} + +func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) { + for _, share := range keyShares { + for _, group := range groups { + if group != share.Group { + continue + } + + pub, priv, err := newKeyShare(share.Group) + if err != nil { + // If we encounter an error, just keep looking + continue + } + + dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv) + if err != nil { + // If we encounter an error, just keep looking + continue + } + + return true, group, pub, dhSecret + } + } + + return false, 0, nil, nil +} + +const ( + ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds +) + +func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) { + logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size()) + for i, id := range identities { + identityHex := hex.EncodeToString(id.Identity) + + psk, ok := psks.Get(identityHex) + if !ok { + logf(logTypeNegotiation, "No PSK for identity %x", identityHex) + continue + } + + // For resumption, make sure the ticket age is correct + if psk.IsResumption { + extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd + knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond) + ticketAgeDelta := knownTicketAge - extTicketAge + if knownTicketAge < extTicketAge { + ticketAgeDelta = extTicketAge - knownTicketAge + } + if ticketAgeDelta > ticketAgeTolerance { + logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity) + logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]", + extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance) + return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity) + } + } + + params, ok := cipherSuiteMap[psk.CipherSuite] + if !ok { + err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite) + return false, 0, nil, CipherSuiteParams{}, err + } + + // Compute binder + binderLabel := labelExternalBinder + if psk.IsResumption { + binderLabel = labelResumptionBinder + } + + h0 := params.Hash.New().Sum(nil) + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + earlySecret := HkdfExtract(params.Hash, zero, psk.Key) + binderKey := deriveSecret(params, earlySecret, binderLabel, h0) + + // context = ClientHello[truncated] + // context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated] + ctxHash := params.Hash.New() + ctxHash.Write(context) + + binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil)) + if !bytes.Equal(binder, binders[i].Binder) { + logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder) + return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity) + } + + logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity) + return true, i, &psk, params, nil + } + + logf(logTypeNegotiation, "Failed to find a usable PSK") + return false, 0, nil, CipherSuiteParams{}, nil +} + +func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) { + logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes) + dhAllowed := false + dhRequired := true + for _, mode := range modes { + dhAllowed = dhAllowed || (mode == PSKModeDHEKE) + dhRequired = dhRequired && (mode == PSKModeDHEKE) + } + + // Use PSK if we can meet DH requirement and modes were provided + usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0) + + // Use DH if allowed + usingDH := canDoDH && (dhAllowed || !usingPSK) + + logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK) + return usingDH, usingPSK +} + +func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) { + // Select for server name if provided + candidates := certs + if serverName != nil { + candidatesByName := []*Certificate{} + for _, cert := range certs { + for _, name := range cert.Chain[0].DNSNames { + if len(*serverName) > 0 && name == *serverName { + candidatesByName = append(candidatesByName, cert) + } + } + } + + if len(candidatesByName) == 0 { + return nil, 0, fmt.Errorf("No certificates available for server name") + } + + candidates = candidatesByName + } + + // Select for signature scheme + for _, cert := range candidates { + for _, scheme := range signatureSchemes { + if !schemeValidForKey(scheme, cert.PrivateKey) { + continue + } + + return cert, scheme, nil + } + } + + return nil, 0, fmt.Errorf("No certificates compatible with signature schemes") +} + +func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool { + usingEarlyData := gotEarlyData && usingPSK && allowEarlyData + logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData) + return usingEarlyData +} + +func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) { + for _, s1 := range offered { + if psk != nil { + if s1 == psk.CipherSuite { + return s1, nil + } + continue + } + + for _, s2 := range supported { + if s1 == s2 { + return s1, nil + } + } + } + + return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil) +} + +func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) { + for _, p1 := range offered { + if psk != nil { + if p1 != psk.NextProto { + continue + } + } + + for _, p2 := range supported { + if p1 == p2 { + return p1, nil + } + } + } + + // If the client offers ALPN on resumption, it must match the earlier one + var err error + if psk != nil && psk.IsResumption && (len(offered) > 0) { + err = fmt.Errorf("ALPN for PSK not provided") + } + return "", err +} diff --git a/vendor/github.com/bifurcation/mint/record-layer.go b/vendor/github.com/bifurcation/mint/record-layer.go new file mode 100644 index 000000000..bcef61369 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/record-layer.go @@ -0,0 +1,296 @@ +package mint + +import ( + "bytes" + "crypto/cipher" + "fmt" + "io" + "sync" +) + +const ( + sequenceNumberLen = 8 // sequence number length + recordHeaderLen = 5 // record header length + maxFragmentLen = 1 << 14 // max number of bytes in a record +) + +type DecryptError string + +func (err DecryptError) Error() string { + return string(err) +} + +// struct { +// ContentType type; +// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */ +// uint16 length; +// opaque fragment[TLSPlaintext.length]; +// } TLSPlaintext; +type TLSPlaintext struct { + // Omitted: record_version (static) + // Omitted: length (computed from fragment) + contentType RecordType + fragment []byte +} + +type RecordLayer struct { + sync.Mutex + + conn io.ReadWriter // The underlying connection + frame *frameReader // The buffered frame reader + nextData []byte // The next record to send + cachedRecord *TLSPlaintext // Last record read, cached to enable "peek" + cachedError error // Error on the last record read + + ivLength int // Length of the seq and nonce fields + seq []byte // Zero-padded sequence number + nonce []byte // Buffer for per-record nonces + cipher cipher.AEAD // AEAD cipher +} + +type recordLayerFrameDetails struct{} + +func (d recordLayerFrameDetails) headerLen() int { + return recordHeaderLen +} + +func (d recordLayerFrameDetails) defaultReadLen() int { + return recordHeaderLen + maxFragmentLen +} + +func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { + return (int(hdr[3]) << 8) | int(hdr[4]), nil +} + +func NewRecordLayer(conn io.ReadWriter) *RecordLayer { + r := RecordLayer{} + r.conn = conn + r.frame = newFrameReader(recordLayerFrameDetails{}) + r.ivLength = 0 + return &r +} + +func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error { + var err error + r.cipher, err = cipher(key) + if err != nil { + return err + } + + r.ivLength = len(iv) + r.seq = bytes.Repeat([]byte{0}, r.ivLength) + r.nonce = make([]byte, r.ivLength) + copy(r.nonce, iv) + return nil +} + +func (r *RecordLayer) incrementSequenceNumber() { + if r.ivLength == 0 { + return + } + + for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- { + r.seq[i]++ + r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i] + if r.seq[i] != 0 { + return + } + } + + // Not allowed to let sequence number wrap. + // Instead, must renegotiate before it does. + // Not likely enough to bother. + panic("TLS: sequence number wraparound") +} + +func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext { + // Expand the fragment to hold contentType, padding, and overhead + originalLen := len(pt.fragment) + plaintextLen := originalLen + 1 + padLen + ciphertextLen := plaintextLen + r.cipher.Overhead() + + // Assemble the revised plaintext + out := &TLSPlaintext{ + contentType: RecordTypeApplicationData, + fragment: make([]byte, ciphertextLen), + } + copy(out.fragment, pt.fragment) + out.fragment[originalLen] = byte(pt.contentType) + for i := 1; i <= padLen; i++ { + out.fragment[originalLen+i] = 0 + } + + // Encrypt the fragment + payload := out.fragment[:plaintextLen] + r.cipher.Seal(payload[:0], r.nonce, payload, nil) + return out +} + +func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) { + if len(pt.fragment) < r.cipher.Overhead() { + msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead()) + return nil, 0, DecryptError(msg) + } + + decryptLen := len(pt.fragment) - r.cipher.Overhead() + out := &TLSPlaintext{ + contentType: pt.contentType, + fragment: make([]byte, decryptLen), + } + + // Decrypt + _, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil) + if err != nil { + return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed") + } + + // Find the padding boundary + padLen := 0 + for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ { + } + + // Transfer the content type + newLen := decryptLen - padLen - 1 + out.contentType = RecordType(out.fragment[newLen]) + + // Truncate the message to remove contentType, padding, overhead + out.fragment = out.fragment[:newLen] + return out, padLen, nil +} + +func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { + var pt *TLSPlaintext + var err error + + for { + pt, err = r.nextRecord() + if err == nil { + break + } + if !block || err != WouldBlock { + return 0, err + } + } + return pt.contentType, nil +} + +func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { + pt, err := r.nextRecord() + + // Consume the cached record if there was one + r.cachedRecord = nil + r.cachedError = nil + + return pt, err +} + +func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { + if r.cachedRecord != nil { + logf(logTypeIO, "Returning cached record") + return r.cachedRecord, r.cachedError + } + + // Loop until one of three things happens: + // + // 1. We get a frame + // 2. We try to read off the socket and get nothing, in which case + // return WouldBlock + // 3. We get an error. + err := WouldBlock + var header, body []byte + + for err != nil { + if r.frame.needed() > 0 { + buf := make([]byte, recordHeaderLen+maxFragmentLen) + n, err := r.conn.Read(buf) + if err != nil { + logf(logTypeIO, "Error reading, %v", err) + return nil, err + } + + if n == 0 { + return nil, WouldBlock + } + + logf(logTypeIO, "Read %v bytes", n) + + buf = buf[:n] + r.frame.addChunk(buf) + } + + header, body, err = r.frame.process() + // Loop around on WouldBlock to see if some + // data is now available. + if err != nil && err != WouldBlock { + return nil, err + } + } + + pt := &TLSPlaintext{} + // Validate content type + switch RecordType(header[0]) { + default: + return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0]) + case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData: + pt.contentType = RecordType(header[0]) + } + + // Validate version + if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) { + return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2]) + } + + // Validate size < max + size := (int(header[3]) << 8) + int(header[4]) + if size > maxFragmentLen+256 { + return nil, fmt.Errorf("tls.record: Ciphertext size too big") + } + + pt.fragment = make([]byte, size) + copy(pt.fragment, body) + + // Attempt to decrypt fragment + if r.cipher != nil { + pt, _, err = r.decrypt(pt) + if err != nil { + return nil, err + } + } + + // Check that plaintext length is not too long + if len(pt.fragment) > maxFragmentLen { + return nil, fmt.Errorf("tls.record: Plaintext size too big") + } + + logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment) + + r.cachedRecord = pt + r.incrementSequenceNumber() + return pt, nil +} + +func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error { + return r.WriteRecordWithPadding(pt, 0) +} + +func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error { + if r.cipher != nil { + pt = r.encrypt(pt, padLen) + } else if padLen > 0 { + return fmt.Errorf("tls.record: Padding can only be done on encrypted records") + } + + if len(pt.fragment) > maxFragmentLen { + return fmt.Errorf("tls.record: Record size too big") + } + + length := len(pt.fragment) + header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)} + record := append(header, pt.fragment...) + + logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment) + + r.incrementSequenceNumber() + _, err := r.conn.Write(record) + return err +} diff --git a/vendor/github.com/bifurcation/mint/server-state-machine.go b/vendor/github.com/bifurcation/mint/server-state-machine.go new file mode 100644 index 000000000..60df9b64e --- /dev/null +++ b/vendor/github.com/bifurcation/mint/server-state-machine.go @@ -0,0 +1,898 @@ +package mint + +import ( + "bytes" + "hash" + "reflect" +) + +// Server State Machine +// +// START <-----+ +// Recv ClientHello | | Send HelloRetryRequest +// v | +// RECVD_CH ----+ +// | Select parameters +// | Send ServerHello +// v +// NEGOTIATED +// | Send EncryptedExtensions +// | [Send CertificateRequest] +// Can send | [Send Certificate + CertificateVerify] +// app data --> | Send Finished +// after +--------+--------+ +// here No 0-RTT | | 0-RTT +// | v +// | WAIT_EOED <---+ +// | Recv | | | Recv +// | EndOfEarlyData | | | early data +// | | +-----+ +// +> WAIT_FLIGHT2 <-+ +// | +// +--------+--------+ +// No auth | | Client auth +// | | +// | v +// | WAIT_CERT +// | Recv | | Recv Certificate +// | empty | v +// | Certificate | WAIT_CV +// | | | Recv +// | v | CertificateVerify +// +-> WAIT_FINISHED <---+ +// | Recv Finished +// v +// CONNECTED +// +// NB: Not using state RECVD_CH +// +// State Instructions +// START {} +// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)] +// WAIT_EOED RekeyIn; +// WAIT_FLIGHT2 {} +// WAIT_CERT_CR {} +// WAIT_CERT {} +// WAIT_CV {} +// WAIT_FINISHED RekeyIn; RekeyOut; +// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) + +type ServerStateStart struct { + Caps Capabilities + conn *Conn + + cookieSent bool + firstClientHello *HandshakeMessage + helloRetryRequest *HandshakeMessage +} + +func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeClientHello { + logf(logTypeHandshake, "[ServerStateStart] unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + ch := &ClientHelloBody{} + _, err := ch.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + clientHello := hm + connParams := ConnectionParameters{} + + supportedVersions := new(SupportedVersionsExtension) + serverName := new(ServerNameExtension) + supportedGroups := new(SupportedGroupsExtension) + signatureAlgorithms := new(SignatureAlgorithmsExtension) + clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello} + clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello} + clientEarlyData := &EarlyDataExtension{} + clientALPN := new(ALPNExtension) + clientPSKModes := new(PSKKeyExchangeModesExtension) + clientCookie := new(CookieExtension) + + // Handle external extensions. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + gotSupportedVersions := ch.Extensions.Find(supportedVersions) + gotServerName := ch.Extensions.Find(serverName) + gotSupportedGroups := ch.Extensions.Find(supportedGroups) + gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms) + gotEarlyData := ch.Extensions.Find(clientEarlyData) + ch.Extensions.Find(clientKeyShares) + ch.Extensions.Find(clientPSK) + ch.Extensions.Find(clientALPN) + ch.Extensions.Find(clientPSKModes) + ch.Extensions.Find(clientCookie) + + if gotServerName { + connParams.ServerName = string(*serverName) + } + + // If the client didn't send supportedVersions or doesn't support 1.3, + // then we're done here. + if !gotSupportedVersions { + logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions") + return nil, nil, AlertProtocolVersion + } + versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion}) + if !versionOK { + logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version") + return nil, nil, AlertProtocolVersion + } + + if state.Caps.RequireCookie && state.cookieSent && !state.Caps.CookieHandler.Validate(state.conn, clientCookie.Cookie) { + logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch") + return nil, nil, AlertAccessDenied + } + + // Figure out if we can do DH + canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups) + + // Figure out if we can do PSK + canDoPSK := false + var selectedPSK int + var psk *PreSharedKey + var params CipherSuiteParams + if len(clientPSK.Identities) > 0 { + contextBase := []byte{} + if state.helloRetryRequest != nil { + chBytes := state.firstClientHello.Marshal() + hrrBytes := state.helloRetryRequest.Marshal() + contextBase = append(chBytes, hrrBytes...) + } + + chTrunc, err := ch.Truncated() + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err) + return nil, nil, AlertDecodeError + } + + context := append(contextBase, chTrunc...) + + canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Figure out if we actually should do DH / PSK + connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes) + + // Select a ciphersuite + connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err) + return nil, nil, AlertHandshakeFailure + } + + // Send a cookie if required + // NB: Need to do this here because it's after ciphersuite selection, which + // has to be after PSK selection. + // XXX: Doing this statefully for now, could be stateless + var cookieData []byte + if state.Caps.RequireCookie && !state.cookieSent { + var err error + cookieData, err = state.Caps.CookieHandler.Generate(state.conn) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err) + return nil, nil, AlertInternalError + } + } + if cookieData != nil { + // Ignoring errors because everything here is newly constructed, so there + // shouldn't be marshal errors + hrr := &HelloRetryRequestBody{ + Version: supportedVersion, + CipherSuite: connParams.CipherSuite, + } + hrr.Extensions.Add(&CookieExtension{Cookie: cookieData}) + + // Run the external extension handler. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err) + return nil, nil, AlertInternalError + } + } + + helloRetryRequest, err := HandshakeMessageFromBody(hrr) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err) + return nil, nil, AlertInternalError + } + + params := cipherSuiteMap[connParams.CipherSuite] + h := params.Hash.New() + h.Write(clientHello.Marshal()) + firstClientHello := &HandshakeMessage{ + msgType: HandshakeTypeMessageHash, + body: h.Sum(nil), + } + + nextState := ServerStateStart{ + Caps: state.Caps, + conn: state.conn, + cookieSent: true, + firstClientHello: firstClientHello, + helloRetryRequest: helloRetryRequest, + } + toSend := []HandshakeAction{SendHandshakeMessage{helloRetryRequest}} + logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]") + return nextState, toSend, AlertNoAlert + } + + // If we've got no entropy to make keys from, fail + if !connParams.UsingDH && !connParams.UsingPSK { + logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated") + return nil, nil, AlertHandshakeFailure + } + + var pskSecret []byte + var cert *Certificate + var certScheme SignatureScheme + if connParams.UsingPSK { + pskSecret = psk.Key + } else { + psk = nil + + // If we're not using a PSK mode, then we need to have certain extensions + if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms { + logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)", + gotServerName, gotSupportedGroups, gotSignatureAlgorithms) + return nil, nil, AlertMissingExtension + } + + // Select a certificate + name := string(*serverName) + var err error + cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err) + return nil, nil, AlertAccessDenied + } + } + + if !connParams.UsingDH { + dhSecret = nil + } + + // Figure out if we're going to do early data + var clientEarlyTrafficSecret []byte + connParams.ClientSendingEarlyData = gotEarlyData + connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData) + if connParams.UsingEarlyData { + + h := params.Hash.New() + h.Write(clientHello.Marshal()) + chHash := h.Sum(nil) + + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + earlySecret := HkdfExtract(params.Hash, zero, pskSecret) + clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) + } + + // Select a next protocol + connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err) + return nil, nil, AlertNoApplicationProtocol + } + + logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]") + return ServerStateNegotiated{ + Caps: state.Caps, + Params: connParams, + + dhGroup: dhGroup, + dhPublic: dhPublic, + dhSecret: dhSecret, + pskSecret: pskSecret, + selectedPSK: selectedPSK, + cert: cert, + certScheme: certScheme, + clientEarlyTrafficSecret: clientEarlyTrafficSecret, + + firstClientHello: state.firstClientHello, + helloRetryRequest: state.helloRetryRequest, + clientHello: clientHello, + }.Next(nil) +} + +type ServerStateNegotiated struct { + Caps Capabilities + Params ConnectionParameters + + dhGroup NamedGroup + dhPublic []byte + dhSecret []byte + pskSecret []byte + clientEarlyTrafficSecret []byte + selectedPSK int + cert *Certificate + certScheme SignatureScheme + + firstClientHello *HandshakeMessage + helloRetryRequest *HandshakeMessage + clientHello *HandshakeMessage +} + +func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + // Create the ServerHello + sh := &ServerHelloBody{ + Version: supportedVersion, + CipherSuite: state.Params.CipherSuite, + } + _, err := prng.Read(sh.Random[:]) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err) + return nil, nil, AlertInternalError + } + if state.Params.UsingDH { + logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension") + err = sh.Extensions.Add(&KeyShareExtension{ + HandshakeType: HandshakeTypeServerHello, + Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}}, + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err) + return nil, nil, AlertInternalError + } + } + if state.Params.UsingPSK { + logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension") + err = sh.Extensions.Add(&PreSharedKeyExtension{ + HandshakeType: HandshakeTypeServerHello, + SelectedIdentity: uint16(state.selectedPSK), + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Run the external extension handler. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) + return nil, nil, AlertInternalError + } + } + + serverHello, err := HandshakeMessageFromBody(sh) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err) + return nil, nil, AlertInternalError + } + + // Look up crypto params + params, ok := cipherSuiteMap[sh.CipherSuite] + if !ok { + logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite) + return nil, nil, AlertHandshakeFailure + } + + // Start up the handshake hash + handshakeHash := params.Hash.New() + handshakeHash.Write(state.firstClientHello.Marshal()) + handshakeHash.Write(state.helloRetryRequest.Marshal()) + handshakeHash.Write(state.clientHello.Marshal()) + handshakeHash.Write(serverHello.Marshal()) + + // Compute handshake secrets + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + + var earlySecret []byte + if state.Params.UsingPSK { + earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret) + } else { + earlySecret = HkdfExtract(params.Hash, zero, zero) + } + + if state.dhSecret == nil { + state.dhSecret = zero + } + + h0 := params.Hash.New().Sum(nil) + h2 := handshakeHash.Sum(nil) + preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) + handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret) + clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) + serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) + preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) + masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) + + logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret) + logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) + logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) + logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) + logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) + + clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret) + serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) + + // Send an EncryptedExtensions message (even if it's empty) + eeList := ExtensionList{} + if state.Params.NextProto != "" { + logf(logTypeHandshake, "[server] sending ALPN extension") + err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}}) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err) + return nil, nil, AlertInternalError + } + } + if state.Params.UsingEarlyData { + logf(logTypeHandshake, "[server] sending EDI extension") + err = eeList.Add(&EarlyDataExtension{}) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err) + return nil, nil, AlertInternalError + } + } + ee := &EncryptedExtensionsBody{eeList} + + // Run the external extension handler. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) + return nil, nil, AlertInternalError + } + } + + eem, err := HandshakeMessageFromBody(ee) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err) + return nil, nil, AlertInternalError + } + + handshakeHash.Write(eem.Marshal()) + + toSend := []HandshakeAction{ + SendHandshakeMessage{serverHello}, + RekeyOut{Label: "handshake", KeySet: serverHandshakeKeys}, + SendHandshakeMessage{eem}, + } + + // Authenticate with a certificate if required + if !state.Params.UsingPSK { + // Send a CertificateRequest message if we want client auth + if state.Caps.RequireClientAuth { + state.Params.UsingClientAuth = true + + // XXX: We don't support sending any constraints besides a list of + // supported signature algorithms + cr := &CertificateRequestBody{} + schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes} + err := cr.Extensions.Add(schemes) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err) + return nil, nil, AlertInternalError + } + + crm, err := HandshakeMessageFromBody(cr) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err) + return nil, nil, AlertInternalError + } + //TODO state.state.serverCertificateRequest = cr + + toSend = append(toSend, SendHandshakeMessage{crm}) + handshakeHash.Write(crm.Marshal()) + } + + // Create and send Certificate, CertificateVerify + certificate := &CertificateBody{ + CertificateList: make([]CertificateEntry, len(state.cert.Chain)), + } + for i, entry := range state.cert.Chain { + certificate.CertificateList[i] = CertificateEntry{CertData: entry} + } + certm, err := HandshakeMessageFromBody(certificate) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, SendHandshakeMessage{certm}) + handshakeHash.Write(certm.Marshal()) + + certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme} + logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash) + + hcv := handshakeHash.Sum(nil) + logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) + + err = certificateVerify.Sign(state.cert.PrivateKey, hcv) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err) + return nil, nil, AlertInternalError + } + certvm, err := HandshakeMessageFromBody(certificateVerify) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, SendHandshakeMessage{certvm}) + handshakeHash.Write(certvm.Marshal()) + } + + // Compute secrets resulting from the server's first flight + h3 := handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3) + logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3) + + serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3) + logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) + + // Assemble the Finished message + fin := &FinishedBody{ + VerifyDataLen: len(serverFinishedData), + VerifyData: serverFinishedData, + } + finm, _ := HandshakeMessageFromBody(fin) + + toSend = append(toSend, SendHandshakeMessage{finm}) + handshakeHash.Write(finm.Marshal()) + + // Compute traffic secrets + h4 := handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4) + logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4) + + clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4) + serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4) + logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret) + logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret) + + serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret) + toSend = append(toSend, RekeyOut{Label: "application", KeySet: serverTrafficKeys}) + + exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4) + logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret) + + if state.Params.UsingEarlyData { + clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret) + + logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]") + nextState := ServerStateWaitEOED{ + AuthCertificate: state.Caps.AuthCertificate, + Params: state.Params, + cryptoParams: params, + handshakeHash: handshakeHash, + masterSecret: masterSecret, + clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, + clientTrafficSecret: clientTrafficSecret, + serverTrafficSecret: serverTrafficSecret, + exporterSecret: exporterSecret, + } + toSend = append(toSend, []HandshakeAction{ + RekeyIn{Label: "early", KeySet: clientEarlyTrafficKeys}, + ReadEarlyData{}, + }...) + return nextState, toSend, AlertNoAlert + } + + logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]") + toSend = append(toSend, []HandshakeAction{ + RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys}, + ReadPastEarlyData{}, + }...) + waitFlight2 := ServerStateWaitFlight2{ + AuthCertificate: state.Caps.AuthCertificate, + Params: state.Params, + cryptoParams: params, + handshakeHash: handshakeHash, + masterSecret: masterSecret, + clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, + clientTrafficSecret: clientTrafficSecret, + serverTrafficSecret: serverTrafficSecret, + exporterSecret: exporterSecret, + } + nextState, moreToSend, alert := waitFlight2.Next(nil) + toSend = append(toSend, moreToSend...) + return nextState, toSend, alert +} + +type ServerStateWaitEOED struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + masterSecret []byte + clientHandshakeTrafficSecret []byte + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData { + logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + if len(hm.body) > 0 { + logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]") + return nil, nil, AlertDecodeError + } + + state.handshakeHash.Write(hm.Marshal()) + + clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) + + logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]") + toSend := []HandshakeAction{ + RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys}, + } + waitFlight2 := ServerStateWaitFlight2{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + nextState, moreToSend, alert := waitFlight2.Next(nil) + toSend = append(toSend, moreToSend...) + return nextState, toSend, alert +} + +type ServerStateWaitFlight2 struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + masterSecret []byte + clientHandshakeTrafficSecret []byte + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm != nil { + logf(logTypeHandshake, "[ServerStateWaitFlight2] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + if state.Params.UsingClientAuth { + logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]") + nextState := ServerStateWaitCert{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert + } + + logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]") + nextState := ServerStateWaitFinished{ + Params: state.Params, + cryptoParams: state.cryptoParams, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + handshakeHash: state.handshakeHash, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert +} + +type ServerStateWaitCert struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + masterSecret []byte + clientHandshakeTrafficSecret []byte + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeCertificate { + logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + cert := &CertificateBody{} + _, err := cert.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") + return nil, nil, AlertDecodeError + } + + state.handshakeHash.Write(hm.Marshal()) + + if len(cert.CertificateList) == 0 { + logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate") + + logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]") + nextState := ServerStateWaitFinished{ + Params: state.Params, + cryptoParams: state.cryptoParams, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + handshakeHash: state.handshakeHash, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert + } + + logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]") + nextState := ServerStateWaitCV{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + handshakeHash: state.handshakeHash, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + clientCertificate: cert, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert +} + +type ServerStateWaitCV struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + + masterSecret []byte + clientHandshakeTrafficSecret []byte + + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte + + clientCertificate *CertificateBody +} + +func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { + logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm)) + return nil, nil, AlertUnexpectedMessage + } + + certVerify := &CertificateVerifyBody{} + _, err := certVerify.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err) + return nil, nil, AlertDecodeError + } + + // Verify client signature over handshake hash + hcv := state.handshakeHash.Sum(nil) + logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) + + clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey + if err := certVerify.Verify(clientPublicKey, hcv); err != nil { + logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err) + return nil, nil, AlertHandshakeFailure + } + + if state.AuthCertificate != nil { + err := state.AuthCertificate(state.clientCertificate.CertificateList) + if err != nil { + logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate") + return nil, nil, AlertBadCertificate + } + } else { + logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate") + } + + // If it passes, record the certificateVerify in the transcript hash + state.handshakeHash.Write(hm.Marshal()) + + logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]") + nextState := ServerStateWaitFinished{ + Params: state.Params, + cryptoParams: state.cryptoParams, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + handshakeHash: state.handshakeHash, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert +} + +type ServerStateWaitFinished struct { + Params ConnectionParameters + cryptoParams CipherSuiteParams + + masterSecret []byte + clientHandshakeTrafficSecret []byte + + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeFinished { + logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()} + _, err := fin.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err) + return nil, nil, AlertDecodeError + } + + // Verify client Finished data + h5 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5) + + clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5) + logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData) + + if !bytes.Equal(fin.VerifyData, clientFinishedData) { + logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify") + return nil, nil, AlertHandshakeFailure + } + + // Compute the resumption secret + state.handshakeHash.Write(hm.Marshal()) + h6 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6) + + resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6) + logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) + + // Compute client traffic keys + clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) + + logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]") + nextState := StateConnected{ + Params: state.Params, + isClient: false, + cryptoParams: state.cryptoParams, + resumptionSecret: resumptionSecret, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + toSend := []HandshakeAction{ + RekeyIn{Label: "application", KeySet: clientTrafficKeys}, + } + return nextState, toSend, AlertNoAlert +} diff --git a/vendor/github.com/bifurcation/mint/state-machine.go b/vendor/github.com/bifurcation/mint/state-machine.go new file mode 100644 index 000000000..4eb468c69 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/state-machine.go @@ -0,0 +1,230 @@ +package mint + +import ( + "time" +) + +// Marker interface for actions that an implementation should take based on +// state transitions. +type HandshakeAction interface{} + +type SendHandshakeMessage struct { + Message *HandshakeMessage +} + +type SendEarlyData struct{} + +type ReadEarlyData struct{} + +type ReadPastEarlyData struct{} + +type RekeyIn struct { + Label string + KeySet keySet +} + +type RekeyOut struct { + Label string + KeySet keySet +} + +type StorePSK struct { + PSK PreSharedKey +} + +type HandshakeState interface { + Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) +} + +type AppExtensionHandler interface { + Send(hs HandshakeType, el *ExtensionList) error + Receive(hs HandshakeType, el *ExtensionList) error +} + +// Capabilities objects represent the capabilities of a TLS client or server, +// as an input to TLS negotiation +type Capabilities struct { + // For both client and server + CipherSuites []CipherSuite + Groups []NamedGroup + SignatureSchemes []SignatureScheme + PSKs PreSharedKeyCache + Certificates []*Certificate + AuthCertificate func(chain []CertificateEntry) error + ExtensionHandler AppExtensionHandler + + // For client + PSKModes []PSKKeyExchangeMode + + // For server + NextProtos []string + AllowEarlyData bool + RequireCookie bool + CookieHandler CookieHandler + RequireClientAuth bool +} + +// ConnectionOptions objects represent per-connection settings for a client +// initiating a connection +type ConnectionOptions struct { + ServerName string + NextProtos []string + EarlyData []byte +} + +// ConnectionParameters objects represent the parameters negotiated for a +// connection. +type ConnectionParameters struct { + UsingPSK bool + UsingDH bool + ClientSendingEarlyData bool + UsingEarlyData bool + UsingClientAuth bool + + CipherSuite CipherSuite + ServerName string + NextProto string +} + +// StateConnected is symmetric between client and server +type StateConnected struct { + Params ConnectionParameters + isClient bool + cryptoParams CipherSuiteParams + resumptionSecret []byte + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) { + var trafficKeys keySet + if state.isClient { + state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, + labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) + trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) + } else { + state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret, + labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) + trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) + } + + kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request}) + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err) + return nil, AlertInternalError + } + + toSend := []HandshakeAction{ + SendHandshakeMessage{kum}, + RekeyOut{Label: "update", KeySet: trafficKeys}, + } + return toSend, AlertNoAlert +} + +func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) { + tkt, err := NewSessionTicket(length, lifetime) + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err) + return nil, AlertInternalError + } + + err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime}) + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err) + return nil, AlertInternalError + } + + resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, + labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size()) + + newPSK := PreSharedKey{ + CipherSuite: state.cryptoParams.Suite, + IsResumption: true, + Identity: tkt.Ticket, + Key: resumptionKey, + NextProto: state.Params.NextProto, + ReceivedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second), + TicketAgeAdd: tkt.TicketAgeAdd, + } + + tktm, err := HandshakeMessageFromBody(tkt) + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err) + return nil, AlertInternalError + } + + toSend := []HandshakeAction{ + StorePSK{newPSK}, + SendHandshakeMessage{tktm}, + } + return toSend, AlertNoAlert +} + +func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil { + logf(logTypeHandshake, "[StateConnected] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + bodyGeneric, err := hm.ToBody() + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + switch body := bodyGeneric.(type) { + case *KeyUpdateBody: + var trafficKeys keySet + if !state.isClient { + state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, + labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) + trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) + } else { + state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret, + labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) + trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) + } + + toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}} + + // If requested, roll outbound keys and send a KeyUpdate + if body.KeyUpdateRequest == KeyUpdateRequested { + moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested) + if alert != AlertNoAlert { + return nil, nil, alert + } + + toSend = append(toSend, moreToSend...) + } + + return state, toSend, AlertNoAlert + + case *NewSessionTicketBody: + // XXX: Allow NewSessionTicket in both directions? + if !state.isClient { + return nil, nil, AlertUnexpectedMessage + } + + resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, + labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size()) + + psk := PreSharedKey{ + CipherSuite: state.cryptoParams.Suite, + IsResumption: true, + Identity: body.Ticket, + Key: resumptionKey, + NextProto: state.Params.NextProto, + ReceivedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second), + TicketAgeAdd: body.TicketAgeAdd, + } + + toSend := []HandshakeAction{StorePSK{psk}} + return state, toSend, AlertNoAlert + } + + logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType) + return nil, nil, AlertUnexpectedMessage +} diff --git a/vendor/github.com/bifurcation/mint/syntax/decode.go b/vendor/github.com/bifurcation/mint/syntax/decode.go new file mode 100644 index 000000000..cd5aadafa --- /dev/null +++ b/vendor/github.com/bifurcation/mint/syntax/decode.go @@ -0,0 +1,243 @@ +package syntax + +import ( + "bytes" + "fmt" + "reflect" + "runtime" +) + +func Unmarshal(data []byte, v interface{}) (int, error) { + // Check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + d := decodeState{} + d.Write(data) + return d.unmarshal(v) +} + +// These are the options that can be specified in the struct tag. Right now, +// all of them apply to variable-length vectors and nothing else +type decOpts struct { + head uint // length of length in bytes + min uint // minimum size in bytes + max uint // maximum size in bytes +} + +type decodeState struct { + bytes.Buffer +} + +func (d *decodeState) unmarshal(v interface{}) (read int, err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + if s, ok := r.(string); ok { + panic(s) + } + err = r.(error) + } + }() + + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)") + } + + read = d.value(rv) + return read, nil +} + +func (e *decodeState) value(v reflect.Value) int { + return valueDecoder(v)(e, v, decOpts{}) +} + +type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int + +func valueDecoder(v reflect.Value) decoderFunc { + return typeDecoder(v.Type().Elem()) +} + +func typeDecoder(t reflect.Type) decoderFunc { + // Note: Omits the caching / wait-group things that encoding/json uses + return newTypeDecoder(t) +} + +func newTypeDecoder(t reflect.Type) decoderFunc { + // Note: Does not support Marshaler, so don't need the allowAddr argument + + switch t.Kind() { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uintDecoder + case reflect.Array: + return newArrayDecoder(t) + case reflect.Slice: + return newSliceDecoder(t) + case reflect.Struct: + return newStructDecoder(t) + default: + panic(fmt.Errorf("Unsupported type (%s)", t)) + } +} + +///// Specific decoders below + +func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int { + var uintLen int + switch v.Elem().Kind() { + case reflect.Uint8: + uintLen = 1 + case reflect.Uint16: + uintLen = 2 + case reflect.Uint32: + uintLen = 4 + case reflect.Uint64: + uintLen = 8 + } + + buf := make([]byte, uintLen) + n, err := d.Read(buf) + if err != nil { + panic(err) + } + if n != uintLen { + panic(fmt.Errorf("Insufficient data to read uint")) + } + + val := uint64(0) + for _, b := range buf { + val = (val << 8) + uint64(b) + } + + v.Elem().SetUint(val) + return uintLen +} + +////////// + +type arrayDecoder struct { + elemDec decoderFunc +} + +func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { + n := v.Elem().Type().Len() + read := 0 + for i := 0; i < n; i += 1 { + read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts) + } + return read +} + +func newArrayDecoder(t reflect.Type) decoderFunc { + dec := &arrayDecoder{typeDecoder(t.Elem())} + return dec.decode +} + +////////// + +type sliceDecoder struct { + elementType reflect.Type + elementDec decoderFunc +} + +func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { + if opts.head == 0 { + panic(fmt.Errorf("Cannot decode a slice without a header length")) + } + + lengthBytes := make([]byte, opts.head) + n, err := d.Read(lengthBytes) + if err != nil { + panic(err) + } + if uint(n) != opts.head { + panic(fmt.Errorf("Not enough data to read header")) + } + + length := uint(0) + for _, b := range lengthBytes { + length = (length << 8) + uint(b) + } + + if opts.max > 0 && length > opts.max { + panic(fmt.Errorf("Length of vector exceeds declared max")) + } + if length < opts.min { + panic(fmt.Errorf("Length of vector below declared min")) + } + + data := make([]byte, length) + n, err = d.Read(data) + if err != nil { + panic(err) + } + if uint(n) != length { + panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length)) + } + + elemBuf := &decodeState{} + elemBuf.Write(data) + elems := []reflect.Value{} + read := int(opts.head) + for elemBuf.Len() > 0 { + elem := reflect.New(sd.elementType) + read += sd.elementDec(elemBuf, elem, opts) + elems = append(elems, elem) + } + + v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems))) + for i := 0; i < len(elems); i += 1 { + v.Elem().Index(i).Set(elems[i].Elem()) + } + return read +} + +func newSliceDecoder(t reflect.Type) decoderFunc { + dec := &sliceDecoder{ + elementType: t.Elem(), + elementDec: typeDecoder(t.Elem()), + } + return dec.decode +} + +////////// + +type structDecoder struct { + fieldOpts []decOpts + fieldDecs []decoderFunc +} + +func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { + read := 0 + for i := range sd.fieldDecs { + read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i]) + } + return read +} + +func newStructDecoder(t reflect.Type) decoderFunc { + n := t.NumField() + sd := structDecoder{ + fieldOpts: make([]decOpts, n), + fieldDecs: make([]decoderFunc, n), + } + + for i := 0; i < n; i += 1 { + f := t.Field(i) + + tag := f.Tag.Get("tls") + tagOpts := parseTag(tag) + + sd.fieldOpts[i] = decOpts{ + head: tagOpts["head"], + max: tagOpts["max"], + min: tagOpts["min"], + } + + sd.fieldDecs[i] = typeDecoder(f.Type) + } + + return sd.decode +} diff --git a/vendor/github.com/bifurcation/mint/syntax/encode.go b/vendor/github.com/bifurcation/mint/syntax/encode.go new file mode 100644 index 000000000..2874f4047 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/syntax/encode.go @@ -0,0 +1,187 @@ +package syntax + +import ( + "bytes" + "fmt" + "reflect" + "runtime" +) + +func Marshal(v interface{}) ([]byte, error) { + e := &encodeState{} + err := e.marshal(v, encOpts{}) + if err != nil { + return nil, err + } + return e.Bytes(), nil +} + +// These are the options that can be specified in the struct tag. Right now, +// all of them apply to variable-length vectors and nothing else +type encOpts struct { + head uint // length of length in bytes + min uint // minimum size in bytes + max uint // maximum size in bytes +} + +type encodeState struct { + bytes.Buffer +} + +func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + if s, ok := r.(string); ok { + panic(s) + } + err = r.(error) + } + }() + e.reflectValue(reflect.ValueOf(v), opts) + return nil +} + +func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) { + valueEncoder(v)(e, v, opts) +} + +type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts) + +func valueEncoder(v reflect.Value) encoderFunc { + if !v.IsValid() { + panic(fmt.Errorf("Cannot encode an invalid value")) + } + return typeEncoder(v.Type()) +} + +func typeEncoder(t reflect.Type) encoderFunc { + // Note: Omits the caching / wait-group things that encoding/json uses + return newTypeEncoder(t) +} + +func newTypeEncoder(t reflect.Type) encoderFunc { + // Note: Does not support Marshaler, so don't need the allowAddr argument + + switch t.Kind() { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uintEncoder + case reflect.Array: + return newArrayEncoder(t) + case reflect.Slice: + return newSliceEncoder(t) + case reflect.Struct: + return newStructEncoder(t) + default: + panic(fmt.Errorf("Unsupported type (%s)", t)) + } +} + +///// Specific encoders below + +func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) { + u := v.Uint() + switch v.Type().Kind() { + case reflect.Uint8: + e.WriteByte(byte(u)) + case reflect.Uint16: + e.Write([]byte{byte(u >> 8), byte(u)}) + case reflect.Uint32: + e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)}) + case reflect.Uint64: + e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32), + byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)}) + } +} + +////////// + +type arrayEncoder struct { + elemEnc encoderFunc +} + +func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { + n := v.Len() + for i := 0; i < n; i += 1 { + ae.elemEnc(e, v.Index(i), opts) + } +} + +func newArrayEncoder(t reflect.Type) encoderFunc { + enc := &arrayEncoder{typeEncoder(t.Elem())} + return enc.encode +} + +////////// + +type sliceEncoder struct { + ae *arrayEncoder +} + +func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { + if opts.head == 0 { + panic(fmt.Errorf("Cannot encode a slice without a header length")) + } + + arrayState := &encodeState{} + se.ae.encode(arrayState, v, opts) + + n := uint(arrayState.Len()) + if opts.max > 0 && n > opts.max { + panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max)) + } + if n>>(8*opts.head) > 0 { + panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head)) + } + if n < opts.min { + panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min)) + } + + for i := int(opts.head - 1); i >= 0; i -= 1 { + e.WriteByte(byte(n >> (8 * uint(i)))) + } + e.Write(arrayState.Bytes()) +} + +func newSliceEncoder(t reflect.Type) encoderFunc { + enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}} + return enc.encode +} + +////////// + +type structEncoder struct { + fieldOpts []encOpts + fieldEncs []encoderFunc +} + +func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { + for i := range se.fieldEncs { + se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i]) + } +} + +func newStructEncoder(t reflect.Type) encoderFunc { + n := t.NumField() + se := structEncoder{ + fieldOpts: make([]encOpts, n), + fieldEncs: make([]encoderFunc, n), + } + + for i := 0; i < n; i += 1 { + f := t.Field(i) + tag := f.Tag.Get("tls") + tagOpts := parseTag(tag) + + se.fieldOpts[i] = encOpts{ + head: tagOpts["head"], + max: tagOpts["max"], + min: tagOpts["min"], + } + se.fieldEncs[i] = typeEncoder(f.Type) + } + + return se.encode +} diff --git a/vendor/github.com/bifurcation/mint/syntax/tags.go b/vendor/github.com/bifurcation/mint/syntax/tags.go new file mode 100644 index 000000000..a6c9c88d2 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/syntax/tags.go @@ -0,0 +1,30 @@ +package syntax + +import ( + "strconv" + "strings" +) + +// `tls:"head=2,min=2,max=255"` + +type tagOptions map[string]uint + +// parseTag parses a struct field's "tls" tag as a comma-separated list of +// name=value pairs, where the values MUST be unsigned integers +func parseTag(tag string) tagOptions { + opts := tagOptions{} + for _, token := range strings.Split(tag, ",") { + if strings.Index(token, "=") == -1 { + continue + } + + parts := strings.Split(token, "=") + if len(parts[0]) == 0 { + continue + } + if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 { + opts[parts[0]] = uint(val) + } + } + return opts +} diff --git a/vendor/github.com/bifurcation/mint/tls.go b/vendor/github.com/bifurcation/mint/tls.go new file mode 100644 index 000000000..0c57aba5f --- /dev/null +++ b/vendor/github.com/bifurcation/mint/tls.go @@ -0,0 +1,168 @@ +package mint + +// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls + +import ( + "errors" + "net" + "strings" + "time" +) + +// Server returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Server(conn net.Conn, config *Config) *Conn { + return NewConn(conn, config, false) +} + +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. +func Client(conn net.Conn, config *Config) *Conn { + return NewConn(conn, config, true) +} + +// A listener implements a network listener (net.Listener) for TLS connections. +type Listener struct { + net.Listener + config *Config +} + +// Accept waits for and returns the next incoming TLS connection. +// The returned connection c is a *tls.Conn. +func (l *Listener) Accept() (c net.Conn, err error) { + c, err = l.Listener.Accept() + if err != nil { + return + } + server := Server(c, l.config) + err = server.Handshake() + if err == AlertNoAlert { + err = nil + } + c = server + return +} + +// NewListener creates a Listener which accepts connections from an inner +// Listener and wraps each connection with Server. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func NewListener(inner net.Listener, config *Config) net.Listener { + l := new(Listener) + l.Listener = inner + l.config = config + return l +} + +// Listen creates a TLS listener accepting connections on the +// given network address using net.Listen. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Listen(network, laddr string, config *Config) (net.Listener, error) { + if config == nil || !config.ValidForServer() { + return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config") + } + l, err := net.Listen(network, laddr) + if err != nil { + return nil, err + } + return NewListener(l, config), nil +} + +type TimeoutError struct{} + +func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (TimeoutError) Timeout() bool { return true } +func (TimeoutError) Temporary() bool { return true } + +// DialWithDialer connects to the given network address using dialer.Dial and +// then initiates a TLS handshake, returning the resulting TLS connection. Any +// timeout or deadline given in the dialer apply to connection and TLS +// handshake as a whole. +// +// DialWithDialer interprets a nil configuration as equivalent to the zero +// configuration; see the documentation of Config for the defaults. +func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { + // We want the Timeout and Deadline values from dialer to cover the + // whole process: TCP connection and TLS handshake. This means that we + // also need to start our own timers now. + timeout := dialer.Timeout + + if !dialer.Deadline.IsZero() { + deadlineTimeout := dialer.Deadline.Sub(time.Now()) + if timeout == 0 || deadlineTimeout < timeout { + timeout = deadlineTimeout + } + } + + var errChannel chan error + + if timeout != 0 { + errChannel = make(chan error, 2) + time.AfterFunc(timeout, func() { + errChannel <- TimeoutError{} + }) + } + + rawConn, err := dialer.Dial(network, addr) + if err != nil { + return nil, err + } + + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + + if config == nil { + config = &Config{} + } + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + if config.ServerName == "" { + // Make a copy to avoid polluting argument or default. + c := config.Clone() + c.ServerName = hostname + config = c + } + + conn := Client(rawConn, config) + + if timeout == 0 { + err = conn.Handshake() + if err == AlertNoAlert { + err = nil + } + } else { + go func() { + errChannel <- conn.Handshake() + }() + + err = <-errChannel + if err == AlertNoAlert { + err = nil + } + } + + if err != nil { + rawConn.Close() + return nil, err + } + + return conn, nil +} + +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, addr string, config *Config) (*Conn, error) { + return DialWithDialer(new(net.Dialer), network, addr, config) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go deleted file mode 100644 index ac14e201d..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go +++ /dev/null @@ -1,32 +0,0 @@ -package ackhandler - -import ( - "time" - - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" -) - -// SentPacketHandler handles ACKs received for outgoing packets -type SentPacketHandler interface { - // SentPacket may modify the packet - SentPacket(packet *Packet) error - ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error - - SendingAllowed() bool - GetStopWaitingFrame(force bool) *frames.StopWaitingFrame - DequeuePacketForRetransmission() (packet *Packet) - GetLeastUnacked() protocol.PacketNumber - - GetAlarmTimeout() time.Time - OnAlarm() -} - -// ReceivedPacketHandler handles ACKs needed to send for incoming packets -type ReceivedPacketHandler interface { - ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error - SetLowerLimit(protocol.PacketNumber) - - GetAlarmTimeout() time.Time - GetAckFrame() *frames.AckFrame -} diff --git a/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go b/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go index f592d475a..5032ca7f1 100644 --- a/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go +++ b/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go @@ -3,7 +3,7 @@ package quic import ( "sync" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) var bufferPool sync.Pool diff --git a/vendor/github.com/lucas-clemente/quic-go/client.go b/vendor/github.com/lucas-clemente/quic-go/client.go index 2e18de801..955c908eb 100644 --- a/vendor/github.com/lucas-clemente/quic-go/client.go +++ b/vendor/github.com/lucas-clemente/quic-go/client.go @@ -10,32 +10,39 @@ import ( "sync" "time" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) type client struct { - mutex sync.Mutex - listenErr error + mutex sync.Mutex conn connection hostname string - errorChan chan struct{} - handshakeChan <-chan handshakeEvent + versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version + versionNegotiated bool // has the server accepted our version + receivedVersionNegotiationPacket bool + negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet - tlsConf *tls.Config - config *Config - versionNegotiated bool // has version negotiation completed yet + tlsConf *tls.Config + config *Config + tls handshake.MintTLS // only used when using TLS connectionID protocol.ConnectionID - version protocol.VersionNumber + + initialVersion protocol.VersionNumber + version protocol.VersionNumber session packetHandler } var ( + // make it possible to mock connection ID generation in the tests + generateConnectionID = utils.GenerateConnectionID errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") ) @@ -53,71 +60,6 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) return Dial(udpConn, udpAddr, addr, tlsConf, config) } -// DialAddrNonFWSecure establishes a new QUIC connection to a server. -// The hostname for SNI is taken from the given address. -func DialAddrNonFWSecure( - addr string, - tlsConf *tls.Config, - config *Config, -) (NonFWSession, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return nil, err - } - return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config) -} - -// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn. -// The host parameter is used for SNI. -func DialNonFWSecure( - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (NonFWSession, error) { - connID, err := utils.GenerateConnectionID() - if err != nil { - return nil, err - } - - var hostname string - if tlsConf != nil { - hostname = tlsConf.ServerName - } - - if hostname == "" { - hostname, _, err = net.SplitHostPort(host) - if err != nil { - return nil, err - } - } - - clientConfig := populateClientConfig(config) - c := &client{ - conn: &conn{pconn: pconn, currentAddr: remoteAddr}, - connectionID: connID, - hostname: hostname, - tlsConf: tlsConf, - config: clientConfig, - version: clientConfig.Versions[0], - errorChan: make(chan struct{}), - } - - err = c.createNewSession(nil) - if err != nil { - return nil, err - } - - utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) - - return c.session.(NonFWSession), c.establishSecureConnection() -} - // Dial establishes a new QUIC connection to a server using a net.PacketConn. // The host parameter is used for SNI. func Dial( @@ -127,15 +69,39 @@ func Dial( tlsConf *tls.Config, config *Config, ) (Session, error) { - sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config) + connID, err := generateConnectionID() if err != nil { return nil, err } - err = sess.WaitUntilHandshakeComplete() - if err != nil { + + var hostname string + if tlsConf != nil { + hostname = tlsConf.ServerName + } + if hostname == "" { + hostname, _, err = net.SplitHostPort(host) + if err != nil { + return nil, err + } + } + + clientConfig := populateClientConfig(config) + c := &client{ + conn: &conn{pconn: pconn, currentAddr: remoteAddr}, + connectionID: connID, + hostname: hostname, + tlsConf: tlsConf, + config: clientConfig, + version: clientConfig.Versions[0], + versionNegotiationChan: make(chan struct{}), + } + + utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) + + if err := c.dial(); err != nil { return nil, err } - return sess, nil + return c.session, nil } // populateClientConfig populates fields in the quic.Config with their default values, if none are set @@ -153,6 +119,10 @@ func populateClientConfig(config *Config) *Config { if config.HandshakeTimeout != 0 { handshakeTimeout = config.HandshakeTimeout } + idleTimeout := protocol.DefaultIdleTimeout + if config.IdleTimeout != 0 { + idleTimeout = config.IdleTimeout + } maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow if maxReceiveStreamFlowControlWindow == 0 { @@ -166,32 +136,109 @@ func populateClientConfig(config *Config) *Config { return &Config{ Versions: versions, HandshakeTimeout: handshakeTimeout, - RequestConnectionIDTruncation: config.RequestConnectionIDTruncation, + IdleTimeout: idleTimeout, + RequestConnectionIDOmission: config.RequestConnectionIDOmission, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, KeepAlive: config.KeepAlive, } } -// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure) -func (c *client) establishSecureConnection() error { +func (c *client) dial() error { + var err error + if c.version.UsesTLS() { + err = c.dialTLS() + } else { + err = c.dialGQUIC() + } + if err == errCloseSessionForNewVersion { + return c.dial() + } + return err +} + +func (c *client) dialGQUIC() error { + if err := c.createNewGQUICSession(); err != nil { + return err + } go c.listen() + return c.establishSecureConnection() +} + +func (c *client) dialTLS() error { + params := &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + IdleTimeout: c.config.IdleTimeout, + OmitConnectionID: c.config.RequestConnectionIDOmission, + // TODO(#523): make these values configurable + MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient), + MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient), + } + csc := handshake.NewCryptoStreamConn(nil) + extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version) + mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient) + if err != nil { + return err + } + mintConf.ExtensionHandler = extHandler + mintConf.ServerName = c.hostname + c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient) + + if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { + return err + } + go c.listen() + if err := c.establishSecureConnection(); err != nil { + if err != handshake.ErrCloseSessionForRetry { + return err + } + utils.Infof("Received a Retry packet. Recreating session.") + if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { + return err + } + if err := c.establishSecureConnection(); err != nil { + return err + } + } + return nil +} + +// establishSecureConnection runs the session, and tries to establish a secure connection +// It returns: +// - errCloseSessionForNewVersion when the server sends a version negotiation packet +// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC) +// - any other error that might occur +// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC) +func (c *client) establishSecureConnection() error { + var runErr error + errorChan := make(chan struct{}) + go func() { + runErr = c.session.run() // returns as soon as the session is closed + close(errorChan) + utils.Infof("Connection %x closed.", c.connectionID) + if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion { + c.conn.Close() + } + }() + + // wait until the server accepts the QUIC version (or an error occurs) + select { + case <-errorChan: + return runErr + case <-c.versionNegotiationChan: + } select { - case <-c.errorChan: - return c.listenErr - case ev := <-c.handshakeChan: - if ev.err != nil { - return ev.err - } - if ev.encLevel != protocol.EncryptionSecure { - return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel) - } - return nil + case <-errorChan: + return runErr + case err := <-c.session.handshakeStatus(): + return err } } -// Listen listens +// Listen listens on the underlying connection and passes packets on for handling. +// It returns when the connection is closed. func (c *client) listen() { var err error @@ -205,13 +252,15 @@ func (c *client) listen() { n, addr, err = c.conn.Read(data) if err != nil { if !strings.HasSuffix(err.Error(), "use of closed network connection") { - c.session.Close(err) + c.mutex.Lock() + if c.session != nil { + c.session.Close(err) + } + c.mutex.Unlock() } break } - data = data[:n] - - c.handlePacket(addr, data) + c.handlePacket(addr, data[:n]) } } @@ -219,10 +268,14 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { rcvTime := time.Now() r := bytes.NewReader(packet) - hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer) + hdr, err := wire.ParseHeaderSentByServer(r, c.version) if err != nil { utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) - // drop this packet if we can't parse the Public Header + // drop this packet if we can't parse the header + return + } + // reject packets with truncated connection id if we didn't request truncation + if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission { return } hdr.Raw = packet[:len(packet)-r.Len()] @@ -230,6 +283,11 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { c.mutex.Lock() defer c.mutex.Unlock() + // reject packets with the wrong connection ID + if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID { + return + } + if hdr.ResetFlag { cr := c.conn.RemoteAddr() // check if the remote address and the connection ID match @@ -238,44 +296,48 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { utils.Infof("Received a spoofed Public Reset. Ignoring.") return } - pr, err := parsePublicReset(r) + pr, err := wire.ParsePublicReset(r) if err != nil { - utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.") + utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err) return } - utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber) - c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber))) + utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber) + c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber))) return } - // ignore delayed / duplicated version negotiation packets - if c.versionNegotiated && hdr.VersionFlag { - return - } + // handle Version Negotiation Packets + if hdr.IsVersionNegotiation { + // ignore delayed / duplicated version negotiation packets + if c.receivedVersionNegotiationPacket || c.versionNegotiated { + return + } - // this is the first packet after the client sent a packet with the VersionFlag set - // if the server doesn't send a version negotiation packet, it supports the suggested version - if !hdr.VersionFlag && !c.versionNegotiated { - c.versionNegotiated = true - } - - if hdr.VersionFlag { // version negotiation packets have no payload - if err := c.handlePacketWithVersionFlag(hdr); err != nil { + if err := c.handleVersionNegotiationPacket(hdr); err != nil { c.session.Close(err) } return } + // this is the first packet we are receiving + // since it is not a Version Negotiation Packet, this means the server supports the suggested version + if !c.versionNegotiated { + c.versionNegotiated = true + close(c.versionNegotiationChan) + } + + // TODO: validate packet number and connection ID on Retry packets (for IETF QUIC) + c.session.handlePacket(&receivedPacket{ - remoteAddr: remoteAddr, - publicHeader: hdr, - data: packet[len(packet)-r.Len():], - rcvTime: rcvTime, + remoteAddr: remoteAddr, + header: hdr, + data: packet[len(packet)-r.Len():], + rcvTime: rcvTime, }) } -func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error { +func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { for _, v := range hdr.SupportedVersions { if v == c.version { // the version negotiation packet contains the version that we offered @@ -285,51 +347,57 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error { } } - newVersion := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) - if newVersion == protocol.VersionUnsupported { + newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) + if !ok { return qerr.InvalidVersion } + c.receivedVersionNegotiationPacket = true + c.negotiatedVersions = hdr.SupportedVersions // switch to negotiated version + c.initialVersion = c.version c.version = newVersion - c.versionNegotiated = true var err error c.connectionID, err = utils.GenerateConnectionID() if err != nil { return err } - utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID) - + utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) c.session.Close(errCloseSessionForNewVersion) - return c.createNewSession(hdr.SupportedVersions) + return nil } -func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { - var err error - c.session, c.handshakeChan, err = newClientSession( +func (c *client) createNewGQUICSession() (err error) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.session, err = newClientSession( c.conn, c.hostname, c.version, c.connectionID, c.tlsConf, c.config, - negotiatedVersions, + c.initialVersion, + c.negotiatedVersions, ) - if err != nil { - return err - } - - go func() { - // session.run() returns as soon as the session is closed - err := c.session.run() - if err == errCloseSessionForNewVersion { - return - } - c.listenErr = err - close(c.errorChan) - - utils.Infof("Connection %x closed.", c.connectionID) - c.conn.Close() - }() - return nil + return err +} + +func (c *client) createNewTLSSession( + paramsChan <-chan handshake.TransportParameters, + version protocol.VersionNumber, +) (err error) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.session, err = newTLSClientSession( + c.conn, + c.hostname, + c.version, + c.connectionID, + c.config, + c.tls, + paramsChan, + 1, + ) + return err } diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go b/vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go deleted file mode 100644 index a738cc2b1..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go +++ /dev/null @@ -1,58 +0,0 @@ -package crypto - -import ( - "crypto/cipher" - "errors" - - "github.com/lucas-clemente/aes12" - - "github.com/lucas-clemente/quic-go/protocol" -) - -type aeadAESGCM struct { - otherIV []byte - myIV []byte - encrypter cipher.AEAD - decrypter cipher.AEAD -} - -// NewAEADAESGCM creates a AEAD using AES-GCM with 12 bytes tag size -// -// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte -// tag size, and couples the cipher and aes packages closely. -// See https://github.com/lucas-clemente/aes12. -func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) { - if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 { - return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs") - } - encrypterCipher, err := aes12.NewCipher(myKey) - if err != nil { - return nil, err - } - encrypter, err := aes12.NewGCM(encrypterCipher) - if err != nil { - return nil, err - } - decrypterCipher, err := aes12.NewCipher(otherKey) - if err != nil { - return nil, err - } - decrypter, err := aes12.NewGCM(decrypterCipher) - if err != nil { - return nil, err - } - return &aeadAESGCM{ - otherIV: otherIV, - myIV: myIV, - encrypter: encrypter, - decrypter: decrypter, - }, nil -} - -func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { - return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData) -} - -func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go b/vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go deleted file mode 100644 index 9b6d41645..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go +++ /dev/null @@ -1,14 +0,0 @@ -package crypto - -import ( - "encoding/binary" - - "github.com/lucas-clemente/quic-go/protocol" -) - -func makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { - res := make([]byte, 12) - copy(res[0:4], iv) - binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber)) - return res -} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go b/vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go deleted file mode 100644 index 3dcb26a75..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go +++ /dev/null @@ -1,76 +0,0 @@ -package crypto - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/sha256" - "fmt" - "io" - - "golang.org/x/crypto/hkdf" -) - -// StkSource is used to create and verify source address tokens -type StkSource interface { - // NewToken creates a new token - NewToken([]byte) ([]byte, error) - // DecodeToken decodes a token - DecodeToken([]byte) ([]byte, error) -} - -type stkSource struct { - aead cipher.AEAD -} - -const stkKeySize = 16 - -// Chrome currently sets this to 12, but discusses changing it to 16. We start -// at 16 :) -const stkNonceSize = 16 - -// NewStkSource creates a source for source address tokens -func NewStkSource() (StkSource, error) { - secret := make([]byte, 32) - if _, err := rand.Read(secret); err != nil { - return nil, err - } - key, err := deriveKey(secret) - if err != nil { - return nil, err - } - c, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize) - if err != nil { - return nil, err - } - return &stkSource{aead: aead}, nil -} - -func (s *stkSource) NewToken(data []byte) ([]byte, error) { - nonce := make([]byte, stkNonceSize) - if _, err := rand.Read(nonce); err != nil { - return nil, err - } - return s.aead.Seal(nonce, nonce, data, nil), nil -} - -func (s *stkSource) DecodeToken(p []byte) ([]byte, error) { - if len(p) < stkNonceSize { - return nil, fmt.Errorf("STK too short: %d", len(p)) - } - nonce := p[:stkNonceSize] - return s.aead.Open(nil, nonce, p[stkNonceSize:], nil) -} - -func deriveKey(secret []byte) ([]byte, error) { - r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key")) - key := make([]byte, stkKeySize) - if _, err := io.ReadFull(r, key); err != nil { - return nil, err - } - return key, nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go b/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go new file mode 100644 index 000000000..8e96ec107 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go @@ -0,0 +1,41 @@ +package quic + +import ( + "io" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type cryptoStreamI interface { + StreamID() protocol.StreamID + io.Reader + io.Writer + handleStreamFrame(*wire.StreamFrame) error + popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool) + closeForShutdown(error) + setReadOffset(protocol.ByteCount) + // methods needed for flow control + getWindowUpdate() protocol.ByteCount + handleMaxStreamDataFrame(*wire.MaxStreamDataFrame) +} + +type cryptoStream struct { + *stream +} + +var _ cryptoStreamI = &cryptoStream{} + +func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI { + str := newStream(version.CryptoStreamID(), sender, flowController, version) + return &cryptoStream{str} +} + +// SetReadOffset sets the read offset. +// It is only needed for the crypto stream. +// It must not be called concurrently with any other stream methods, especially Read and Write. +func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) { + s.receiveStream.readOffset = offset + s.receiveStream.frameQueue.readPosition = offset +} diff --git a/vendor/github.com/lucas-clemente/quic-go/example/client/main.go b/vendor/github.com/lucas-clemente/quic-go/example/client/main.go index f4e3e57b1..2a28c1612 100644 --- a/vendor/github.com/lucas-clemente/quic-go/example/client/main.go +++ b/vendor/github.com/lucas-clemente/quic-go/example/client/main.go @@ -7,12 +7,15 @@ import ( "net/http" "sync" + quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/h2quic" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" ) func main() { verbose := flag.Bool("v", false, "verbose") + tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)") flag.Parse() urls := flag.Args() @@ -23,8 +26,17 @@ func main() { } utils.SetLogTimeFormat("") + versions := protocol.SupportedVersions + if *tls { + versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...) + } + + roundTripper := &h2quic.RoundTripper{ + QuicConfig: &quic.Config{Versions: versions}, + } + defer roundTripper.Close() hclient := &http.Client{ - Transport: &h2quic.RoundTripper{}, + Transport: roundTripper, } var wg sync.WaitGroup diff --git a/vendor/github.com/lucas-clemente/quic-go/example/main.go b/vendor/github.com/lucas-clemente/quic-go/example/main.go index d6f330fc3..35aaa85c6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/example/main.go +++ b/vendor/github.com/lucas-clemente/quic-go/example/main.go @@ -17,7 +17,9 @@ import ( _ "net/http/pprof" + quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/h2quic" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" ) @@ -121,6 +123,7 @@ func main() { certPath := flag.String("certpath", getBuildDir(), "certificate directory") www := flag.String("www", "/var/www", "www data") tcp := flag.Bool("tcp", false, "also listen on TCP") + tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)") flag.Parse() if *verbose { @@ -130,6 +133,11 @@ func main() { } utils.SetLogTimeFormat("") + versions := protocol.SupportedVersions + if *tls { + versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...) + } + certFile := *certPath + "/fullchain.pem" keyFile := *certPath + "/privkey.pem" @@ -148,7 +156,11 @@ func main() { if *tcp { err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil) } else { - err = h2quic.ListenAndServeQUIC(bCap, certFile, keyFile, nil) + server := h2quic.Server{ + Server: &http.Server{Addr: bCap}, + QuicConfig: &quic.Config{Versions: versions}, + } + err = server.ListenAndServeTLS(certFile, keyFile) } if err != nil { fmt.Println(err) diff --git a/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go b/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go deleted file mode 100644 index 9362d60a1..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go +++ /dev/null @@ -1,240 +0,0 @@ -package flowcontrol - -import ( - "errors" - "fmt" - "sync" - - "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/handshake" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/qerr" -) - -type flowControlManager struct { - connectionParameters handshake.ConnectionParametersManager - rttStats *congestion.RTTStats - - streamFlowController map[protocol.StreamID]*flowController - connFlowController *flowController - mutex sync.RWMutex -} - -var _ FlowControlManager = &flowControlManager{} - -var errMapAccess = errors.New("Error accessing the flowController map.") - -// NewFlowControlManager creates a new flow control manager -func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager { - return &flowControlManager{ - connectionParameters: connectionParameters, - rttStats: rttStats, - streamFlowController: make(map[protocol.StreamID]*flowController), - connFlowController: newFlowController(0, false, connectionParameters, rttStats), - } -} - -// NewStream creates new flow controllers for a stream -// it does nothing if the stream already exists -func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesToConnection bool) { - f.mutex.Lock() - defer f.mutex.Unlock() - - if _, ok := f.streamFlowController[streamID]; ok { - return - } - - f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connectionParameters, f.rttStats) -} - -// RemoveStream removes a closed stream from flow control -func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) { - f.mutex.Lock() - delete(f.streamFlowController, streamID) - f.mutex.Unlock() -} - -// ResetStream should be called when receiving a RstStreamFrame -// it updates the byte offset to the value in the RstStreamFrame -// streamID must not be 0 here -func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { - f.mutex.Lock() - defer f.mutex.Unlock() - - streamFlowController, err := f.getFlowController(streamID) - if err != nil { - return err - } - increment, err := streamFlowController.UpdateHighestReceived(byteOffset) - if err != nil { - return qerr.StreamDataAfterTermination - } - - if streamFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow)) - } - - if streamFlowController.ContributesToConnection() { - f.connFlowController.IncrementHighestReceived(increment) - if f.connFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow)) - } - } - - return nil -} - -// UpdateHighestReceived updates the highest received byte offset for a stream -// it adds the number of additional bytes to connection level flow control -// streamID must not be 0 here -func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { - f.mutex.Lock() - defer f.mutex.Unlock() - - streamFlowController, err := f.getFlowController(streamID) - if err != nil { - return err - } - // UpdateHighestReceived returns an ErrReceivedSmallerByteOffset when StreamFrames got reordered - // this error can be ignored here - increment, _ := streamFlowController.UpdateHighestReceived(byteOffset) - - if streamFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow)) - } - - if streamFlowController.ContributesToConnection() { - f.connFlowController.IncrementHighestReceived(increment) - if f.connFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow)) - } - } - - return nil -} - -// streamID must not be 0 here -func (f *flowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error { - f.mutex.Lock() - defer f.mutex.Unlock() - - fc, err := f.getFlowController(streamID) - if err != nil { - return err - } - - fc.AddBytesRead(n) - if fc.ContributesToConnection() { - f.connFlowController.AddBytesRead(n) - } - - return nil -} - -func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) { - f.mutex.Lock() - defer f.mutex.Unlock() - - // get WindowUpdates for streams - for id, fc := range f.streamFlowController { - if necessary, newIncrement, offset := fc.MaybeUpdateWindow(); necessary { - res = append(res, WindowUpdate{StreamID: id, Offset: offset}) - if fc.ContributesToConnection() && newIncrement != 0 { - f.connFlowController.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(newIncrement) * protocol.ConnectionFlowControlMultiplier)) - } - } - } - // get a WindowUpdate for the connection - if necessary, _, offset := f.connFlowController.MaybeUpdateWindow(); necessary { - res = append(res, WindowUpdate{StreamID: 0, Offset: offset}) - } - - return -} - -func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) { - f.mutex.RLock() - defer f.mutex.RUnlock() - - // StreamID can be 0 when retransmitting - if streamID == 0 { - return f.connFlowController.receiveWindow, nil - } - - flowController, err := f.getFlowController(streamID) - if err != nil { - return 0, err - } - return flowController.receiveWindow, nil -} - -// streamID must not be 0 here -func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error { - f.mutex.Lock() - defer f.mutex.Unlock() - - fc, err := f.getFlowController(streamID) - if err != nil { - return err - } - - fc.AddBytesSent(n) - if fc.ContributesToConnection() { - f.connFlowController.AddBytesSent(n) - } - - return nil -} - -// must not be called with StreamID 0 -func (f *flowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) { - f.mutex.RLock() - defer f.mutex.RUnlock() - - fc, err := f.getFlowController(streamID) - if err != nil { - return 0, err - } - res := fc.SendWindowSize() - - if fc.ContributesToConnection() { - res = utils.MinByteCount(res, f.connFlowController.SendWindowSize()) - } - - return res, nil -} - -func (f *flowControlManager) RemainingConnectionWindowSize() protocol.ByteCount { - f.mutex.RLock() - defer f.mutex.RUnlock() - - return f.connFlowController.SendWindowSize() -} - -// streamID may be 0 here -func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) { - f.mutex.Lock() - defer f.mutex.Unlock() - - var fc *flowController - if streamID == 0 { - fc = f.connFlowController - } else { - var err error - fc, err = f.getFlowController(streamID) - if err != nil { - return false, err - } - } - - return fc.UpdateSendWindow(offset), nil -} - -func (f *flowControlManager) getFlowController(streamID protocol.StreamID) (*flowController, error) { - streamFlowController, ok := f.streamFlowController[streamID] - if !ok { - return nil, errMapAccess - } - return streamFlowController, nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go deleted file mode 100644 index 387ee05b9..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go +++ /dev/null @@ -1,198 +0,0 @@ -package flowcontrol - -import ( - "errors" - "time" - - "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/handshake" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" -) - -type flowController struct { - streamID protocol.StreamID - contributesToConnection bool // does the stream contribute to connection level flow control - - connectionParameters handshake.ConnectionParametersManager - rttStats *congestion.RTTStats - - bytesSent protocol.ByteCount - sendWindow protocol.ByteCount - - lastWindowUpdateTime time.Time - - bytesRead protocol.ByteCount - highestReceived protocol.ByteCount - receiveWindow protocol.ByteCount - receiveWindowIncrement protocol.ByteCount - maxReceiveWindowIncrement protocol.ByteCount -} - -// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously -var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset") - -// newFlowController gets a new flow controller -func newFlowController(streamID protocol.StreamID, contributesToConnection bool, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController { - fc := flowController{ - streamID: streamID, - contributesToConnection: contributesToConnection, - connectionParameters: connectionParameters, - rttStats: rttStats, - } - - if streamID == 0 { - fc.receiveWindow = connectionParameters.GetReceiveConnectionFlowControlWindow() - fc.receiveWindowIncrement = fc.receiveWindow - fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow() - } else { - fc.receiveWindow = connectionParameters.GetReceiveStreamFlowControlWindow() - fc.receiveWindowIncrement = fc.receiveWindow - fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow() - } - - return &fc -} - -func (c *flowController) ContributesToConnection() bool { - return c.contributesToConnection -} - -func (c *flowController) getSendWindow() protocol.ByteCount { - if c.sendWindow == 0 { - if c.streamID == 0 { - return c.connectionParameters.GetSendConnectionFlowControlWindow() - } - return c.connectionParameters.GetSendStreamFlowControlWindow() - } - return c.sendWindow -} - -func (c *flowController) AddBytesSent(n protocol.ByteCount) { - c.bytesSent += n -} - -// UpdateSendWindow should be called after receiving a WindowUpdateFrame -// it returns true if the window was actually updated -func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool { - if newOffset > c.sendWindow { - c.sendWindow = newOffset - return true - } - return false -} - -func (c *flowController) SendWindowSize() protocol.ByteCount { - sendWindow := c.getSendWindow() - - if c.bytesSent > sendWindow { // should never happen, but make sure we don't do an underflow here - return 0 - } - return sendWindow - c.bytesSent -} - -func (c *flowController) SendWindowOffset() protocol.ByteCount { - return c.getSendWindow() -} - -// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher -// Should **only** be used for the stream-level FlowController -// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before -// This error occurs every time StreamFrames get reordered and has to be ignored in that case -// It should only be treated as an error when resetting a stream -func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) (protocol.ByteCount, error) { - if byteOffset == c.highestReceived { - return 0, nil - } - if byteOffset > c.highestReceived { - increment := byteOffset - c.highestReceived - c.highestReceived = byteOffset - return increment, nil - } - return 0, ErrReceivedSmallerByteOffset -} - -// IncrementHighestReceived adds an increment to the highestReceived value -// Should **only** be used for the connection-level FlowController -func (c *flowController) IncrementHighestReceived(increment protocol.ByteCount) { - c.highestReceived += increment -} - -func (c *flowController) AddBytesRead(n protocol.ByteCount) { - // pretend we sent a WindowUpdate when reading the first byte - // this way auto-tuning of the window increment already works for the first WindowUpdate - if c.bytesRead == 0 { - c.lastWindowUpdateTime = time.Now() - } - c.bytesRead += n -} - -// MaybeUpdateWindow updates the receive window, if necessary -// if the receive window increment is changed, the new value is returned, otherwise a 0 -// the last return value is the new offset of the receive window -func (c *flowController) MaybeUpdateWindow() (bool, protocol.ByteCount /* new increment */, protocol.ByteCount /* new offset */) { - diff := c.receiveWindow - c.bytesRead - - // Chromium implements the same threshold - if diff < (c.receiveWindowIncrement / 2) { - var newWindowIncrement protocol.ByteCount - oldWindowIncrement := c.receiveWindowIncrement - - c.maybeAdjustWindowIncrement() - if c.receiveWindowIncrement != oldWindowIncrement { - newWindowIncrement = c.receiveWindowIncrement - } - - c.lastWindowUpdateTime = time.Now() - c.receiveWindow = c.bytesRead + c.receiveWindowIncrement - return true, newWindowIncrement, c.receiveWindow - } - - return false, 0, 0 -} - -// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often -func (c *flowController) maybeAdjustWindowIncrement() { - if c.lastWindowUpdateTime.IsZero() { - return - } - - rtt := c.rttStats.SmoothedRTT() - if rtt == 0 { - return - } - - timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime) - - // interval between the window updates is sufficiently large, no need to increase the increment - if timeSinceLastWindowUpdate >= 2*rtt { - return - } - - oldWindowSize := c.receiveWindowIncrement - c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement) - - // debug log, if the window size was actually increased - if oldWindowSize < c.receiveWindowIncrement { - newWindowSize := c.receiveWindowIncrement / (1 << 10) - if c.streamID == 0 { - utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize) - } else { - utils.Debugf("Increasing receive flow control window increment for stream %d to %d kB", c.streamID, newWindowSize) - } - } -} - -// EnsureMinimumWindowIncrement sets a minimum window increment -// it is intended be used for the connection-level flow controller -// it should make sure that the connection-level window is increased when a stream-level window grows -func (c *flowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) { - if inc > c.receiveWindowIncrement { - c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement) - c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update - } -} - -func (c *flowController) CheckFlowControlViolation() bool { - return c.highestReceived > c.receiveWindow -} diff --git a/vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go b/vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go deleted file mode 100644 index e1ea3fac6..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go +++ /dev/null @@ -1,26 +0,0 @@ -package flowcontrol - -import "github.com/lucas-clemente/quic-go/protocol" - -// WindowUpdate provides the data for WindowUpdateFrames. -type WindowUpdate struct { - StreamID protocol.StreamID - Offset protocol.ByteCount -} - -// A FlowControlManager manages the flow control -type FlowControlManager interface { - NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) - RemoveStream(streamID protocol.StreamID) - // methods needed for receiving data - ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error - UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error - AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error - GetWindowUpdates() []WindowUpdate - GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) - // methods needed for sending data - AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error - SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) - RemainingConnectionWindowSize() protocol.ByteCount - UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go b/vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go deleted file mode 100644 index ac65d33ea..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go +++ /dev/null @@ -1,9 +0,0 @@ -package frames - -import "github.com/lucas-clemente/quic-go/protocol" - -// AckRange is an ACK range -type AckRange struct { - FirstPacketNumber protocol.PacketNumber - LastPacketNumber protocol.PacketNumber -} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go deleted file mode 100644 index 44645780d..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go +++ /dev/null @@ -1,44 +0,0 @@ -package frames - -import ( - "bytes" - - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" -) - -// A BlockedFrame in QUIC -type BlockedFrame struct { - StreamID protocol.StreamID -} - -//Write writes a BlockedFrame frame -func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x05) - utils.WriteUint32(b, uint32(f.StreamID)) - return nil -} - -// MinLength of a written frame -func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1 + 4, nil -} - -// ParseBlockedFrame parses a BLOCKED frame -func ParseBlockedFrame(r *bytes.Reader) (*BlockedFrame, error) { - frame := &BlockedFrame{} - - // read the TypeByte - _, err := r.ReadByte() - if err != nil { - return nil, err - } - - sid, err := utils.ReadUint32(r) - if err != nil { - return nil, err - } - frame.StreamID = protocol.StreamID(sid) - - return frame, nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go b/vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go deleted file mode 100644 index 5a7ed04cf..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go +++ /dev/null @@ -1,73 +0,0 @@ -package frames - -import ( - "bytes" - "errors" - "io" - "math" - - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/qerr" -) - -// A ConnectionCloseFrame in QUIC -type ConnectionCloseFrame struct { - ErrorCode qerr.ErrorCode - ReasonPhrase string -} - -// ParseConnectionCloseFrame reads a CONNECTION_CLOSE frame -func ParseConnectionCloseFrame(r *bytes.Reader) (*ConnectionCloseFrame, error) { - frame := &ConnectionCloseFrame{} - - // read the TypeByte - _, err := r.ReadByte() - if err != nil { - return nil, err - } - - errorCode, err := utils.ReadUint32(r) - if err != nil { - return nil, err - } - frame.ErrorCode = qerr.ErrorCode(errorCode) - - reasonPhraseLen, err := utils.ReadUint16(r) - if err != nil { - return nil, err - } - - if reasonPhraseLen > uint16(protocol.MaxPacketSize) { - return nil, qerr.Error(qerr.InvalidConnectionCloseData, "reason phrase too long") - } - - reasonPhrase := make([]byte, reasonPhraseLen) - if _, err := io.ReadFull(r, reasonPhrase); err != nil { - return nil, err - } - frame.ReasonPhrase = string(reasonPhrase) - - return frame, nil -} - -// MinLength of a written frame -func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)), nil -} - -// Write writes an CONNECTION_CLOSE frame. -func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x02) - utils.WriteUint32(b, uint32(f.ErrorCode)) - - if len(f.ReasonPhrase) > math.MaxUint16 { - return errors.New("ConnectionFrame: ReasonPhrase too long") - } - - reasonPhraseLen := uint16(len(f.ReasonPhrase)) - utils.WriteUint16(b, reasonPhraseLen) - b.WriteString(f.ReasonPhrase) - - return nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/frame.go b/vendor/github.com/lucas-clemente/quic-go/frames/frame.go deleted file mode 100644 index 464e6693a..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/frames/frame.go +++ /dev/null @@ -1,13 +0,0 @@ -package frames - -import ( - "bytes" - - "github.com/lucas-clemente/quic-go/protocol" -) - -// A Frame in QUIC -type Frame interface { - Write(b *bytes.Buffer, version protocol.VersionNumber) error - MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/log.go b/vendor/github.com/lucas-clemente/quic-go/frames/log.go deleted file mode 100644 index 6b7fdcec4..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/frames/log.go +++ /dev/null @@ -1,28 +0,0 @@ -package frames - -import "github.com/lucas-clemente/quic-go/internal/utils" - -// LogFrame logs a frame, either sent or received -func LogFrame(frame Frame, sent bool) { - if !utils.Debug() { - return - } - dir := "<-" - if sent { - dir = "->" - } - switch f := frame.(type) { - case *StreamFrame: - utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen()) - case *StopWaitingFrame: - if sent { - utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen) - } else { - utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked) - } - case *AckFrame: - utils.Debugf("\t%s &frames.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String()) - default: - utils.Debugf("\t%s %#v", dir, frame) - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go b/vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go deleted file mode 100644 index ea2531c68..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go +++ /dev/null @@ -1,59 +0,0 @@ -package frames - -import ( - "bytes" - - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" -) - -// A RstStreamFrame in QUIC -type RstStreamFrame struct { - StreamID protocol.StreamID - ErrorCode uint32 - ByteOffset protocol.ByteCount -} - -//Write writes a RST_STREAM frame -func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x01) - utils.WriteUint32(b, uint32(f.StreamID)) - utils.WriteUint64(b, uint64(f.ByteOffset)) - utils.WriteUint32(b, f.ErrorCode) - return nil -} - -// MinLength of a written frame -func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1 + 4 + 8 + 4, nil -} - -// ParseRstStreamFrame parses a RST_STREAM frame -func ParseRstStreamFrame(r *bytes.Reader) (*RstStreamFrame, error) { - frame := &RstStreamFrame{} - - // read the TypeByte - _, err := r.ReadByte() - if err != nil { - return nil, err - } - - sid, err := utils.ReadUint32(r) - if err != nil { - return nil, err - } - frame.StreamID = protocol.StreamID(sid) - - byteOffset, err := utils.ReadUint64(r) - if err != nil { - return nil, err - } - frame.ByteOffset = protocol.ByteCount(byteOffset) - - frame.ErrorCode, err = utils.ReadUint32(r) - if err != nil { - return nil, err - } - - return frame, nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go b/vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go deleted file mode 100644 index 9b8b45986..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go +++ /dev/null @@ -1,54 +0,0 @@ -package frames - -import ( - "bytes" - - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" -) - -// A WindowUpdateFrame in QUIC -type WindowUpdateFrame struct { - StreamID protocol.StreamID - ByteOffset protocol.ByteCount -} - -//Write writes a RST_STREAM frame -func (f *WindowUpdateFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - typeByte := uint8(0x04) - b.WriteByte(typeByte) - - utils.WriteUint32(b, uint32(f.StreamID)) - utils.WriteUint64(b, uint64(f.ByteOffset)) - return nil -} - -// MinLength of a written frame -func (f *WindowUpdateFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1 + 4 + 8, nil -} - -// ParseWindowUpdateFrame parses a RST_STREAM frame -func ParseWindowUpdateFrame(r *bytes.Reader) (*WindowUpdateFrame, error) { - frame := &WindowUpdateFrame{} - - // read the TypeByte - _, err := r.ReadByte() - if err != nil { - return nil, err - } - - sid, err := utils.ReadUint32(r) - if err != nil { - return nil, err - } - frame.StreamID = protocol.StreamID(sid) - - byteOffset, err := utils.ReadUint64(r) - if err != nil { - return nil, err - } - frame.ByteOffset = protocol.ByteCount(byteOffset) - - return frame, nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go index 866b11abc..f506867dc 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go @@ -15,8 +15,8 @@ import ( "golang.org/x/net/idna" quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) @@ -34,10 +34,10 @@ type client struct { config *quic.Config opts *roundTripperOpts - hostname string - encryptionLevel protocol.EncryptionLevel - handshakeErr error - dialOnce sync.Once + hostname string + handshakeErr error + dialOnce sync.Once + dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error) session quic.Session headerStream quic.Stream @@ -51,8 +51,8 @@ type client struct { var _ http.RoundTripper = &client{} var defaultQuicConfig = &quic.Config{ - RequestConnectionIDTruncation: true, - KeepAlive: true, + RequestConnectionIDOmission: true, + KeepAlive: true, } // newClient creates a new client @@ -61,26 +61,31 @@ func newClient( tlsConfig *tls.Config, opts *roundTripperOpts, quicConfig *quic.Config, + dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error), ) *client { config := defaultQuicConfig if quicConfig != nil { config = quicConfig } return &client{ - hostname: authorityAddr("https", hostname), - responses: make(map[protocol.StreamID]chan *http.Response), - encryptionLevel: protocol.EncryptionUnencrypted, - tlsConf: tlsConfig, - config: config, - opts: opts, - headerErrored: make(chan struct{}), + hostname: authorityAddr("https", hostname), + responses: make(map[protocol.StreamID]chan *http.Response), + tlsConf: tlsConfig, + config: config, + opts: opts, + headerErrored: make(chan struct{}), + dialer: dialer, } } // dial dials the connection func (c *client) dial() error { var err error - c.session, err = dialAddr(c.hostname, c.tlsConf, c.config) + if c.dialer != nil { + c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config) + } else { + c.session, err = dialAddr(c.hostname, c.tlsConf, c.config) + } if err != nil { return err } @@ -90,9 +95,6 @@ func (c *client) dial() error { if err != nil { return err } - if c.headerStream.StreamID() != 3 { - return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3") - } c.requestWriter = newRequestWriter(c.headerStream) go c.handleHeaderStream() return nil @@ -102,45 +104,44 @@ func (c *client) handleHeaderStream() { decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) h2framer := http2.NewFramer(nil, c.headerStream) - var lastStream protocol.StreamID + var err error + for err == nil { + err = c.readResponse(h2framer, decoder) + } + utils.Debugf("Error handling header stream: %s", err) + c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error()) + // stop all running request + close(c.headerErrored) +} - for { - frame, err := h2framer.ReadFrame() - if err != nil { - c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") - break - } - lastStream = protocol.StreamID(frame.Header().StreamID) - hframe, ok := frame.(*http2.HeadersFrame) - if !ok { - c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame") - break - } - mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} - mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) - if err != nil { - c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields") - break - } - - c.mutex.RLock() - responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] - c.mutex.RUnlock() - if !ok { - c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) - break - } - - rsp, err := responseFromHeaders(mhframe) - if err != nil { - c.headerErr = qerr.Error(qerr.InternalError, err.Error()) - } - responseChan <- rsp +func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error { + frame, err := h2framer.ReadFrame() + if err != nil { + return err + } + hframe, ok := frame.(*http2.HeadersFrame) + if !ok { + return errors.New("not a headers frame") + } + mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} + mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) + if err != nil { + return fmt.Errorf("cannot read header fields: %s", err.Error()) } - // stop all running request - utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error()) - close(c.headerErrored) + c.mutex.RLock() + responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] + c.mutex.RUnlock() + if !ok { + return fmt.Errorf("response channel for stream %d not found", hframe.StreamID) + } + + rsp, err := responseFromHeaders(mhframe) + if err != nil { + return err + } + responseChan <- rsp + return nil } // Roundtrip executes a request and returns a response diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go index dad591cc2..3f323691f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go @@ -13,8 +13,8 @@ import ( "golang.org/x/net/lex/httplex" quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" ) type requestWriter struct { diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go index 246893468..1dd4e928a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go @@ -8,8 +8,8 @@ import ( "sync" quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" ) @@ -83,7 +83,7 @@ func (w *responseWriter) Write(p []byte) (int, error) { func (w *responseWriter) Flush() {} -// TODO: Implement a functional CloseNotify method. +// This is a NOP. Use http.Request.Context func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) } // test that we implement http.Flusher diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go index 9ac5f1933..f6c170b0a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go @@ -41,6 +41,11 @@ type RoundTripper struct { // If nil, reasonable default values will be used. QuicConfig *quic.Config + // Dial specifies an optional dial function for creating QUIC + // connections for requests. + // If Dial is nil, quic.DialAddr will be used. + Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error) + clients map[string]roundTripCloser } @@ -120,7 +125,13 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr if onlyCached { return nil, ErrNoCachedConn } - client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig) + client = newClient( + hostname, + r.TLSClientConfig, + &roundTripperOpts{DisableCompression: r.DisableCompression}, + r.QuicConfig, + r.Dial, + ) r.clients[hostname] = client } return client, nil diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go index 3647dc681..329edfd0d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go @@ -7,14 +7,14 @@ import ( "net" "net/http" "runtime" - "strconv" + "strings" "sync" "sync/atomic" "time" quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" @@ -50,6 +50,7 @@ type Server struct { listenerMutex sync.Mutex listener quic.Listener + closed bool supportedVersionsAsString string } @@ -88,6 +89,10 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { return errors.New("use of h2quic.Server without http.Server") } s.listenerMutex.Lock() + if s.closed { + s.listenerMutex.Unlock() + return errors.New("Server is already closed") + } if s.listener != nil { s.listenerMutex.Unlock() return errors.New("ListenAndServe may only be called once") @@ -122,29 +127,23 @@ func (s *Server) handleHeaderStream(session streamCreator) { session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error())) return } - if stream.StreamID() != 3 { - session.Close(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3")) - return - } hpackDecoder := hpack.NewDecoder(4096, nil) h2framer := http2.NewFramer(nil, stream) - go func() { - var headerStreamMutex sync.Mutex // Protects concurrent calls to Write() - for { - if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil { - // QuicErrors must originate from stream.Read() returning an error. - // In this case, the session has already logged the error, so we don't - // need to log it again. - if _, ok := err.(*qerr.QuicError); !ok { - utils.Errorf("error handling h2 request: %s", err.Error()) - } - session.Close(err) - return + var headerStreamMutex sync.Mutex // Protects concurrent calls to Write() + for { + if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil { + // QuicErrors must originate from stream.Read() returning an error. + // In this case, the session has already logged the error, so we don't + // need to log it again. + if _, ok := err.(*qerr.QuicError); !ok { + utils.Errorf("error handling h2 request: %s", err.Error()) } + session.Close(err) + return } - }() + } } func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error { @@ -170,8 +169,6 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, return err } - req.RemoteAddr = session.RemoteAddr().String() - if utils.Debug() { utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID) } else { @@ -187,19 +184,25 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, return nil } - var streamEnded bool - if h2headersFrame.StreamEnded() { - dataStream.(remoteCloser).CloseRemote(0) - streamEnded = true - _, _ = dataStream.Read([]byte{0}) // read the eof - } - - reqBody := newRequestBody(dataStream) - req.Body = reqBody - - responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID)) - + // handleRequest should be as non-blocking as possible to minimize + // head-of-line blocking. Potentially blocking code is run in a separate + // goroutine, enabling handleRequest to return before the code is executed. go func() { + streamEnded := h2headersFrame.StreamEnded() + if streamEnded { + dataStream.(remoteCloser).CloseRemote(0) + streamEnded = true + _, _ = dataStream.Read([]byte{0}) // read the eof + } + + req = req.WithContext(dataStream.Context()) + reqBody := newRequestBody(dataStream) + req.Body = reqBody + + req.RemoteAddr = session.RemoteAddr().String() + + responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID)) + handler := s.Handler if handler == nil { handler = http.DefaultServeMux @@ -225,7 +228,8 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, } if responseWriter.dataStream != nil { if !streamEnded && !reqBody.requestRead { - responseWriter.dataStream.Reset(nil) + // in gQUIC, the error code doesn't matter, so just use 0 here + responseWriter.dataStream.CancelRead(0) } responseWriter.dataStream.Close() } @@ -243,6 +247,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, func (s *Server) Close() error { s.listenerMutex.Lock() defer s.listenerMutex.Unlock() + s.closed = true if s.listener != nil { err := s.listener.Close() s.listener = nil @@ -279,12 +284,11 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error { } if s.supportedVersionsAsString == "" { - for i, v := range protocol.SupportedVersions { - s.supportedVersionsAsString += strconv.Itoa(int(v)) - if i != len(protocol.SupportedVersions)-1 { - s.supportedVersionsAsString += "," - } + var versions []string + for _, v := range protocol.SupportedVersions { + versions = append(versions, v.ToAltSvc()) } + s.supportedVersionsAsString = strings.Join(versions, ",") } hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString)) @@ -344,6 +348,9 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error } defer tcpConn.Close() + tlsConn := tls.NewListener(tcpConn, config) + defer tlsConn.Close() + // Start the servers httpServer := &http.Server{ Addr: addr, @@ -365,7 +372,7 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error hErr := make(chan error) qErr := make(chan error) go func() { - hErr <- httpServer.Serve(tcpConn) + hErr <- httpServer.Serve(tlsConn) }() go func() { qErr <- quicServer.Serve(udpConn) diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go b/vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go deleted file mode 100644 index 1ad9a3a41..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go +++ /dev/null @@ -1,265 +0,0 @@ -package handshake - -import ( - "bytes" - "sync" - "time" - - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/qerr" -) - -// ConnectionParametersManager negotiates and stores the connection parameters -// A ConnectionParametersManager can be used for a server as well as a client -// For the server: -// 1. call SetFromMap with the values received in the CHLO. This sets the corresponding values here, subject to negotiation -// 2. call GetHelloMap to get the values to send in the SHLO -// For the client: -// 1. call GetHelloMap to get the values to send in a CHLO -// 2. call SetFromMap with the values received in the SHLO -type ConnectionParametersManager interface { - SetFromMap(map[Tag][]byte) error - GetHelloMap() (map[Tag][]byte, error) - - GetSendStreamFlowControlWindow() protocol.ByteCount - GetSendConnectionFlowControlWindow() protocol.ByteCount - GetReceiveStreamFlowControlWindow() protocol.ByteCount - GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount - GetReceiveConnectionFlowControlWindow() protocol.ByteCount - GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount - GetMaxOutgoingStreams() uint32 - GetMaxIncomingStreams() uint32 - GetIdleConnectionStateLifetime() time.Duration - TruncateConnectionID() bool -} - -type connectionParametersManager struct { - mutex sync.RWMutex - - version protocol.VersionNumber - perspective protocol.Perspective - - flowControlNegotiated bool - - truncateConnectionID bool - maxStreamsPerConnection uint32 - maxIncomingDynamicStreamsPerConnection uint32 - idleConnectionStateLifetime time.Duration - sendStreamFlowControlWindow protocol.ByteCount - sendConnectionFlowControlWindow protocol.ByteCount - receiveStreamFlowControlWindow protocol.ByteCount - receiveConnectionFlowControlWindow protocol.ByteCount - maxReceiveStreamFlowControlWindow protocol.ByteCount - maxReceiveConnectionFlowControlWindow protocol.ByteCount -} - -var _ ConnectionParametersManager = &connectionParametersManager{} - -// ErrMalformedTag is returned when the tag value cannot be read -var ( - ErrMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value") - ErrFlowControlRenegotiationNotSupported = qerr.Error(qerr.InvalidCryptoMessageParameter, "renegotiation of flow control parameters not supported") -) - -// NewConnectionParamatersManager creates a new connection parameters manager -func NewConnectionParamatersManager( - pers protocol.Perspective, v protocol.VersionNumber, - maxReceiveStreamFlowControlWindow protocol.ByteCount, maxReceiveConnectionFlowControlWindow protocol.ByteCount, -) ConnectionParametersManager { - h := &connectionParametersManager{ - perspective: pers, - version: v, - sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client - sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client - receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, - receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, - maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, - maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, - } - - if h.perspective == protocol.PerspectiveServer { - h.idleConnectionStateLifetime = protocol.DefaultIdleTimeout - h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent - h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective - } else { - h.idleConnectionStateLifetime = protocol.MaxIdleTimeoutClient - h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent - h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the server's perspective - } - - return h -} - -// SetFromMap reads all params -func (h *connectionParametersManager) SetFromMap(params map[Tag][]byte) error { - h.mutex.Lock() - defer h.mutex.Unlock() - - if value, ok := params[TagTCID]; ok && h.perspective == protocol.PerspectiveServer { - clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.truncateConnectionID = (clientValue == 0) - } - if value, ok := params[TagMSPC]; ok { - clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue) - } - if value, ok := params[TagMIDS]; ok { - clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.maxIncomingDynamicStreamsPerConnection = h.negotiateMaxIncomingDynamicStreamsPerConnection(clientValue) - } - if value, ok := params[TagICSL]; ok { - clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.idleConnectionStateLifetime = h.negotiateIdleConnectionStateLifetime(time.Duration(clientValue) * time.Second) - } - if value, ok := params[TagSFCW]; ok { - if h.flowControlNegotiated { - return ErrFlowControlRenegotiationNotSupported - } - sendStreamFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow) - } - if value, ok := params[TagCFCW]; ok { - if h.flowControlNegotiated { - return ErrFlowControlRenegotiationNotSupported - } - sendConnectionFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow) - } - - _, containsSFCW := params[TagSFCW] - _, containsCFCW := params[TagCFCW] - if containsCFCW || containsSFCW { - h.flowControlNegotiated = true - } - - return nil -} - -func (h *connectionParametersManager) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 { - return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection) -} - -func (h *connectionParametersManager) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 { - return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection) -} - -func (h *connectionParametersManager) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration { - if h.perspective == protocol.PerspectiveServer { - return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutServer) - } - return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutClient) -} - -// GetHelloMap gets all parameters needed for the Hello message -func (h *connectionParametersManager) GetHelloMap() (map[Tag][]byte, error) { - sfcw := bytes.NewBuffer([]byte{}) - utils.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow())) - cfcw := bytes.NewBuffer([]byte{}) - utils.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow())) - mspc := bytes.NewBuffer([]byte{}) - utils.WriteUint32(mspc, h.maxStreamsPerConnection) - mids := bytes.NewBuffer([]byte{}) - utils.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection) - icsl := bytes.NewBuffer([]byte{}) - utils.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second)) - - return map[Tag][]byte{ - TagICSL: icsl.Bytes(), - TagMSPC: mspc.Bytes(), - TagMIDS: mids.Bytes(), - TagCFCW: cfcw.Bytes(), - TagSFCW: sfcw.Bytes(), - }, nil -} - -// GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data -func (h *connectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.sendStreamFlowControlWindow -} - -// GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data -func (h *connectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.sendConnectionFlowControlWindow -} - -// GetReceiveStreamFlowControlWindow gets the size of the stream-level flow control window for receiving data -func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.receiveStreamFlowControlWindow -} - -// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data -func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { - return h.maxReceiveStreamFlowControlWindow -} - -// GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data -func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.receiveConnectionFlowControlWindow -} - -// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data -func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { - return h.maxReceiveConnectionFlowControlWindow -} - -// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection -func (h *connectionParametersManager) GetMaxOutgoingStreams() uint32 { - h.mutex.RLock() - defer h.mutex.RUnlock() - - return h.maxIncomingDynamicStreamsPerConnection -} - -// GetMaxIncomingStreams get the maximum number of incoming streams per connection -func (h *connectionParametersManager) GetMaxIncomingStreams() uint32 { - h.mutex.RLock() - defer h.mutex.RUnlock() - - maxStreams := protocol.MaxIncomingDynamicStreamsPerConnection - return utils.MaxUint32(uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, uint32(float64(maxStreams)*protocol.MaxStreamsMultiplier)) -} - -// GetIdleConnectionStateLifetime gets the idle timeout -func (h *connectionParametersManager) GetIdleConnectionStateLifetime() time.Duration { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.idleConnectionStateLifetime -} - -// TruncateConnectionID determines if the client requests truncated ConnectionIDs -func (h *connectionParametersManager) TruncateConnectionID() bool { - if h.perspective == protocol.PerspectiveClient { - return false - } - - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.truncateConnectionID -} diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/interface.go b/vendor/github.com/lucas-clemente/quic-go/handshake/interface.go deleted file mode 100644 index 751aae1e5..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/interface.go +++ /dev/null @@ -1,24 +0,0 @@ -package handshake - -import "github.com/lucas-clemente/quic-go/protocol" - -// Sealer seals a packet -type Sealer func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte - -// CryptoSetup is a crypto setup -type CryptoSetup interface { - Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) - HandleCryptoStream() error - // TODO: clean up this interface - DiversificationNonce() []byte // only needed for cryptoSetupServer - SetDiversificationNonce([]byte) // only needed for cryptoSetupClient - - GetSealer() (protocol.EncryptionLevel, Sealer) - GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) - GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) -} - -// TransportParameters are parameters sent to the peer during the handshake -type TransportParameters struct { - RequestConnectionIDTruncation bool -} diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go b/vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go deleted file mode 100644 index c3caea3d2..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go +++ /dev/null @@ -1,100 +0,0 @@ -package handshake - -import ( - "encoding/asn1" - "fmt" - "net" - "time" - - "github.com/lucas-clemente/quic-go/crypto" -) - -const ( - stkPrefixIP byte = iota - stkPrefixString -) - -// An STK is a source address token -type STK struct { - RemoteAddr string - SentTime time.Time -} - -// token is the struct that is used for ASN1 serialization and deserialization -type token struct { - Data []byte - Timestamp int64 -} - -// An STKGenerator generates STKs -type STKGenerator struct { - stkSource crypto.StkSource -} - -// NewSTKGenerator initializes a new STKGenerator -func NewSTKGenerator() (*STKGenerator, error) { - stkSource, err := crypto.NewStkSource() - if err != nil { - return nil, err - } - return &STKGenerator{ - stkSource: stkSource, - }, nil -} - -// NewToken generates a new STK token for a given source address -func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) { - data, err := asn1.Marshal(token{ - Data: encodeRemoteAddr(raddr), - Timestamp: time.Now().Unix(), - }) - if err != nil { - return nil, err - } - return g.stkSource.NewToken(data) -} - -// DecodeToken decodes an STK token -func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) { - // if the client didn't send any STK, DecodeToken will be called with a nil-slice - if len(encrypted) == 0 { - return nil, nil - } - - data, err := g.stkSource.DecodeToken(encrypted) - if err != nil { - return nil, err - } - t := &token{} - rest, err := asn1.Unmarshal(data, t) - if err != nil { - return nil, err - } - if len(rest) != 0 { - return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) - } - return &STK{ - RemoteAddr: decodeRemoteAddr(t.Data), - SentTime: time.Unix(t.Timestamp, 0), - }, nil -} - -// encodeRemoteAddr encodes a remote address such that it can be saved in the STK -func encodeRemoteAddr(remoteAddr net.Addr) []byte { - if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { - return append([]byte{stkPrefixIP}, udpAddr.IP...) - } - return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...) -} - -// decodeRemoteAddr decodes the remote address saved in the STK -func decodeRemoteAddr(data []byte) string { - // data will never be empty for an STK that we generated. Check it to be on the safe side - if len(data) == 0 { - return "" - } - if data[0] == stkPrefixIP { - return net.IP(data[1:]).String() - } - return string(data[1:]) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/integrationtests/chrome/chrome.go b/vendor/github.com/lucas-clemente/quic-go/integrationtests/chrome/chrome.go deleted file mode 100644 index 7b442802b..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/integrationtests/chrome/chrome.go +++ /dev/null @@ -1 +0,0 @@ -package chrome diff --git a/vendor/github.com/lucas-clemente/quic-go/integrationtests/gquic/gquic.go b/vendor/github.com/lucas-clemente/quic-go/integrationtests/gquic/gquic.go deleted file mode 100644 index 36ba02f34..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/integrationtests/gquic/gquic.go +++ /dev/null @@ -1 +0,0 @@ -package gquic diff --git a/vendor/github.com/lucas-clemente/quic-go/integrationtests/self/self.go b/vendor/github.com/lucas-clemente/quic-go/integrationtests/self/self.go deleted file mode 100644 index 51e34e221..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/integrationtests/self/self.go +++ /dev/null @@ -1 +0,0 @@ -package self diff --git a/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go b/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go index 70be88892..5b1dc21d0 100644 --- a/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go +++ b/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go @@ -1,14 +1,12 @@ package quicproxy import ( - "bytes" "net" "sync" "sync/atomic" "time" - "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // Connection is a UDP connection @@ -28,21 +26,43 @@ const ( DirectionIncoming Direction = iota // DirectionOutgoing is the direction from the server to the client. DirectionOutgoing + // DirectionBoth is both incoming and outgoing + DirectionBoth ) +func (d Direction) String() string { + switch d { + case DirectionIncoming: + return "incoming" + case DirectionOutgoing: + return "outgoing" + case DirectionBoth: + return "both" + default: + panic("unknown direction") + } +} + +func (d Direction) Is(dir Direction) bool { + if d == DirectionBoth || dir == DirectionBoth { + return true + } + return d == dir +} + // DropCallback is a callback that determines which packet gets dropped. -type DropCallback func(Direction, protocol.PacketNumber) bool +type DropCallback func(dir Direction, packetCount uint64) bool // NoDropper doesn't drop packets. -var NoDropper DropCallback = func(Direction, protocol.PacketNumber) bool { +var NoDropper DropCallback = func(Direction, uint64) bool { return false } // DelayCallback is a callback that determines how much delay to apply to a packet. -type DelayCallback func(Direction, protocol.PacketNumber) time.Duration +type DelayCallback func(dir Direction, packetCount uint64) time.Duration // NoDelay doesn't apply a delay. -var NoDelay DelayCallback = func(Direction, protocol.PacketNumber) time.Duration { +var NoDelay DelayCallback = func(Direction, uint64) time.Duration { return 0 } @@ -62,6 +82,8 @@ type Opts struct { type QuicProxy struct { mutex sync.Mutex + version protocol.VersionNumber + conn *net.UDPConn serverAddr *net.UDPAddr @@ -73,7 +95,10 @@ type QuicProxy struct { } // NewQuicProxy creates a new UDP proxy -func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) { +func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) { + if opts == nil { + opts = &Opts{} + } laddr, err := net.ResolveUDPAddr("udp", local) if err != nil { return nil, err @@ -103,6 +128,7 @@ func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) { serverAddr: raddr, dropPacket: packetDropper, delayPacket: packetDelayer, + version: version, } go p.runProxy() @@ -119,6 +145,7 @@ func (p *QuicProxy) LocalAddr() net.Addr { return p.conn.LocalAddr() } +// LocalPort is the UDP port number the proxy is listening on. func (p *QuicProxy) LocalPort() int { return p.conn.LocalAddr().(*net.UDPAddr).Port } @@ -137,7 +164,7 @@ func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) { // runProxy listens on the proxy address and handles incoming packets. func (p *QuicProxy) runProxy() error { for { - buffer := make([]byte, protocol.MaxPacketSize) + buffer := make([]byte, protocol.MaxReceivePacketSize) n, cliaddr, err := p.conn.ReadFromUDP(buffer) if err != nil { return err @@ -159,20 +186,14 @@ func (p *QuicProxy) runProxy() error { } p.mutex.Unlock() - atomic.AddUint64(&conn.incomingPacketCounter, 1) + packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1) - r := bytes.NewReader(raw) - hdr, err := quic.ParsePublicHeader(r, protocol.PerspectiveClient) - if err != nil { - return err - } - - if p.dropPacket(DirectionIncoming, hdr.PacketNumber) { + if p.dropPacket(DirectionIncoming, packetCount) { continue } // Send the packet to the server - delay := p.delayPacket(DirectionIncoming, hdr.PacketNumber) + delay := p.delayPacket(DirectionIncoming, packetCount) if delay != 0 { time.AfterFunc(delay, func() { // TODO: handle error @@ -190,28 +211,20 @@ func (p *QuicProxy) runProxy() error { // runConnection handles packets from server to a single client func (p *QuicProxy) runConnection(conn *connection) error { for { - buffer := make([]byte, protocol.MaxPacketSize) + buffer := make([]byte, protocol.MaxReceivePacketSize) n, err := conn.ServerConn.Read(buffer) if err != nil { return err } raw := buffer[0:n] - // TODO: Switch back to using the public header once Chrome properly sets the type byte. - // r := bytes.NewReader(raw) - // , err := quic.ParsePublicHeader(r, protocol.PerspectiveServer) - // if err != nil { - // return err - // } + packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1) - v := atomic.AddUint64(&conn.outgoingPacketCounter, 1) - - packetNumber := protocol.PacketNumber(v) - if p.dropPacket(DirectionOutgoing, packetNumber) { + if p.dropPacket(DirectionOutgoing, packetCount) { continue } - delay := p.delayPacket(DirectionOutgoing, packetNumber) + delay := p.delayPacket(DirectionOutgoing, packetCount) if delay != 0 { time.AfterFunc(delay, func() { // TODO: handle error diff --git a/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go b/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go index 7db09a443..783531627 100644 --- a/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go +++ b/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go @@ -27,7 +27,7 @@ var _ = BeforeEach(func() { if len(logFileName) > 0 { var err error - logFile, err = os.Create("./log.txt") + logFile, err = os.Create(logFileName) Expect(err).ToNot(HaveOccurred()) log.SetOutput(logFile) utils.SetLogLevel(utils.LogLevelDebug) diff --git a/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go b/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go index 05ca66dd9..909f560b2 100644 --- a/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go +++ b/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go @@ -7,7 +7,9 @@ import ( "net/http" "strconv" + quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/h2quic" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" . "github.com/onsi/ginkgo" @@ -23,8 +25,9 @@ var ( PRData = GeneratePRData(dataLen) PRDataLong = GeneratePRData(dataLenLong) - server *h2quic.Server - port string + server *h2quic.Server + stoppedServing chan struct{} + port string ) func init() { @@ -75,11 +78,16 @@ func GeneratePRData(l int) []byte { return res } -func StartQuicServer() { +// StartQuicServer starts a h2quic.Server. +// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used. +func StartQuicServer(versions []protocol.VersionNumber) { server = &h2quic.Server{ Server: &http.Server{ TLSConfig: testdata.GetTLSConfig(), }, + QuicConfig: &quic.Config{ + Versions: versions, + }, } addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0") @@ -88,14 +96,18 @@ func StartQuicServer() { Expect(err).NotTo(HaveOccurred()) port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port) + stoppedServing = make(chan struct{}) + go func() { defer GinkgoRecover() server.Serve(conn) + close(stoppedServing) }() } func StopQuicServer() { Expect(server.Close()).NotTo(HaveOccurred()) + Eventually(stoppedServing).Should(BeClosed()) } func Port() string { diff --git a/vendor/github.com/lucas-clemente/quic-go/interface.go b/vendor/github.com/lucas-clemente/quic-go/interface.go index 41b607ca1..b0a18293d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/interface.go @@ -6,23 +6,55 @@ import ( "net" "time" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" ) +// The StreamID is the ID of a QUIC stream. +type StreamID = protocol.StreamID + +// A VersionNumber is a QUIC version number. +type VersionNumber = protocol.VersionNumber + +// A Cookie can be used to verify the ownership of the client address. +type Cookie = handshake.Cookie + +// ConnectionState records basic details about the QUIC connection. +type ConnectionState = handshake.ConnectionState + +// An ErrorCode is an application-defined error code. +type ErrorCode = protocol.ApplicationErrorCode + // Stream is the interface implemented by QUIC streams type Stream interface { + // StreamID returns the stream ID. + StreamID() StreamID // Read reads data from the stream. // Read can be made to time out and return a net.Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetReadDeadline. + // If the stream was canceled by the peer, the error implements the StreamError + // interface, and Canceled() == true. io.Reader // Write writes data to the stream. // Write can be made to time out and return a net.Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetWriteDeadline. + // If the stream was canceled by the peer, the error implements the StreamError + // interface, and Canceled() == true. io.Writer + // Close closes the write-direction of the stream. + // Future calls to Write are not permitted after calling Close. + // It must not be called concurrently with Write. + // It must not be called after calling CancelWrite. io.Closer - StreamID() protocol.StreamID - // Reset closes the stream with an error. - Reset(error) + // CancelWrite aborts sending on this stream. + // It must not be called after Close. + // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably. + // Write will unblock immediately, and future calls to Write will fail. + CancelWrite(ErrorCode) error + // CancelRead aborts receiving on this stream. + // It will ask the peer to stop transmitting stream data. + // Read will unblock immediately, and future Read calls will fail. + CancelRead(ErrorCode) error // The context is canceled as soon as the write-side of the stream is closed. // This happens when Close() is called, or when the stream is reset (either locally or remotely). // Warning: This API should not be considered stable and might change soon. @@ -43,6 +75,41 @@ type Stream interface { SetDeadline(t time.Time) error } +// A ReceiveStream is a unidirectional Receive Stream. +type ReceiveStream interface { + // see Stream.StreamID + StreamID() StreamID + // see Stream.Read + io.Reader + // see Stream.CancelRead + CancelRead(ErrorCode) error + // see Stream.SetReadDealine + SetReadDeadline(t time.Time) error +} + +// A SendStream is a unidirectional Send Stream. +type SendStream interface { + // see Stream.StreamID + StreamID() StreamID + // see Stream.Write + io.Writer + // see Stream.Close + io.Closer + // see Stream.CancelWrite + CancelWrite(ErrorCode) error + // see Stream.Context + Context() context.Context + // see Stream.SetWriteDeadline + SetWriteDeadline(t time.Time) error +} + +// StreamError is returned by Read and Write when the peer cancels the stream. +type StreamError interface { + error + Canceled() bool + ErrorCode() ErrorCode +} + // A Session is a QUIC connection between two peers. type Session interface { // AcceptStream returns the next stream opened by the peer, blocking until one is available. @@ -64,53 +131,41 @@ type Session interface { // The context is cancelled when the session is closed. // Warning: This API should not be considered stable and might change soon. Context() context.Context -} - -// A NonFWSession is a QUIC connection between two peers half-way through the handshake. -// The communication is encrypted, but not yet forward secure. -type NonFWSession interface { - Session - WaitUntilHandshakeComplete() error -} - -// An STK is a Source Address token. -// It is issued by the server and sent to the client. For the client, it is an opaque blob. -// The client can send the STK in subsequent handshakes to prove ownership of its IP address. -type STK struct { - // The remote address this token was issued for. - // If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String()) - // Otherwise, this is the string representation of the net.Addr (net.Addr.String()) - remoteAddr string - // The time that the STK was issued (resolution 1 second) - sentTime time.Time + // ConnectionState returns basic details about the QUIC connection. + // Warning: This API should not be considered stable and might change soon. + ConnectionState() ConnectionState } // Config contains all configuration data needed for a QUIC server or client. -// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441. type Config struct { // The QUIC versions that can be negotiated. // If not set, it uses all versions available. // Warning: This API should not be considered stable and will change soon. - Versions []protocol.VersionNumber - // Ask the server to truncate the connection ID sent in the Public Header. + Versions []VersionNumber + // Ask the server to omit the connection ID sent in the Public Header. // This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated. // Currently only valid for the client. - RequestConnectionIDTruncation bool + RequestConnectionIDOmission bool // HandshakeTimeout is the maximum duration that the cryptographic handshake may take. // If the timeout is exceeded, the connection is closed. // If this value is zero, the timeout is set to 10 seconds. HandshakeTimeout time.Duration - // AcceptSTK determines if an STK is accepted. - // It is called with stk = nil if the client didn't send an STK. - // If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours. + // IdleTimeout is the maximum duration that may pass without any incoming network activity. + // This value only applies after the handshake has completed. + // If the timeout is exceeded, the connection is closed. + // If this value is zero, the timeout is set to 30 seconds. + IdleTimeout time.Duration + // AcceptCookie determines if a Cookie is accepted. + // It is called with cookie = nil if the client didn't send an Cookie. + // If not set, it verifies that the address matches, and that the Cookie was issued within the last 24 hours. // This option is only valid for the server. - AcceptSTK func(clientAddr net.Addr, stk *STK) bool + AcceptCookie func(clientAddr net.Addr, cookie *Cookie) bool // MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data. // If this value is zero, it will default to 1 MB for the server and 6 MB for the client. - MaxReceiveStreamFlowControlWindow protocol.ByteCount + MaxReceiveStreamFlowControlWindow uint64 // MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data. // If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client. - MaxReceiveConnectionFlowControlWindow protocol.ByteCount + MaxReceiveConnectionFlowControlWindow uint64 // KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive. KeepAlive bool } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go new file mode 100644 index 000000000..09b2c0172 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go @@ -0,0 +1,48 @@ +package ackhandler + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +// SentPacketHandler handles ACKs received for outgoing packets +type SentPacketHandler interface { + // SentPacket may modify the packet + SentPacket(packet *Packet) error + ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error + SetHandshakeComplete() + + // SendingAllowed says if a packet can be sent. + // Sending packets might not be possible because: + // * we're congestion limited + // * we're tracking the maximum number of sent packets + SendingAllowed() bool + // TimeUntilSend is the time when the next packet should be sent. + // It is used for pacing packets. + TimeUntilSend() time.Time + // ShouldSendNumPackets returns the number of packets that should be sent immediately. + // It always returns a number greater or equal than 1. + // A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay. + // Note that the number of packets is only calculated based on the pacing algorithm. + // Before sending any packet, SendingAllowed() must be called to learn if we can actually send it. + ShouldSendNumPackets() int + + GetStopWaitingFrame(force bool) *wire.StopWaitingFrame + GetLowestPacketNotConfirmedAcked() protocol.PacketNumber + DequeuePacketForRetransmission() (packet *Packet) + GetLeastUnacked() protocol.PacketNumber + + GetAlarmTimeout() time.Time + OnAlarm() +} + +// ReceivedPacketHandler handles ACKs needed to send for incoming packets +type ReceivedPacketHandler interface { + ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error + IgnoreBelow(protocol.PacketNumber) + + GetAlarmTimeout() time.Time + GetAckFrame() *wire.AckFrame +} diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet.go similarity index 51% rename from vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go rename to vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet.go index e9dbf6ab4..e4213a0b6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet.go @@ -3,29 +3,30 @@ package ackhandler import ( "time" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) // A Packet is a packet // +gen linkedlist type Packet struct { PacketNumber protocol.PacketNumber - Frames []frames.Frame + Frames []wire.Frame Length protocol.ByteCount EncryptionLevel protocol.EncryptionLevel - SendTime time.Time + largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK + sendTime time.Time } // GetFramesForRetransmission gets all the frames for retransmission -func (p *Packet) GetFramesForRetransmission() []frames.Frame { - var fs []frames.Frame +func (p *Packet) GetFramesForRetransmission() []wire.Frame { + var fs []wire.Frame for _, frame := range p.Frames { switch frame.(type) { - case *frames.AckFrame: + case *wire.AckFrame: continue - case *frames.StopWaitingFrame: + case *wire.StopWaitingFrame: continue } fs = append(fs, frame) diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet_linkedlist.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/ackhandler/packet_linkedlist.go rename to vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go similarity index 63% rename from vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go rename to vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go index 0661fdc29..c316af4aa 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go @@ -1,18 +1,15 @@ package ackhandler import ( - "errors" "time" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) -var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number") - type receivedPacketHandler struct { largestObserved protocol.PacketNumber - lowerLimit protocol.PacketNumber + ignoreBelow protocol.PacketNumber largestObservedReceivedTime time.Time packetHistory *receivedPacketHistory @@ -23,46 +20,45 @@ type receivedPacketHandler struct { retransmittablePacketsReceivedSinceLastAck int ackQueued bool ackAlarm time.Time - lastAck *frames.AckFrame + lastAck *wire.AckFrame + + version protocol.VersionNumber } // NewReceivedPacketHandler creates a new receivedPacketHandler -func NewReceivedPacketHandler() ReceivedPacketHandler { +func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler { return &receivedPacketHandler{ packetHistory: newReceivedPacketHistory(), ackSendDelay: protocol.AckSendDelay, + version: version, } } -func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error { - if packetNumber == 0 { - return errInvalidPacketNumber - } - +func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error { if packetNumber > h.largestObserved { h.largestObserved = packetNumber - h.largestObservedReceivedTime = time.Now() + h.largestObservedReceivedTime = rcvTime } - if packetNumber <= h.lowerLimit { + if packetNumber < h.ignoreBelow { return nil } if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil { return err } - h.maybeQueueAck(packetNumber, shouldInstigateAck) + h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck) return nil } -// SetLowerLimit sets a lower limit for acking packets. -// Packets with packet numbers smaller or equal than p will not be acked. -func (h *receivedPacketHandler) SetLowerLimit(p protocol.PacketNumber) { - h.lowerLimit = p - h.packetHistory.DeleteUpTo(p) +// IgnoreBelow sets a lower limit for acking packets. +// Packets with packet numbers smaller than p will not be acked. +func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) { + h.ignoreBelow = p + h.packetHistory.DeleteBelow(p) } -func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) { +func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) { h.packetsReceivedSinceLastAck++ if shouldInstigateAck { @@ -74,12 +70,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber h.ackQueued = true } - // Always send an ack every 20 packets in order to allow the peer to discard - // information from the SentPacketManager and provide an RTT measurement. - if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend { - h.ackQueued = true - } - // if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK // note that it cannot be a duplicate because they're already filtered out by ReceivedPacket() if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked { @@ -87,7 +77,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber } // check if a new missing range above the previously was created - if h.lastAck != nil && h.packetHistory.GetHighestAckRange().FirstPacketNumber > h.lastAck.LargestAcked { + if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked { h.ackQueued = true } @@ -96,7 +86,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber h.ackQueued = true } else { if h.ackAlarm.IsZero() { - h.ackAlarm = time.Now().Add(h.ackSendDelay) + h.ackAlarm = rcvTime.Add(h.ackSendDelay) } } } @@ -107,15 +97,15 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber } } -func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame { +func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame { if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) { return nil } ackRanges := h.packetHistory.GetAckRanges() - ack := &frames.AckFrame{ + ack := &wire.AckFrame{ LargestAcked: h.largestObserved, - LowestAcked: ackRanges[len(ackRanges)-1].FirstPacketNumber, + LowestAcked: ackRanges[len(ackRanges)-1].First, PacketReceivedTime: h.largestObservedReceivedTime, } diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go similarity index 53% rename from vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go rename to vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go index a17cae5ea..ba1195448 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go @@ -1,9 +1,9 @@ package ackhandler import ( - "github.com/lucas-clemente/quic-go/frames" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) @@ -12,21 +12,15 @@ import ( type receivedPacketHistory struct { ranges *utils.PacketIntervalList - // the map is used as a replacement for a set here. The bool is always supposed to be set to true - receivedPacketNumbers map[protocol.PacketNumber]bool lowestInReceivedPacketNumbers protocol.PacketNumber } -var ( - errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges") - errTooManyOutstandingReceivedPackets = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received packets") -) +var errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges") // newReceivedPacketHistory creates a new received packet history func newReceivedPacketHistory() *receivedPacketHistory { return &receivedPacketHistory{ - ranges: utils.NewPacketIntervalList(), - receivedPacketNumbers: make(map[protocol.PacketNumber]bool), + ranges: utils.NewPacketIntervalList(), } } @@ -36,12 +30,6 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error { return errTooManyOutstandingReceivedAckRanges } - if len(h.receivedPacketNumbers) >= protocol.MaxTrackedReceivedPackets { - return errTooManyOutstandingReceivedPackets - } - - h.receivedPacketNumbers[p] = true - if h.ranges.Len() == 0 { h.ranges.PushBack(utils.PacketInterval{Start: p, End: p}) return nil @@ -86,23 +74,20 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error { return nil } -// DeleteUpTo deletes all entries up to (and including) p -func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) { - h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, p+1) +// DeleteBelow deletes all entries below (but not including) p +func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) { + if p <= h.lowestInReceivedPacketNumbers { + return + } + h.lowestInReceivedPacketNumbers = p nextEl := h.ranges.Front() for el := h.ranges.Front(); nextEl != nil; el = nextEl { nextEl = el.Next() - if p >= el.Value.Start && p < el.Value.End { - for i := el.Value.Start; i <= p; i++ { // adjust start value of a range - delete(h.receivedPacketNumbers, i) - } - el.Value.Start = p + 1 - } else if el.Value.End <= p { // delete a whole range - for i := el.Value.Start; i <= el.Value.End; i++ { - delete(h.receivedPacketNumbers, i) - } + if p > el.Value.Start && p <= el.Value.End { + el.Value.Start = p + } else if el.Value.End < p { // delete a whole range h.ranges.Remove(el) } else { // no ranges affected. Nothing to do return @@ -110,38 +95,27 @@ func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) { } } -// IsDuplicate determines if a packet should be regarded as a duplicate packet -// note that after receiving a StopWaitingFrame, all packets below the LeastUnacked should be regarded as duplicates, even if the packet was just delayed -func (h *receivedPacketHistory) IsDuplicate(p protocol.PacketNumber) bool { - if p < h.lowestInReceivedPacketNumbers { - return true - } - - _, ok := h.receivedPacketNumbers[p] - return ok -} - // GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame -func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange { +func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange { if h.ranges.Len() == 0 { return nil } - var ackRanges []frames.AckRange - + ackRanges := make([]wire.AckRange, h.ranges.Len()) + i := 0 for el := h.ranges.Back(); el != nil; el = el.Prev() { - ackRanges = append(ackRanges, frames.AckRange{FirstPacketNumber: el.Value.Start, LastPacketNumber: el.Value.End}) + ackRanges[i] = wire.AckRange{First: el.Value.Start, Last: el.Value.End} + i++ } - return ackRanges } -func (h *receivedPacketHistory) GetHighestAckRange() frames.AckRange { - ackRange := frames.AckRange{} +func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange { + ackRange := wire.AckRange{} if h.ranges.Len() > 0 { r := h.ranges.Back().Value - ackRange.FirstPacketNumber = r.Start - ackRange.LastPacketNumber = r.End + ackRange.First = r.Start + ackRange.Last = r.End } return ackRange } diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/retransmittable.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/retransmittable.go similarity index 61% rename from vendor/github.com/lucas-clemente/quic-go/ackhandler/retransmittable.go rename to vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/retransmittable.go index 17437b8c9..e6ce46f87 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/retransmittable.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/retransmittable.go @@ -1,12 +1,10 @@ package ackhandler -import ( - "github.com/lucas-clemente/quic-go/frames" -) +import "github.com/lucas-clemente/quic-go/internal/wire" // Returns a new slice with all non-retransmittable frames deleted. -func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame { - res := make([]frames.Frame, 0, len(fs)) +func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame { + res := make([]wire.Frame, 0, len(fs)) for _, f := range fs { if IsFrameRetransmittable(f) { res = append(res, f) @@ -16,11 +14,11 @@ func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame { } // IsFrameRetransmittable returns true if the frame should be retransmitted. -func IsFrameRetransmittable(f frames.Frame) bool { +func IsFrameRetransmittable(f wire.Frame) bool { switch f.(type) { - case *frames.StopWaitingFrame: + case *wire.StopWaitingFrame: return false - case *frames.AckFrame: + case *wire.AckFrame: return false default: return true @@ -28,7 +26,7 @@ func IsFrameRetransmittable(f frames.Frame) bool { } // HasRetransmittableFrames returns true if at least one frame is retransmittable. -func HasRetransmittableFrames(fs []frames.Frame) bool { +func HasRetransmittableFrames(fs []wire.Frame) bool { for _, f := range fs { if IsFrameRetransmittable(f) { return true diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go similarity index 60% rename from vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go rename to vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go index 300b665ed..08b4ee5a3 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go @@ -3,12 +3,13 @@ package ackhandler import ( "errors" "fmt" + "math" "time" - "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/frames" + "github.com/lucas-clemente/quic-go/internal/congestion" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) @@ -16,33 +17,33 @@ const ( // Maximum reordering in time space before time based loss detection considers a packet lost. // In fraction of an RTT. timeReorderingFraction = 1.0 / 8 + // The default RTT used before an RTT sample is taken. + // Note: This constant is also defined in the congestion package. + defaultInitialRTT = 100 * time.Millisecond // defaultRTOTimeout is the RTO time on new connections defaultRTOTimeout = 500 * time.Millisecond + // Minimum time in the future a tail loss probe alarm may be set for. + minTPLTimeout = 10 * time.Millisecond // Minimum time in the future an RTO alarm may be set for. minRTOTimeout = 200 * time.Millisecond // maxRTOTimeout is the maximum RTO time maxRTOTimeout = 60 * time.Second ) -var ( - // ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received - ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK") - // ErrTooManyTrackedSentPackets occurs when the sentPacketHandler has to keep track of too many packets - ErrTooManyTrackedSentPackets = errors.New("Too many outstanding non-acked and non-retransmitted packets") - // ErrAckForSkippedPacket occurs when the client sent an ACK for a packet number that we intentionally skipped - ErrAckForSkippedPacket = qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number") - errAckForUnsentPacket = qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package") -) - -var errPacketNumberNotIncreasing = errors.New("Already sent a packet with a higher packet number") +// ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received +var ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK") type sentPacketHandler struct { lastSentPacketNumber protocol.PacketNumber + nextPacketSendTime time.Time skippedPackets []protocol.PacketNumber - LargestAcked protocol.PacketNumber - + largestAcked protocol.PacketNumber largestReceivedPacketWithAck protocol.PacketNumber + // lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived + // example: we send an ACK for packets 90-100 with packet number 20 + // once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101 + lowestPacketNotConfirmedAcked protocol.PacketNumber packetHistory *PacketList stopWaitingManager stopWaitingManager @@ -54,6 +55,10 @@ type sentPacketHandler struct { congestion congestion.SendAlgorithm rttStats *congestion.RTTStats + handshakeComplete bool + // The number of times the handshake packets have been retransmitted without receiving an ack. + handshakeCount uint32 + // The number of times an RTO has been sent without receiving an ack. rtoCount uint32 @@ -82,20 +87,27 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler { } } -func (h *sentPacketHandler) largestInOrderAcked() protocol.PacketNumber { +func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber { if f := h.packetHistory.Front(); f != nil { - return f.Value.PacketNumber - 1 + return f.Value.PacketNumber } - return h.LargestAcked + return h.largestAcked + 1 +} + +func (h *sentPacketHandler) SetHandshakeComplete() { + var queue []*Packet + for _, packet := range h.retransmissionQueue { + if packet.EncryptionLevel == protocol.EncryptionForwardSecure { + queue = append(queue, packet) + } + } + h.retransmissionQueue = queue + h.handshakeComplete = true } func (h *sentPacketHandler) SentPacket(packet *Packet) error { - if packet.PacketNumber <= h.lastSentPacketNumber { - return errPacketNumberNotIncreasing - } - if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets { - return ErrTooManyTrackedSentPackets + return errors.New("Too many outstanding non-acked and non-retransmitted packets") } for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ { @@ -106,14 +118,22 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { } } - h.lastSentPacketNumber = packet.PacketNumber now := time.Now() + h.lastSentPacketNumber = packet.PacketNumber + + var largestAcked protocol.PacketNumber + if len(packet.Frames) > 0 { + if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok { + largestAcked = ackFrame.LargestAcked + } + } packet.Frames = stripNonRetransmittableFrames(packet.Frames) isRetransmittable := len(packet.Frames) != 0 if isRetransmittable { - packet.SendTime = now + packet.sendTime = now + packet.largestAcked = largestAcked h.bytesInFlight += packet.Length h.packetHistory.PushBack(*packet) } @@ -126,29 +146,32 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { isRetransmittable, ) - h.updateLossDetectionAlarm() + h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, now).Add(h.congestion.TimeUntilSend(h.bytesInFlight)) + + h.updateLossDetectionAlarm(now) return nil } -func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, rcvTime time.Time) error { +func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { if ackFrame.LargestAcked > h.lastSentPacketNumber { - return errAckForUnsentPacket + return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package") } // duplicate or out-of-order ACK + // if withPacketNumber <= h.largestReceivedPacketWithAck && withPacketNumber != 0 { if withPacketNumber <= h.largestReceivedPacketWithAck { return ErrDuplicateOrOutOfOrderAck } h.largestReceivedPacketWithAck = withPacketNumber // ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK) - if ackFrame.LargestAcked <= h.largestInOrderAcked() { + if ackFrame.LargestAcked < h.lowestUnacked() { return nil } - h.LargestAcked = ackFrame.LargestAcked + h.largestAcked = ackFrame.LargestAcked if h.skippedPacketsAcked(ackFrame) { - return ErrAckForSkippedPacket + return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number") } rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime) @@ -164,13 +187,22 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum if len(ackedPackets) > 0 { for _, p := range ackedPackets { + if encLevel < p.Value.EncryptionLevel { + return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel) + } + // largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0 + // It is safe to ignore the corner case of packets that just acked packet 0, because + // the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send. + if p.Value.largestAcked != 0 { + h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.Value.largestAcked+1) + } h.onPacketAcked(p) h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight) } } - h.detectLostPackets() - h.updateLossDetectionAlarm() + h.detectLostPackets(rcvTime) + h.updateLossDetectionAlarm(rcvTime) h.garbageCollectSkippedPackets() h.stopWaitingManager.ReceivedAck(ackFrame) @@ -178,7 +210,11 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum return nil } -func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame) ([]*PacketElement, error) { +func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { + return h.lowestPacketNotConfirmedAcked +} + +func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) { var ackedPackets []*PacketElement ackRangeIndex := 0 for el := h.packetHistory.Front(); el != nil; el = el.Next() { @@ -197,14 +233,14 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame if ackFrame.HasMissingRanges() { ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] - for packetNumber > ackRange.LastPacketNumber && ackRangeIndex < len(ackFrame.AckRanges)-1 { + for packetNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 { ackRangeIndex++ ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] } - if packetNumber >= ackRange.FirstPacketNumber { // packet i contained in ACK range - if packetNumber > ackRange.LastPacketNumber { - return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.FirstPacketNumber, ackRange.LastPacketNumber) + if packetNumber >= ackRange.First { // packet i contained in ACK range + if packetNumber > ackRange.Last { + return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.First, ackRange.Last) } ackedPackets = append(ackedPackets, el) } @@ -212,7 +248,6 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame ackedPackets = append(ackedPackets, el) } } - return ackedPackets, nil } @@ -220,7 +255,7 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a for el := h.packetHistory.Front(); el != nil; el = el.Next() { packet := el.Value if packet.PacketNumber == largestAcked { - h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now()) + h.rttStats.UpdateRTT(rcvTime.Sub(packet.sendTime), ackDelay, rcvTime) return true } // Packets are sorted by number, so we can stop searching @@ -231,27 +266,27 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a return false } -func (h *sentPacketHandler) updateLossDetectionAlarm() { +func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) { // Cancel the alarm if no packets are outstanding if h.packetHistory.Len() == 0 { h.alarm = time.Time{} return } - // TODO(#496): Handle handshake packets separately // TODO(#497): TLP - if !h.lossTime.IsZero() { + if !h.handshakeComplete { + h.alarm = now.Add(h.computeHandshakeTimeout()) + } else if !h.lossTime.IsZero() { // Early retransmit timer or time loss detection. h.alarm = h.lossTime } else { // RTO - h.alarm = time.Now().Add(h.computeRTOTimeout()) + h.alarm = now.Add(h.computeRTOTimeout()) } } -func (h *sentPacketHandler) detectLostPackets() { +func (h *sentPacketHandler) detectLostPackets(now time.Time) { h.lossTime = time.Time{} - now := time.Now() maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT) @@ -260,11 +295,11 @@ func (h *sentPacketHandler) detectLostPackets() { for el := h.packetHistory.Front(); el != nil; el = el.Next() { packet := el.Value - if packet.PacketNumber > h.LargestAcked { + if packet.PacketNumber > h.largestAcked { break } - timeSinceSent := now.Sub(packet.SendTime) + timeSinceSent := now.Sub(packet.sendTime) if timeSinceSent > delayUntilLost { lostPackets = append(lostPackets, el) } else if h.lossTime.IsZero() { @@ -282,18 +317,22 @@ func (h *sentPacketHandler) detectLostPackets() { } func (h *sentPacketHandler) OnAlarm() { - // TODO(#496): Handle handshake packets separately + now := time.Now() + // TODO(#497): TLP - if !h.lossTime.IsZero() { + if !h.handshakeComplete { + h.queueHandshakePacketsForRetransmission() + h.handshakeCount++ + } else if !h.lossTime.IsZero() { // Early retransmit or time loss detection - h.detectLostPackets() + h.detectLostPackets(now) } else { // RTO h.retransmitOldestTwoPackets() h.rtoCount++ } - h.updateLossDetectionAlarm() + h.updateLossDetectionAlarm(now) } func (h *sentPacketHandler) GetAlarmTimeout() time.Time { @@ -303,6 +342,7 @@ func (h *sentPacketHandler) GetAlarmTimeout() time.Time { func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) { h.bytesInFlight -= packetElement.Value.Length h.rtoCount = 0 + h.handshakeCount = 0 // TODO(#497): h.tlpCount = 0 h.packetHistory.Remove(packetElement) } @@ -320,20 +360,19 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet { } func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber { - return h.largestInOrderAcked() + 1 + return h.lowestUnacked() } -func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame { +func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame { return h.stopWaitingManager.GetStopWaitingFrame(force) } func (h *sentPacketHandler) SendingAllowed() bool { - congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow() + cwnd := h.congestion.GetCongestionWindow() + congestionLimited := h.bytesInFlight > cwnd maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets if congestionLimited { - utils.Debugf("Congestion limited: bytes in flight %d, window %d", - h.bytesInFlight, - h.congestion.GetCongestionWindow()) + utils.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd) } // Workaround for #555: // Always allow sending of retransmissions. This should probably be limited @@ -342,6 +381,18 @@ func (h *sentPacketHandler) SendingAllowed() bool { return !maxTrackedLimited && (!congestionLimited || haveRetransmissions) } +func (h *sentPacketHandler) TimeUntilSend() time.Time { + return h.nextPacketSendTime +} + +func (h *sentPacketHandler) ShouldSendNumPackets() int { + delay := h.congestion.TimeUntilSend(h.bytesInFlight) + if delay == 0 || delay > protocol.MinPacingDelay { + return 1 + } + return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay))) +} + func (h *sentPacketHandler) retransmitOldestTwoPackets() { if p := h.packetHistory.Front(); p != nil { h.queueRTO(p) @@ -363,6 +414,18 @@ func (h *sentPacketHandler) queueRTO(el *PacketElement) { h.congestion.OnRetransmissionTimeout(true) } +func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() { + var handshakePackets []*PacketElement + for el := h.packetHistory.Front(); el != nil; el = el.Next() { + if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure { + handshakePackets = append(handshakePackets, el) + } + } + for _, el := range handshakePackets { + h.queuePacketForRetransmission(el) + } +} + func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) { packet := &packetElement.Value h.bytesInFlight -= packet.Length @@ -371,6 +434,17 @@ func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketEl h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber) } +func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration { + duration := 2 * h.rttStats.SmoothedRTT() + if duration == 0 { + duration = 2 * defaultInitialRTT + } + duration = utils.MaxDuration(duration, minTPLTimeout) + // exponential backoff + // There's an implicit limit to this set by the handshake timeout. + return duration << h.handshakeCount +} + func (h *sentPacketHandler) computeRTOTimeout() time.Duration { rto := h.congestion.RetransmissionDelay() if rto == 0 { @@ -382,7 +456,7 @@ func (h *sentPacketHandler) computeRTOTimeout() time.Duration { return utils.MinDuration(rto, maxRTOTimeout) } -func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *frames.AckFrame) bool { +func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool { for _, p := range h.skippedPackets { if ackFrame.AcksPacket(p) { return true @@ -392,10 +466,10 @@ func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *frames.AckFrame) bool } func (h *sentPacketHandler) garbageCollectSkippedPackets() { - lioa := h.largestInOrderAcked() + lowestUnacked := h.lowestUnacked() deleteIndex := 0 for i, p := range h.skippedPackets { - if p <= lioa { + if p < lowestUnacked { deleteIndex = i + 1 } } diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/stop_waiting_manager.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go similarity index 72% rename from vendor/github.com/lucas-clemente/quic-go/ackhandler/stop_waiting_manager.go rename to vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go index dfd79ae0f..04cb61f9d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/stop_waiting_manager.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go @@ -1,8 +1,8 @@ package ackhandler import ( - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) // This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33 @@ -10,10 +10,10 @@ type stopWaitingManager struct { largestLeastUnackedSent protocol.PacketNumber nextLeastUnacked protocol.PacketNumber - lastStopWaitingFrame *frames.StopWaitingFrame + lastStopWaitingFrame *wire.StopWaitingFrame } -func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame { +func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame { if s.nextLeastUnacked <= s.largestLeastUnackedSent { if force { return s.lastStopWaitingFrame @@ -22,14 +22,14 @@ func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *frames.StopWaiting } s.largestLeastUnackedSent = s.nextLeastUnacked - swf := &frames.StopWaitingFrame{ + swf := &wire.StopWaitingFrame{ LeastUnacked: s.nextLeastUnacked, } s.lastStopWaitingFrame = swf return swf } -func (s *stopWaitingManager) ReceivedAck(ack *frames.AckFrame) { +func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) { if ack.LargestAcked >= s.nextLeastUnacked { s.nextLeastUnacked = ack.LargestAcked + 1 } diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/bandwidth.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/bandwidth.go similarity index 90% rename from vendor/github.com/lucas-clemente/quic-go/congestion/bandwidth.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/bandwidth.go index e76ea161c..54269c567 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/bandwidth.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/bandwidth.go @@ -3,7 +3,7 @@ package congestion import ( "time" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // Bandwidth of a connection diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/clock.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/clock.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/congestion/clock.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/clock.go diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/cubic.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go similarity index 99% rename from vendor/github.com/lucas-clemente/quic-go/congestion/cubic.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go index 62e735563..3922f4760 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/cubic.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go @@ -4,8 +4,8 @@ import ( "math" "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" ) // This cubic implementation is based on the one found in Chromiums's QUIC diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/cubic_sender.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go similarity index 95% rename from vendor/github.com/lucas-clemente/quic-go/congestion/cubic_sender.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go index 02e4206b6..1ab59535f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/cubic_sender.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go @@ -3,8 +3,8 @@ package congestion import ( "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" ) const ( @@ -76,15 +76,19 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio } } -func (c *cubicSender) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration { +// TimeUntilSend returns when the next packet should be sent. +func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration { if c.InRecovery() { // PRR is used when in recovery. - return c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) + if c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) == 0 { + return 0 + } } - if c.GetCongestionWindow() > bytesInFlight { - return 0 + delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow()/protocol.DefaultTCPMSS) + if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt + delay = delay * 8 / 5 } - return utils.InfDuration + return delay } func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool { diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/hybrid_slow_start.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/hybrid_slow_start.go similarity index 98% rename from vendor/github.com/lucas-clemente/quic-go/congestion/hybrid_slow_start.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/hybrid_slow_start.go index 01a64f826..f41c1e5c3 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/hybrid_slow_start.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/hybrid_slow_start.go @@ -3,8 +3,8 @@ package congestion import ( "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" ) // Note(pwestin): the magic clamping numbers come from the original code in diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go similarity index 90% rename from vendor/github.com/lucas-clemente/quic-go/congestion/interface.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go index bbce0a637..3c09428fd 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go @@ -3,12 +3,12 @@ package congestion import ( "time" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // A SendAlgorithm performs congestion control and calculates the congestion window type SendAlgorithm interface { - TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration + TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool GetCongestionWindow() protocol.ByteCount MaybeExitSlowStart() diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/prr_sender.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/prr_sender.go similarity index 97% rename from vendor/github.com/lucas-clemente/quic-go/congestion/prr_sender.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/prr_sender.go index b8a0a10be..18a3736a8 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/prr_sender.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/prr_sender.go @@ -3,8 +3,8 @@ package congestion import ( "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" ) // PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937 diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go similarity index 97% rename from vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go index 546c1cb98..9e5e4541a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go @@ -7,6 +7,7 @@ import ( ) const ( + // Note: This constant is also defined in the ackhandler package. initialRTTus = 100 * 1000 rttAlpha float32 = 0.125 oneMinusAlpha float32 = (1 - rttAlpha) @@ -97,10 +98,10 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { r.updateRecentMinRTT(sendDelta, now) // Correct for ackDelay if information received from the peer results in a - // positive RTT sample. Otherwise, we use the sendDelta as a reasonable - // measure for smoothedRTT. + // an RTT sample at least as large as minRTT. Otherwise, only use the + // sendDelta. sample := sendDelta - if sample > ackDelay { + if sample-r.minRTT >= ackDelay { sample -= ackDelay } r.latestRTT = sample diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/stats.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/stats.go similarity index 69% rename from vendor/github.com/lucas-clemente/quic-go/congestion/stats.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/stats.go index 8f272b26d..ed669c146 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/stats.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/stats.go @@ -1,6 +1,6 @@ package congestion -import "github.com/lucas-clemente/quic-go/protocol" +import "github.com/lucas-clemente/quic-go/internal/protocol" type connectionStats struct { slowstartPacketsLost protocol.PacketNumber diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/AEAD.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/AEAD.go similarity index 79% rename from vendor/github.com/lucas-clemente/quic-go/crypto/AEAD.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/AEAD.go index a59ce6e8e..d1905159d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/AEAD.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/AEAD.go @@ -1,9 +1,10 @@ package crypto -import "github.com/lucas-clemente/quic-go/protocol" +import "github.com/lucas-clemente/quic-go/internal/protocol" // An AEAD implements QUIC's authenticated encryption and associated data type AEAD interface { Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte + Overhead() int } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm12_aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm12_aead.go new file mode 100644 index 000000000..55e45be65 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm12_aead.go @@ -0,0 +1,72 @@ +package crypto + +import ( + "crypto/cipher" + "encoding/binary" + "errors" + + "github.com/lucas-clemente/aes12" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type aeadAESGCM12 struct { + otherIV []byte + myIV []byte + encrypter cipher.AEAD + decrypter cipher.AEAD +} + +var _ AEAD = &aeadAESGCM12{} + +// NewAEADAESGCM12 creates a AEAD using AES-GCM with 12 bytes tag size +// +// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte +// tag size, and couples the cipher and aes packages closely. +// See https://github.com/lucas-clemente/aes12. +func NewAEADAESGCM12(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) { + if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 { + return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs") + } + encrypterCipher, err := aes12.NewCipher(myKey) + if err != nil { + return nil, err + } + encrypter, err := aes12.NewGCM(encrypterCipher) + if err != nil { + return nil, err + } + decrypterCipher, err := aes12.NewCipher(otherKey) + if err != nil { + return nil, err + } + decrypter, err := aes12.NewGCM(decrypterCipher) + if err != nil { + return nil, err + } + return &aeadAESGCM12{ + otherIV: otherIV, + myIV: myIV, + encrypter: encrypter, + decrypter: decrypter, + }, nil +} + +func (aead *aeadAESGCM12) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { + return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData) +} + +func (aead *aeadAESGCM12) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { + return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData) +} + +func (aead *aeadAESGCM12) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { + res := make([]byte, 12) + copy(res[0:4], iv) + binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber)) + return res +} + +func (aead *aeadAESGCM12) Overhead() int { + return aead.encrypter.Overhead() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm_aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm_aead.go new file mode 100644 index 000000000..d55974e62 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm_aead.go @@ -0,0 +1,74 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/binary" + "errors" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type aeadAESGCM struct { + otherIV []byte + myIV []byte + encrypter cipher.AEAD + decrypter cipher.AEAD +} + +var _ AEAD = &aeadAESGCM{} + +const ivLen = 12 + +// NewAEADAESGCM creates a AEAD using AES-GCM +func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) { + // the IVs need to be at least 8 bytes long, otherwise we can't compute the nonce + if len(otherIV) != ivLen || len(myIV) != ivLen { + return nil, errors.New("AES-GCM: expected 12 byte IVs") + } + + encrypterCipher, err := aes.NewCipher(myKey) + if err != nil { + return nil, err + } + encrypter, err := cipher.NewGCM(encrypterCipher) + if err != nil { + return nil, err + } + decrypterCipher, err := aes.NewCipher(otherKey) + if err != nil { + return nil, err + } + decrypter, err := cipher.NewGCM(decrypterCipher) + if err != nil { + return nil, err + } + + return &aeadAESGCM{ + otherIV: otherIV, + myIV: myIV, + encrypter: encrypter, + decrypter: decrypter, + }, nil +} + +func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { + return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData) +} + +func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { + return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData) +} + +func (aead *aeadAESGCM) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { + nonce := make([]byte, ivLen) + binary.BigEndian.PutUint64(nonce[ivLen-8:], uint64(packetNumber)) + for i := 0; i < ivLen; i++ { + nonce[i] ^= iv[i] + } + return nonce +} + +func (aead *aeadAESGCM) Overhead() int { + return aead.encrypter.Overhead() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_cache.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_cache.go similarity index 94% rename from vendor/github.com/lucas-clemente/quic-go/crypto/cert_cache.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_cache.go index 3ebdc1ae5..d8e8d8f36 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_cache.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_cache.go @@ -5,7 +5,7 @@ import ( "hash/fnv" "github.com/hashicorp/golang-lru" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) var ( diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_chain.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/cert_chain.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_compression.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_compression.go similarity index 94% rename from vendor/github.com/lucas-clemente/quic-go/crypto/cert_compression.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_compression.go index ea5ecff36..908b7ce98 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_compression.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_compression.go @@ -51,10 +51,10 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by res.WriteByte(uint8(e.t)) switch e.t { case entryCached: - utils.WriteUint64(res, e.h) + utils.LittleEndian.WriteUint64(res, e.h) case entryCommon: - utils.WriteUint64(res, e.h) - utils.WriteUint32(res, e.i) + utils.LittleEndian.WriteUint64(res, e.h) + utils.LittleEndian.WriteUint32(res, e.i) case entryCompressed: totalUncompressedLen += 4 + len(chain[i]) } @@ -67,7 +67,7 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by return nil, fmt.Errorf("cert compression failed: %s", err.Error()) } - utils.WriteUint32(res, uint32(totalUncompressedLen)) + utils.LittleEndian.WriteUint32(res, uint32(totalUncompressedLen)) for i, e := range entries { if e.t != entryCompressed { @@ -115,11 +115,11 @@ func decompressChain(data []byte) ([][]byte, error) { return nil, errors.New("unexpected cached certificate") case entryCommon: e := entry{t: entryCommon} - e.h, err = utils.ReadUint64(r) + e.h, err = utils.LittleEndian.ReadUint64(r) if err != nil { return nil, err } - e.i, err = utils.ReadUint32(r) + e.i, err = utils.LittleEndian.ReadUint32(r) if err != nil { return nil, err } @@ -146,7 +146,7 @@ func decompressChain(data []byte) ([][]byte, error) { } if hasCompressedCerts { - uncompressedLength, err := utils.ReadUint32(r) + uncompressedLength, err := utils.LittleEndian.ReadUint32(r) if err != nil { fmt.Println(4) return nil, err diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_dict.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_dict.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/cert_dict.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_dict.go diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_manager.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go similarity index 96% rename from vendor/github.com/lucas-clemente/quic-go/crypto/cert_manager.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go index 5aaa1877c..8b8c9faa8 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_manager.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go @@ -18,6 +18,7 @@ type CertManager interface { GetLeafCertHash() (uint64, error) VerifyServerProof(proof, chlo, serverConfigData []byte) bool Verify(hostname string) error + GetChain() []*x509.Certificate } type certManager struct { @@ -54,6 +55,10 @@ func (c *certManager) SetData(data []byte) error { return nil } +func (c *certManager) GetChain() []*x509.Certificate { + return c.chain +} + func (c *certManager) GetCommonCertificateHashes() []byte { return getCommonCertificateHashes() } diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_sets.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_sets.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/cert_sets.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_sets.go diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/chacha20poly1305_aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go similarity index 72% rename from vendor/github.com/lucas-clemente/quic-go/crypto/chacha20poly1305_aead.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go index 5c58c4e3c..5d2e36f9f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/chacha20poly1305_aead.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go @@ -4,11 +4,12 @@ package crypto import ( "crypto/cipher" + "encoding/binary" "errors" "github.com/aead/chacha20" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) type aeadChacha20Poly1305 struct { @@ -45,9 +46,16 @@ func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV } func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { - return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData) + return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData) } func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData) + return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData) +} + +func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { + res := make([]byte, 12) + copy(res[0:4], iv) + binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber)) + return res } diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/curve_25519.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/curve_25519.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go new file mode 100644 index 000000000..316bd1b3b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go @@ -0,0 +1,49 @@ +package crypto + +import ( + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +const ( + clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret" + serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret" +) + +// A TLSExporter gets the negotiated ciphersuite and computes exporter +type TLSExporter interface { + GetCipherSuite() mint.CipherSuiteParams + ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) +} + +// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance +func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) { + var myLabel, otherLabel string + if pers == protocol.PerspectiveClient { + myLabel = clientExporterLabel + otherLabel = serverExporterLabel + } else { + myLabel = serverExporterLabel + otherLabel = clientExporterLabel + } + myKey, myIV, err := computeKeyAndIV(tls, myLabel) + if err != nil { + return nil, err + } + otherKey, otherIV, err := computeKeyAndIV(tls, otherLabel) + if err != nil { + return nil, err + } + return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) +} + +func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) { + cs := tls.GetCipherSuite() + secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size()) + if err != nil { + return nil, nil, err + } + key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen) + iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen) + return key, iv, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/key_derivation.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go similarity index 84% rename from vendor/github.com/lucas-clemente/quic-go/crypto/key_derivation.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go index accdbeaa2..28f6c2cca 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/key_derivation.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go @@ -5,8 +5,8 @@ import ( "crypto/sha256" "io" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "golang.org/x/crypto/hkdf" ) @@ -20,8 +20,8 @@ import ( // return NewAEADChacha20Poly1305(otherKey, myKey, otherIV, myIV) // } -// DeriveKeysAESGCM derives the client and server keys and creates a matching AES-GCM AEAD instance -func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) { +// DeriveQuicCryptoAESKeys derives the client and server keys and creates a matching AES-GCM AEAD instance +func DeriveQuicCryptoAESKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) { var swap bool if pers == protocol.PerspectiveClient { swap = true @@ -30,7 +30,7 @@ func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID pr if err != nil { return nil, err } - return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) + return NewAEADAESGCM12(otherKey, myKey, otherIV, myIV) } // deriveKeys derives the keys and the IVs @@ -42,7 +42,7 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol } else { info.Write([]byte("QUIC key expansion\x00")) } - utils.WriteUint64(&info, uint64(connID)) + utils.BigEndian.WriteUint64(&info, uint64(connID)) info.Write(chlo) info.Write(scfg) info.Write(cert) diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/key_exchange.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_exchange.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/key_exchange.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_exchange.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead.go new file mode 100644 index 000000000..27158bee5 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead.go @@ -0,0 +1,11 @@ +package crypto + +import "github.com/lucas-clemente/quic-go/internal/protocol" + +// NewNullAEAD creates a NullAEAD +func NewNullAEAD(p protocol.Perspective, connID protocol.ConnectionID, v protocol.VersionNumber) (AEAD, error) { + if v.UsesTLS() { + return newNullAEADAESGCM(connID, p) + } + return &nullAEADFNV128a{perspective: p}, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go new file mode 100644 index 000000000..a647ad7f6 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go @@ -0,0 +1,44 @@ +package crypto + +import ( + "crypto" + "encoding/binary" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39} + +func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) { + clientSecret, serverSecret := computeSecrets(connectionID) + + var mySecret, otherSecret []byte + if pers == protocol.PerspectiveClient { + mySecret = clientSecret + otherSecret = serverSecret + } else { + mySecret = serverSecret + otherSecret = clientSecret + } + + myKey, myIV := computeNullAEADKeyAndIV(mySecret) + otherKey, otherIV := computeNullAEADKeyAndIV(otherSecret) + + return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) +} + +func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) { + connID := make([]byte, 8) + binary.BigEndian.PutUint64(connID, uint64(connectionID)) + cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID) + clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size()) + serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size()) + return +} + +func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) { + key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16) + iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12) + return +} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/null_aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go similarity index 55% rename from vendor/github.com/lucas-clemente/quic-go/crypto/null_aead.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go index ed8566337..ecc4010bd 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/null_aead.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go @@ -5,27 +5,18 @@ import ( "errors" "github.com/lucas-clemente/fnv128a" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // nullAEAD handles not-yet encrypted packets -type nullAEAD struct { +type nullAEADFNV128a struct { perspective protocol.Perspective - version protocol.VersionNumber } -var _ AEAD = &nullAEAD{} - -// NewNullAEAD creates a NullAEAD -func NewNullAEAD(p protocol.Perspective, v protocol.VersionNumber) AEAD { - return &nullAEAD{ - perspective: p, - version: v, - } -} +var _ AEAD = &nullAEADFNV128a{} // Open and verify the ciphertext -func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { +func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { if len(src) < 12 { return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long") } @@ -33,12 +24,10 @@ func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, ass hash := fnv128a.New() hash.Write(associatedData) hash.Write(src[12:]) - if n.version >= protocol.Version37 { - if n.perspective == protocol.PerspectiveServer { - hash.Write([]byte("Client")) - } else { - hash.Write([]byte("Server")) - } + if n.perspective == protocol.PerspectiveServer { + hash.Write([]byte("Client")) + } else { + hash.Write([]byte("Server")) } testHigh, testLow := hash.Sum128() @@ -52,7 +41,7 @@ func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, ass } // Seal writes hash and ciphertext to the buffer -func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { +func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { if cap(dst) < 12+len(src) { dst = make([]byte, 12+len(src)) } else { @@ -63,12 +52,10 @@ func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, ass hash.Write(associatedData) hash.Write(src) - if n.version >= protocol.Version37 { - if n.perspective == protocol.PerspectiveServer { - hash.Write([]byte("Server")) - } else { - hash.Write([]byte("Client")) - } + if n.perspective == protocol.PerspectiveServer { + hash.Write([]byte("Server")) + } else { + hash.Write([]byte("Client")) } high, low := hash.Sum128() @@ -78,3 +65,7 @@ func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, ass binary.LittleEndian.PutUint32(dst[8:], uint32(high)) return dst } + +func (n *nullAEADFNV128a) Overhead() int { + return 12 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/server_proof.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/server_proof.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/server_proof.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/server_proof.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go new file mode 100644 index 000000000..3a5c01615 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go @@ -0,0 +1,108 @@ +package flowcontrol + +import ( + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/congestion" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type baseFlowController struct { + // for sending data + bytesSent protocol.ByteCount + sendWindow protocol.ByteCount + + // for receiving data + mutex sync.RWMutex + bytesRead protocol.ByteCount + highestReceived protocol.ByteCount + receiveWindow protocol.ByteCount + receiveWindowSize protocol.ByteCount + maxReceiveWindowSize protocol.ByteCount + + epochStartTime time.Time + epochStartOffset protocol.ByteCount + rttStats *congestion.RTTStats +} + +func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { + c.bytesSent += n +} + +// UpdateSendWindow should be called after receiving a WindowUpdateFrame +// it returns true if the window was actually updated +func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { + if offset > c.sendWindow { + c.sendWindow = offset + } +} + +func (c *baseFlowController) sendWindowSize() protocol.ByteCount { + // this only happens during connection establishment, when data is sent before we receive the peer's transport parameters + if c.bytesSent > c.sendWindow { + return 0 + } + return c.sendWindow - c.bytesSent +} + +func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) { + c.mutex.Lock() + defer c.mutex.Unlock() + + // pretend we sent a WindowUpdate when reading the first byte + // this way auto-tuning of the window size already works for the first WindowUpdate + if c.bytesRead == 0 { + c.startNewAutoTuningEpoch() + } + c.bytesRead += n +} + +func (c *baseFlowController) hasWindowUpdate() bool { + bytesRemaining := c.receiveWindow - c.bytesRead + // update the window when more than the threshold was consumed + return bytesRemaining <= protocol.ByteCount((float64(c.receiveWindowSize) * float64((1 - protocol.WindowUpdateThreshold)))) +} + +// getWindowUpdate updates the receive window, if necessary +// it returns the new offset +func (c *baseFlowController) getWindowUpdate() protocol.ByteCount { + if !c.hasWindowUpdate() { + return 0 + } + + c.maybeAdjustWindowSize() + c.receiveWindow = c.bytesRead + c.receiveWindowSize + return c.receiveWindow +} + +// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often. +// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing. +func (c *baseFlowController) maybeAdjustWindowSize() { + bytesReadInEpoch := c.bytesRead - c.epochStartOffset + // don't do anything if less than half the window has been consumed + if bytesReadInEpoch <= c.receiveWindowSize/2 { + return + } + rtt := c.rttStats.SmoothedRTT() + if rtt == 0 { + return + } + + fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize) + if time.Since(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) { + // window is consumed too fast, try to increase the window size + c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize) + } + c.startNewAutoTuningEpoch() +} + +func (c *baseFlowController) startNewAutoTuningEpoch() { + c.epochStartTime = time.Now() + c.epochStartOffset = c.bytesRead +} + +func (c *baseFlowController) checkFlowControlViolation() bool { + return c.highestReceived > c.receiveWindow +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go new file mode 100644 index 000000000..975cc5836 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go @@ -0,0 +1,83 @@ +package flowcontrol + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/congestion" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +type connectionFlowController struct { + lastBlockedAt protocol.ByteCount + baseFlowController +} + +var _ ConnectionFlowController = &connectionFlowController{} + +// NewConnectionFlowController gets a new flow controller for the connection +// It is created before we receive the peer's transport paramenters, thus it starts with a sendWindow of 0. +func NewConnectionFlowController( + receiveWindow protocol.ByteCount, + maxReceiveWindow protocol.ByteCount, + rttStats *congestion.RTTStats, +) ConnectionFlowController { + return &connectionFlowController{ + baseFlowController: baseFlowController{ + rttStats: rttStats, + receiveWindow: receiveWindow, + receiveWindowSize: receiveWindow, + maxReceiveWindowSize: maxReceiveWindow, + }, + } +} + +func (c *connectionFlowController) SendWindowSize() protocol.ByteCount { + return c.baseFlowController.sendWindowSize() +} + +// IsNewlyBlocked says if it is newly blocked by flow control. +// For every offset, it only returns true once. +// If it is blocked, the offset is returned. +func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { + if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt { + return false, 0 + } + c.lastBlockedAt = c.sendWindow + return true, c.sendWindow +} + +// IncrementHighestReceived adds an increment to the highestReceived value +func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.highestReceived += increment + if c.checkFlowControlViolation() { + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow)) + } + return nil +} + +func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { + c.mutex.Lock() + oldWindowSize := c.receiveWindowSize + offset := c.baseFlowController.getWindowUpdate() + if oldWindowSize < c.receiveWindowSize { + utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) + } + c.mutex.Unlock() + return offset +} + +// EnsureMinimumWindowSize sets a minimum window size +// it should make sure that the connection-level window is increased when a stream-level window grows +func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) { + c.mutex.Lock() + if inc > c.receiveWindowSize { + c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize) + c.startNewAutoTuningEpoch() + } + c.mutex.Unlock() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go new file mode 100644 index 000000000..61d57e31b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go @@ -0,0 +1,42 @@ +package flowcontrol + +import "github.com/lucas-clemente/quic-go/internal/protocol" + +type flowController interface { + // for sending + SendWindowSize() protocol.ByteCount + UpdateSendWindow(protocol.ByteCount) + AddBytesSent(protocol.ByteCount) + // for receiving + AddBytesRead(protocol.ByteCount) + GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary +} + +// A StreamFlowController is a flow controller for a QUIC stream. +type StreamFlowController interface { + flowController + // for sending + IsBlocked() (bool, protocol.ByteCount) + // for receiving + // UpdateHighestReceived should be called when a new highest offset is received + // final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame + UpdateHighestReceived(offset protocol.ByteCount, final bool) error + // HasWindowUpdate says if it is necessary to update the window + HasWindowUpdate() bool +} + +// The ConnectionFlowController is the flow controller for the connection. +type ConnectionFlowController interface { + flowController + // for sending + IsNewlyBlocked() (bool, protocol.ByteCount) +} + +type connectionFlowControllerI interface { + ConnectionFlowController + // The following two methods are not supposed to be called from outside this packet, but are needed internally + // for sending + EnsureMinimumWindowSize(protocol.ByteCount) + // for receiving + IncrementHighestReceived(protocol.ByteCount) error +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go new file mode 100644 index 000000000..51ecfe7f0 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go @@ -0,0 +1,147 @@ +package flowcontrol + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/congestion" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +type streamFlowController struct { + baseFlowController + + streamID protocol.StreamID + + connection connectionFlowControllerI + contributesToConnection bool // does the stream contribute to connection level flow control + + receivedFinalOffset bool +} + +var _ StreamFlowController = &streamFlowController{} + +// NewStreamFlowController gets a new flow controller for a stream +func NewStreamFlowController( + streamID protocol.StreamID, + contributesToConnection bool, + cfc ConnectionFlowController, + receiveWindow protocol.ByteCount, + maxReceiveWindow protocol.ByteCount, + initialSendWindow protocol.ByteCount, + rttStats *congestion.RTTStats, +) StreamFlowController { + return &streamFlowController{ + streamID: streamID, + contributesToConnection: contributesToConnection, + connection: cfc.(connectionFlowControllerI), + baseFlowController: baseFlowController{ + rttStats: rttStats, + receiveWindow: receiveWindow, + receiveWindowSize: receiveWindow, + maxReceiveWindowSize: maxReceiveWindow, + sendWindow: initialSendWindow, + }, + } +} + +// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher +// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before +func (c *streamFlowController) UpdateHighestReceived(byteOffset protocol.ByteCount, final bool) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + // when receiving a final offset, check that this final offset is consistent with a final offset we might have received earlier + if final && c.receivedFinalOffset && byteOffset != c.highestReceived { + return qerr.Error(qerr.StreamDataAfterTermination, fmt.Sprintf("Received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, byteOffset)) + } + // if we already received a final offset, check that the offset in the STREAM frames is below the final offset + if c.receivedFinalOffset && byteOffset > c.highestReceived { + return qerr.StreamDataAfterTermination + } + if final { + c.receivedFinalOffset = true + } + if byteOffset == c.highestReceived { + return nil + } + if byteOffset <= c.highestReceived { + // a STREAM_FRAME with a higher offset was received before. + if final { + // If the current byteOffset is smaller than the offset in that STREAM_FRAME, this STREAM_FRAME contained data after the end of the stream + return qerr.StreamDataAfterTermination + } + // this is a reordered STREAM_FRAME + return nil + } + + increment := byteOffset - c.highestReceived + c.highestReceived = byteOffset + if c.checkFlowControlViolation() { + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, c.streamID, c.receiveWindow)) + } + if c.contributesToConnection { + return c.connection.IncrementHighestReceived(increment) + } + return nil +} + +func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) { + c.baseFlowController.AddBytesRead(n) + if c.contributesToConnection { + c.connection.AddBytesRead(n) + } +} + +func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) { + c.baseFlowController.AddBytesSent(n) + if c.contributesToConnection { + c.connection.AddBytesSent(n) + } +} + +func (c *streamFlowController) SendWindowSize() protocol.ByteCount { + window := c.baseFlowController.sendWindowSize() + if c.contributesToConnection { + window = utils.MinByteCount(window, c.connection.SendWindowSize()) + } + return window +} + +// IsBlocked says if it is blocked by stream-level flow control. +// If it is blocked, the offset is returned. +func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) { + if c.sendWindowSize() != 0 { + return false, 0 + } + return true, c.sendWindow +} + +func (c *streamFlowController) HasWindowUpdate() bool { + c.mutex.Lock() + hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate() + c.mutex.Unlock() + return hasWindowUpdate +} + +func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { + // don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler + c.mutex.Lock() + // if we already received the final offset for this stream, the peer won't need any additional flow control credit + if c.receivedFinalOffset { + c.mutex.Unlock() + return 0 + } + + oldWindowSize := c.receiveWindowSize + offset := c.baseFlowController.getWindowUpdate() + if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size + utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) + if c.contributesToConnection { + c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)) + } + } + c.mutex.Unlock() + return offset +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go new file mode 100644 index 000000000..97accb73d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go @@ -0,0 +1,101 @@ +package handshake + +import ( + "encoding/asn1" + "fmt" + "net" + "time" + + "github.com/bifurcation/mint" +) + +const ( + cookiePrefixIP byte = iota + cookiePrefixString +) + +// A Cookie is derived from the client address and can be used to verify the ownership of this address. +type Cookie struct { + RemoteAddr string + // The time that the STK was issued (resolution 1 second) + SentTime time.Time +} + +// token is the struct that is used for ASN1 serialization and deserialization +type token struct { + Data []byte + Timestamp int64 +} + +// A CookieGenerator generates Cookies +type CookieGenerator struct { + cookieProtector mint.CookieProtector +} + +// NewCookieGenerator initializes a new CookieGenerator +func NewCookieGenerator() (*CookieGenerator, error) { + cookieProtector, err := mint.NewDefaultCookieProtector() + if err != nil { + return nil, err + } + return &CookieGenerator{ + cookieProtector: cookieProtector, + }, nil +} + +// NewToken generates a new Cookie for a given source address +func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) { + data, err := asn1.Marshal(token{ + Data: encodeRemoteAddr(raddr), + Timestamp: time.Now().Unix(), + }) + if err != nil { + return nil, err + } + return g.cookieProtector.NewToken(data) +} + +// DecodeToken decodes a Cookie +func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) { + // if the client didn't send any Cookie, DecodeToken will be called with a nil-slice + if len(encrypted) == 0 { + return nil, nil + } + + data, err := g.cookieProtector.DecodeToken(encrypted) + if err != nil { + return nil, err + } + t := &token{} + rest, err := asn1.Unmarshal(data, t) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) + } + return &Cookie{ + RemoteAddr: decodeRemoteAddr(t.Data), + SentTime: time.Unix(t.Timestamp, 0), + }, nil +} + +// encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie +func encodeRemoteAddr(remoteAddr net.Addr) []byte { + if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { + return append([]byte{cookiePrefixIP}, udpAddr.IP...) + } + return append([]byte{cookiePrefixString}, []byte(remoteAddr.String())...) +} + +// decodeRemoteAddr decodes the remote address saved in the Cookie +func decodeRemoteAddr(data []byte) string { + // data will never be empty for a Cookie that we generated. Check it to be on the safe side + if len(data) == 0 { + return "" + } + if data[0] == cookiePrefixIP { + return net.IP(data[1:]).String() + } + return string(data[1:]) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go new file mode 100644 index 000000000..1d3052c49 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go @@ -0,0 +1,43 @@ +package handshake + +import ( + "net" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type CookieHandler struct { + callback func(net.Addr, *Cookie) bool + + cookieGenerator *CookieGenerator +} + +var _ mint.CookieHandler = &CookieHandler{} + +func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) { + cookieGenerator, err := NewCookieGenerator() + if err != nil { + return nil, err + } + return &CookieHandler{ + callback: callback, + cookieGenerator: cookieGenerator, + }, nil +} + +func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) { + if h.callback(conn.RemoteAddr(), nil) { + return nil, nil + } + return h.cookieGenerator.NewToken(conn.RemoteAddr()) +} + +func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool { + data, err := h.cookieGenerator.DecodeToken(token) + if err != nil { + utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error()) + return false + } + return h.callback(conn.RemoteAddr(), data) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_client.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go similarity index 76% rename from vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_client.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go index a8d881297..cb500b583 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_client.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go @@ -11,9 +11,9 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) @@ -23,6 +23,7 @@ type cryptoSetupClient struct { hostname string connID protocol.ConnectionID version protocol.VersionNumber + initialVersion protocol.VersionNumber negotiatedVersions []protocol.VersionNumber cryptoStream io.ReadWriter @@ -42,17 +43,18 @@ type cryptoSetupClient struct { clientHelloCounter int serverVerified bool // has the certificate chain and the proof already been verified - keyDerivation KeyDerivationFunction + keyDerivation QuicCryptoKeyDerivationFunction keyExchange KeyExchangeFunction receivedSecurePacket bool nullAEAD crypto.AEAD secureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD - aeadChanged chan<- protocol.EncryptionLevel - params *TransportParameters - connectionParameters ConnectionParametersManager + paramsChan chan<- TransportParameters + handshakeEvent chan<- struct{} + + params *TransportParameters } var _ CryptoSetup = &cryptoSetupClient{} @@ -65,36 +67,42 @@ var ( // NewCryptoSetupClient creates a new CryptoSetup instance for a client func NewCryptoSetupClient( + cryptoStream io.ReadWriter, hostname string, connID protocol.ConnectionID, version protocol.VersionNumber, - cryptoStream io.ReadWriter, tlsConfig *tls.Config, - connectionParameters ConnectionParametersManager, - aeadChanged chan<- protocol.EncryptionLevel, params *TransportParameters, + paramsChan chan<- TransportParameters, + handshakeEvent chan<- struct{}, + initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, ) (CryptoSetup, error) { + nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) + if err != nil { + return nil, err + } return &cryptoSetupClient{ - hostname: hostname, - connID: connID, - version: version, - cryptoStream: cryptoStream, - certManager: crypto.NewCertManager(tlsConfig), - connectionParameters: connectionParameters, - keyDerivation: crypto.DeriveKeysAESGCM, - keyExchange: getEphermalKEX, - nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), - aeadChanged: aeadChanged, - negotiatedVersions: negotiatedVersions, - divNonceChan: make(chan []byte), - params: params, + cryptoStream: cryptoStream, + hostname: hostname, + connID: connID, + version: version, + certManager: crypto.NewCertManager(tlsConfig), + params: params, + keyDerivation: crypto.DeriveQuicCryptoAESKeys, + keyExchange: getEphermalKEX, + nullAEAD: nullAEAD, + paramsChan: paramsChan, + handshakeEvent: handshakeEvent, + initialVersion: initialVersion, + negotiatedVersions: negotiatedVersions, + divNonceChan: make(chan []byte), }, nil } func (h *cryptoSetupClient) HandleCryptoStream() error { messageChan := make(chan HandshakeMessage) - errorChan := make(chan error) + errorChan := make(chan error, 1) go func() { for { @@ -141,15 +149,21 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { utils.Debugf("Got %s", message) switch message.Tag { case TagREJ: - err = h.handleREJMessage(message.Data) + if err := h.handleREJMessage(message.Data); err != nil { + return err + } case TagSHLO: - err = h.handleSHLOMessage(message.Data) + params, err := h.handleSHLOMessage(message.Data) + if err != nil { + return err + } + // blocks until the session has received the parameters + h.paramsChan <- *params + h.handshakeEvent <- struct{}{} + close(h.handshakeEvent) default: return qerr.InvalidCryptoMessageType } - if err != nil { - return err - } } } @@ -215,12 +229,12 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { return nil } -func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { +func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) { h.mutex.Lock() defer h.mutex.Unlock() if !h.receivedSecurePacket { - return qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message") + return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message") } if sno, ok := cryptoData[TagSNO]; ok { @@ -229,22 +243,22 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { serverPubs, ok := cryptoData[TagPUBS] if !ok { - return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") + return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") } verTag, ok := cryptoData[TagVER] if !ok { - return qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list") + return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list") } if !h.validateVersionList(verTag) { - return qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") + return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") } nonce := append(h.nonc, h.sno...) ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs) if err != nil { - return err + return nil, err } leafCert := h.certManager.GetLeafCert() @@ -261,39 +275,32 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { protocol.PerspectiveClient, ) if err != nil { - return err + return nil, err } - err = h.connectionParameters.SetFromMap(cryptoData) + params, err := readHelloMap(cryptoData) if err != nil { - return qerr.InvalidCryptoMessageParameter + return nil, qerr.InvalidCryptoMessageParameter } - - h.aeadChanged <- protocol.EncryptionForwardSecure - close(h.aeadChanged) - - return nil + return params, nil } func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool { - if len(h.negotiatedVersions) == 0 { + numNegotiatedVersions := len(h.negotiatedVersions) + if numNegotiatedVersions == 0 { return true } - if len(verTags)%4 != 0 || len(verTags)/4 != len(h.negotiatedVersions) { + if len(verTags)%4 != 0 || len(verTags)/4 != numNegotiatedVersions { return false } b := bytes.NewReader(verTags) - for _, negotiatedVersion := range h.negotiatedVersions { - verTag, err := utils.ReadUint32(b) + for i := 0; i < numNegotiatedVersions; i++ { + v, err := utils.BigEndian.ReadUint32(b) if err != nil { // should never occur, since the length was already checked return false } - ver := protocol.VersionTagToNumber(verTag) - if !protocol.IsSupportedVersion(protocol.SupportedVersions, ver) { - ver = protocol.VersionUnsupported - } - if ver != negotiatedVersion { + if protocol.VersionNumber(v) != h.negotiatedVersions[i] { return false } } @@ -333,16 +340,16 @@ func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) { h.mutex.RLock() defer h.mutex.RUnlock() if h.forwardSecureAEAD != nil { - return protocol.EncryptionForwardSecure, h.sealForwardSecure + return protocol.EncryptionForwardSecure, h.forwardSecureAEAD } else if h.secureAEAD != nil { - return protocol.EncryptionSecure, h.sealSecure + return protocol.EncryptionSecure, h.secureAEAD } else { - return protocol.EncryptionUnencrypted, h.sealUnencrypted + return protocol.EncryptionUnencrypted, h.nullAEAD } } func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) { - return protocol.EncryptionUnencrypted, h.sealUnencrypted + return protocol.EncryptionUnencrypted, h.nullAEAD } func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { @@ -351,33 +358,21 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry switch encLevel { case protocol.EncryptionUnencrypted: - return h.sealUnencrypted, nil + return h.nullAEAD, nil case protocol.EncryptionSecure: if h.secureAEAD == nil { return nil, errors.New("CryptoSetupClient: no secureAEAD") } - return h.sealSecure, nil + return h.secureAEAD, nil case protocol.EncryptionForwardSecure: if h.forwardSecureAEAD == nil { return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD") } - return h.sealForwardSecure, nil + return h.forwardSecureAEAD, nil } return nil, errors.New("CryptoSetupClient: no encryption level specified") } -func (h *cryptoSetupClient) sealUnencrypted(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return h.nullAEAD.Seal(dst, src, packetNumber, associatedData) -} - -func (h *cryptoSetupClient) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return h.secureAEAD.Seal(dst, src, packetNumber, associatedData) -} - -func (h *cryptoSetupClient) sealForwardSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData) -} - func (h *cryptoSetupClient) DiversificationNonce() []byte { panic("not needed for cryptoSetupClient") } @@ -386,6 +381,15 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) { h.divNonceChan <- data } +func (h *cryptoSetupClient) ConnectionState() ConnectionState { + h.mutex.Lock() + defer h.mutex.Unlock() + return ConnectionState{ + HandshakeComplete: h.forwardSecureAEAD != nil, + PeerCertificates: h.certManager.GetChain(), + } +} + func (h *cryptoSetupClient) sendCHLO() error { h.clientHelloCounter++ if h.clientHelloCounter > protocol.MaxClientHellos { @@ -413,15 +417,11 @@ func (h *cryptoSetupClient) sendCHLO() error { } h.lastSentCHLO = b.Bytes() - return nil } func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) { - tags, err := h.connectionParameters.GetHelloMap() - if err != nil { - return nil, err - } + tags := h.params.getHelloMap() tags[TagSNI] = []byte(h.hostname) tags[TagPDMD] = []byte("X509") @@ -431,12 +431,9 @@ func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) { } versionTag := make([]byte, 4) - binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version)) + binary.BigEndian.PutUint32(versionTag, uint32(h.initialVersion)) tags[TagVER] = versionTag - if h.params.RequestConnectionIDTruncation { - tags[TagTCID] = []byte{0, 0, 0, 0} - } if len(h.stk) > 0 { tags[TagSTK] = h.stk } @@ -470,7 +467,7 @@ func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) { for _, tag := range tags { size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data } - paddingSize := protocol.ClientHelloMinimumSize - size + paddingSize := protocol.MinClientHelloSize - size if paddingSize > 0 { tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize) } @@ -508,10 +505,8 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error { if err != nil { return err } - - h.aeadChanged <- protocol.EncryptionSecure + h.handshakeEvent <- struct{}{} } - return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_server.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go similarity index 78% rename from vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_server.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go index 35dd6f05d..7d5f32ee8 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_server.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go @@ -9,47 +9,51 @@ import ( "net" "sync" - "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) -// KeyDerivationFunction is used for key derivation -type KeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) +// QuicCryptoKeyDerivationFunction is used for key derivation +type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) // KeyExchangeFunction is used to make a new KEX type KeyExchangeFunction func() crypto.KeyExchange // The CryptoSetupServer handles all things crypto for the Session type cryptoSetupServer struct { + mutex sync.RWMutex + connID protocol.ConnectionID remoteAddr net.Addr scfg *ServerConfig - stkGenerator *STKGenerator diversificationNonce []byte version protocol.VersionNumber supportedVersions []protocol.VersionNumber - acceptSTKCallback func(net.Addr, *STK) bool + acceptSTKCallback func(net.Addr, *Cookie) bool nullAEAD crypto.AEAD secureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD receivedForwardSecurePacket bool - sentSHLO bool receivedSecurePacket bool - aeadChanged chan<- protocol.EncryptionLevel + sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written - keyDerivation KeyDerivationFunction + receivedParams bool + paramsChan chan<- TransportParameters + handshakeEvent chan<- struct{} + + keyDerivation QuicCryptoKeyDerivationFunction keyExchange KeyExchangeFunction cryptoStream io.ReadWriter - connectionParameters ConnectionParametersManager + params *TransportParameters - mutex sync.RWMutex + sni string // need to fill out the ConnectionState } var _ CryptoSetup = &cryptoSetupServer{} @@ -65,35 +69,36 @@ var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP exp // NewCryptoSetup creates a new CryptoSetup instance for a server func NewCryptoSetup( + cryptoStream io.ReadWriter, connID protocol.ConnectionID, remoteAddr net.Addr, version protocol.VersionNumber, scfg *ServerConfig, - cryptoStream io.ReadWriter, - connectionParametersManager ConnectionParametersManager, + params *TransportParameters, supportedVersions []protocol.VersionNumber, - acceptSTK func(net.Addr, *STK) bool, - aeadChanged chan<- protocol.EncryptionLevel, + acceptSTK func(net.Addr, *Cookie) bool, + paramsChan chan<- TransportParameters, + handshakeEvent chan<- struct{}, ) (CryptoSetup, error) { - stkGenerator, err := NewSTKGenerator() + nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) if err != nil { return nil, err } - return &cryptoSetupServer{ - connID: connID, - remoteAddr: remoteAddr, - version: version, - supportedVersions: supportedVersions, - scfg: scfg, - stkGenerator: stkGenerator, - keyDerivation: crypto.DeriveKeysAESGCM, - keyExchange: getEphermalKEX, - nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), - cryptoStream: cryptoStream, - connectionParameters: connectionParametersManager, - acceptSTKCallback: acceptSTK, - aeadChanged: aeadChanged, + cryptoStream: cryptoStream, + connID: connID, + remoteAddr: remoteAddr, + version: version, + supportedVersions: supportedVersions, + scfg: scfg, + keyDerivation: crypto.DeriveQuicCryptoAESKeys, + keyExchange: getEphermalKEX, + nullAEAD: nullAEAD, + params: params, + acceptSTKCallback: acceptSTK, + sentSHLO: make(chan struct{}), + paramsChan: paramsChan, + handshakeEvent: handshakeEvent, }, nil } @@ -136,6 +141,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] if sni == "" { return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required") } + h.sni = sni // prevent version downgrade attacks // see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples @@ -146,8 +152,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] if len(verSlice) != 4 { return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag") } - verTag := binary.LittleEndian.Uint32(verSlice) - ver := protocol.VersionTagToNumber(verTag) + ver := protocol.VersionNumber(binary.BigEndian.Uint32(verSlice)) // If the client's preferred version is not the version we are currently speaking, then the client went through a version negotiation. In this case, we need to make sure that we actually do not support this version and that it wasn't a downgrade attack. if ver != h.version && protocol.IsSupportedVersion(h.supportedVersions, ver) { return false, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") @@ -161,16 +166,27 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] return false, err } + params, err := readHelloMap(cryptoData) + if err != nil { + return false, err + } + // blocks until the session has received the parameters + if !h.receivedParams { + h.receivedParams = true + h.paramsChan <- *params + } + if !h.isInchoateCHLO(cryptoData, certUncompressed) { // We have a CHLO with a proper server config ID, do a 0-RTT handshake reply, err = h.handleCHLO(sni, chloData, cryptoData) if err != nil { return false, err } - _, err = h.cryptoStream.Write(reply) - if err != nil { + if _, err := h.cryptoStream.Write(reply); err != nil { return false, err } + h.handshakeEvent <- struct{}{} + close(h.sentSHLO) return true, nil } @@ -193,7 +209,9 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu if err == nil { if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client h.receivedForwardSecurePacket = true - close(h.aeadChanged) + // wait for the send on the handshakeEvent chan + <-h.sentSHLO + close(h.handshakeEvent) } return res, protocol.EncryptionForwardSecure, nil } @@ -222,18 +240,18 @@ func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) { h.mutex.RLock() defer h.mutex.RUnlock() if h.forwardSecureAEAD != nil { - return protocol.EncryptionForwardSecure, h.sealForwardSecure + return protocol.EncryptionForwardSecure, h.forwardSecureAEAD } - return protocol.EncryptionUnencrypted, h.sealUnencrypted + return protocol.EncryptionUnencrypted, h.nullAEAD } func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) { h.mutex.RLock() defer h.mutex.RUnlock() if h.secureAEAD != nil { - return protocol.EncryptionSecure, h.sealSecure + return protocol.EncryptionSecure, h.secureAEAD } - return protocol.EncryptionUnencrypted, h.sealUnencrypted + return protocol.EncryptionUnencrypted, h.nullAEAD } func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { @@ -242,33 +260,21 @@ func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.Encry switch encLevel { case protocol.EncryptionUnencrypted: - return h.sealUnencrypted, nil + return h.nullAEAD, nil case protocol.EncryptionSecure: if h.secureAEAD == nil { return nil, errors.New("CryptoSetupServer: no secureAEAD") } - return h.sealSecure, nil + return h.secureAEAD, nil case protocol.EncryptionForwardSecure: if h.forwardSecureAEAD == nil { return nil, errors.New("CryptoSetupServer: no forwardSecureAEAD") } - return h.sealForwardSecure, nil + return h.forwardSecureAEAD, nil } return nil, errors.New("CryptoSetupServer: no encryption level specified") } -func (h *cryptoSetupServer) sealUnencrypted(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return h.nullAEAD.Seal(dst, src, packetNumber, associatedData) -} - -func (h *cryptoSetupServer) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return h.secureAEAD.Seal(dst, src, packetNumber, associatedData) -} - -func (h *cryptoSetupServer) sealForwardSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData) -} - func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool { if _, ok := cryptoData[TagPUBS]; !ok { return true @@ -289,7 +295,7 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt } func (h *cryptoSetupServer) acceptSTK(token []byte) bool { - stk, err := h.stkGenerator.DecodeToken(token) + stk, err := h.scfg.cookieGenerator.DecodeToken(token) if err != nil { utils.Debugf("STK invalid: %s", err.Error()) return false @@ -298,11 +304,7 @@ func (h *cryptoSetupServer) acceptSTK(token []byte) bool { } func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) { - if len(chlo) < protocol.ClientHelloMinimumSize { - return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small") - } - - token, err := h.stkGenerator.NewToken(h.remoteAddr) + token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr) if err != nil { return nil, err } @@ -397,8 +399,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T if err != nil { return nil, err } - - h.aeadChanged <- protocol.EncryptionSecure + h.handshakeEvent <- struct{}{} // Generate a new curve instance to derive the forward secure key var fsNonce bytes.Buffer @@ -425,19 +426,11 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T return nil, err } - err = h.connectionParameters.SetFromMap(cryptoData) - if err != nil { - return nil, err - } - - replyMap, err := h.connectionParameters.GetHelloMap() - if err != nil { - return nil, err - } + replyMap := h.params.getHelloMap() // add crypto parameters verTag := &bytes.Buffer{} - for _, v := range h.supportedVersions { - utils.WriteUint32(verTag, protocol.VersionNumberToTag(v)) + for _, v := range protocol.GetGreasedVersions(h.supportedVersions) { + utils.BigEndian.WriteUint32(verTag, uint32(v)) } replyMap[TagPUBS] = ephermalKex.PublicKey() replyMap[TagSNO] = serverNonce @@ -451,9 +444,6 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T var reply bytes.Buffer message.Write(&reply) utils.Debugf("Sending %s", message) - - h.aeadChanged <- protocol.EncryptionForwardSecure - return reply.Bytes(), nil } @@ -466,6 +456,15 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) { panic("not needed for cryptoSetupServer") } +func (h *cryptoSetupServer) ConnectionState() ConnectionState { + h.mutex.Lock() + defer h.mutex.Unlock() + return ConnectionState{ + ServerName: h.sni, + HandshakeComplete: h.receivedForwardSecurePacket, + } +} + func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error { if len(nonce) != 32 { return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length") diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go new file mode 100644 index 000000000..54dfe1c03 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go @@ -0,0 +1,177 @@ +package handshake + +import ( + "errors" + "fmt" + "io" + "sync" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry +var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry") + +// KeyDerivationFunction is used for key derivation +type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) + +type cryptoSetupTLS struct { + mutex sync.RWMutex + + perspective protocol.Perspective + + keyDerivation KeyDerivationFunction + nullAEAD crypto.AEAD + aead crypto.AEAD + + tls MintTLS + cryptoStream *CryptoStreamConn + handshakeEvent chan<- struct{} +} + +// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server +func NewCryptoSetupTLSServer( + tls MintTLS, + cryptoStream *CryptoStreamConn, + nullAEAD crypto.AEAD, + handshakeEvent chan<- struct{}, + version protocol.VersionNumber, +) CryptoSetup { + return &cryptoSetupTLS{ + tls: tls, + cryptoStream: cryptoStream, + nullAEAD: nullAEAD, + perspective: protocol.PerspectiveServer, + keyDerivation: crypto.DeriveAESKeys, + handshakeEvent: handshakeEvent, + } +} + +// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client +func NewCryptoSetupTLSClient( + cryptoStream io.ReadWriter, + connID protocol.ConnectionID, + hostname string, + handshakeEvent chan<- struct{}, + tls MintTLS, + version protocol.VersionNumber, +) (CryptoSetup, error) { + nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) + if err != nil { + return nil, err + } + + return &cryptoSetupTLS{ + perspective: protocol.PerspectiveClient, + tls: tls, + nullAEAD: nullAEAD, + keyDerivation: crypto.DeriveAESKeys, + handshakeEvent: handshakeEvent, + }, nil +} + +func (h *cryptoSetupTLS) HandleCryptoStream() error { + if h.perspective == protocol.PerspectiveServer { + // mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer + // send out that data now + if _, err := h.cryptoStream.Flush(); err != nil { + return err + } + } + +handshakeLoop: + for { + if alert := h.tls.Handshake(); alert != mint.AlertNoAlert { + return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) + } + switch h.tls.State() { + case mint.StateClientStart: // this happens if a stateless retry is performed + return ErrCloseSessionForRetry + case mint.StateClientConnected, mint.StateServerConnected: + break handshakeLoop + } + } + + aead, err := h.keyDerivation(h.tls, h.perspective) + if err != nil { + return err + } + h.mutex.Lock() + h.aead = aead + h.mutex.Unlock() + + h.handshakeEvent <- struct{}{} + close(h.handshakeEvent) + return nil +} + +func (h *cryptoSetupTLS) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { + h.mutex.RLock() + defer h.mutex.RUnlock() + + if h.aead != nil { + data, err := h.aead.Open(dst, src, packetNumber, associatedData) + if err != nil { + return nil, protocol.EncryptionUnspecified, err + } + return data, protocol.EncryptionForwardSecure, nil + } + data, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData) + if err != nil { + return nil, protocol.EncryptionUnspecified, err + } + return data, protocol.EncryptionUnencrypted, nil +} + +func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) { + h.mutex.RLock() + defer h.mutex.RUnlock() + + if h.aead != nil { + return protocol.EncryptionForwardSecure, h.aead + } + return protocol.EncryptionUnencrypted, h.nullAEAD +} + +func (h *cryptoSetupTLS) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { + errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", encLevel.String()) + h.mutex.RLock() + defer h.mutex.RUnlock() + + switch encLevel { + case protocol.EncryptionUnencrypted: + return h.nullAEAD, nil + case protocol.EncryptionForwardSecure: + if h.aead == nil { + return nil, errNoSealer + } + return h.aead, nil + default: + return nil, errNoSealer + } +} + +func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) { + return protocol.EncryptionUnencrypted, h.nullAEAD +} + +func (h *cryptoSetupTLS) DiversificationNonce() []byte { + panic("diversification nonce not needed for TLS") +} + +func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) { + panic("diversification nonce not needed for TLS") +} + +func (h *cryptoSetupTLS) ConnectionState() ConnectionState { + h.mutex.Lock() + defer h.mutex.Unlock() + mintConnState := h.tls.ConnectionState() + return ConnectionState{ + // TODO: set the ServerName, once mint exports it + HandshakeComplete: h.aead != nil, + PeerCertificates: mintConnState.PeerCertificates, + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go new file mode 100644 index 000000000..03825c41b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go @@ -0,0 +1,101 @@ +package handshake + +import ( + "bytes" + "io" + "net" + "time" +) + +// The CryptoStreamConn is used as the net.Conn passed to mint. +// It has two operating modes: +// 1. It can read and write to bytes.Buffers. +// 2. It can use a quic.Stream for reading and writing. +// The buffer-mode is only used by the server, in order to statelessly handle retries. +type CryptoStreamConn struct { + remoteAddr net.Addr + + // the buffers are used before the session is initialized + readBuf bytes.Buffer + writeBuf bytes.Buffer + + // stream will be set once the session is initialized + stream io.ReadWriter +} + +var _ net.Conn = &CryptoStreamConn{} + +// NewCryptoStreamConn creates a new CryptoStreamConn +func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn { + return &CryptoStreamConn{remoteAddr: remoteAddr} +} + +func (c *CryptoStreamConn) Read(b []byte) (int, error) { + if c.stream != nil { + return c.stream.Read(b) + } + return c.readBuf.Read(b) +} + +// AddDataForReading adds data to the read buffer. +// This data will ONLY be read when the stream has not been set. +func (c *CryptoStreamConn) AddDataForReading(data []byte) { + c.readBuf.Write(data) +} + +func (c *CryptoStreamConn) Write(p []byte) (int, error) { + if c.stream != nil { + return c.stream.Write(p) + } + return c.writeBuf.Write(p) +} + +// GetDataForWriting returns all data currently in the write buffer, and resets this buffer. +func (c *CryptoStreamConn) GetDataForWriting() []byte { + defer c.writeBuf.Reset() + data := make([]byte, c.writeBuf.Len()) + copy(data, c.writeBuf.Bytes()) + return data +} + +// SetStream sets the stream. +// After setting the stream, the read and write buffer won't be used any more. +func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) { + c.stream = stream +} + +// Flush copies the contents of the write buffer to the stream +func (c *CryptoStreamConn) Flush() (int, error) { + n, err := io.Copy(c.stream, &c.writeBuf) + return int(n), err +} + +// Close is not implemented +func (c *CryptoStreamConn) Close() error { + return nil +} + +// LocalAddr is not implemented +func (c *CryptoStreamConn) LocalAddr() net.Addr { + return nil +} + +// RemoteAddr returns the remote address +func (c *CryptoStreamConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +// SetReadDeadline is not implemented +func (c *CryptoStreamConn) SetReadDeadline(time.Time) error { + return nil +} + +// SetWriteDeadline is not implemented +func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error { + return nil +} + +// SetDeadline is not implemented +func (c *CryptoStreamConn) SetDeadline(time.Time) error { + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/ephermal_cache.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go similarity index 92% rename from vendor/github.com/lucas-clemente/quic-go/handshake/ephermal_cache.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go index da6724f3d..3bccbef06 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/ephermal_cache.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go @@ -4,9 +4,9 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" ) var ( diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/handshake_message.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go similarity index 93% rename from vendor/github.com/lucas-clemente/quic-go/handshake/handshake_message.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go index 87c8b1d4b..c09db26a4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/handshake_message.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go @@ -7,8 +7,8 @@ import ( "io" "sort" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) @@ -72,9 +72,9 @@ func ParseHandshakeMessage(r io.Reader) (HandshakeMessage, error) { // Write writes a crypto message func (h HandshakeMessage) Write(b *bytes.Buffer) { data := h.Data - utils.WriteUint32(b, uint32(h.Tag)) - utils.WriteUint16(b, uint16(len(data))) - utils.WriteUint16(b, 0) + utils.LittleEndian.WriteUint32(b, uint32(h.Tag)) + utils.LittleEndian.WriteUint16(b, uint16(len(data))) + utils.LittleEndian.WriteUint16(b, 0) // Save current position in the buffer, so that we can update the index in-place later indexStart := b.Len() diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go new file mode 100644 index 000000000..0fd673313 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go @@ -0,0 +1,58 @@ +package handshake + +import ( + "crypto/x509" + "io" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// Sealer seals a packet +type Sealer interface { + Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte + Overhead() int +} + +// A TLSExtensionHandler sends and received the QUIC TLS extension. +// It provides the parameters sent by the peer on a channel. +type TLSExtensionHandler interface { + Send(mint.HandshakeType, *mint.ExtensionList) error + Receive(mint.HandshakeType, *mint.ExtensionList) error + GetPeerParams() <-chan TransportParameters +} + +// MintTLS combines some methods needed to interact with mint. +type MintTLS interface { + crypto.TLSExporter + + // additional methods + Handshake() mint.Alert + State() mint.State + ConnectionState() mint.ConnectionState + + SetCryptoStream(io.ReadWriter) +} + +// CryptoSetup is a crypto setup +type CryptoSetup interface { + Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) + HandleCryptoStream() error + // TODO: clean up this interface + DiversificationNonce() []byte // only needed for cryptoSetupServer + SetDiversificationNonce([]byte) // only needed for cryptoSetupClient + ConnectionState() ConnectionState + + GetSealer() (protocol.EncryptionLevel, Sealer) + GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) + GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) +} + +// ConnectionState records basic details about the QUIC connection. +// Warning: This API should not be considered stable and might change soon. +type ConnectionState struct { + HandshakeComplete bool // handshake is complete + ServerName string // server name requested by client, if any (server side only) + PeerCertificates []*x509.Certificate // certificate chain presented by remote peer +} diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/server_config.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go similarity index 88% rename from vendor/github.com/lucas-clemente/quic-go/handshake/server_config.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go index fce66efdf..2b7fba67b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/server_config.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go @@ -4,7 +4,7 @@ import ( "bytes" "crypto/rand" - "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/internal/crypto" ) // ServerConfig is a server config @@ -13,6 +13,7 @@ type ServerConfig struct { certChain crypto.CertChain ID []byte obit []byte + cookieGenerator *CookieGenerator } // NewServerConfig creates a new server config @@ -28,11 +29,18 @@ func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*Serve return nil, err } + cookieGenerator, err := NewCookieGenerator() + + if err != nil { + return nil, err + } + return &ServerConfig{ kex: kex, certChain: certChain, ID: id, obit: obit, + cookieGenerator: cookieGenerator, }, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/server_config_client.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go similarity index 98% rename from vendor/github.com/lucas-clemente/quic-go/handshake/server_config_client.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go index 420141957..eb042f6ff 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/server_config_client.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go @@ -7,7 +7,7 @@ import ( "math" "time" - "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" ) diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/tags.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tags.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/handshake/tags.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/tags.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go new file mode 100644 index 000000000..c6e8b35d3 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go @@ -0,0 +1,55 @@ +package handshake + +import ( + "github.com/bifurcation/mint" +) + +type transportParameterID uint16 + +const quicTLSExtensionType = 26 + +const ( + initialMaxStreamDataParameterID transportParameterID = 0x0 + initialMaxDataParameterID transportParameterID = 0x1 + initialMaxStreamIDBiDiParameterID transportParameterID = 0x2 + idleTimeoutParameterID transportParameterID = 0x3 + omitConnectionIDParameterID transportParameterID = 0x4 + maxPacketSizeParameterID transportParameterID = 0x5 + statelessResetTokenParameterID transportParameterID = 0x6 + initialMaxStreamIDUniParameterID transportParameterID = 0x8 +) + +type transportParameter struct { + Parameter transportParameterID + Value []byte `tls:"head=2"` +} + +type clientHelloTransportParameters struct { + InitialVersion uint32 // actually a protocol.VersionNumber + Parameters []transportParameter `tls:"head=2"` +} + +type encryptedExtensionsTransportParameters struct { + NegotiatedVersion uint32 // actually a protocol.VersionNumber + SupportedVersions []uint32 `tls:"head=1"` // actually a protocol.VersionNumber + Parameters []transportParameter `tls:"head=2"` +} + +type tlsExtensionBody struct { + data []byte +} + +var _ mint.ExtensionBody = &tlsExtensionBody{} + +func (e *tlsExtensionBody) Type() mint.ExtensionType { + return quicTLSExtensionType +} + +func (e *tlsExtensionBody) Marshal() ([]byte, error) { + return e.data, nil +} + +func (e *tlsExtensionBody) Unmarshal(data []byte) (int, error) { + e.data = data + return len(data), nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go new file mode 100644 index 000000000..20d2d06ba --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go @@ -0,0 +1,134 @@ +package handshake + +import ( + "errors" + "fmt" + "math" + + "github.com/lucas-clemente/quic-go/qerr" + + "github.com/bifurcation/mint" + "github.com/bifurcation/mint/syntax" + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type extensionHandlerClient struct { + ourParams *TransportParameters + paramsChan chan TransportParameters + + initialVersion protocol.VersionNumber + supportedVersions []protocol.VersionNumber + version protocol.VersionNumber +} + +var _ mint.AppExtensionHandler = &extensionHandlerClient{} +var _ TLSExtensionHandler = &extensionHandlerClient{} + +// NewExtensionHandlerClient creates a new extension handler for the client. +func NewExtensionHandlerClient( + params *TransportParameters, + initialVersion protocol.VersionNumber, + supportedVersions []protocol.VersionNumber, + version protocol.VersionNumber, +) TLSExtensionHandler { + paramsChan := make(chan TransportParameters, 1) + return &extensionHandlerClient{ + ourParams: params, + paramsChan: paramsChan, + initialVersion: initialVersion, + supportedVersions: supportedVersions, + version: version, + } +} + +func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.ExtensionList) error { + if hType != mint.HandshakeTypeClientHello { + return nil + } + + data, err := syntax.Marshal(clientHelloTransportParameters{ + InitialVersion: uint32(h.initialVersion), + Parameters: h.ourParams.getTransportParameters(), + }) + if err != nil { + return err + } + return el.Add(&tlsExtensionBody{data}) +} + +func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { + ext := &tlsExtensionBody{} + found, err := el.Find(ext) + if err != nil { + return err + } + + if hType != mint.HandshakeTypeEncryptedExtensions && hType != mint.HandshakeTypeNewSessionTicket { + if found { + return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType) + } + return nil + } + if hType == mint.HandshakeTypeNewSessionTicket { + // the extension it's optional in the NewSessionTicket message + // TODO: handle this + return nil + } + + // hType == mint.HandshakeTypeEncryptedExtensions + if !found { + return errors.New("EncryptedExtensions message didn't contain a QUIC extension") + } + + eetp := &encryptedExtensionsTransportParameters{} + if _, err := syntax.Unmarshal(ext.data, eetp); err != nil { + return err + } + serverSupportedVersions := make([]protocol.VersionNumber, len(eetp.SupportedVersions)) + for i, v := range eetp.SupportedVersions { + serverSupportedVersions[i] = protocol.VersionNumber(v) + } + // check that the negotiated_version is the current version + if protocol.VersionNumber(eetp.NegotiatedVersion) != h.version { + return qerr.Error(qerr.VersionNegotiationMismatch, "current version doesn't match negotiated_version") + } + // check that the current version is included in the supported versions + if !protocol.IsSupportedVersion(serverSupportedVersions, h.version) { + return qerr.Error(qerr.VersionNegotiationMismatch, "current version not included in the supported versions") + } + // if version negotiation was performed, check that we would have selected the current version based on the supported versions sent by the server + if h.version != h.initialVersion { + negotiatedVersion, ok := protocol.ChooseSupportedVersion(h.supportedVersions, serverSupportedVersions) + if !ok || h.version != negotiatedVersion { + return qerr.Error(qerr.VersionNegotiationMismatch, "would have picked a different version") + } + } + + // check that the server sent the stateless reset token + var foundStatelessResetToken bool + for _, p := range eetp.Parameters { + if p.Parameter == statelessResetTokenParameterID { + if len(p.Value) != 16 { + return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", len(p.Value)) + } + foundStatelessResetToken = true + // TODO: handle this value + } + } + if !foundStatelessResetToken { + // TODO: return the right error here + return errors.New("server didn't sent stateless_reset_token") + } + params, err := readTransportParamters(eetp.Parameters) + if err != nil { + return err + } + // TODO(#878): remove this when implementing the MAX_STREAM_ID frame + params.MaxStreams = math.MaxUint32 + h.paramsChan <- *params + return nil +} + +func (h *extensionHandlerClient) GetPeerParams() <-chan TransportParameters { + return h.paramsChan +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go new file mode 100644 index 000000000..3e7e2705f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go @@ -0,0 +1,113 @@ +package handshake + +import ( + "bytes" + "errors" + "fmt" + + "github.com/lucas-clemente/quic-go/qerr" + + "github.com/bifurcation/mint" + "github.com/bifurcation/mint/syntax" + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type extensionHandlerServer struct { + ourParams *TransportParameters + paramsChan chan TransportParameters + + version protocol.VersionNumber + supportedVersions []protocol.VersionNumber +} + +var _ mint.AppExtensionHandler = &extensionHandlerServer{} +var _ TLSExtensionHandler = &extensionHandlerServer{} + +// NewExtensionHandlerServer creates a new extension handler for the server +func NewExtensionHandlerServer( + params *TransportParameters, + supportedVersions []protocol.VersionNumber, + version protocol.VersionNumber, +) TLSExtensionHandler { + paramsChan := make(chan TransportParameters, 1) + return &extensionHandlerServer{ + ourParams: params, + paramsChan: paramsChan, + supportedVersions: supportedVersions, + version: version, + } +} + +func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.ExtensionList) error { + if hType != mint.HandshakeTypeEncryptedExtensions { + return nil + } + + transportParams := append( + h.ourParams.getTransportParameters(), + // TODO(#855): generate a real token + transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)}, + ) + supportedVersions := protocol.GetGreasedVersions(h.supportedVersions) + versions := make([]uint32, len(supportedVersions)) + for i, v := range supportedVersions { + versions[i] = uint32(v) + } + data, err := syntax.Marshal(encryptedExtensionsTransportParameters{ + NegotiatedVersion: uint32(h.version), + SupportedVersions: versions, + Parameters: transportParams, + }) + if err != nil { + return err + } + return el.Add(&tlsExtensionBody{data}) +} + +func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { + ext := &tlsExtensionBody{} + found, err := el.Find(ext) + if err != nil { + return err + } + + if hType != mint.HandshakeTypeClientHello { + if found { + return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType) + } + return nil + } + + if !found { + return errors.New("ClientHello didn't contain a QUIC extension") + } + chtp := &clientHelloTransportParameters{} + if _, err := syntax.Unmarshal(ext.data, chtp); err != nil { + return err + } + initialVersion := protocol.VersionNumber(chtp.InitialVersion) + + // perform the stateless version negotiation validation: + // make sure that we would have sent a Version Negotiation Packet if the client offered the initial version + // this is the case if and only if the initial version is not contained in the supported versions + if initialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, initialVersion) { + return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version") + } + + for _, p := range chtp.Parameters { + if p.Parameter == statelessResetTokenParameterID { + // TODO: return the correct error type + return errors.New("client sent a stateless reset token") + } + } + params, err := readTransportParamters(chtp.Parameters) + if err != nil { + return err + } + h.paramsChan <- *params + return nil +} + +func (h *extensionHandlerServer) GetPeerParams() <-chan TransportParameters { + return h.paramsChan +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go new file mode 100644 index 000000000..a02835823 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go @@ -0,0 +1,176 @@ +package handshake + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +// errMalformedTag is returned when the tag value cannot be read +var errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value") + +// TransportParameters are parameters sent to the peer during the handshake +type TransportParameters struct { + StreamFlowControlWindow protocol.ByteCount + ConnectionFlowControlWindow protocol.ByteCount + + MaxBidiStreamID protocol.StreamID // only used for IETF QUIC + MaxUniStreamID protocol.StreamID // only used for IETF QUIC + MaxStreams uint32 // only used for gQUIC + + OmitConnectionID bool + IdleTimeout time.Duration +} + +// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message +func readHelloMap(tags map[Tag][]byte) (*TransportParameters, error) { + params := &TransportParameters{} + if value, ok := tags[TagTCID]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.OmitConnectionID = (v == 0) + } + if value, ok := tags[TagMIDS]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.MaxStreams = v + } + if value, ok := tags[TagICSL]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(v)*time.Second) + } + if value, ok := tags[TagSFCW]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.StreamFlowControlWindow = protocol.ByteCount(v) + } + if value, ok := tags[TagCFCW]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.ConnectionFlowControlWindow = protocol.ByteCount(v) + } + return params, nil +} + +// GetHelloMap gets all parameters needed for the Hello message in the gQUIC handshake. +func (p *TransportParameters) getHelloMap() map[Tag][]byte { + sfcw := bytes.NewBuffer([]byte{}) + utils.LittleEndian.WriteUint32(sfcw, uint32(p.StreamFlowControlWindow)) + cfcw := bytes.NewBuffer([]byte{}) + utils.LittleEndian.WriteUint32(cfcw, uint32(p.ConnectionFlowControlWindow)) + mids := bytes.NewBuffer([]byte{}) + utils.LittleEndian.WriteUint32(mids, p.MaxStreams) + icsl := bytes.NewBuffer([]byte{}) + utils.LittleEndian.WriteUint32(icsl, uint32(p.IdleTimeout/time.Second)) + + tags := map[Tag][]byte{ + TagICSL: icsl.Bytes(), + TagMIDS: mids.Bytes(), + TagCFCW: cfcw.Bytes(), + TagSFCW: sfcw.Bytes(), + } + if p.OmitConnectionID { + tags[TagTCID] = []byte{0, 0, 0, 0} + } + return tags +} + +// readTransportParameters reads the transport parameters sent in the QUIC TLS extension +func readTransportParamters(paramsList []transportParameter) (*TransportParameters, error) { + params := &TransportParameters{} + + var foundInitialMaxStreamData bool + var foundInitialMaxData bool + var foundIdleTimeout bool + + for _, p := range paramsList { + switch p.Parameter { + case initialMaxStreamDataParameterID: + foundInitialMaxStreamData = true + if len(p.Value) != 4 { + return nil, fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", len(p.Value)) + } + params.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value)) + case initialMaxDataParameterID: + foundInitialMaxData = true + if len(p.Value) != 4 { + return nil, fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value)) + } + params.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value)) + case initialMaxStreamIDBiDiParameterID: + if len(p.Value) != 4 { + return nil, fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 4)", len(p.Value)) + } + // TODO(#1154): validate the stream ID + params.MaxBidiStreamID = protocol.StreamID(binary.BigEndian.Uint32(p.Value)) + case initialMaxStreamIDUniParameterID: + if len(p.Value) != 4 { + return nil, fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 4)", len(p.Value)) + } + // TODO(#1154): validate the stream ID + params.MaxUniStreamID = protocol.StreamID(binary.BigEndian.Uint32(p.Value)) + case idleTimeoutParameterID: + foundIdleTimeout = true + if len(p.Value) != 2 { + return nil, fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value)) + } + params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(p.Value))*time.Second) + case omitConnectionIDParameterID: + if len(p.Value) != 0 { + return nil, fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value)) + } + params.OmitConnectionID = true + } + } + + if !(foundInitialMaxStreamData && foundInitialMaxData && foundIdleTimeout) { + return nil, errors.New("missing parameter") + } + return params, nil +} + +// GetTransportParameters gets the parameters needed for the TLS handshake. +// It doesn't send the initial_max_stream_id_uni parameter, so the peer isn't allowed to open any unidirectional streams. +func (p *TransportParameters) getTransportParameters() []transportParameter { + initialMaxStreamData := make([]byte, 4) + binary.BigEndian.PutUint32(initialMaxStreamData, uint32(p.StreamFlowControlWindow)) + initialMaxData := make([]byte, 4) + binary.BigEndian.PutUint32(initialMaxData, uint32(p.ConnectionFlowControlWindow)) + initialMaxBidiStreamID := make([]byte, 4) + binary.BigEndian.PutUint32(initialMaxBidiStreamID, uint32(p.MaxBidiStreamID)) + initialMaxUniStreamID := make([]byte, 4) + binary.BigEndian.PutUint32(initialMaxUniStreamID, uint32(p.MaxUniStreamID)) + idleTimeout := make([]byte, 2) + binary.BigEndian.PutUint16(idleTimeout, uint16(p.IdleTimeout/time.Second)) + maxPacketSize := make([]byte, 2) + binary.BigEndian.PutUint16(maxPacketSize, uint16(protocol.MaxReceivePacketSize)) + params := []transportParameter{ + {initialMaxStreamDataParameterID, initialMaxStreamData}, + {initialMaxDataParameterID, initialMaxData}, + {initialMaxStreamIDBiDiParameterID, initialMaxBidiStreamID}, + {initialMaxStreamIDUniParameterID, initialMaxUniStreamID}, + {idleTimeoutParameterID, idleTimeout}, + {maxPacketSizeParameterID, maxPacketSize}, + } + if p.OmitConnectionID { + params = append(params, transportParameter{omitConnectionIDParameterID, []byte{}}) + } + return params +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/ackhandler/received_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/ackhandler/received_packet_handler.go new file mode 100644 index 000000000..d3e864f64 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/ackhandler/received_packet_handler.go @@ -0,0 +1,83 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/ackhandler (interfaces: ReceivedPacketHandler) + +// Package mockackhandler is a generated GoMock package. +package mockackhandler + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockReceivedPacketHandler is a mock of ReceivedPacketHandler interface +type MockReceivedPacketHandler struct { + ctrl *gomock.Controller + recorder *MockReceivedPacketHandlerMockRecorder +} + +// MockReceivedPacketHandlerMockRecorder is the mock recorder for MockReceivedPacketHandler +type MockReceivedPacketHandlerMockRecorder struct { + mock *MockReceivedPacketHandler +} + +// NewMockReceivedPacketHandler creates a new mock instance +func NewMockReceivedPacketHandler(ctrl *gomock.Controller) *MockReceivedPacketHandler { + mock := &MockReceivedPacketHandler{ctrl: ctrl} + mock.recorder = &MockReceivedPacketHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockReceivedPacketHandler) EXPECT() *MockReceivedPacketHandlerMockRecorder { + return m.recorder +} + +// GetAckFrame mocks base method +func (m *MockReceivedPacketHandler) GetAckFrame() *wire.AckFrame { + ret := m.ctrl.Call(m, "GetAckFrame") + ret0, _ := ret[0].(*wire.AckFrame) + return ret0 +} + +// GetAckFrame indicates an expected call of GetAckFrame +func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame)) +} + +// GetAlarmTimeout mocks base method +func (m *MockReceivedPacketHandler) GetAlarmTimeout() time.Time { + ret := m.ctrl.Call(m, "GetAlarmTimeout") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// GetAlarmTimeout indicates an expected call of GetAlarmTimeout +func (mr *MockReceivedPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAlarmTimeout)) +} + +// IgnoreBelow mocks base method +func (m *MockReceivedPacketHandler) IgnoreBelow(arg0 protocol.PacketNumber) { + m.ctrl.Call(m, "IgnoreBelow", arg0) +} + +// IgnoreBelow indicates an expected call of IgnoreBelow +func (mr *MockReceivedPacketHandlerMockRecorder) IgnoreBelow(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IgnoreBelow", reflect.TypeOf((*MockReceivedPacketHandler)(nil).IgnoreBelow), arg0) +} + +// ReceivedPacket mocks base method +func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 time.Time, arg2 bool) error { + ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReceivedPacket indicates an expected call of ReceivedPacket +func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/ackhandler/sent_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/ackhandler/sent_packet_handler.go new file mode 100644 index 000000000..b28a970ad --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/ackhandler/sent_packet_handler.go @@ -0,0 +1,178 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/ackhandler (interfaces: SentPacketHandler) + +// Package mockackhandler is a generated GoMock package. +package mockackhandler + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + ackhandler "github.com/lucas-clemente/quic-go/internal/ackhandler" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockSentPacketHandler is a mock of SentPacketHandler interface +type MockSentPacketHandler struct { + ctrl *gomock.Controller + recorder *MockSentPacketHandlerMockRecorder +} + +// MockSentPacketHandlerMockRecorder is the mock recorder for MockSentPacketHandler +type MockSentPacketHandlerMockRecorder struct { + mock *MockSentPacketHandler +} + +// NewMockSentPacketHandler creates a new mock instance +func NewMockSentPacketHandler(ctrl *gomock.Controller) *MockSentPacketHandler { + mock := &MockSentPacketHandler{ctrl: ctrl} + mock.recorder = &MockSentPacketHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder { + return m.recorder +} + +// DequeuePacketForRetransmission mocks base method +func (m *MockSentPacketHandler) DequeuePacketForRetransmission() *ackhandler.Packet { + ret := m.ctrl.Call(m, "DequeuePacketForRetransmission") + ret0, _ := ret[0].(*ackhandler.Packet) + return ret0 +} + +// DequeuePacketForRetransmission indicates an expected call of DequeuePacketForRetransmission +func (mr *MockSentPacketHandlerMockRecorder) DequeuePacketForRetransmission() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeuePacketForRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeuePacketForRetransmission)) +} + +// GetAlarmTimeout mocks base method +func (m *MockSentPacketHandler) GetAlarmTimeout() time.Time { + ret := m.ctrl.Call(m, "GetAlarmTimeout") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// GetAlarmTimeout indicates an expected call of GetAlarmTimeout +func (mr *MockSentPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetAlarmTimeout)) +} + +// GetLeastUnacked mocks base method +func (m *MockSentPacketHandler) GetLeastUnacked() protocol.PacketNumber { + ret := m.ctrl.Call(m, "GetLeastUnacked") + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// GetLeastUnacked indicates an expected call of GetLeastUnacked +func (mr *MockSentPacketHandlerMockRecorder) GetLeastUnacked() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeastUnacked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLeastUnacked)) +} + +// GetLowestPacketNotConfirmedAcked mocks base method +func (m *MockSentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { + ret := m.ctrl.Call(m, "GetLowestPacketNotConfirmedAcked") + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// GetLowestPacketNotConfirmedAcked indicates an expected call of GetLowestPacketNotConfirmedAcked +func (mr *MockSentPacketHandlerMockRecorder) GetLowestPacketNotConfirmedAcked() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLowestPacketNotConfirmedAcked)) +} + +// GetStopWaitingFrame mocks base method +func (m *MockSentPacketHandler) GetStopWaitingFrame(arg0 bool) *wire.StopWaitingFrame { + ret := m.ctrl.Call(m, "GetStopWaitingFrame", arg0) + ret0, _ := ret[0].(*wire.StopWaitingFrame) + return ret0 +} + +// GetStopWaitingFrame indicates an expected call of GetStopWaitingFrame +func (mr *MockSentPacketHandlerMockRecorder) GetStopWaitingFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStopWaitingFrame", reflect.TypeOf((*MockSentPacketHandler)(nil).GetStopWaitingFrame), arg0) +} + +// OnAlarm mocks base method +func (m *MockSentPacketHandler) OnAlarm() { + m.ctrl.Call(m, "OnAlarm") +} + +// OnAlarm indicates an expected call of OnAlarm +func (mr *MockSentPacketHandlerMockRecorder) OnAlarm() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAlarm", reflect.TypeOf((*MockSentPacketHandler)(nil).OnAlarm)) +} + +// ReceivedAck mocks base method +func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.PacketNumber, arg2 protocol.EncryptionLevel, arg3 time.Time) error { + ret := m.ctrl.Call(m, "ReceivedAck", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReceivedAck indicates an expected call of ReceivedAck +func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2, arg3) +} + +// SendingAllowed mocks base method +func (m *MockSentPacketHandler) SendingAllowed() bool { + ret := m.ctrl.Call(m, "SendingAllowed") + ret0, _ := ret[0].(bool) + return ret0 +} + +// SendingAllowed indicates an expected call of SendingAllowed +func (mr *MockSentPacketHandlerMockRecorder) SendingAllowed() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendingAllowed", reflect.TypeOf((*MockSentPacketHandler)(nil).SendingAllowed)) +} + +// SentPacket mocks base method +func (m *MockSentPacketHandler) SentPacket(arg0 *ackhandler.Packet) error { + ret := m.ctrl.Call(m, "SentPacket", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SentPacket indicates an expected call of SentPacket +func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0) +} + +// SetHandshakeComplete mocks base method +func (m *MockSentPacketHandler) SetHandshakeComplete() { + m.ctrl.Call(m, "SetHandshakeComplete") +} + +// SetHandshakeComplete indicates an expected call of SetHandshakeComplete +func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeComplete() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeComplete", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeComplete)) +} + +// ShouldSendNumPackets mocks base method +func (m *MockSentPacketHandler) ShouldSendNumPackets() int { + ret := m.ctrl.Call(m, "ShouldSendNumPackets") + ret0, _ := ret[0].(int) + return ret0 +} + +// ShouldSendNumPackets indicates an expected call of ShouldSendNumPackets +func (mr *MockSentPacketHandlerMockRecorder) ShouldSendNumPackets() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendNumPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).ShouldSendNumPackets)) +} + +// TimeUntilSend mocks base method +func (m *MockSentPacketHandler) TimeUntilSend() time.Time { + ret := m.ctrl.Call(m, "TimeUntilSend") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// TimeUntilSend indicates an expected call of TimeUntilSend +func (mr *MockSentPacketHandlerMockRecorder) TimeUntilSend() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSentPacketHandler)(nil).TimeUntilSend)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/congestion.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/congestion.go new file mode 100644 index 000000000..fec7b6c14 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/congestion.go @@ -0,0 +1,154 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/congestion (interfaces: SendAlgorithm) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockSendAlgorithm is a mock of SendAlgorithm interface +type MockSendAlgorithm struct { + ctrl *gomock.Controller + recorder *MockSendAlgorithmMockRecorder +} + +// MockSendAlgorithmMockRecorder is the mock recorder for MockSendAlgorithm +type MockSendAlgorithmMockRecorder struct { + mock *MockSendAlgorithm +} + +// NewMockSendAlgorithm creates a new mock instance +func NewMockSendAlgorithm(ctrl *gomock.Controller) *MockSendAlgorithm { + mock := &MockSendAlgorithm{ctrl: ctrl} + mock.recorder = &MockSendAlgorithmMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockSendAlgorithm) EXPECT() *MockSendAlgorithmMockRecorder { + return m.recorder +} + +// GetCongestionWindow mocks base method +func (m *MockSendAlgorithm) GetCongestionWindow() protocol.ByteCount { + ret := m.ctrl.Call(m, "GetCongestionWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetCongestionWindow indicates an expected call of GetCongestionWindow +func (mr *MockSendAlgorithmMockRecorder) GetCongestionWindow() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCongestionWindow", reflect.TypeOf((*MockSendAlgorithm)(nil).GetCongestionWindow)) +} + +// MaybeExitSlowStart mocks base method +func (m *MockSendAlgorithm) MaybeExitSlowStart() { + m.ctrl.Call(m, "MaybeExitSlowStart") +} + +// MaybeExitSlowStart indicates an expected call of MaybeExitSlowStart +func (mr *MockSendAlgorithmMockRecorder) MaybeExitSlowStart() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeExitSlowStart", reflect.TypeOf((*MockSendAlgorithm)(nil).MaybeExitSlowStart)) +} + +// OnConnectionMigration mocks base method +func (m *MockSendAlgorithm) OnConnectionMigration() { + m.ctrl.Call(m, "OnConnectionMigration") +} + +// OnConnectionMigration indicates an expected call of OnConnectionMigration +func (mr *MockSendAlgorithmMockRecorder) OnConnectionMigration() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnConnectionMigration", reflect.TypeOf((*MockSendAlgorithm)(nil).OnConnectionMigration)) +} + +// OnPacketAcked mocks base method +func (m *MockSendAlgorithm) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) { + m.ctrl.Call(m, "OnPacketAcked", arg0, arg1, arg2) +} + +// OnPacketAcked indicates an expected call of OnPacketAcked +func (mr *MockSendAlgorithmMockRecorder) OnPacketAcked(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithm)(nil).OnPacketAcked), arg0, arg1, arg2) +} + +// OnPacketLost mocks base method +func (m *MockSendAlgorithm) OnPacketLost(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) { + m.ctrl.Call(m, "OnPacketLost", arg0, arg1, arg2) +} + +// OnPacketLost indicates an expected call of OnPacketLost +func (mr *MockSendAlgorithmMockRecorder) OnPacketLost(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketLost", reflect.TypeOf((*MockSendAlgorithm)(nil).OnPacketLost), arg0, arg1, arg2) +} + +// OnPacketSent mocks base method +func (m *MockSendAlgorithm) OnPacketSent(arg0 time.Time, arg1 protocol.ByteCount, arg2 protocol.PacketNumber, arg3 protocol.ByteCount, arg4 bool) bool { + ret := m.ctrl.Call(m, "OnPacketSent", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(bool) + return ret0 +} + +// OnPacketSent indicates an expected call of OnPacketSent +func (mr *MockSendAlgorithmMockRecorder) OnPacketSent(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketSent", reflect.TypeOf((*MockSendAlgorithm)(nil).OnPacketSent), arg0, arg1, arg2, arg3, arg4) +} + +// OnRetransmissionTimeout mocks base method +func (m *MockSendAlgorithm) OnRetransmissionTimeout(arg0 bool) { + m.ctrl.Call(m, "OnRetransmissionTimeout", arg0) +} + +// OnRetransmissionTimeout indicates an expected call of OnRetransmissionTimeout +func (mr *MockSendAlgorithmMockRecorder) OnRetransmissionTimeout(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRetransmissionTimeout", reflect.TypeOf((*MockSendAlgorithm)(nil).OnRetransmissionTimeout), arg0) +} + +// RetransmissionDelay mocks base method +func (m *MockSendAlgorithm) RetransmissionDelay() time.Duration { + ret := m.ctrl.Call(m, "RetransmissionDelay") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// RetransmissionDelay indicates an expected call of RetransmissionDelay +func (mr *MockSendAlgorithmMockRecorder) RetransmissionDelay() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetransmissionDelay", reflect.TypeOf((*MockSendAlgorithm)(nil).RetransmissionDelay)) +} + +// SetNumEmulatedConnections mocks base method +func (m *MockSendAlgorithm) SetNumEmulatedConnections(arg0 int) { + m.ctrl.Call(m, "SetNumEmulatedConnections", arg0) +} + +// SetNumEmulatedConnections indicates an expected call of SetNumEmulatedConnections +func (mr *MockSendAlgorithmMockRecorder) SetNumEmulatedConnections(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNumEmulatedConnections", reflect.TypeOf((*MockSendAlgorithm)(nil).SetNumEmulatedConnections), arg0) +} + +// SetSlowStartLargeReduction mocks base method +func (m *MockSendAlgorithm) SetSlowStartLargeReduction(arg0 bool) { + m.ctrl.Call(m, "SetSlowStartLargeReduction", arg0) +} + +// SetSlowStartLargeReduction indicates an expected call of SetSlowStartLargeReduction +func (mr *MockSendAlgorithmMockRecorder) SetSlowStartLargeReduction(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSlowStartLargeReduction", reflect.TypeOf((*MockSendAlgorithm)(nil).SetSlowStartLargeReduction), arg0) +} + +// TimeUntilSend mocks base method +func (m *MockSendAlgorithm) TimeUntilSend(arg0 protocol.ByteCount) time.Duration { + ret := m.ctrl.Call(m, "TimeUntilSend", arg0) + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// TimeUntilSend indicates an expected call of TimeUntilSend +func (mr *MockSendAlgorithmMockRecorder) TimeUntilSend(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSendAlgorithm)(nil).TimeUntilSend), arg0) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/connection_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/connection_flow_controller.go new file mode 100644 index 000000000..ae10e785f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/connection_flow_controller.go @@ -0,0 +1,102 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/flowcontrol (interfaces: ConnectionFlowController) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockConnectionFlowController is a mock of ConnectionFlowController interface +type MockConnectionFlowController struct { + ctrl *gomock.Controller + recorder *MockConnectionFlowControllerMockRecorder +} + +// MockConnectionFlowControllerMockRecorder is the mock recorder for MockConnectionFlowController +type MockConnectionFlowControllerMockRecorder struct { + mock *MockConnectionFlowController +} + +// NewMockConnectionFlowController creates a new mock instance +func NewMockConnectionFlowController(ctrl *gomock.Controller) *MockConnectionFlowController { + mock := &MockConnectionFlowController{ctrl: ctrl} + mock.recorder = &MockConnectionFlowControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockConnectionFlowController) EXPECT() *MockConnectionFlowControllerMockRecorder { + return m.recorder +} + +// AddBytesRead mocks base method +func (m *MockConnectionFlowController) AddBytesRead(arg0 protocol.ByteCount) { + m.ctrl.Call(m, "AddBytesRead", arg0) +} + +// AddBytesRead indicates an expected call of AddBytesRead +func (mr *MockConnectionFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesRead), arg0) +} + +// AddBytesSent mocks base method +func (m *MockConnectionFlowController) AddBytesSent(arg0 protocol.ByteCount) { + m.ctrl.Call(m, "AddBytesSent", arg0) +} + +// AddBytesSent indicates an expected call of AddBytesSent +func (mr *MockConnectionFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesSent), arg0) +} + +// GetWindowUpdate mocks base method +func (m *MockConnectionFlowController) GetWindowUpdate() protocol.ByteCount { + ret := m.ctrl.Call(m, "GetWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetWindowUpdate indicates an expected call of GetWindowUpdate +func (mr *MockConnectionFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockConnectionFlowController)(nil).GetWindowUpdate)) +} + +// IsNewlyBlocked mocks base method +func (m *MockConnectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { + ret := m.ctrl.Call(m, "IsNewlyBlocked") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(protocol.ByteCount) + return ret0, ret1 +} + +// IsNewlyBlocked indicates an expected call of IsNewlyBlocked +func (mr *MockConnectionFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockConnectionFlowController)(nil).IsNewlyBlocked)) +} + +// SendWindowSize mocks base method +func (m *MockConnectionFlowController) SendWindowSize() protocol.ByteCount { + ret := m.ctrl.Call(m, "SendWindowSize") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// SendWindowSize indicates an expected call of SendWindowSize +func (mr *MockConnectionFlowControllerMockRecorder) SendWindowSize() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockConnectionFlowController)(nil).SendWindowSize)) +} + +// UpdateSendWindow mocks base method +func (m *MockConnectionFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { + m.ctrl.Call(m, "UpdateSendWindow", arg0) +} + +// UpdateSendWindow indicates an expected call of UpdateSendWindow +func (mr *MockConnectionFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockConnectionFlowController)(nil).UpdateSendWindow), arg0) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/cpm.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/cpm.go deleted file mode 100644 index 686928f63..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/cpm.go +++ /dev/null @@ -1,153 +0,0 @@ -// Automatically generated by MockGen. DO NOT EDIT! -// Source: github.com/lucas-clemente/quic-go/handshake (interfaces: ConnectionParametersManager) - -package mocks - -import ( - gomock "github.com/golang/mock/gomock" - handshake "github.com/lucas-clemente/quic-go/handshake" - protocol "github.com/lucas-clemente/quic-go/protocol" - time "time" -) - -// Mock of ConnectionParametersManager interface -type MockConnectionParametersManager struct { - ctrl *gomock.Controller - recorder *_MockConnectionParametersManagerRecorder -} - -// Recorder for MockConnectionParametersManager (not exported) -type _MockConnectionParametersManagerRecorder struct { - mock *MockConnectionParametersManager -} - -func NewMockConnectionParametersManager(ctrl *gomock.Controller) *MockConnectionParametersManager { - mock := &MockConnectionParametersManager{ctrl: ctrl} - mock.recorder = &_MockConnectionParametersManagerRecorder{mock} - return mock -} - -func (_m *MockConnectionParametersManager) EXPECT() *_MockConnectionParametersManagerRecorder { - return _m.recorder -} - -func (_m *MockConnectionParametersManager) GetHelloMap() (map[handshake.Tag][]byte, error) { - ret := _m.ctrl.Call(_m, "GetHelloMap") - ret0, _ := ret[0].(map[handshake.Tag][]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -func (_mr *_MockConnectionParametersManagerRecorder) GetHelloMap() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetHelloMap") -} - -func (_m *MockConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration { - ret := _m.ctrl.Call(_m, "GetIdleConnectionStateLifetime") - ret0, _ := ret[0].(time.Duration) - return ret0 -} - -func (_mr *_MockConnectionParametersManagerRecorder) GetIdleConnectionStateLifetime() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetIdleConnectionStateLifetime") -} - -func (_m *MockConnectionParametersManager) GetMaxIncomingStreams() uint32 { - ret := _m.ctrl.Call(_m, "GetMaxIncomingStreams") - ret0, _ := ret[0].(uint32) - return ret0 -} - -func (_mr *_MockConnectionParametersManagerRecorder) GetMaxIncomingStreams() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxIncomingStreams") -} - -func (_m *MockConnectionParametersManager) GetMaxOutgoingStreams() uint32 { - ret := _m.ctrl.Call(_m, "GetMaxOutgoingStreams") - ret0, _ := ret[0].(uint32) - return ret0 -} - -func (_mr *_MockConnectionParametersManagerRecorder) GetMaxOutgoingStreams() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxOutgoingStreams") -} - -func (_m *MockConnectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "GetMaxReceiveConnectionFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -func (_mr *_MockConnectionParametersManagerRecorder) GetMaxReceiveConnectionFlowControlWindow() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxReceiveConnectionFlowControlWindow") -} - -func (_m *MockConnectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "GetMaxReceiveStreamFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -func (_mr *_MockConnectionParametersManagerRecorder) GetMaxReceiveStreamFlowControlWindow() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxReceiveStreamFlowControlWindow") -} - -func (_m *MockConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "GetReceiveConnectionFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -func (_mr *_MockConnectionParametersManagerRecorder) GetReceiveConnectionFlowControlWindow() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveConnectionFlowControlWindow") -} - -func (_m *MockConnectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "GetReceiveStreamFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -func (_mr *_MockConnectionParametersManagerRecorder) GetReceiveStreamFlowControlWindow() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveStreamFlowControlWindow") -} - -func (_m *MockConnectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "GetSendConnectionFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -func (_mr *_MockConnectionParametersManagerRecorder) GetSendConnectionFlowControlWindow() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetSendConnectionFlowControlWindow") -} - -func (_m *MockConnectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "GetSendStreamFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -func (_mr *_MockConnectionParametersManagerRecorder) GetSendStreamFlowControlWindow() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetSendStreamFlowControlWindow") -} - -func (_m *MockConnectionParametersManager) SetFromMap(_param0 map[handshake.Tag][]byte) error { - ret := _m.ctrl.Call(_m, "SetFromMap", _param0) - ret0, _ := ret[0].(error) - return ret0 -} - -func (_mr *_MockConnectionParametersManagerRecorder) SetFromMap(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "SetFromMap", arg0) -} - -func (_m *MockConnectionParametersManager) TruncateConnectionID() bool { - ret := _m.ctrl.Call(_m, "TruncateConnectionID") - ret0, _ := ret[0].(bool) - return ret0 -} - -func (_mr *_MockConnectionParametersManagerRecorder) TruncateConnectionID() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "TruncateConnectionID") -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/crypto/aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/crypto/aead.go new file mode 100644 index 000000000..324a00cee --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/crypto/aead.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/crypto (interfaces: AEAD) + +// Package mockcrypto is a generated GoMock package. +package mockcrypto + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockAEAD is a mock of AEAD interface +type MockAEAD struct { + ctrl *gomock.Controller + recorder *MockAEADMockRecorder +} + +// MockAEADMockRecorder is the mock recorder for MockAEAD +type MockAEADMockRecorder struct { + mock *MockAEAD +} + +// NewMockAEAD creates a new mock instance +func NewMockAEAD(ctrl *gomock.Controller) *MockAEAD { + mock := &MockAEAD{ctrl: ctrl} + mock.recorder = &MockAEADMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockAEAD) EXPECT() *MockAEADMockRecorder { + return m.recorder +} + +// Open mocks base method +func (m *MockAEAD) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { + ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Open indicates an expected call of Open +func (mr *MockAEADMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockAEAD)(nil).Open), arg0, arg1, arg2, arg3) +} + +// Overhead mocks base method +func (m *MockAEAD) Overhead() int { + ret := m.ctrl.Call(m, "Overhead") + ret0, _ := ret[0].(int) + return ret0 +} + +// Overhead indicates an expected call of Overhead +func (mr *MockAEADMockRecorder) Overhead() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockAEAD)(nil).Overhead)) +} + +// Seal mocks base method +func (m *MockAEAD) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) []byte { + ret := m.ctrl.Call(m, "Seal", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]byte) + return ret0 +} + +// Seal indicates an expected call of Seal +func (mr *MockAEADMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockAEAD)(nil).Seal), arg0, arg1, arg2, arg3) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/gen.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/gen.go index e77844769..bd33e7d01 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/gen.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/gen.go @@ -1,4 +1,11 @@ package mocks -//go:generate mockgen -destination mocks_fc/flow_control_manager.go -package mocks_fc github.com/lucas-clemente/quic-go/flowcontrol FlowControlManager -//go:generate mockgen -destination cpm.go -package mocks github.com/lucas-clemente/quic-go/handshake ConnectionParametersManager +//go:generate sh -c "./mockgen_internal.sh mockhandshake handshake/mint_tls.go github.com/lucas-clemente/quic-go/internal/handshake MintTLS" +//go:generate sh -c "./mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler" +//go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController" +//go:generate sh -c "./mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler" +//go:generate sh -c "./mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler" +//go:generate sh -c "./mockgen_internal.sh mocks congestion.go github.com/lucas-clemente/quic-go/internal/congestion SendAlgorithm" +//go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController" +//go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD" +//go:generate sh -c "goimports -w ." diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/handshake/mint_tls.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/handshake/mint_tls.go new file mode 100644 index 000000000..0a0714dbd --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/handshake/mint_tls.go @@ -0,0 +1,107 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: MintTLS) + +// Package mockhandshake is a generated GoMock package. +package mockhandshake + +import ( + io "io" + reflect "reflect" + + mint "github.com/bifurcation/mint" + gomock "github.com/golang/mock/gomock" +) + +// MockMintTLS is a mock of MintTLS interface +type MockMintTLS struct { + ctrl *gomock.Controller + recorder *MockMintTLSMockRecorder +} + +// MockMintTLSMockRecorder is the mock recorder for MockMintTLS +type MockMintTLSMockRecorder struct { + mock *MockMintTLS +} + +// NewMockMintTLS creates a new mock instance +func NewMockMintTLS(ctrl *gomock.Controller) *MockMintTLS { + mock := &MockMintTLS{ctrl: ctrl} + mock.recorder = &MockMintTLSMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockMintTLS) EXPECT() *MockMintTLSMockRecorder { + return m.recorder +} + +// ComputeExporter mocks base method +func (m *MockMintTLS) ComputeExporter(arg0 string, arg1 []byte, arg2 int) ([]byte, error) { + ret := m.ctrl.Call(m, "ComputeExporter", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ComputeExporter indicates an expected call of ComputeExporter +func (mr *MockMintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ComputeExporter", reflect.TypeOf((*MockMintTLS)(nil).ComputeExporter), arg0, arg1, arg2) +} + +// ConnectionState mocks base method +func (m *MockMintTLS) ConnectionState() mint.ConnectionState { + ret := m.ctrl.Call(m, "ConnectionState") + ret0, _ := ret[0].(mint.ConnectionState) + return ret0 +} + +// ConnectionState indicates an expected call of ConnectionState +func (mr *MockMintTLSMockRecorder) ConnectionState() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockMintTLS)(nil).ConnectionState)) +} + +// GetCipherSuite mocks base method +func (m *MockMintTLS) GetCipherSuite() mint.CipherSuiteParams { + ret := m.ctrl.Call(m, "GetCipherSuite") + ret0, _ := ret[0].(mint.CipherSuiteParams) + return ret0 +} + +// GetCipherSuite indicates an expected call of GetCipherSuite +func (mr *MockMintTLSMockRecorder) GetCipherSuite() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCipherSuite", reflect.TypeOf((*MockMintTLS)(nil).GetCipherSuite)) +} + +// Handshake mocks base method +func (m *MockMintTLS) Handshake() mint.Alert { + ret := m.ctrl.Call(m, "Handshake") + ret0, _ := ret[0].(mint.Alert) + return ret0 +} + +// Handshake indicates an expected call of Handshake +func (mr *MockMintTLSMockRecorder) Handshake() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handshake", reflect.TypeOf((*MockMintTLS)(nil).Handshake)) +} + +// SetCryptoStream mocks base method +func (m *MockMintTLS) SetCryptoStream(arg0 io.ReadWriter) { + m.ctrl.Call(m, "SetCryptoStream", arg0) +} + +// SetCryptoStream indicates an expected call of SetCryptoStream +func (mr *MockMintTLSMockRecorder) SetCryptoStream(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCryptoStream", reflect.TypeOf((*MockMintTLS)(nil).SetCryptoStream), arg0) +} + +// State mocks base method +func (m *MockMintTLS) State() mint.State { + ret := m.ctrl.Call(m, "State") + ret0, _ := ret[0].(mint.State) + return ret0 +} + +// State indicates an expected call of State +func (mr *MockMintTLSMockRecorder) State() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*MockMintTLS)(nil).State)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc/flow_control_manager.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc/flow_control_manager.go deleted file mode 100644 index d18bf48fc..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc/flow_control_manager.go +++ /dev/null @@ -1,140 +0,0 @@ -// Automatically generated by MockGen. DO NOT EDIT! -// Source: github.com/lucas-clemente/quic-go/flowcontrol (interfaces: FlowControlManager) - -package mocks_fc - -import ( - gomock "github.com/golang/mock/gomock" - flowcontrol "github.com/lucas-clemente/quic-go/flowcontrol" - protocol "github.com/lucas-clemente/quic-go/protocol" -) - -// Mock of FlowControlManager interface -type MockFlowControlManager struct { - ctrl *gomock.Controller - recorder *_MockFlowControlManagerRecorder -} - -// Recorder for MockFlowControlManager (not exported) -type _MockFlowControlManagerRecorder struct { - mock *MockFlowControlManager -} - -func NewMockFlowControlManager(ctrl *gomock.Controller) *MockFlowControlManager { - mock := &MockFlowControlManager{ctrl: ctrl} - mock.recorder = &_MockFlowControlManagerRecorder{mock} - return mock -} - -func (_m *MockFlowControlManager) EXPECT() *_MockFlowControlManagerRecorder { - return _m.recorder -} - -func (_m *MockFlowControlManager) AddBytesRead(_param0 protocol.StreamID, _param1 protocol.ByteCount) error { - ret := _m.ctrl.Call(_m, "AddBytesRead", _param0, _param1) - ret0, _ := ret[0].(error) - return ret0 -} - -func (_mr *_MockFlowControlManagerRecorder) AddBytesRead(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "AddBytesRead", arg0, arg1) -} - -func (_m *MockFlowControlManager) AddBytesSent(_param0 protocol.StreamID, _param1 protocol.ByteCount) error { - ret := _m.ctrl.Call(_m, "AddBytesSent", _param0, _param1) - ret0, _ := ret[0].(error) - return ret0 -} - -func (_mr *_MockFlowControlManagerRecorder) AddBytesSent(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "AddBytesSent", arg0, arg1) -} - -func (_m *MockFlowControlManager) GetReceiveWindow(_param0 protocol.StreamID) (protocol.ByteCount, error) { - ret := _m.ctrl.Call(_m, "GetReceiveWindow", _param0) - ret0, _ := ret[0].(protocol.ByteCount) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -func (_mr *_MockFlowControlManagerRecorder) GetReceiveWindow(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveWindow", arg0) -} - -func (_m *MockFlowControlManager) GetWindowUpdates() []flowcontrol.WindowUpdate { - ret := _m.ctrl.Call(_m, "GetWindowUpdates") - ret0, _ := ret[0].([]flowcontrol.WindowUpdate) - return ret0 -} - -func (_mr *_MockFlowControlManagerRecorder) GetWindowUpdates() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetWindowUpdates") -} - -func (_m *MockFlowControlManager) NewStream(_param0 protocol.StreamID, _param1 bool) { - _m.ctrl.Call(_m, "NewStream", _param0, _param1) -} - -func (_mr *_MockFlowControlManagerRecorder) NewStream(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "NewStream", arg0, arg1) -} - -func (_m *MockFlowControlManager) RemainingConnectionWindowSize() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "RemainingConnectionWindowSize") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -func (_mr *_MockFlowControlManagerRecorder) RemainingConnectionWindowSize() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "RemainingConnectionWindowSize") -} - -func (_m *MockFlowControlManager) RemoveStream(_param0 protocol.StreamID) { - _m.ctrl.Call(_m, "RemoveStream", _param0) -} - -func (_mr *_MockFlowControlManagerRecorder) RemoveStream(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "RemoveStream", arg0) -} - -func (_m *MockFlowControlManager) ResetStream(_param0 protocol.StreamID, _param1 protocol.ByteCount) error { - ret := _m.ctrl.Call(_m, "ResetStream", _param0, _param1) - ret0, _ := ret[0].(error) - return ret0 -} - -func (_mr *_MockFlowControlManagerRecorder) ResetStream(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "ResetStream", arg0, arg1) -} - -func (_m *MockFlowControlManager) SendWindowSize(_param0 protocol.StreamID) (protocol.ByteCount, error) { - ret := _m.ctrl.Call(_m, "SendWindowSize", _param0) - ret0, _ := ret[0].(protocol.ByteCount) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -func (_mr *_MockFlowControlManagerRecorder) SendWindowSize(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "SendWindowSize", arg0) -} - -func (_m *MockFlowControlManager) UpdateHighestReceived(_param0 protocol.StreamID, _param1 protocol.ByteCount) error { - ret := _m.ctrl.Call(_m, "UpdateHighestReceived", _param0, _param1) - ret0, _ := ret[0].(error) - return ret0 -} - -func (_mr *_MockFlowControlManagerRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "UpdateHighestReceived", arg0, arg1) -} - -func (_m *MockFlowControlManager) UpdateWindow(_param0 protocol.StreamID, _param1 protocol.ByteCount) (bool, error) { - ret := _m.ctrl.Call(_m, "UpdateWindow", _param0, _param1) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -func (_mr *_MockFlowControlManagerRecorder) UpdateWindow(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "UpdateWindow", arg0, arg1) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/stream_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/stream_flow_controller.go new file mode 100644 index 000000000..a69e73f19 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/stream_flow_controller.go @@ -0,0 +1,126 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/flowcontrol (interfaces: StreamFlowController) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockStreamFlowController is a mock of StreamFlowController interface +type MockStreamFlowController struct { + ctrl *gomock.Controller + recorder *MockStreamFlowControllerMockRecorder +} + +// MockStreamFlowControllerMockRecorder is the mock recorder for MockStreamFlowController +type MockStreamFlowControllerMockRecorder struct { + mock *MockStreamFlowController +} + +// NewMockStreamFlowController creates a new mock instance +func NewMockStreamFlowController(ctrl *gomock.Controller) *MockStreamFlowController { + mock := &MockStreamFlowController{ctrl: ctrl} + mock.recorder = &MockStreamFlowControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStreamFlowController) EXPECT() *MockStreamFlowControllerMockRecorder { + return m.recorder +} + +// AddBytesRead mocks base method +func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) { + m.ctrl.Call(m, "AddBytesRead", arg0) +} + +// AddBytesRead indicates an expected call of AddBytesRead +func (mr *MockStreamFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesRead), arg0) +} + +// AddBytesSent mocks base method +func (m *MockStreamFlowController) AddBytesSent(arg0 protocol.ByteCount) { + m.ctrl.Call(m, "AddBytesSent", arg0) +} + +// AddBytesSent indicates an expected call of AddBytesSent +func (mr *MockStreamFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesSent), arg0) +} + +// GetWindowUpdate mocks base method +func (m *MockStreamFlowController) GetWindowUpdate() protocol.ByteCount { + ret := m.ctrl.Call(m, "GetWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetWindowUpdate indicates an expected call of GetWindowUpdate +func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate)) +} + +// HasWindowUpdate mocks base method +func (m *MockStreamFlowController) HasWindowUpdate() bool { + ret := m.ctrl.Call(m, "HasWindowUpdate") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasWindowUpdate indicates an expected call of HasWindowUpdate +func (mr *MockStreamFlowControllerMockRecorder) HasWindowUpdate() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).HasWindowUpdate)) +} + +// IsBlocked mocks base method +func (m *MockStreamFlowController) IsBlocked() (bool, protocol.ByteCount) { + ret := m.ctrl.Call(m, "IsBlocked") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(protocol.ByteCount) + return ret0, ret1 +} + +// IsBlocked indicates an expected call of IsBlocked +func (mr *MockStreamFlowControllerMockRecorder) IsBlocked() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsBlocked)) +} + +// SendWindowSize mocks base method +func (m *MockStreamFlowController) SendWindowSize() protocol.ByteCount { + ret := m.ctrl.Call(m, "SendWindowSize") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// SendWindowSize indicates an expected call of SendWindowSize +func (mr *MockStreamFlowControllerMockRecorder) SendWindowSize() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockStreamFlowController)(nil).SendWindowSize)) +} + +// UpdateHighestReceived mocks base method +func (m *MockStreamFlowController) UpdateHighestReceived(arg0 protocol.ByteCount, arg1 bool) error { + ret := m.ctrl.Call(m, "UpdateHighestReceived", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateHighestReceived indicates an expected call of UpdateHighestReceived +func (mr *MockStreamFlowControllerMockRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateHighestReceived), arg0, arg1) +} + +// UpdateSendWindow mocks base method +func (m *MockStreamFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { + m.ctrl.Call(m, "UpdateSendWindow", arg0) +} + +// UpdateSendWindow indicates an expected call of UpdateSendWindow +func (mr *MockStreamFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateSendWindow), arg0) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/tls_extension_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/tls_extension_handler.go new file mode 100644 index 000000000..fcceee2eb --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/tls_extension_handler.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: TLSExtensionHandler) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + mint "github.com/bifurcation/mint" + gomock "github.com/golang/mock/gomock" + handshake "github.com/lucas-clemente/quic-go/internal/handshake" +) + +// MockTLSExtensionHandler is a mock of TLSExtensionHandler interface +type MockTLSExtensionHandler struct { + ctrl *gomock.Controller + recorder *MockTLSExtensionHandlerMockRecorder +} + +// MockTLSExtensionHandlerMockRecorder is the mock recorder for MockTLSExtensionHandler +type MockTLSExtensionHandlerMockRecorder struct { + mock *MockTLSExtensionHandler +} + +// NewMockTLSExtensionHandler creates a new mock instance +func NewMockTLSExtensionHandler(ctrl *gomock.Controller) *MockTLSExtensionHandler { + mock := &MockTLSExtensionHandler{ctrl: ctrl} + mock.recorder = &MockTLSExtensionHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockTLSExtensionHandler) EXPECT() *MockTLSExtensionHandlerMockRecorder { + return m.recorder +} + +// GetPeerParams mocks base method +func (m *MockTLSExtensionHandler) GetPeerParams() <-chan handshake.TransportParameters { + ret := m.ctrl.Call(m, "GetPeerParams") + ret0, _ := ret[0].(<-chan handshake.TransportParameters) + return ret0 +} + +// GetPeerParams indicates an expected call of GetPeerParams +func (mr *MockTLSExtensionHandlerMockRecorder) GetPeerParams() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerParams", reflect.TypeOf((*MockTLSExtensionHandler)(nil).GetPeerParams)) +} + +// Receive mocks base method +func (m *MockTLSExtensionHandler) Receive(arg0 mint.HandshakeType, arg1 *mint.ExtensionList) error { + ret := m.ctrl.Call(m, "Receive", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Receive indicates an expected call of Receive +func (mr *MockTLSExtensionHandlerMockRecorder) Receive(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Receive", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Receive), arg0, arg1) +} + +// Send mocks base method +func (m *MockTLSExtensionHandler) Send(arg0 mint.HandshakeType, arg1 *mint.ExtensionList) error { + ret := m.ctrl.Call(m, "Send", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Send indicates an expected call of Send +func (mr *MockTLSExtensionHandlerMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Send), arg0, arg1) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/protocol/encryption_level.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/protocol/encryption_level.go rename to vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go diff --git a/vendor/github.com/lucas-clemente/quic-go/protocol/packet_number.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go similarity index 74% rename from vendor/github.com/lucas-clemente/quic-go/protocol/packet_number.go rename to vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go index c4f468ad5..4bc8bfc9a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/protocol/packet_number.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go @@ -27,18 +27,14 @@ func delta(a, b PacketNumber) PacketNumber { return a - b } -// GetPacketNumberLengthForPublicHeader gets the length of the packet number for the public header +// GetPacketNumberLengthForHeader gets the length of the packet number for the public header // it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances -func GetPacketNumberLengthForPublicHeader(packetNumber PacketNumber, leastUnacked PacketNumber) PacketNumberLen { +func GetPacketNumberLengthForHeader(packetNumber PacketNumber, leastUnacked PacketNumber) PacketNumberLen { diff := uint64(packetNumber - leastUnacked) - if diff < (2 << (uint8(PacketNumberLen2)*8 - 2)) { + if diff < (1 << (uint8(PacketNumberLen2)*8 - 1)) { return PacketNumberLen2 } - if diff < (2 << (uint8(PacketNumberLen4)*8 - 2)) { - return PacketNumberLen4 - } - // we do not check if there are less than 2^46 packets in flight, since flow control and congestion control will limit this number *a lot* sooner - return PacketNumberLen6 + return PacketNumberLen4 } // GetPacketNumberLength gets the minimum length needed to fully represent the packet number diff --git a/vendor/github.com/lucas-clemente/quic-go/protocol/perspective.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go similarity index 52% rename from vendor/github.com/lucas-clemente/quic-go/protocol/perspective.go rename to vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go index 6aa3b70c3..948e371ae 100644 --- a/vendor/github.com/lucas-clemente/quic-go/protocol/perspective.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go @@ -8,3 +8,14 @@ const ( PerspectiveServer Perspective = 1 PerspectiveClient Perspective = 2 ) + +func (p Perspective) String() string { + switch p { + case PerspectiveServer: + return "Server" + case PerspectiveClient: + return "Client" + default: + return "invalid perspective" + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/protocol/protocol.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go similarity index 58% rename from vendor/github.com/lucas-clemente/quic-go/protocol/protocol.go rename to vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go index cf9cf056f..0901b19eb 100644 --- a/vendor/github.com/lucas-clemente/quic-go/protocol/protocol.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go @@ -1,6 +1,8 @@ package protocol -import "math" +import ( + "fmt" +) // A PacketNumber in QUIC type PacketNumber uint64 @@ -21,17 +23,46 @@ const ( PacketNumberLen6 PacketNumberLen = 6 ) +// The PacketType is the Long Header Type (only used for the IETF draft header format) +type PacketType uint8 + +const ( + // PacketTypeInitial is the packet type of a Initial packet + PacketTypeInitial PacketType = 2 + // PacketTypeRetry is the packet type of a Retry packet + PacketTypeRetry PacketType = 3 + // PacketTypeHandshake is the packet type of a Cleartext packet + PacketTypeHandshake PacketType = 4 + // PacketType0RTT is the packet type of a 0-RTT packet + PacketType0RTT PacketType = 5 +) + +func (t PacketType) String() string { + switch t { + case PacketTypeInitial: + return "Initial" + case PacketTypeRetry: + return "Retry" + case PacketTypeHandshake: + return "Handshake" + case PacketType0RTT: + return "0-RTT Protected" + default: + return fmt.Sprintf("unknown packet type: %d", t) + } +} + // A ConnectionID in QUIC type ConnectionID uint64 -// A StreamID in QUIC -type StreamID uint32 - // A ByteCount in QUIC type ByteCount uint64 // MaxByteCount is the maximum value of a ByteCount -const MaxByteCount = ByteCount(math.MaxUint64) +const MaxByteCount = ByteCount(1<<62 - 1) + +// An ApplicationErrorCode is an application-defined error code. +type ApplicationErrorCode uint16 // MaxReceivePacketSize maximum packet size of any QUIC packet, based on // ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, @@ -43,14 +74,11 @@ const MaxReceivePacketSize ByteCount = 1452 // Used in QUIC for congestion window computations in bytes. const DefaultTCPMSS ByteCount = 1460 -// InitialStreamFlowControlWindow is the initial stream-level flow control window for sending -const InitialStreamFlowControlWindow ByteCount = (1 << 14) // 16 kB +// MinClientHelloSize is the minimum size the server expects an inchoate CHLO to have (in gQUIC) +const MinClientHelloSize = 1024 -// InitialConnectionFlowControlWindow is the initial connection-level flow control window for sending -const InitialConnectionFlowControlWindow ByteCount = (1 << 14) // 16 kB - -// ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have. -const ClientHelloMinimumSize = 1024 +// MinInitialPacketSize is the minimum size an Initial packet (in IETF QUIC) is requried to have. +const MinInitialPacketSize = 1200 // MaxClientHellos is the maximum number of times we'll send a client hello // The value 3 accounts for: diff --git a/vendor/github.com/lucas-clemente/quic-go/protocol/server_parameters.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go similarity index 71% rename from vendor/github.com/lucas-clemente/quic-go/protocol/server_parameters.go rename to vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go index 8e632cc13..61e5a2dfc 100644 --- a/vendor/github.com/lucas-clemente/quic-go/protocol/server_parameters.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go @@ -2,12 +2,9 @@ package protocol import "time" -// MaxPacketSize is the maximum packet size, including the public header, that we use for sending packets -// This is the value used by Chromium for a QUIC packet sent using IPv6 (for IPv4 it would be 1370) -const MaxPacketSize ByteCount = 1350 - -// MaxFrameAndPublicHeaderSize is the maximum size of a QUIC frame plus PublicHeader -const MaxFrameAndPublicHeaderSize = MaxPacketSize - 12 /*crypto signature*/ +// MaxPacketSize is the maximum packet size that we use for sending packets. +// It includes the QUIC packet header, but excludes the UDP and IP header. +const MaxPacketSize ByteCount = 1200 // NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet // This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames @@ -33,37 +30,37 @@ const AckSendDelay = 25 * time.Millisecond // ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data // This is the value that Google servers are using -const ReceiveStreamFlowControlWindow ByteCount = (1 << 10) * 32 // 32 kB +const ReceiveStreamFlowControlWindow = (1 << 10) * 32 // 32 kB // ReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data // This is the value that Google servers are using -const ReceiveConnectionFlowControlWindow ByteCount = (1 << 10) * 48 // 48 kB +const ReceiveConnectionFlowControlWindow = (1 << 10) * 48 // 48 kB // DefaultMaxReceiveStreamFlowControlWindowServer is the default maximum stream-level flow control window for receiving data, for the server // This is the value that Google servers are using -const DefaultMaxReceiveStreamFlowControlWindowServer ByteCount = 1 * (1 << 20) // 1 MB +const DefaultMaxReceiveStreamFlowControlWindowServer = 1 * (1 << 20) // 1 MB // DefaultMaxReceiveConnectionFlowControlWindowServer is the default connection-level flow control window for receiving data, for the server // This is the value that Google servers are using -const DefaultMaxReceiveConnectionFlowControlWindowServer ByteCount = 1.5 * (1 << 20) // 1.5 MB +const DefaultMaxReceiveConnectionFlowControlWindowServer = 1.5 * (1 << 20) // 1.5 MB // DefaultMaxReceiveStreamFlowControlWindowClient is the default maximum stream-level flow control window for receiving data, for the client // This is the value that Chromium is using -const DefaultMaxReceiveStreamFlowControlWindowClient ByteCount = 6 * (1 << 20) // 6 MB +const DefaultMaxReceiveStreamFlowControlWindowClient = 6 * (1 << 20) // 6 MB // DefaultMaxReceiveConnectionFlowControlWindowClient is the default connection-level flow control window for receiving data, for the client // This is the value that Google servers are using -const DefaultMaxReceiveConnectionFlowControlWindowClient ByteCount = 15 * (1 << 20) // 15 MB +const DefaultMaxReceiveConnectionFlowControlWindowClient = 15 * (1 << 20) // 15 MB // ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window // This is the value that Chromium is using const ConnectionFlowControlMultiplier = 1.5 -// MaxStreamsPerConnection is the maximum value accepted for the number of streams per connection -const MaxStreamsPerConnection = 100 +// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client +const WindowUpdateThreshold = 0.25 -// MaxIncomingDynamicStreamsPerConnection is the maximum value accepted for the incoming number of dynamic streams per connection -const MaxIncomingDynamicStreamsPerConnection = 100 +// MaxIncomingStreams is the maximum number of streams that a peer may open +const MaxIncomingStreams = 100 // MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used. const MaxStreamsMultiplier = 1.1 @@ -73,7 +70,7 @@ const MaxStreamsMinimumIncrement = 10 // MaxNewStreamIDDelta is the maximum difference between and a newly opened Stream and the highest StreamID that a client has ever opened // note that the number of streams is half this value, since the client can only open streams with open StreamID -const MaxNewStreamIDDelta = 4 * MaxStreamsPerConnection +const MaxNewStreamIDDelta = 4 * MaxIncomingStreams // MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed. const MaxSessionUnprocessedPackets = DefaultMaxCongestionWindow @@ -84,23 +81,20 @@ const SkipPacketAveragePeriodLength PacketNumber = 500 // MaxTrackedSkippedPackets is the maximum number of skipped packet numbers the SentPacketHandler keep track of for Optimistic ACK attack mitigation const MaxTrackedSkippedPackets = 10 -// STKExpiryTime is the valid time of a source address token -const STKExpiryTime = 24 * time.Hour +// CookieExpiryTime is the valid time of a cookie +const CookieExpiryTime = 24 * time.Hour // MaxTrackedSentPackets is maximum number of sent packets saved for either later retransmission or entropy calculation const MaxTrackedSentPackets = 2 * DefaultMaxCongestionWindow -// MaxTrackedReceivedPackets is the maximum number of received packets saved for doing the entropy calculations -const MaxTrackedReceivedPackets = 2 * DefaultMaxCongestionWindow - // MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow -// MaxPacketsReceivedBeforeAckSend is the number of packets that can be received before an ACK frame is sent -const MaxPacketsReceivedBeforeAckSend = 20 +// MaxNonRetransmittableAcks is the maximum number of packets containing an ACK, but no retransmittable frames, that we send in a row +const MaxNonRetransmittableAcks = 19 // RetransmittablePacketsBeforeAck is the number of retransmittable that an ACK is sent for -const RetransmittablePacketsBeforeAck = 2 +const RetransmittablePacketsBeforeAck = 10 // MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames // prevents DoS attacks against the streamFrameSorter @@ -116,18 +110,12 @@ const CryptoParameterMaxLength = 4000 // EphermalKeyLifetime is the lifetime of the ephermal key during the handshake, see handshake.getEphermalKEX. const EphermalKeyLifetime = time.Minute -// InitialIdleTimeout is the timeout before the handshake succeeds. -const InitialIdleTimeout = 5 * time.Second +// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout +const MinRemoteIdleTimeout = 5 * time.Second -// DefaultIdleTimeout is the default idle timeout, for the server +// DefaultIdleTimeout is the default idle timeout const DefaultIdleTimeout = 30 * time.Second -// MaxIdleTimeoutServer is the maximum idle timeout that can be negotiated, for the server -const MaxIdleTimeoutServer = 1 * time.Minute - -// MaxIdleTimeoutClient is the idle timeout that the client suggests to the server -const MaxIdleTimeoutClient = 2 * time.Minute - // DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds. const DefaultHandshakeTimeout = 10 * time.Second @@ -137,3 +125,14 @@ const ClosedSessionDeleteTimeout = time.Minute // NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space const NumCachedCertificates = 128 + +// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame. +// This avoids splitting up STREAM frames into small pieces, which has 2 advantages: +// 1. it reduces the framing overhead +// 2. it reduces the head-of-line blocking, when a packet is lost +const MinStreamFrameSize ByteCount = 128 + +// MinPacingDelay is the minimum duration that is used for packet pacing +// If the packet packing frequency is higher, multiple packets might be sent at once. +// Example: For a packet pacing delay of 20 microseconds, we would send 5 packets at once, wait for 100 microseconds, and so forth. +const MinPacingDelay time.Duration = 100 * time.Microsecond diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/stream_id.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/stream_id.go new file mode 100644 index 000000000..a0dced0ce --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/stream_id.go @@ -0,0 +1,36 @@ +package protocol + +// A StreamID in QUIC +type StreamID uint64 + +// MaxBidiStreamID is the highest stream ID that the peer is allowed to open, +// when it is allowed to open numStreams bidirectional streams. +// It is only valid for IETF QUIC. +func MaxBidiStreamID(numStreams int, pers Perspective) StreamID { + if numStreams == 0 { + return 0 + } + var first StreamID + if pers == PerspectiveClient { + first = 1 + } else { + first = 4 + } + return first + 4*StreamID(numStreams-1) +} + +// MaxUniStreamID is the highest stream ID that the peer is allowed to open, +// when it is allowed to open numStreams unidirectional streams. +// It is only valid for IETF QUIC. +func MaxUniStreamID(numStreams int, pers Perspective) StreamID { + if numStreams == 0 { + return 0 + } + var first StreamID + if pers == PerspectiveClient { + first = 3 + } else { + first = 2 + } + return first + 4*StreamID(numStreams-1) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go new file mode 100644 index 000000000..3135ca853 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go @@ -0,0 +1,135 @@ +package protocol + +import ( + "crypto/rand" + "encoding/binary" + "fmt" +) + +// VersionNumber is a version number as int +type VersionNumber int32 + +// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions +const ( + gquicVersion0 = 0x51303030 + maxGquicVersion = 0x51303439 +) + +// The version numbers, making grepping easier +const ( + Version39 VersionNumber = gquicVersion0 + 3*0x100 + 0x9 + iota + VersionTLS VersionNumber = 101 + VersionWhatever VersionNumber = 0 // for when the version doesn't matter + VersionUnknown VersionNumber = -1 +) + +// SupportedVersions lists the versions that the server supports +// must be in sorted descending order +var SupportedVersions = []VersionNumber{ + Version39, +} + +// UsesTLS says if this QUIC version uses TLS 1.3 for the handshake +func (vn VersionNumber) UsesTLS() bool { + return vn == VersionTLS +} + +func (vn VersionNumber) String() string { + switch vn { + case VersionWhatever: + return "whatever" + case VersionUnknown: + return "unknown" + case VersionTLS: + return "TLS dev version (WIP)" + default: + if vn.isGQUIC() { + return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion()) + } + return fmt.Sprintf("%d", vn) + } +} + +// ToAltSvc returns the representation of the version for the H2 Alt-Svc parameters +func (vn VersionNumber) ToAltSvc() string { + if vn.isGQUIC() { + return fmt.Sprintf("%d", vn.toGQUICVersion()) + } + return fmt.Sprintf("%d", vn) +} + +// CryptoStreamID gets the Stream ID of the crypto stream +func (vn VersionNumber) CryptoStreamID() StreamID { + if vn.isGQUIC() { + return 1 + } + return 0 +} + +// UsesIETFFrameFormat tells if this version uses the IETF frame format +func (vn VersionNumber) UsesIETFFrameFormat() bool { + return vn != Version39 +} + +// StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control +func (vn VersionNumber) StreamContributesToConnectionFlowControl(id StreamID) bool { + if id == vn.CryptoStreamID() { + return false + } + if vn.isGQUIC() && id == 3 { + return false + } + return true +} + +func (vn VersionNumber) isGQUIC() bool { + return vn > gquicVersion0 && vn <= maxGquicVersion +} + +func (vn VersionNumber) toGQUICVersion() int { + return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10) +} + +// IsSupportedVersion returns true if the server supports this version +func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { + for _, t := range supported { + if t == v { + return true + } + } + return false +} + +// ChooseSupportedVersion finds the best version in the overlap of ours and theirs +// ours is a slice of versions that we support, sorted by our preference (descending) +// theirs is a slice of versions offered by the peer. The order does not matter. +// The bool returned indicates if a matching version was found. +func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) { + for _, ourVer := range ours { + for _, theirVer := range theirs { + if ourVer == theirVer { + return ourVer, true + } + } + } + return 0, false +} + +// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a) +func generateReservedVersion() VersionNumber { + b := make([]byte, 4) + _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything + return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa) +} + +// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position +func GetGreasedVersions(supported []VersionNumber) []VersionNumber { + b := make([]byte, 1) + _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything + randPos := int(b[0]) % (len(supported) + 1) + greased := make([]VersionNumber, len(supported)+1) + copy(greased, supported[:randPos]) + greased[randPos] = generateReservedVersion() + copy(greased[randPos+1:], supported[randPos:]) + return greased +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go new file mode 100644 index 000000000..b45800a37 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go @@ -0,0 +1,25 @@ +package utils + +import ( + "bytes" + "io" +) + +// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers. +type ByteOrder interface { + ReadUintN(b io.ByteReader, length uint8) (uint64, error) + ReadUint64(io.ByteReader) (uint64, error) + ReadUint32(io.ByteReader) (uint32, error) + ReadUint16(io.ByteReader) (uint16, error) + + WriteUint64(*bytes.Buffer, uint64) + WriteUint56(*bytes.Buffer, uint64) + WriteUint48(*bytes.Buffer, uint64) + WriteUint40(*bytes.Buffer, uint64) + WriteUint32(*bytes.Buffer, uint32) + WriteUint24(*bytes.Buffer, uint32) + WriteUint16(*bytes.Buffer, uint16) + + ReadUfloat16(io.ByteReader) (uint64, error) + WriteUfloat16(*bytes.Buffer, uint64) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go new file mode 100644 index 000000000..9f6c9a617 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go @@ -0,0 +1,157 @@ +package utils + +import ( + "bytes" + "fmt" + "io" +) + +// BigEndian is the big-endian implementation of ByteOrder. +var BigEndian ByteOrder = bigEndian{} + +type bigEndian struct{} + +var _ ByteOrder = &bigEndian{} + +// ReadUintN reads N bytes +func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) { + var res uint64 + for i := uint8(0); i < length; i++ { + bt, err := b.ReadByte() + if err != nil { + return 0, err + } + res ^= uint64(bt) << ((length - 1 - i) * 8) + } + return res, nil +} + +// ReadUint64 reads a uint64 +func (bigEndian) ReadUint64(b io.ByteReader) (uint64, error) { + var b1, b2, b3, b4, b5, b6, b7, b8 uint8 + var err error + if b8, err = b.ReadByte(); err != nil { + return 0, err + } + if b7, err = b.ReadByte(); err != nil { + return 0, err + } + if b6, err = b.ReadByte(); err != nil { + return 0, err + } + if b5, err = b.ReadByte(); err != nil { + return 0, err + } + if b4, err = b.ReadByte(); err != nil { + return 0, err + } + if b3, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil +} + +// ReadUint32 reads a uint32 +func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) { + var b1, b2, b3, b4 uint8 + var err error + if b4, err = b.ReadByte(); err != nil { + return 0, err + } + if b3, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil +} + +// ReadUint16 reads a uint16 +func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) { + var b1, b2 uint8 + var err error + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint16(b1) + uint16(b2)<<8, nil +} + +// WriteUint64 writes a uint64 +func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) { + b.Write([]byte{ + uint8(i >> 56), uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }) +} + +// WriteUint56 writes 56 bit of a uint64 +func (bigEndian) WriteUint56(b *bytes.Buffer, i uint64) { + if i >= (1 << 56) { + panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i)) + } + b.Write([]byte{ + uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }) +} + +// WriteUint48 writes 48 bit of a uint64 +func (bigEndian) WriteUint48(b *bytes.Buffer, i uint64) { + if i >= (1 << 48) { + panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i)) + } + b.Write([]byte{ + uint8(i >> 40), uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }) +} + +// WriteUint40 writes 40 bit of a uint64 +func (bigEndian) WriteUint40(b *bytes.Buffer, i uint64) { + if i >= (1 << 40) { + panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i)) + } + b.Write([]byte{ + uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }) +} + +// WriteUint32 writes a uint32 +func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) { + b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)}) +} + +// WriteUint24 writes 24 bit of a uint32 +func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) { + if i >= (1 << 24) { + panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i)) + } + b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)}) +} + +// WriteUint16 writes a uint16 +func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) { + b.Write([]byte{uint8(i >> 8), uint8(i)}) +} + +func (l bigEndian) ReadUfloat16(b io.ByteReader) (uint64, error) { + return readUfloat16(b, l) +} + +func (l bigEndian) WriteUfloat16(b *bytes.Buffer, val uint64) { + writeUfloat16(b, l, val) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/utils.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_little_endian.go similarity index 64% rename from vendor/github.com/lucas-clemente/quic-go/internal/utils/utils.go rename to vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_little_endian.go index 5a987e61a..71ff95d5b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/utils.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_little_endian.go @@ -2,11 +2,19 @@ package utils import ( "bytes" + "fmt" "io" ) +// LittleEndian is the little-endian implementation of ByteOrder. +var LittleEndian ByteOrder = littleEndian{} + +type littleEndian struct{} + +var _ ByteOrder = &littleEndian{} + // ReadUintN reads N bytes -func ReadUintN(b io.ByteReader, length uint8) (uint64, error) { +func (littleEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) { var res uint64 for i := uint8(0); i < length; i++ { bt, err := b.ReadByte() @@ -19,7 +27,7 @@ func ReadUintN(b io.ByteReader, length uint8) (uint64, error) { } // ReadUint64 reads a uint64 -func ReadUint64(b io.ByteReader) (uint64, error) { +func (littleEndian) ReadUint64(b io.ByteReader) (uint64, error) { var b1, b2, b3, b4, b5, b6, b7, b8 uint8 var err error if b1, err = b.ReadByte(); err != nil { @@ -50,7 +58,7 @@ func ReadUint64(b io.ByteReader) (uint64, error) { } // ReadUint32 reads a uint32 -func ReadUint32(b io.ByteReader) (uint32, error) { +func (littleEndian) ReadUint32(b io.ByteReader) (uint32, error) { var b1, b2, b3, b4 uint8 var err error if b1, err = b.ReadByte(); err != nil { @@ -69,7 +77,7 @@ func ReadUint32(b io.ByteReader) (uint32, error) { } // ReadUint16 reads a uint16 -func ReadUint16(b io.ByteReader) (uint16, error) { +func (littleEndian) ReadUint16(b io.ByteReader) (uint16, error) { var b1, b2 uint8 var err error if b1, err = b.ReadByte(); err != nil { @@ -82,7 +90,7 @@ func ReadUint16(b io.ByteReader) (uint16, error) { } // WriteUint64 writes a uint64 -func WriteUint64(b *bytes.Buffer, i uint64) { +func (littleEndian) WriteUint64(b *bytes.Buffer, i uint64) { b.Write([]byte{ uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24), uint8(i >> 32), uint8(i >> 40), uint8(i >> 48), uint8(i >> 56), @@ -90,7 +98,10 @@ func WriteUint64(b *bytes.Buffer, i uint64) { } // WriteUint56 writes 56 bit of a uint64 -func WriteUint56(b *bytes.Buffer, i uint64) { +func (littleEndian) WriteUint56(b *bytes.Buffer, i uint64) { + if i >= (1 << 56) { + panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i)) + } b.Write([]byte{ uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24), uint8(i >> 32), uint8(i >> 40), uint8(i >> 48), @@ -98,7 +109,10 @@ func WriteUint56(b *bytes.Buffer, i uint64) { } // WriteUint48 writes 48 bit of a uint64 -func WriteUint48(b *bytes.Buffer, i uint64) { +func (littleEndian) WriteUint48(b *bytes.Buffer, i uint64) { + if i >= (1 << 48) { + panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i)) + } b.Write([]byte{ uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24), uint8(i >> 32), uint8(i >> 40), @@ -106,7 +120,10 @@ func WriteUint48(b *bytes.Buffer, i uint64) { } // WriteUint40 writes 40 bit of a uint64 -func WriteUint40(b *bytes.Buffer, i uint64) { +func (littleEndian) WriteUint40(b *bytes.Buffer, i uint64) { + if i >= (1 << 40) { + panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i)) + } b.Write([]byte{ uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24), uint8(i >> 32), @@ -114,16 +131,27 @@ func WriteUint40(b *bytes.Buffer, i uint64) { } // WriteUint32 writes a uint32 -func WriteUint32(b *bytes.Buffer, i uint32) { +func (littleEndian) WriteUint32(b *bytes.Buffer, i uint32) { b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24)}) } // WriteUint24 writes 24 bit of a uint32 -func WriteUint24(b *bytes.Buffer, i uint32) { +func (littleEndian) WriteUint24(b *bytes.Buffer, i uint32) { + if i >= (1 << 24) { + panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i)) + } b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16)}) } // WriteUint16 writes a uint16 -func WriteUint16(b *bytes.Buffer, i uint16) { +func (littleEndian) WriteUint16(b *bytes.Buffer, i uint16) { b.Write([]byte{uint8(i), uint8(i >> 8)}) } + +func (l littleEndian) ReadUfloat16(b io.ByteReader) (uint64, error) { + return readUfloat16(b, l) +} + +func (l littleEndian) WriteUfloat16(b *bytes.Buffer, val uint64) { + writeUfloat16(b, l, val) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go index c2252e6ed..b4af4e780 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go @@ -4,7 +4,7 @@ import ( "crypto/rand" "encoding/binary" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // GenerateConnectionID generates a connection ID using cryptographic random diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/float16.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/float16.go index 8abdb51d8..8e2ca1bca 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/float16.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/float16.go @@ -21,9 +21,9 @@ const uFloat16MantissaBits = 16 - uFloat16ExponentBits const uFloat16MantissaEffectiveBits = uFloat16MantissaBits + 1 // 12 const uFloat16MaxValue = ((uint64(1) << uFloat16MantissaEffectiveBits) - 1) << uFloat16MaxExponent // 0x3FFC0000000 -// ReadUfloat16 reads a float in the QUIC-float16 format and returns its uint64 representation -func ReadUfloat16(b io.ByteReader) (uint64, error) { - val, err := ReadUint16(b) +// readUfloat16 reads a float in the QUIC-float16 format and returns its uint64 representation +func readUfloat16(b io.ByteReader, byteOrder ByteOrder) (uint64, error) { + val, err := byteOrder.ReadUint16(b) if err != nil { return 0, err } @@ -50,8 +50,8 @@ func ReadUfloat16(b io.ByteReader) (uint64, error) { return res, nil } -// WriteUfloat16 writes a float in the QUIC-float16 format from its uint64 representation -func WriteUfloat16(b *bytes.Buffer, value uint64) { +// writeUfloat16 writes a float in the QUIC-float16 format from its uint64 representation +func writeUfloat16(b *bytes.Buffer, byteOrder ByteOrder, value uint64) { var result uint16 if value < (uint64(1) << uFloat16MantissaEffectiveBits) { // Fast path: either the value is denormalized, or has exponent zero. @@ -82,5 +82,5 @@ func WriteUfloat16(b *bytes.Buffer, value uint64) { result = (uint16(value) + (exponent << uFloat16MantissaBits)) } - WriteUint16(b, result) + byteOrder.WriteUint16(b, result) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go index 9128510e6..342d8ddca 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "os" + "strings" "time" ) @@ -79,14 +80,14 @@ func init() { } func readLoggingEnv() { - switch os.Getenv(logEnv) { + switch strings.ToLower(os.Getenv(logEnv)) { case "": return - case "DEBUG": + case "debug": logLevel = LogLevelDebug - case "INFO": + case "info": logLevel = LogLevelInfo - case "ERROR": + case "error": logLevel = LogLevelError default: fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging") diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go index 6e23df5a5..ef71c7fa8 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go @@ -4,7 +4,7 @@ import ( "math" "time" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // InfDuration is a duration of infinite length @@ -114,6 +114,14 @@ func MinTime(a, b time.Time) time.Time { return a } +// MaxTime returns the later time +func MaxTime(a, b time.Time) time.Time { + if a.After(b) { + return a + } + return b +} + // MaxPacketNumber returns the max packet number func MaxPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { if a > b { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go index 09800b6b6..f49b0c426 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go @@ -1,6 +1,6 @@ package utils -import "github.com/lucas-clemente/quic-go/protocol" +import "github.com/lucas-clemente/quic-go/internal/protocol" // PacketInterval is an interval from one PacketNumber to the other // +gen linkedlist diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go index c918b62eb..3c8325b25 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go @@ -1,6 +1,6 @@ package utils -import "github.com/lucas-clemente/quic-go/protocol" +import "github.com/lucas-clemente/quic-go/internal/protocol" // ByteInterval is an interval from one ByteCount to the other // +gen linkedlist diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go index 695ad3e75..7f8ffc7a0 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go @@ -21,7 +21,7 @@ func (t *Timer) Chan() <-chan time.Time { // Reset the timer, no matter whether the value was read or not func (t *Timer) Reset(deadline time.Time) { - if deadline.Equal(t.deadline) { + if deadline.Equal(t.deadline) && !t.read { // No need to reset the timer return } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/varint.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/varint.go new file mode 100644 index 000000000..35e8674e2 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/varint.go @@ -0,0 +1,101 @@ +package utils + +import ( + "bytes" + "fmt" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// taken from the QUIC draft +const ( + maxVarInt1 = 63 + maxVarInt2 = 16383 + maxVarInt4 = 1073741823 + maxVarInt8 = 4611686018427387903 +) + +// ReadVarInt reads a number in the QUIC varint format +func ReadVarInt(b io.ByteReader) (uint64, error) { + firstByte, err := b.ReadByte() + if err != nil { + return 0, err + } + // the first two bits of the first byte encode the length + len := 1 << ((firstByte & 0xc0) >> 6) + b1 := firstByte & (0xff - 0xc0) + if len == 1 { + return uint64(b1), nil + } + b2, err := b.ReadByte() + if err != nil { + return 0, err + } + if len == 2 { + return uint64(b2) + uint64(b1)<<8, nil + } + b3, err := b.ReadByte() + if err != nil { + return 0, err + } + b4, err := b.ReadByte() + if err != nil { + return 0, err + } + if len == 4 { + return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil + } + b5, err := b.ReadByte() + if err != nil { + return 0, err + } + b6, err := b.ReadByte() + if err != nil { + return 0, err + } + b7, err := b.ReadByte() + if err != nil { + return 0, err + } + b8, err := b.ReadByte() + if err != nil { + return 0, err + } + return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil +} + +// WriteVarInt writes a number in the QUIC varint format +func WriteVarInt(b *bytes.Buffer, i uint64) { + if i <= maxVarInt1 { + b.WriteByte(uint8(i)) + } else if i <= maxVarInt2 { + b.Write([]byte{uint8(i>>8) | 0x40, uint8(i)}) + } else if i <= maxVarInt4 { + b.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}) + } else if i <= maxVarInt8 { + b.Write([]byte{ + uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }) + } else { + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) + } +} + +// VarIntLen determines the number of bytes that will be needed to write a number +func VarIntLen(i uint64) protocol.ByteCount { + if i <= maxVarInt1 { + return 1 + } + if i <= maxVarInt2 { + return 2 + } + if i <= maxVarInt4 { + return 4 + } + if i <= maxVarInt8 { + return 8 + } + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go new file mode 100644 index 000000000..996b771b4 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go @@ -0,0 +1,239 @@ +package wire + +import ( + "bytes" + "errors" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// TODO: use the value sent in the transport parameters +const ackDelayExponent = 3 + +// An AckFrame is an ACK frame +type AckFrame struct { + LargestAcked protocol.PacketNumber + LowestAcked protocol.PacketNumber + AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last + + // time when the LargestAcked was receiveid + // this field will not be set for received ACKs frames + PacketReceivedTime time.Time + DelayTime time.Duration +} + +// ParseAckFrame reads an ACK frame +func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) { + if !version.UsesIETFFrameFormat() { + return parseAckFrameLegacy(r, version) + } + + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + frame := &AckFrame{} + + largestAcked, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + frame.LargestAcked = protocol.PacketNumber(largestAcked) + delay, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + frame.DelayTime = time.Duration(delay*1< frame.LargestAcked { + return nil, errors.New("invalid first ACK range") + } + smallest := frame.LargestAcked - protocol.PacketNumber(ackBlock) + + // read all the other ACK ranges + if numBlocks > 0 { + frame.AckRanges = append(frame.AckRanges, AckRange{First: smallest, Last: frame.LargestAcked}) + } + for i := uint64(0); i < numBlocks; i++ { + g, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + gap := protocol.PacketNumber(g) + if smallest < gap+2 { + return nil, errInvalidAckRanges + } + largest := smallest - gap - 2 + + ab, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + ackBlock := protocol.PacketNumber(ab) + + if ackBlock > largest { + return nil, errInvalidAckRanges + } + smallest = largest - protocol.PacketNumber(ackBlock) + frame.AckRanges = append(frame.AckRanges, AckRange{First: smallest, Last: largest}) + } + + frame.LowestAcked = smallest + if !frame.validateAckRanges() { + return nil, errInvalidAckRanges + } + + return frame, nil +} + +// Write writes an ACK frame. +func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if !version.UsesIETFFrameFormat() { + return f.writeLegacy(b, version) + } + + b.WriteByte(0xe) + utils.WriteVarInt(b, uint64(f.LargestAcked)) + utils.WriteVarInt(b, encodeAckDelay(f.DelayTime)) + + // TODO: limit the number of ACK ranges, such that the frame doesn't grow larger than an upper bound + var lowestInFirstRange protocol.PacketNumber + if f.HasMissingRanges() { + utils.WriteVarInt(b, uint64(len(f.AckRanges)-1)) + lowestInFirstRange = f.AckRanges[0].First + } else { + utils.WriteVarInt(b, 0) + lowestInFirstRange = f.LowestAcked + } + + // write the first range + utils.WriteVarInt(b, uint64(f.LargestAcked-lowestInFirstRange)) + + // write all the other range + if !f.HasMissingRanges() { + return nil + } + var lowest protocol.PacketNumber + for i, ackRange := range f.AckRanges { + if i == 0 { + lowest = lowestInFirstRange + continue + } + utils.WriteVarInt(b, uint64(lowest-ackRange.Last-2)) + utils.WriteVarInt(b, uint64(ackRange.Last-ackRange.First)) + lowest = ackRange.First + } + return nil +} + +// Length of a written frame +func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if !version.UsesIETFFrameFormat() { + return f.lengthLegacy(version) + } + + length := 1 + utils.VarIntLen(uint64(f.LargestAcked)) + utils.VarIntLen(uint64(encodeAckDelay(f.DelayTime))) + + var lowestInFirstRange protocol.PacketNumber + if f.HasMissingRanges() { + length += utils.VarIntLen(uint64(len(f.AckRanges) - 1)) + lowestInFirstRange = f.AckRanges[0].First + } else { + length += utils.VarIntLen(0) + lowestInFirstRange = f.LowestAcked + } + length += utils.VarIntLen(uint64(f.LargestAcked - lowestInFirstRange)) + + if !f.HasMissingRanges() { + return length + } + var lowest protocol.PacketNumber + for i, ackRange := range f.AckRanges { + if i == 0 { + lowest = ackRange.First + continue + } + length += utils.VarIntLen(uint64(lowest - ackRange.Last - 2)) + length += utils.VarIntLen(uint64(ackRange.Last - ackRange.First)) + lowest = ackRange.First + } + return length +} + +// HasMissingRanges returns if this frame reports any missing packets +func (f *AckFrame) HasMissingRanges() bool { + return len(f.AckRanges) > 0 +} + +func (f *AckFrame) validateAckRanges() bool { + if len(f.AckRanges) == 0 { + return true + } + + // if there are missing packets, there will always be at least 2 ACK ranges + if len(f.AckRanges) == 1 { + return false + } + + if f.AckRanges[0].Last != f.LargestAcked { + return false + } + + // check the validity of every single ACK range + for _, ackRange := range f.AckRanges { + if ackRange.First > ackRange.Last { + return false + } + } + + // check the consistency for ACK with multiple NACK ranges + for i, ackRange := range f.AckRanges { + if i == 0 { + continue + } + lastAckRange := f.AckRanges[i-1] + if lastAckRange.First <= ackRange.First { + return false + } + if lastAckRange.First <= ackRange.Last+1 { + return false + } + } + + return true +} + +// AcksPacket determines if this ACK frame acks a certain packet number +func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { + if p < f.LowestAcked || p > f.LargestAcked { // this is just a performance optimization + return false + } + + if f.HasMissingRanges() { + // TODO: this could be implemented as a binary search + for _, ackRange := range f.AckRanges { + if p >= ackRange.First && p <= ackRange.Last { + return true + } + } + return false + } + // if packet doesn't have missing ranges + return (p >= f.LowestAcked && p <= f.LargestAcked) +} + +func encodeAckDelay(delay time.Duration) uint64 { + return uint64(delay.Nanoseconds() / (1000 * (1 << ackDelayExponent))) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/ack_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go similarity index 58% rename from vendor/github.com/lucas-clemente/quic-go/frames/ack_frame.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go index ceeba48c7..1f1c22e99 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frames/ack_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go @@ -1,40 +1,21 @@ -package frames +package wire import ( "bytes" "errors" "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" -) - -var ( - // ErrInvalidAckRanges occurs when a client sends inconsistent ACK ranges - ErrInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") - // ErrInvalidFirstAckRange occurs when the first ACK range contains no packets - ErrInvalidFirstAckRange = errors.New("AckFrame: ACK frame has invalid first ACK range") ) var ( errInconsistentAckLargestAcked = errors.New("internal inconsistency: LargestAcked does not match ACK ranges") errInconsistentAckLowestAcked = errors.New("internal inconsistency: LowestAcked does not match ACK ranges") + errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") ) -// An AckFrame is an ACK frame in QUIC -type AckFrame struct { - LargestAcked protocol.PacketNumber - LowestAcked protocol.PacketNumber - AckRanges []AckRange // has to be ordered. The ACK range with the highest FirstPacketNumber goes first, the ACK range with the lowest FirstPacketNumber goes last - - // time when the LargestAcked was receiveid - // this field Will not be set for received ACKs frames - PacketReceivedTime time.Time - DelayTime time.Duration -} - -// ParseAckFrame reads an ACK frame -func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) { +func parseAckFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (*AckFrame, error) { frame := &AckFrame{} typeByte, err := r.ReadByte() @@ -57,13 +38,13 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, missingSequenceNumberDeltaLen = 1 } - largestAcked, err := utils.ReadUintN(r, largestAckedLen) + largestAcked, err := utils.BigEndian.ReadUintN(r, largestAckedLen) if err != nil { return nil, err } frame.LargestAcked = protocol.PacketNumber(largestAcked) - delay, err := utils.ReadUfloat16(r) + delay, err := utils.BigEndian.ReadUfloat16(r) if err != nil { return nil, err } @@ -78,25 +59,25 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, } if hasMissingRanges && numAckBlocks == 0 { - return nil, ErrInvalidAckRanges + return nil, errInvalidAckRanges } - ackBlockLength, err := utils.ReadUintN(r, missingSequenceNumberDeltaLen) + ackBlockLength, err := utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen) if err != nil { return nil, err } if frame.LargestAcked > 0 && ackBlockLength < 1 { - return nil, ErrInvalidFirstAckRange + return nil, errors.New("invalid first ACK range") } - if ackBlockLength > largestAcked { - return nil, ErrInvalidAckRanges + if ackBlockLength > largestAcked+1 { + return nil, errInvalidAckRanges } if hasMissingRanges { ackRange := AckRange{ - FirstPacketNumber: protocol.PacketNumber(largestAcked-ackBlockLength) + 1, - LastPacketNumber: frame.LargestAcked, + First: protocol.PacketNumber(largestAcked-ackBlockLength) + 1, + Last: frame.LargestAcked, } frame.AckRanges = append(frame.AckRanges, ackRange) @@ -109,7 +90,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, return nil, err } - ackBlockLength, err = utils.ReadUintN(r, missingSequenceNumberDeltaLen) + ackBlockLength, err = utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen) if err != nil { return nil, err } @@ -117,14 +98,14 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, length := protocol.PacketNumber(ackBlockLength) if inLongBlock { - frame.AckRanges[len(frame.AckRanges)-1].FirstPacketNumber -= protocol.PacketNumber(gap) + length - frame.AckRanges[len(frame.AckRanges)-1].LastPacketNumber -= protocol.PacketNumber(gap) + frame.AckRanges[len(frame.AckRanges)-1].First -= protocol.PacketNumber(gap) + length + frame.AckRanges[len(frame.AckRanges)-1].Last -= protocol.PacketNumber(gap) } else { lastRangeComplete = false ackRange := AckRange{ - LastPacketNumber: frame.AckRanges[len(frame.AckRanges)-1].FirstPacketNumber - protocol.PacketNumber(gap) - 1, + Last: frame.AckRanges[len(frame.AckRanges)-1].First - protocol.PacketNumber(gap) - 1, } - ackRange.FirstPacketNumber = ackRange.LastPacketNumber - length + 1 + ackRange.First = ackRange.Last - length + 1 frame.AckRanges = append(frame.AckRanges, ackRange) } @@ -135,13 +116,13 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, inLongBlock = (ackBlockLength == 0) } - // if the last range was not complete, FirstPacketNumber and LastPacketNumber make no sense + // if the last range was not complete, First and Last make no sense // remove the range from frame.AckRanges if !lastRangeComplete { frame.AckRanges = frame.AckRanges[:len(frame.AckRanges)-1] } - frame.LowestAcked = frame.AckRanges[len(frame.AckRanges)-1].FirstPacketNumber + frame.LowestAcked = frame.AckRanges[len(frame.AckRanges)-1].First } else { if frame.LargestAcked == 0 { frame.LowestAcked = 0 @@ -151,7 +132,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, } if !frame.validateAckRanges() { - return nil, ErrInvalidAckRanges + return nil, errInvalidAckRanges } var numTimestamp byte @@ -167,7 +148,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, return nil, err } // First Timestamp - _, err = utils.ReadUint32(r) + _, err = utils.BigEndian.ReadUint32(r) if err != nil { return nil, err } @@ -180,18 +161,16 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, } // Time Since Previous Timestamp - _, err = utils.ReadUint16(r) + _, err = utils.BigEndian.ReadUint16(r) if err != nil { return nil, err } } } - return frame, nil } -// Write writes an ACK frame. -func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { +func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error { largestAckedLen := protocol.GetPacketNumberLength(f.LargestAcked) typeByte := uint8(0x40) @@ -215,15 +194,15 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error case protocol.PacketNumberLen1: b.WriteByte(uint8(f.LargestAcked)) case protocol.PacketNumberLen2: - utils.WriteUint16(b, uint16(f.LargestAcked)) + utils.BigEndian.WriteUint16(b, uint16(f.LargestAcked)) case protocol.PacketNumberLen4: - utils.WriteUint32(b, uint32(f.LargestAcked)) + utils.BigEndian.WriteUint32(b, uint32(f.LargestAcked)) case protocol.PacketNumberLen6: - utils.WriteUint48(b, uint64(f.LargestAcked)) + utils.BigEndian.WriteUint48(b, uint64(f.LargestAcked)&(1<<48-1)) } f.DelayTime = time.Since(f.PacketReceivedTime) - utils.WriteUfloat16(b, uint64(f.DelayTime/time.Microsecond)) + utils.BigEndian.WriteUfloat16(b, uint64(f.DelayTime/time.Microsecond)) var numRanges uint64 var numRangesWritten uint64 @@ -239,13 +218,13 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error if !f.HasMissingRanges() { firstAckBlockLength = f.LargestAcked - f.LowestAcked + 1 } else { - if f.LargestAcked != f.AckRanges[0].LastPacketNumber { + if f.LargestAcked != f.AckRanges[0].Last { return errInconsistentAckLargestAcked } - if f.LowestAcked != f.AckRanges[len(f.AckRanges)-1].FirstPacketNumber { + if f.LowestAcked != f.AckRanges[len(f.AckRanges)-1].First { return errInconsistentAckLowestAcked } - firstAckBlockLength = f.LargestAcked - f.AckRanges[0].FirstPacketNumber + 1 + firstAckBlockLength = f.LargestAcked - f.AckRanges[0].First + 1 numRangesWritten++ } @@ -253,11 +232,11 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error case protocol.PacketNumberLen1: b.WriteByte(uint8(firstAckBlockLength)) case protocol.PacketNumberLen2: - utils.WriteUint16(b, uint16(firstAckBlockLength)) + utils.BigEndian.WriteUint16(b, uint16(firstAckBlockLength)) case protocol.PacketNumberLen4: - utils.WriteUint32(b, uint32(firstAckBlockLength)) + utils.BigEndian.WriteUint32(b, uint32(firstAckBlockLength)) case protocol.PacketNumberLen6: - utils.WriteUint48(b, uint64(firstAckBlockLength)) + utils.BigEndian.WriteUint48(b, uint64(firstAckBlockLength)&(1<<48-1)) } for i, ackRange := range f.AckRanges { @@ -265,8 +244,8 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error continue } - length := ackRange.LastPacketNumber - ackRange.FirstPacketNumber + 1 - gap := f.AckRanges[i-1].FirstPacketNumber - ackRange.LastPacketNumber - 1 + length := ackRange.Last - ackRange.First + 1 + gap := f.AckRanges[i-1].First - ackRange.Last - 1 num := gap/0xFF + 1 if gap%0xFF == 0 { @@ -279,11 +258,11 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error case protocol.PacketNumberLen1: b.WriteByte(uint8(length)) case protocol.PacketNumberLen2: - utils.WriteUint16(b, uint16(length)) + utils.BigEndian.WriteUint16(b, uint16(length)) case protocol.PacketNumberLen4: - utils.WriteUint32(b, uint32(length)) + utils.BigEndian.WriteUint32(b, uint32(length)) case protocol.PacketNumberLen6: - utils.WriteUint48(b, uint64(length)) + utils.BigEndian.WriteUint48(b, uint64(length)&(1<<48-1)) } numRangesWritten++ } else { @@ -304,11 +283,11 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error case protocol.PacketNumberLen1: b.WriteByte(uint8(lengthWritten)) case protocol.PacketNumberLen2: - utils.WriteUint16(b, uint16(lengthWritten)) + utils.BigEndian.WriteUint16(b, uint16(lengthWritten)) case protocol.PacketNumberLen4: - utils.WriteUint32(b, uint32(lengthWritten)) + utils.BigEndian.WriteUint32(b, uint32(lengthWritten)) case protocol.PacketNumberLen6: - utils.WriteUint48(b, lengthWritten) + utils.BigEndian.WriteUint48(b, lengthWritten&(1<<48-1)) } numRangesWritten++ @@ -326,12 +305,10 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error } b.WriteByte(0) // no timestamps - return nil } -// MinLength of a written frame -func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { +func (f *AckFrame) lengthLegacy(_ protocol.VersionNumber) protocol.ByteCount { length := protocol.ByteCount(1 + 2 + 1) // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked)) @@ -342,53 +319,8 @@ func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount } else { length += missingSequenceNumberDeltaLen } - - length += (1 + 2) * 0 /* TODO: num_timestamps */ - - return length, nil -} - -// HasMissingRanges returns if this frame reports any missing packets -func (f *AckFrame) HasMissingRanges() bool { - return len(f.AckRanges) > 0 -} - -func (f *AckFrame) validateAckRanges() bool { - if len(f.AckRanges) == 0 { - return true - } - - // if there are missing packets, there will always be at least 2 ACK ranges - if len(f.AckRanges) == 1 { - return false - } - - if f.AckRanges[0].LastPacketNumber != f.LargestAcked { - return false - } - - // check the validity of every single ACK range - for _, ackRange := range f.AckRanges { - if ackRange.FirstPacketNumber > ackRange.LastPacketNumber { - return false - } - } - - // check the consistency for ACK with multiple NACK ranges - for i, ackRange := range f.AckRanges { - if i == 0 { - continue - } - lastAckRange := f.AckRanges[i-1] - if lastAckRange.FirstPacketNumber <= ackRange.FirstPacketNumber { - return false - } - if lastAckRange.FirstPacketNumber <= ackRange.LastPacketNumber+1 { - return false - } - } - - return true + // we don't write + return length } // numWritableNackRanges calculates the number of ACK blocks that are about to be written @@ -405,7 +337,7 @@ func (f *AckFrame) numWritableNackRanges() uint64 { } lastAckRange := f.AckRanges[i-1] - gap := lastAckRange.FirstPacketNumber - ackRange.LastPacketNumber - 1 + gap := lastAckRange.First - ackRange.Last - 1 rangeLength := 1 + uint64(gap)/0xFF if uint64(gap)%0xFF == 0 { rangeLength-- @@ -426,7 +358,7 @@ func (f *AckFrame) getMissingSequenceNumberDeltaLen() protocol.PacketNumberLen { if f.HasMissingRanges() { for _, ackRange := range f.AckRanges { - rangeLength := ackRange.LastPacketNumber - ackRange.FirstPacketNumber + 1 + rangeLength := ackRange.Last - ackRange.First + 1 if rangeLength > maxRangeLength { maxRangeLength = rangeLength } @@ -447,22 +379,3 @@ func (f *AckFrame) getMissingSequenceNumberDeltaLen() protocol.PacketNumberLen { return protocol.PacketNumberLen6 } - -// AcksPacket determines if this ACK frame acks a certain packet number -func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { - if p < f.LowestAcked || p > f.LargestAcked { // this is just a performance optimization - return false - } - - if f.HasMissingRanges() { - // TODO: this could be implemented as a binary search - for _, ackRange := range f.AckRanges { - if p >= ackRange.FirstPacketNumber && p <= ackRange.LastPacketNumber { - return true - } - } - return false - } - // if packet doesn't have missing ranges - return (p >= f.LowestAcked && p <= f.LargestAcked) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go new file mode 100644 index 000000000..c561762d3 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go @@ -0,0 +1,9 @@ +package wire + +import "github.com/lucas-clemente/quic-go/internal/protocol" + +// AckRange is an ACK range +type AckRange struct { + First protocol.PacketNumber + Last protocol.PacketNumber +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go new file mode 100644 index 000000000..e4cad2d6d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go @@ -0,0 +1,45 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A BlockedFrame is a BLOCKED frame +type BlockedFrame struct { + Offset protocol.ByteCount +} + +// ParseBlockedFrame parses a BLOCKED frame +func ParseBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*BlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + offset, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &BlockedFrame{ + Offset: protocol.ByteCount(offset), + }, nil +} + +func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if !version.UsesIETFFrameFormat() { + return (&blockedFrameLegacy{}).Write(b, version) + } + typeByte := uint8(0x08) + b.WriteByte(typeByte) + utils.WriteVarInt(b, uint64(f.Offset)) + return nil +} + +// Length of a written frame +func (f *BlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if !version.UsesIETFFrameFormat() { + return 1 + 4 + } + return 1 + utils.VarIntLen(uint64(f.Offset)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go new file mode 100644 index 000000000..41cf0ee7e --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go @@ -0,0 +1,37 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type blockedFrameLegacy struct { + StreamID protocol.StreamID +} + +// ParseBlockedFrameLegacy parses a BLOCKED frame (in gQUIC format) +// The frame returned is +// * a STREAM_BLOCKED frame, if the BLOCKED applies to a stream +// * a BLOCKED frame, if the BLOCKED applies to the connection +func ParseBlockedFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (Frame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + streamID, err := utils.BigEndian.ReadUint32(r) + if err != nil { + return nil, err + } + if streamID == 0 { + return &BlockedFrame{}, nil + } + return &StreamBlockedFrame{StreamID: protocol.StreamID(streamID)}, nil +} + +//Write writes a BLOCKED frame +func (f *blockedFrameLegacy) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x05) + utils.BigEndian.WriteUint32(b, uint32(f.StreamID)) + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go new file mode 100644 index 000000000..a2a7e9667 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go @@ -0,0 +1,96 @@ +package wire + +import ( + "bytes" + "errors" + "io" + "math" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +// A ConnectionCloseFrame in QUIC +type ConnectionCloseFrame struct { + ErrorCode qerr.ErrorCode + ReasonPhrase string +} + +// ParseConnectionCloseFrame reads a CONNECTION_CLOSE frame +func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) (*ConnectionCloseFrame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + + var errorCode qerr.ErrorCode + var reasonPhraseLen uint64 + if version.UsesIETFFrameFormat() { + ec, err := utils.BigEndian.ReadUint16(r) + if err != nil { + return nil, err + } + errorCode = qerr.ErrorCode(ec) + reasonPhraseLen, err = utils.ReadVarInt(r) + if err != nil { + return nil, err + } + } else { + ec, err := utils.BigEndian.ReadUint32(r) + if err != nil { + return nil, err + } + errorCode = qerr.ErrorCode(ec) + length, err := utils.BigEndian.ReadUint16(r) + if err != nil { + return nil, err + } + reasonPhraseLen = uint64(length) + } + + // shortcut to prevent the unneccessary allocation of dataLen bytes + // if the dataLen is larger than the remaining length of the packet + // reading the whole reason phrase would result in EOF when attempting to READ + if int(reasonPhraseLen) > r.Len() { + return nil, io.EOF + } + + reasonPhrase := make([]byte, reasonPhraseLen) + if _, err := io.ReadFull(r, reasonPhrase); err != nil { + // this should never happen, since we already checked the reasonPhraseLen earlier + return nil, err + } + + return &ConnectionCloseFrame{ + ErrorCode: qerr.ErrorCode(errorCode), + ReasonPhrase: string(reasonPhrase), + }, nil +} + +// Length of a written frame +func (f *ConnectionCloseFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if version.UsesIETFFrameFormat() { + return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)) + } + return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)) +} + +// Write writes an CONNECTION_CLOSE frame. +func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + b.WriteByte(0x02) + + if len(f.ReasonPhrase) > math.MaxUint16 { + return errors.New("ConnectionFrame: ReasonPhrase too long") + } + + if version.UsesIETFFrameFormat() { + utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode)) + utils.WriteVarInt(b, uint64(len(f.ReasonPhrase))) + } else { + utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode)) + utils.BigEndian.WriteUint16(b, uint16(len(f.ReasonPhrase))) + } + b.WriteString(f.ReasonPhrase) + + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go new file mode 100644 index 000000000..835905a41 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go @@ -0,0 +1,13 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A Frame in QUIC +type Frame interface { + Write(b *bytes.Buffer, version protocol.VersionNumber) error + Length(version protocol.VersionNumber) protocol.ByteCount +} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/goaway_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go similarity index 54% rename from vendor/github.com/lucas-clemente/quic-go/frames/goaway_frame.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go index e00a6cf5a..fd5aca921 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frames/goaway_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go @@ -1,11 +1,11 @@ -package frames +package wire import ( "bytes" "io" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) @@ -17,27 +17,26 @@ type GoawayFrame struct { } // ParseGoawayFrame parses a GOAWAY frame -func ParseGoawayFrame(r *bytes.Reader) (*GoawayFrame, error) { +func ParseGoawayFrame(r *bytes.Reader, _ protocol.VersionNumber) (*GoawayFrame, error) { frame := &GoawayFrame{} - _, err := r.ReadByte() - if err != nil { + if _, err := r.ReadByte(); err != nil { return nil, err } - errorCode, err := utils.ReadUint32(r) + errorCode, err := utils.BigEndian.ReadUint32(r) if err != nil { return nil, err } frame.ErrorCode = qerr.ErrorCode(errorCode) - lastGoodStream, err := utils.ReadUint32(r) + lastGoodStream, err := utils.BigEndian.ReadUint32(r) if err != nil { return nil, err } frame.LastGoodStream = protocol.StreamID(lastGoodStream) - reasonPhraseLen, err := utils.ReadUint16(r) + reasonPhraseLen, err := utils.BigEndian.ReadUint16(r) if err != nil { return nil, err } @@ -51,23 +50,19 @@ func ParseGoawayFrame(r *bytes.Reader) (*GoawayFrame, error) { return nil, err } frame.ReasonPhrase = string(reasonPhrase) - return frame, nil } -func (f *GoawayFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - typeByte := uint8(0x03) - b.WriteByte(typeByte) - - utils.WriteUint32(b, uint32(f.ErrorCode)) - utils.WriteUint32(b, uint32(f.LastGoodStream)) - utils.WriteUint16(b, uint16(len(f.ReasonPhrase))) +func (f *GoawayFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x03) + utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode)) + utils.BigEndian.WriteUint32(b, uint32(f.LastGoodStream)) + utils.BigEndian.WriteUint16(b, uint16(len(f.ReasonPhrase))) b.WriteString(f.ReasonPhrase) - return nil } -// MinLength of a written frame -func (f *GoawayFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase)), nil +// Length of a written frame +func (f *GoawayFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go new file mode 100644 index 000000000..19c45c3ee --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go @@ -0,0 +1,110 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// Header is the header of a QUIC packet. +// It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header. +type Header struct { + Raw []byte + ConnectionID protocol.ConnectionID + OmitConnectionID bool + PacketNumberLen protocol.PacketNumberLen + PacketNumber protocol.PacketNumber + Version protocol.VersionNumber // VersionNumber sent by the client + + IsVersionNegotiation bool + SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server + + // only needed for the gQUIC Public Header + VersionFlag bool + ResetFlag bool + DiversificationNonce []byte + + // only needed for the IETF Header + Type protocol.PacketType + IsLongHeader bool + KeyPhase int + + // only needed for logging + isPublicHeader bool +} + +// ParseHeaderSentByServer parses the header for a packet that was sent by the server. +func ParseHeaderSentByServer(b *bytes.Reader, version protocol.VersionNumber) (*Header, error) { + typeByte, err := b.ReadByte() + if err != nil { + return nil, err + } + _ = b.UnreadByte() // unread the type byte + + var isPublicHeader bool + if typeByte&0x80 > 0 { // gQUIC always has 0x80 unset. IETF Long Header or Version Negotiation + isPublicHeader = false + } else if typeByte&0xcf == 0x9 { // gQUIC Version Negotiation Packet + isPublicHeader = true + } else { + // the client knows the version that this packet was sent with + isPublicHeader = !version.UsesTLS() + } + return parsePacketHeader(b, protocol.PerspectiveServer, isPublicHeader) +} + +// ParseHeaderSentByClient parses the header for a packet that was sent by the client. +func ParseHeaderSentByClient(b *bytes.Reader) (*Header, error) { + typeByte, err := b.ReadByte() + if err != nil { + return nil, err + } + _ = b.UnreadByte() // unread the type byte + + // If this is a gQUIC header 0x80 and 0x40 will be set to 0. + // If this is an IETF QUIC header there are two options: + // * either 0x80 will be 1 (for the Long Header) + // * or 0x40 (the Connection ID Flag) will be 0 (for the Short Header), since we don't the client to omit it + isPublicHeader := typeByte&0xc0 == 0 + + return parsePacketHeader(b, protocol.PerspectiveClient, isPublicHeader) +} + +func parsePacketHeader(b *bytes.Reader, sentBy protocol.Perspective, isPublicHeader bool) (*Header, error) { + // This is a gQUIC Public Header. + if isPublicHeader { + hdr, err := parsePublicHeader(b, sentBy) + if err != nil { + return nil, err + } + hdr.isPublicHeader = true // save that this is a Public Header, so we can log it correctly later + return hdr, nil + } + return parseHeader(b, sentBy) +} + +// Write writes the Header. +func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, version protocol.VersionNumber) error { + if !version.UsesTLS() { + h.isPublicHeader = true // save that this is a Public Header, so we can log it correctly later + return h.writePublicHeader(b, pers, version) + } + return h.writeHeader(b) +} + +// GetLength determines the length of the Header. +func (h *Header) GetLength(pers protocol.Perspective, version protocol.VersionNumber) (protocol.ByteCount, error) { + if !version.UsesTLS() { + return h.getPublicHeaderLength(pers) + } + return h.getHeaderLength() +} + +// Log logs the Header +func (h *Header) Log() { + if h.isPublicHeader { + h.logPublicHeader() + } else { + h.logHeader() + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go new file mode 100644 index 000000000..88bd139fb --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go @@ -0,0 +1,172 @@ +package wire + +import ( + "bytes" + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +// parseHeader parses the header. +func parseHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, error) { + typeByte, err := b.ReadByte() + if err != nil { + return nil, err + } + if typeByte&0x80 > 0 { + return parseLongHeader(b, packetSentBy, typeByte) + } + return parseShortHeader(b, typeByte) +} + +// parse long header and version negotiation packets +func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte) (*Header, error) { + connID, err := utils.BigEndian.ReadUint64(b) + if err != nil { + return nil, err + } + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, err + } + pn, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, err + } + h := &Header{ + ConnectionID: protocol.ConnectionID(connID), + PacketNumber: protocol.PacketNumber(pn), + PacketNumberLen: protocol.PacketNumberLen4, + Version: protocol.VersionNumber(v), + } + if v == 0 { // version negotiation packet + if sentBy == protocol.PerspectiveClient { + return nil, qerr.InvalidVersion + } + if b.Len() == 0 { + return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list") + } + h.IsVersionNegotiation = true + h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4) + for i := 0; b.Len() > 0; i++ { + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, qerr.InvalidVersionNegotiationPacket + } + h.SupportedVersions[i] = protocol.VersionNumber(v) + } + return h, nil + } + h.IsLongHeader = true + h.Type = protocol.PacketType(typeByte & 0x7f) + if sentBy == protocol.PerspectiveClient && (h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeHandshake && h.Type != protocol.PacketType0RTT) { + return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type)) + } + if sentBy == protocol.PerspectiveServer && (h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketTypeHandshake) { + return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type)) + } + return h, nil +} + +func parseShortHeader(b *bytes.Reader, typeByte byte) (*Header, error) { + hasConnID := typeByte&0x40 > 0 + var connID uint64 + if hasConnID { + var err error + connID, err = utils.BigEndian.ReadUint64(b) + if err != nil { + return nil, err + } + } + pnLen := 1 << ((typeByte & 0x3) - 1) + pn, err := utils.BigEndian.ReadUintN(b, uint8(pnLen)) + if err != nil { + return nil, err + } + return &Header{ + KeyPhase: int(typeByte&0x20) >> 5, + OmitConnectionID: !hasConnID, + ConnectionID: protocol.ConnectionID(connID), + PacketNumber: protocol.PacketNumber(pn), + PacketNumberLen: protocol.PacketNumberLen(pnLen), + }, nil +} + +// writeHeader writes the Header. +func (h *Header) writeHeader(b *bytes.Buffer) error { + if h.IsLongHeader { + return h.writeLongHeader(b) + } + return h.writeShortHeader(b) +} + +// TODO: add support for the key phase +func (h *Header) writeLongHeader(b *bytes.Buffer) error { + b.WriteByte(byte(0x80 | h.Type)) + utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) + utils.BigEndian.WriteUint32(b, uint32(h.Version)) + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) + return nil +} + +func (h *Header) writeShortHeader(b *bytes.Buffer) error { + typeByte := byte(h.KeyPhase << 5) + if !h.OmitConnectionID { + typeByte ^= 0x40 + } + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + typeByte ^= 0x1 + case protocol.PacketNumberLen2: + typeByte ^= 0x2 + case protocol.PacketNumberLen4: + typeByte ^= 0x3 + default: + return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) + } + b.WriteByte(typeByte) + + if !h.OmitConnectionID { + utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) + } + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(h.PacketNumber)) + case protocol.PacketNumberLen2: + utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) + case protocol.PacketNumberLen4: + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) + } + return nil +} + +// getHeaderLength gets the length of the Header in bytes. +func (h *Header) getHeaderLength() (protocol.ByteCount, error) { + if h.IsLongHeader { + return 1 + 8 + 4 + 4, nil + } + + length := protocol.ByteCount(1) // type byte + if !h.OmitConnectionID { + length += 8 + } + if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 { + return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) + } + length += protocol.ByteCount(h.PacketNumberLen) + return length, nil +} + +func (h *Header) logHeader() { + if h.IsLongHeader { + utils.Debugf(" Long Header{Type: %s, ConnectionID: %#x, PacketNumber: %#x, Version: %s}", h.Type, h.ConnectionID, h.PacketNumber, h.Version) + } else { + connID := "(omitted)" + if !h.OmitConnectionID { + connID = fmt.Sprintf("%#x", h.ConnectionID) + } + utils.Debugf(" Short Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", connID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go new file mode 100644 index 000000000..0e72ea98a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go @@ -0,0 +1,28 @@ +package wire + +import "github.com/lucas-clemente/quic-go/internal/utils" + +// LogFrame logs a frame, either sent or received +func LogFrame(frame Frame, sent bool) { + if !utils.Debug() { + return + } + dir := "<-" + if sent { + dir = "->" + } + switch f := frame.(type) { + case *StreamFrame: + utils.Debugf("\t%s &wire.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen()) + case *StopWaitingFrame: + if sent { + utils.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen) + } else { + utils.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked) + } + case *AckFrame: + utils.Debugf("\t%s &wire.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String()) + default: + utils.Debugf("\t%s %#v", dir, frame) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go new file mode 100644 index 000000000..8ba4fc09a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go @@ -0,0 +1,51 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A MaxDataFrame carries flow control information for the connection +type MaxDataFrame struct { + ByteOffset protocol.ByteCount +} + +// ParseMaxDataFrame parses a MAX_DATA frame +func ParseMaxDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxDataFrame, error) { + // read the TypeByte + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + frame := &MaxDataFrame{} + byteOffset, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + frame.ByteOffset = protocol.ByteCount(byteOffset) + return frame, nil +} + +//Write writes a MAX_STREAM_DATA frame +func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if !version.UsesIETFFrameFormat() { + // write a gQUIC WINDOW_UPDATE frame (with stream ID 0, which means connection-level there) + return (&windowUpdateFrame{ + StreamID: 0, + ByteOffset: f.ByteOffset, + }).Write(b, version) + } + b.WriteByte(0x4) + utils.WriteVarInt(b, uint64(f.ByteOffset)) + return nil +} + +// Length of a written frame +func (f *MaxDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if !version.UsesIETFFrameFormat() { // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which is longer + return 1 + 4 + 8 + } + return 1 + utils.VarIntLen(uint64(f.ByteOffset)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go new file mode 100644 index 000000000..e88f245b5 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go @@ -0,0 +1,60 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A MaxStreamDataFrame carries flow control information for a stream +type MaxStreamDataFrame struct { + StreamID protocol.StreamID + ByteOffset protocol.ByteCount +} + +// ParseMaxStreamDataFrame parses a MAX_STREAM_DATA frame +func ParseMaxStreamDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxStreamDataFrame, error) { + frame := &MaxStreamDataFrame{} + + // read the TypeByte + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + sid, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + frame.StreamID = protocol.StreamID(sid) + + byteOffset, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + frame.ByteOffset = protocol.ByteCount(byteOffset) + return frame, nil +} + +// Write writes a MAX_STREAM_DATA frame +func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if !version.UsesIETFFrameFormat() { + return (&windowUpdateFrame{ + StreamID: f.StreamID, + ByteOffset: f.ByteOffset, + }).Write(b, version) + } + b.WriteByte(0x5) + utils.WriteVarInt(b, uint64(f.StreamID)) + utils.WriteVarInt(b, uint64(f.ByteOffset)) + return nil +} + +// Length of a written frame +func (f *MaxStreamDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which has a different length + if !version.UsesIETFFrameFormat() { + return 1 + 4 + 8 + } + return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.ByteOffset)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame.go new file mode 100644 index 000000000..31e51ae3d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame.go @@ -0,0 +1,37 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A MaxStreamIDFrame is a MAX_STREAM_ID frame +type MaxStreamIDFrame struct { + StreamID protocol.StreamID +} + +// ParseMaxStreamIDFrame parses a MAX_STREAM_ID frame +func ParseMaxStreamIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamIDFrame, error) { + // read the Type byte + if _, err := r.ReadByte(); err != nil { + return nil, err + } + streamID, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &MaxStreamIDFrame{StreamID: protocol.StreamID(streamID)}, nil +} + +func (f *MaxStreamIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x6) + utils.WriteVarInt(b, uint64(f.StreamID)) + return nil +} + +// Length of a written frame +func (f *MaxStreamIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { + return 1 + utils.VarIntLen(uint64(f.StreamID)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/ping_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go similarity index 57% rename from vendor/github.com/lucas-clemente/quic-go/frames/ping_frame.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go index 8486af57b..ac4fd7d25 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frames/ping_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go @@ -1,16 +1,16 @@ -package frames +package wire import ( "bytes" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // A PingFrame is a ping frame type PingFrame struct{} // ParsePingFrame parses a Ping frame -func ParsePingFrame(r *bytes.Reader) (*PingFrame, error) { +func ParsePingFrame(r *bytes.Reader, version protocol.VersionNumber) (*PingFrame, error) { frame := &PingFrame{} _, err := r.ReadByte() @@ -27,7 +27,7 @@ func (f *PingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error return nil } -// MinLength of a written frame -func (f *PingFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1, nil +// Length of a written frame +func (f *PingFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 } diff --git a/vendor/github.com/lucas-clemente/quic-go/public_header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go similarity index 59% rename from vendor/github.com/lucas-clemente/quic-go/public_header.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go index 4af66ab16..e4c997557 100644 --- a/vendor/github.com/lucas-clemente/quic-go/public_header.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go @@ -1,62 +1,45 @@ -package quic +package wire import ( "bytes" "errors" + "fmt" "io" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) var ( - errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set") errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time") - errReceivedTruncatedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with truncated ConnectionID is not supported") + errReceivedOmittedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with omitted ConnectionID is not supported") errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0") errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets") ) -// The PublicHeader of a QUIC packet. Warning: This struct should not be considered stable and will change soon. -type PublicHeader struct { - Raw []byte - ConnectionID protocol.ConnectionID - VersionFlag bool - ResetFlag bool - TruncateConnectionID bool - PacketNumberLen protocol.PacketNumberLen - PacketNumber protocol.PacketNumber - VersionNumber protocol.VersionNumber // VersionNumber sent by the client - SupportedVersions []protocol.VersionNumber // VersionNumbers sent by the server - DiversificationNonce []byte -} - -// Write writes a public header. Warning: This API should not be considered stable and will change soon. -func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pers protocol.Perspective) error { - publicFlagByte := uint8(0x00) - +// writePublicHeader writes a Public Header. +func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error { if h.VersionFlag && h.ResetFlag { return errResetAndVersionFlagSet } + publicFlagByte := uint8(0x00) if h.VersionFlag { publicFlagByte |= 0x01 } if h.ResetFlag { publicFlagByte |= 0x02 } - if !h.TruncateConnectionID { + if !h.OmitConnectionID { publicFlagByte |= 0x08 } - if len(h.DiversificationNonce) > 0 { if len(h.DiversificationNonce) != 32 { return errors.New("invalid diversification nonce length") } publicFlagByte |= 0x04 } - // only set PacketNumberLen bits if a packet number will be written if h.hasPacketNumber(pers) { switch h.PacketNumberLen { @@ -70,59 +53,50 @@ func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pe publicFlagByte |= 0x30 } } - b.WriteByte(publicFlagByte) - if !h.TruncateConnectionID { - utils.WriteUint64(b, uint64(h.ConnectionID)) + if !h.OmitConnectionID { + utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) } - if h.VersionFlag && pers == protocol.PerspectiveClient { - utils.WriteUint32(b, protocol.VersionNumberToTag(h.VersionNumber)) + utils.BigEndian.WriteUint32(b, uint32(h.Version)) } - if len(h.DiversificationNonce) > 0 { b.Write(h.DiversificationNonce) } - // if we're a server, and the VersionFlag is set, we must not include anything else in the packet if !h.hasPacketNumber(pers) { return nil } - if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { - return errPacketNumberLenNotSet - } - switch h.PacketNumberLen { case protocol.PacketNumberLen1: b.WriteByte(uint8(h.PacketNumber)) case protocol.PacketNumberLen2: - utils.WriteUint16(b, uint16(h.PacketNumber)) + utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) case protocol.PacketNumberLen4: - utils.WriteUint32(b, uint32(h.PacketNumber)) + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) case protocol.PacketNumberLen6: - utils.WriteUint48(b, uint64(h.PacketNumber)) + utils.BigEndian.WriteUint48(b, uint64(h.PacketNumber)&(1<<48-1)) default: - return errPacketNumberLenNotSet + return errors.New("PublicHeader: PacketNumberLen not set") } return nil } -// ParsePublicHeader parses a QUIC packet's public header. +// parsePublicHeader parses a QUIC packet's Public Header. // The packetSentBy is the perspective of the peer that sent this PublicHeader, i.e. if we're the server, packetSentBy should be PerspectiveClient. -// Warning: This API should not be considered stable and will change soon. -func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*PublicHeader, error) { - header := &PublicHeader{} +func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, error) { + header := &Header{} // First byte publicFlagByte, err := b.ReadByte() if err != nil { return nil, err } - header.VersionFlag = publicFlagByte&0x01 > 0 header.ResetFlag = publicFlagByte&0x02 > 0 + header.VersionFlag = publicFlagByte&0x01 > 0 // TODO: activate this check once Chrome sends the correct value // see https://github.com/lucas-clemente/quic-go/issues/232 @@ -130,11 +104,10 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub // return nil, errors.New("diversification nonces should only be sent by servers") // } - header.TruncateConnectionID = publicFlagByte&0x08 == 0 - if header.TruncateConnectionID && packetSentBy == protocol.PerspectiveClient { - return nil, errReceivedTruncatedConnectionID + header.OmitConnectionID = publicFlagByte&0x08 == 0 + if header.OmitConnectionID && packetSentBy == protocol.PerspectiveClient { + return nil, errReceivedOmittedConnectionID } - if header.hasPacketNumber(packetSentBy) { switch publicFlagByte & 0x30 { case 0x30: @@ -149,9 +122,9 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub } // Connection ID - if !header.TruncateConnectionID { + if !header.OmitConnectionID { var connID uint64 - connID, err = utils.ReadUint64(b) + connID, err = utils.BigEndian.ReadUint64(b) if err != nil { return nil, err } @@ -174,82 +147,79 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub } // Version (optional) - if !header.ResetFlag { - if header.VersionFlag { - if packetSentBy == protocol.PerspectiveClient { - var versionTag uint32 - versionTag, err = utils.ReadUint32(b) - if err != nil { - return nil, err - } - header.VersionNumber = protocol.VersionTagToNumber(versionTag) - } else { // parse the version negotiaton packet - if b.Len()%4 != 0 { - return nil, qerr.InvalidVersionNegotiationPacket - } - header.SupportedVersions = make([]protocol.VersionNumber, 0) - for { - var versionTag uint32 - versionTag, err = utils.ReadUint32(b) - if err != nil { - break - } - v := protocol.VersionTagToNumber(versionTag) - header.SupportedVersions = append(header.SupportedVersions, v) - } + if !header.ResetFlag && header.VersionFlag { + if packetSentBy == protocol.PerspectiveServer { // parse the version negotiaton packet + if b.Len() == 0 { + return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list") } + if b.Len()%4 != 0 { + return nil, qerr.InvalidVersionNegotiationPacket + } + header.IsVersionNegotiation = true + header.SupportedVersions = make([]protocol.VersionNumber, 0) + for { + var versionTag uint32 + versionTag, err = utils.BigEndian.ReadUint32(b) + if err != nil { + break + } + v := protocol.VersionNumber(versionTag) + header.SupportedVersions = append(header.SupportedVersions, v) + } + // a version negotiation packet doesn't have a packet number + return header, nil } + // packet was sent by the client. Read the version number + var versionTag uint32 + versionTag, err = utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, err + } + header.Version = protocol.VersionNumber(versionTag) } // Packet number if header.hasPacketNumber(packetSentBy) { - packetNumber, err := utils.ReadUintN(b, uint8(header.PacketNumberLen)) + packetNumber, err := utils.BigEndian.ReadUintN(b, uint8(header.PacketNumberLen)) if err != nil { return nil, err } header.PacketNumber = protocol.PacketNumber(packetNumber) } - return header, nil } -// GetLength gets the length of the publicHeader in bytes. +// getPublicHeaderLength gets the length of the publicHeader in bytes. // It can only be called for regular packets. -func (h *PublicHeader) GetLength(pers protocol.Perspective) (protocol.ByteCount, error) { +func (h *Header) getPublicHeaderLength(pers protocol.Perspective) (protocol.ByteCount, error) { if h.VersionFlag && h.ResetFlag { return 0, errResetAndVersionFlagSet } - if h.VersionFlag && pers == protocol.PerspectiveServer { return 0, errGetLengthNotForVersionNegotiation } length := protocol.ByteCount(1) // 1 byte for public flags - if h.hasPacketNumber(pers) { if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { return 0, errPacketNumberLenNotSet } length += protocol.ByteCount(h.PacketNumberLen) } - - if !h.TruncateConnectionID { + if !h.OmitConnectionID { length += 8 // 8 bytes for the connection ID } - // Version Number in packets sent by the client if h.VersionFlag { length += 4 } - length += protocol.ByteCount(len(h.DiversificationNonce)) - return length, nil } -// hasPacketNumber determines if this PublicHeader will contain a packet number +// hasPacketNumber determines if this Public Header will contain a packet number // this depends on the ResetFlag, the VersionFlag and who sent the packet -func (h *PublicHeader) hasPacketNumber(packetSentBy protocol.Perspective) bool { +func (h *Header) hasPacketNumber(packetSentBy protocol.Perspective) bool { if h.ResetFlag { return false } @@ -258,3 +228,15 @@ func (h *PublicHeader) hasPacketNumber(packetSentBy protocol.Perspective) bool { } return true } + +func (h *Header) logPublicHeader() { + connID := "(omitted)" + if !h.OmitConnectionID { + connID = fmt.Sprintf("%#x", h.ConnectionID) + } + ver := "(unset)" + if h.Version != 0 { + ver = fmt.Sprintf("%s", h.Version) + } + utils.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", connID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go new file mode 100644 index 000000000..6adc9f690 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go @@ -0,0 +1,65 @@ +package wire + +import ( + "bytes" + "encoding/binary" + "errors" + + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A PublicReset is a PUBLIC_RESET +type PublicReset struct { + RejectedPacketNumber protocol.PacketNumber + Nonce uint64 +} + +// WritePublicReset writes a Public Reset +func WritePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte { + b := &bytes.Buffer{} + b.WriteByte(0x0a) + utils.BigEndian.WriteUint64(b, uint64(connectionID)) + utils.LittleEndian.WriteUint32(b, uint32(handshake.TagPRST)) + utils.LittleEndian.WriteUint32(b, 2) + utils.LittleEndian.WriteUint32(b, uint32(handshake.TagRNON)) + utils.LittleEndian.WriteUint32(b, 8) + utils.LittleEndian.WriteUint32(b, uint32(handshake.TagRSEQ)) + utils.LittleEndian.WriteUint32(b, 16) + utils.LittleEndian.WriteUint64(b, nonceProof) + utils.LittleEndian.WriteUint64(b, uint64(rejectedPacketNumber)) + return b.Bytes() +} + +// ParsePublicReset parses a Public Reset +func ParsePublicReset(r *bytes.Reader) (*PublicReset, error) { + pr := PublicReset{} + msg, err := handshake.ParseHandshakeMessage(r) + if err != nil { + return nil, err + } + if msg.Tag != handshake.TagPRST { + return nil, errors.New("wrong public reset tag") + } + + // The RSEQ tag is mandatory according to the gQUIC wire spec. + // However, Google doesn't send RSEQ in their Public Resets. + // Therefore, we'll treat RSEQ as an optional field. + if rseq, ok := msg.Data[handshake.TagRSEQ]; ok { + if len(rseq) != 8 { + return nil, errors.New("invalid RSEQ tag") + } + pr.RejectedPacketNumber = protocol.PacketNumber(binary.LittleEndian.Uint64(rseq)) + } + + rnon, ok := msg.Data[handshake.TagRNON] + if !ok { + return nil, errors.New("RNON missing") + } + if len(rnon) != 8 { + return nil, errors.New("invalid RNON tag") + } + pr.Nonce = binary.LittleEndian.Uint64(rnon) + return &pr, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go new file mode 100644 index 000000000..ea25f381d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go @@ -0,0 +1,89 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A RstStreamFrame is a RST_STREAM frame in QUIC +type RstStreamFrame struct { + StreamID protocol.StreamID + // The error code is a uint32 in gQUIC, but a uint16 in IETF QUIC. + // protocol.ApplicaitonErrorCode is a uint16, so larger values in gQUIC frames will be truncated. + ErrorCode protocol.ApplicationErrorCode + ByteOffset protocol.ByteCount +} + +// ParseRstStreamFrame parses a RST_STREAM frame +func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstStreamFrame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + + var streamID protocol.StreamID + var errorCode uint16 + var byteOffset protocol.ByteCount + if version.UsesIETFFrameFormat() { + sid, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + streamID = protocol.StreamID(sid) + errorCode, err = utils.BigEndian.ReadUint16(r) + if err != nil { + return nil, err + } + bo, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + byteOffset = protocol.ByteCount(bo) + } else { + sid, err := utils.BigEndian.ReadUint32(r) + if err != nil { + return nil, err + } + streamID = protocol.StreamID(sid) + bo, err := utils.BigEndian.ReadUint64(r) + if err != nil { + return nil, err + } + byteOffset = protocol.ByteCount(bo) + ec, err := utils.BigEndian.ReadUint32(r) + if err != nil { + return nil, err + } + errorCode = uint16(ec) + } + + return &RstStreamFrame{ + StreamID: streamID, + ErrorCode: protocol.ApplicationErrorCode(errorCode), + ByteOffset: byteOffset, + }, nil +} + +//Write writes a RST_STREAM frame +func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + b.WriteByte(0x01) + if version.UsesIETFFrameFormat() { + utils.WriteVarInt(b, uint64(f.StreamID)) + utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode)) + utils.WriteVarInt(b, uint64(f.ByteOffset)) + } else { + utils.BigEndian.WriteUint32(b, uint32(f.StreamID)) + utils.BigEndian.WriteUint64(b, uint64(f.ByteOffset)) + utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode)) + } + return nil +} + +// Length of a written frame +func (f *RstStreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if version.UsesIETFFrameFormat() { + return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 + utils.VarIntLen(uint64(f.ByteOffset)) + } + return 1 + 4 + 8 + 4 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go new file mode 100644 index 000000000..2a33756bd --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go @@ -0,0 +1,47 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A StopSendingFrame is a STOP_SENDING frame +type StopSendingFrame struct { + StreamID protocol.StreamID + ErrorCode protocol.ApplicationErrorCode +} + +// ParseStopSendingFrame parses a STOP_SENDING frame +func ParseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + + streamID, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + errorCode, err := utils.BigEndian.ReadUint16(r) + if err != nil { + return nil, err + } + + return &StopSendingFrame{ + StreamID: protocol.StreamID(streamID), + ErrorCode: protocol.ApplicationErrorCode(errorCode), + }, nil +} + +// Length of a written frame +func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 +} + +func (f *StopSendingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x0c) + utils.WriteVarInt(b, uint64(f.StreamID)) + utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode)) + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/stop_waiting_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go similarity index 54% rename from vendor/github.com/lucas-clemente/quic-go/frames/stop_waiting_frame.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go index 5b54154ca..4ee9578e6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frames/stop_waiting_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go @@ -1,19 +1,19 @@ -package frames +package wire import ( "bytes" "errors" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/qerr" ) // A StopWaitingFrame in QUIC type StopWaitingFrame struct { LeastUnacked protocol.PacketNumber PacketNumberLen protocol.PacketNumberLen - PacketNumber protocol.PacketNumber + // PacketNumber is the packet number of the packet that this StopWaitingFrame will be sent with + PacketNumber protocol.PacketNumber } var ( @@ -22,70 +22,56 @@ var ( errPacketNumberLenNotSet = errors.New("StopWaitingFrame: PacketNumberLen not set") ) -func (f *StopWaitingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - // packetNumber is the packet number of the packet that this StopWaitingFrame will be sent with - typeByte := uint8(0x06) - b.WriteByte(typeByte) - +func (f *StopWaitingFrame) Write(b *bytes.Buffer, v protocol.VersionNumber) error { + if v.UsesIETFFrameFormat() { + return errors.New("STOP_WAITING not defined in IETF QUIC") + } // make sure the PacketNumber was set if f.PacketNumber == protocol.PacketNumber(0) { return errPacketNumberNotSet } - if f.LeastUnacked > f.PacketNumber { return errLeastUnackedHigherThanPacketNumber } + b.WriteByte(0x06) leastUnackedDelta := uint64(f.PacketNumber - f.LeastUnacked) - switch f.PacketNumberLen { case protocol.PacketNumberLen1: b.WriteByte(uint8(leastUnackedDelta)) case protocol.PacketNumberLen2: - utils.WriteUint16(b, uint16(leastUnackedDelta)) + utils.BigEndian.WriteUint16(b, uint16(leastUnackedDelta)) case protocol.PacketNumberLen4: - utils.WriteUint32(b, uint32(leastUnackedDelta)) + utils.BigEndian.WriteUint32(b, uint32(leastUnackedDelta)) case protocol.PacketNumberLen6: - utils.WriteUint48(b, leastUnackedDelta) + utils.BigEndian.WriteUint48(b, leastUnackedDelta&(1<<48-1)) default: return errPacketNumberLenNotSet } - return nil } -// MinLength of a written frame -func (f *StopWaitingFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - minLength := protocol.ByteCount(1) // typeByte - - if f.PacketNumberLen == protocol.PacketNumberLenInvalid { - return 0, errPacketNumberLenNotSet - } - minLength += protocol.ByteCount(f.PacketNumberLen) - - return minLength, nil +// Length of a written frame +func (f *StopWaitingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + protocol.ByteCount(f.PacketNumberLen) } // ParseStopWaitingFrame parses a StopWaiting frame -func ParseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen, version protocol.VersionNumber) (*StopWaitingFrame, error) { +func ParseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen, _ protocol.VersionNumber) (*StopWaitingFrame, error) { frame := &StopWaitingFrame{} // read the TypeByte - _, err := r.ReadByte() - if err != nil { + if _, err := r.ReadByte(); err != nil { return nil, err } - leastUnackedDelta, err := utils.ReadUintN(r, uint8(packetNumberLen)) + leastUnackedDelta, err := utils.BigEndian.ReadUintN(r, uint8(packetNumberLen)) if err != nil { return nil, err } - - if leastUnackedDelta >= uint64(packetNumber) { - return nil, qerr.Error(qerr.InvalidStopWaitingData, "invalid LeastUnackedDelta") + if leastUnackedDelta > uint64(packetNumber) { + return nil, errors.New("invalid LeastUnackedDelta") } - frame.LeastUnacked = protocol.PacketNumber(uint64(packetNumber) - leastUnackedDelta) - return frame, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go new file mode 100644 index 000000000..625698cd1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go @@ -0,0 +1,52 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A StreamBlockedFrame in QUIC +type StreamBlockedFrame struct { + StreamID protocol.StreamID + Offset protocol.ByteCount +} + +// ParseStreamBlockedFrame parses a STREAM_BLOCKED frame +func ParseStreamBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamBlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + sid, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + offset, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &StreamBlockedFrame{ + StreamID: protocol.StreamID(sid), + Offset: protocol.ByteCount(offset), + }, nil +} + +// Write writes a STREAM_BLOCKED frame +func (f *StreamBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if !version.UsesIETFFrameFormat() { + return (&blockedFrameLegacy{StreamID: f.StreamID}).Write(b, version) + } + b.WriteByte(0x09) + utils.WriteVarInt(b, uint64(f.StreamID)) + utils.WriteVarInt(b, uint64(f.Offset)) + return nil +} + +// Length of a written frame +func (f *StreamBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if !version.UsesIETFFrameFormat() { + return 1 + 4 + } + return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.Offset)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go new file mode 100644 index 000000000..5168e315c --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go @@ -0,0 +1,182 @@ +package wire + +import ( + "bytes" + "errors" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +// A StreamFrame of QUIC +type StreamFrame struct { + StreamID protocol.StreamID + FinBit bool + DataLenPresent bool + Offset protocol.ByteCount + Data []byte +} + +// ParseStreamFrame reads a STREAM frame +func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamFrame, error) { + if !version.UsesIETFFrameFormat() { + return parseLegacyStreamFrame(r, version) + } + + frame := &StreamFrame{} + + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + frame.FinBit = typeByte&0x1 > 0 + frame.DataLenPresent = typeByte&0x2 > 0 + hasOffset := typeByte&0x4 > 0 + + streamID, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + frame.StreamID = protocol.StreamID(streamID) + if hasOffset { + offset, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + frame.Offset = protocol.ByteCount(offset) + } + + var dataLen uint64 + if frame.DataLenPresent { + var err error + dataLen, err = utils.ReadVarInt(r) + if err != nil { + return nil, err + } + // shortcut to prevent the unneccessary allocation of dataLen bytes + // if the dataLen is larger than the remaining length of the packet + // reading the packet contents would result in EOF when attempting to READ + if dataLen > uint64(r.Len()) { + return nil, io.EOF + } + } else { + // The rest of the packet is data + dataLen = uint64(r.Len()) + } + if dataLen != 0 { + frame.Data = make([]byte, dataLen) + if _, err := io.ReadFull(r, frame.Data); err != nil { + // this should never happen, since we already checked the dataLen earlier + return nil, err + } + } + if frame.Offset+frame.DataLen() > protocol.MaxByteCount { + return nil, qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset") + } + if !frame.FinBit && frame.DataLen() == 0 { + return nil, qerr.EmptyStreamFrameNoFin + } + return frame, nil +} + +// Write writes a STREAM frame +func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if !version.UsesIETFFrameFormat() { + return f.writeLegacy(b, version) + } + + if len(f.Data) == 0 && !f.FinBit { + return errors.New("StreamFrame: attempting to write empty frame without FIN") + } + + typeByte := byte(0x10) + if f.FinBit { + typeByte ^= 0x1 + } + hasOffset := f.Offset != 0 + if f.DataLenPresent { + typeByte ^= 0x2 + } + if hasOffset { + typeByte ^= 0x4 + } + b.WriteByte(typeByte) + utils.WriteVarInt(b, uint64(f.StreamID)) + if hasOffset { + utils.WriteVarInt(b, uint64(f.Offset)) + } + if f.DataLenPresent { + utils.WriteVarInt(b, uint64(f.DataLen())) + } + b.Write(f.Data) + return nil +} + +// Length returns the total length of the STREAM frame +func (f *StreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if !version.UsesIETFFrameFormat() { + return f.lengthLegacy(version) + } + length := 1 + utils.VarIntLen(uint64(f.StreamID)) + if f.Offset != 0 { + length += utils.VarIntLen(uint64(f.Offset)) + } + if f.DataLenPresent { + length += utils.VarIntLen(uint64(f.DataLen())) + } + return length + f.DataLen() +} + +// MaxDataLen returns the maximum data length +// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data). +func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { + if !version.UsesIETFFrameFormat() { + return f.maxDataLenLegacy(maxSize, version) + } + + headerLen := 1 + utils.VarIntLen(uint64(f.StreamID)) + if f.Offset != 0 { + headerLen += utils.VarIntLen(uint64(f.Offset)) + } + if f.DataLenPresent { + // pretend that the data size will be 1 bytes + // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards + headerLen++ + } + if headerLen > maxSize { + return 0 + } + maxDataLen := maxSize - headerLen + if f.DataLenPresent && utils.VarIntLen(uint64(maxDataLen)) != 1 { + maxDataLen-- + } + return maxDataLen +} + +// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. +// If n >= len(frame), nil is returned and nothing is modified. +func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*StreamFrame, error) { + if maxSize >= f.Length(version) { + return nil, nil + } + + n := f.MaxDataLen(maxSize, version) + if n == 0 { + return nil, errors.New("too small") + } + newFrame := &StreamFrame{ + FinBit: false, + StreamID: f.StreamID, + Offset: f.Offset, + Data: f.Data[:n], + DataLenPresent: f.DataLenPresent, + } + + f.Data = f.Data[n:] + f.Offset += n + + return newFrame, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/stream_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy.go similarity index 55% rename from vendor/github.com/lucas-clemente/quic-go/frames/stream_frame.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy.go index f4a7bf739..a01618e1a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frames/stream_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy.go @@ -1,31 +1,22 @@ -package frames +package wire import ( "bytes" "errors" "io" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) -// A StreamFrame of QUIC -type StreamFrame struct { - StreamID protocol.StreamID - FinBit bool - DataLenPresent bool - Offset protocol.ByteCount - Data []byte -} - var ( errInvalidStreamIDLen = errors.New("StreamFrame: Invalid StreamID length") errInvalidOffsetLen = errors.New("StreamFrame: Invalid offset length") ) -// ParseStreamFrame reads a stream frame. The type byte must not have been read yet. -func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { +// parseLegacyStreamFrame reads a stream frame. The type byte must not have been read yet. +func parseLegacyStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamFrame, error) { frame := &StreamFrame{} typeByte, err := r.ReadByte() @@ -35,19 +26,19 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { frame.FinBit = typeByte&0x40 > 0 frame.DataLenPresent = typeByte&0x20 > 0 - offsetLen := typeByte & 0x1C >> 2 + offsetLen := typeByte & 0x1c >> 2 if offsetLen != 0 { offsetLen++ } - streamIDLen := typeByte&0x03 + 1 + streamIDLen := typeByte&0x3 + 1 - sid, err := utils.ReadUintN(r, streamIDLen) + sid, err := utils.BigEndian.ReadUintN(r, streamIDLen) if err != nil { return nil, err } frame.StreamID = protocol.StreamID(sid) - offset, err := utils.ReadUintN(r, offsetLen) + offset, err := utils.BigEndian.ReadUintN(r, offsetLen) if err != nil { return nil, err } @@ -55,14 +46,17 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { var dataLen uint16 if frame.DataLenPresent { - dataLen, err = utils.ReadUint16(r) + dataLen, err = utils.BigEndian.ReadUint16(r) if err != nil { return nil, err } } - if dataLen > uint16(protocol.MaxPacketSize) { - return nil, qerr.Error(qerr.InvalidStreamData, "data len too large") + // shortcut to prevent the unneccessary allocation of dataLen bytes + // if the dataLen is larger than the remaining length of the packet + // reading the packet contents would result in EOF when attempting to READ + if int(dataLen) > r.Len() { + return nil, io.EOF } if !frame.DataLenPresent { @@ -72,39 +66,37 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { if dataLen != 0 { frame.Data = make([]byte, dataLen) if _, err := io.ReadFull(r, frame.Data); err != nil { + // this should never happen, since we already checked the dataLen earlier return nil, err } } - if frame.Offset+frame.DataLen() < frame.Offset { + // MaxByteCount is the highest value that can be encoded with the IETF QUIC variable integer encoding (2^62-1). + // Note that this value is smaller than the maximum value that could be encoded in the gQUIC STREAM frame (2^64-1). + if frame.Offset+frame.DataLen() > protocol.MaxByteCount { return nil, qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset") } - if !frame.FinBit && frame.DataLen() == 0 { return nil, qerr.EmptyStreamFrameNoFin } - return frame, nil } -// WriteStreamFrame writes a stream frame. -func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { +// writeLegacy writes a stream frame. +func (f *StreamFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error { if len(f.Data) == 0 && !f.FinBit { return errors.New("StreamFrame: attempting to write empty frame without FIN") } typeByte := uint8(0x80) // sets the leftmost bit to 1 - if f.FinBit { typeByte ^= 0x40 } - if f.DataLenPresent { typeByte ^= 0x20 } offsetLength := f.getOffsetLength() - if offsetLength > 0 { typeByte ^= (uint8(offsetLength) - 1) << 2 } @@ -118,11 +110,11 @@ func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) err case 1: b.WriteByte(uint8(f.StreamID)) case 2: - utils.WriteUint16(b, uint16(f.StreamID)) + utils.BigEndian.WriteUint16(b, uint16(f.StreamID)) case 3: - utils.WriteUint24(b, uint32(f.StreamID)) + utils.BigEndian.WriteUint24(b, uint32(f.StreamID)) case 4: - utils.WriteUint32(b, uint32(f.StreamID)) + utils.BigEndian.WriteUint32(b, uint32(f.StreamID)) default: return errInvalidStreamIDLen } @@ -130,29 +122,28 @@ func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) err switch offsetLength { case 0: case 2: - utils.WriteUint16(b, uint16(f.Offset)) + utils.BigEndian.WriteUint16(b, uint16(f.Offset)) case 3: - utils.WriteUint24(b, uint32(f.Offset)) + utils.BigEndian.WriteUint24(b, uint32(f.Offset)) case 4: - utils.WriteUint32(b, uint32(f.Offset)) + utils.BigEndian.WriteUint32(b, uint32(f.Offset)) case 5: - utils.WriteUint40(b, uint64(f.Offset)) + utils.BigEndian.WriteUint40(b, uint64(f.Offset)) case 6: - utils.WriteUint48(b, uint64(f.Offset)) + utils.BigEndian.WriteUint48(b, uint64(f.Offset)) case 7: - utils.WriteUint56(b, uint64(f.Offset)) + utils.BigEndian.WriteUint56(b, uint64(f.Offset)) case 8: - utils.WriteUint64(b, uint64(f.Offset)) + utils.BigEndian.WriteUint64(b, uint64(f.Offset)) default: return errInvalidOffsetLen } if f.DataLenPresent { - utils.WriteUint16(b, uint16(len(f.Data))) + utils.BigEndian.WriteUint16(b, uint16(len(f.Data))) } b.Write(f.Data) - return nil } @@ -192,15 +183,24 @@ func (f *StreamFrame) getOffsetLength() protocol.ByteCount { return 8 } -// MinLength returns the length of the header of a StreamFrame -// the total length of the StreamFrame is frame.MinLength() + frame.DataLen() -func (f *StreamFrame) MinLength(protocol.VersionNumber) (protocol.ByteCount, error) { +func (f *StreamFrame) headerLengthLegacy(_ protocol.VersionNumber) protocol.ByteCount { length := protocol.ByteCount(1) + protocol.ByteCount(f.calculateStreamIDLength()) + f.getOffsetLength() if f.DataLenPresent { length += 2 } + return length +} - return length, nil +func (f *StreamFrame) lengthLegacy(version protocol.VersionNumber) protocol.ByteCount { + return f.headerLengthLegacy(version) + f.DataLen() +} + +func (f *StreamFrame) maxDataLenLegacy(maxFrameSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { + headerLen := f.headerLengthLegacy(version) + if headerLen > maxFrameSize { + return 0 + } + return maxFrameSize - headerLen } // DataLen gives the length of data in bytes diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame.go new file mode 100644 index 000000000..7b390a4d9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame.go @@ -0,0 +1,37 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A StreamIDBlockedFrame is a STREAM_ID_BLOCKED frame +type StreamIDBlockedFrame struct { + StreamID protocol.StreamID +} + +// ParseStreamIDBlockedFrame parses a STREAM_ID_BLOCKED frame +func ParseStreamIDBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamIDBlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + streamID, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &StreamIDBlockedFrame{StreamID: protocol.StreamID(streamID)}, nil +} + +func (f *StreamIDBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + typeByte := uint8(0x0a) + b.WriteByte(typeByte) + utils.WriteVarInt(b, uint64(f.StreamID)) + return nil +} + +// Length of a written frame +func (f *StreamIDBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + utils.VarIntLen(uint64(f.StreamID)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go new file mode 100644 index 000000000..b20c43c2d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go @@ -0,0 +1,59 @@ +package wire + +import ( + "bytes" + "crypto/rand" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// ComposeGQUICVersionNegotiation composes a Version Negotiation Packet for gQUIC +func ComposeGQUICVersionNegotiation(connID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { + fullReply := &bytes.Buffer{} + ph := Header{ + ConnectionID: connID, + PacketNumber: 1, + VersionFlag: true, + IsVersionNegotiation: true, + } + if err := ph.writePublicHeader(fullReply, protocol.PerspectiveServer, protocol.VersionWhatever); err != nil { + utils.Errorf("error composing version negotiation packet: %s", err.Error()) + return nil + } + writeVersions(fullReply, versions) + return fullReply.Bytes() +} + +// ComposeVersionNegotiation composes a Version Negotiation according to the IETF draft +func ComposeVersionNegotiation( + connID protocol.ConnectionID, + pn protocol.PacketNumber, + versions []protocol.VersionNumber, +) []byte { + fullReply := &bytes.Buffer{} + r := make([]byte, 1) + _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. + h := Header{ + IsLongHeader: true, + Type: protocol.PacketType(r[0] | 0x80), + ConnectionID: connID, + PacketNumber: pn, + Version: 0, + IsVersionNegotiation: true, + } + if err := h.writeHeader(fullReply); err != nil { + utils.Errorf("error composing version negotiation packet: %s", err.Error()) + return nil + } + writeVersions(fullReply, versions) + return fullReply.Bytes() +} + +// writeVersions writes the versions for a Version Negotiation Packet. +// It inserts one reserved version number at a random position. +func writeVersions(buf *bytes.Buffer, supported []protocol.VersionNumber) { + for _, v := range protocol.GetGreasedVersions(supported) { + utils.BigEndian.WriteUint32(buf, uint32(v)) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/window_update_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/window_update_frame.go new file mode 100644 index 000000000..8f7556e75 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/window_update_frame.go @@ -0,0 +1,45 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type windowUpdateFrame struct { + StreamID protocol.StreamID + ByteOffset protocol.ByteCount +} + +// ParseWindowUpdateFrame parses a WINDOW_UPDATE frame +// The frame returned is +// * a MAX_STREAM_DATA frame, if the WINDOW_UPDATE applies to a stream +// * a MAX_DATA frame, if the WINDOW_UPDATE applies to the connection +func ParseWindowUpdateFrame(r *bytes.Reader, _ protocol.VersionNumber) (Frame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + streamID, err := utils.BigEndian.ReadUint32(r) + if err != nil { + return nil, err + } + offset, err := utils.BigEndian.ReadUint64(r) + if err != nil { + return nil, err + } + if streamID == 0 { + return &MaxDataFrame{ByteOffset: protocol.ByteCount(offset)}, nil + } + return &MaxStreamDataFrame{ + StreamID: protocol.StreamID(streamID), + ByteOffset: protocol.ByteCount(offset), + }, nil +} + +func (f *windowUpdateFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x4) + utils.BigEndian.WriteUint32(b, uint32(f.StreamID)) + utils.BigEndian.WriteUint64(b, uint64(f.ByteOffset)) + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/mint_utils.go b/vendor/github.com/lucas-clemente/quic-go/mint_utils.go new file mode 100644 index 000000000..578aecca9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mint_utils.go @@ -0,0 +1,160 @@ +package quic + +import ( + "bytes" + gocrypto "crypto" + "crypto/tls" + "crypto/x509" + "errors" + "io" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type mintController struct { + csc *handshake.CryptoStreamConn + conn *mint.Conn +} + +var _ handshake.MintTLS = &mintController{} + +func newMintController( + csc *handshake.CryptoStreamConn, + mconf *mint.Config, + pers protocol.Perspective, +) handshake.MintTLS { + var conn *mint.Conn + if pers == protocol.PerspectiveClient { + conn = mint.Client(csc, mconf) + } else { + conn = mint.Server(csc, mconf) + } + return &mintController{ + csc: csc, + conn: conn, + } +} + +func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams { + return mc.conn.ConnectionState().CipherSuite +} + +func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { + return mc.conn.ComputeExporter(label, context, keyLength) +} + +func (mc *mintController) Handshake() mint.Alert { + return mc.conn.Handshake() +} + +func (mc *mintController) State() mint.State { + return mc.conn.ConnectionState().HandshakeState +} + +func (mc *mintController) ConnectionState() mint.ConnectionState { + return mc.conn.ConnectionState() +} + +func (mc *mintController) SetCryptoStream(stream io.ReadWriter) { + mc.csc.SetStream(stream) +} + +func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) { + mconf := &mint.Config{ + NonBlocking: true, + CipherSuites: []mint.CipherSuite{ + mint.TLS_AES_128_GCM_SHA256, + mint.TLS_AES_256_GCM_SHA384, + }, + } + if tlsConf != nil { + mconf.ServerName = tlsConf.ServerName + mconf.InsecureSkipVerify = tlsConf.InsecureSkipVerify + mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates)) + mconf.VerifyPeerCertificate = tlsConf.VerifyPeerCertificate + for i, certChain := range tlsConf.Certificates { + mconf.Certificates[i] = &mint.Certificate{ + Chain: make([]*x509.Certificate, len(certChain.Certificate)), + PrivateKey: certChain.PrivateKey.(gocrypto.Signer), + } + for j, cert := range certChain.Certificate { + c, err := x509.ParseCertificate(cert) + if err != nil { + return nil, err + } + mconf.Certificates[i].Chain[j] = c + } + } + switch tlsConf.ClientAuth { + case tls.NoClientCert: + case tls.RequireAnyClientCert: + mconf.RequireClientAuth = true + default: + return nil, errors.New("mint currently only support ClientAuthType RequireAnyClientCert") + } + } + if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil { + return nil, err + } + return mconf, nil +} + +// unpackInitialOrRetryPacket unpacks packets Initial and Retry packets +// These packets must contain a STREAM_FRAME for the crypto stream, starting at offset 0. +func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.StreamFrame, error) { + unpacker := &packetUnpacker{aead: &nullAEAD{aead}, version: version} + packet, err := unpacker.Unpack(hdr.Raw, hdr, data) + if err != nil { + return nil, err + } + var frame *wire.StreamFrame + for _, f := range packet.frames { + var ok bool + frame, ok = f.(*wire.StreamFrame) + if ok { + break + } + } + if frame == nil { + return nil, errors.New("Packet doesn't contain a STREAM_FRAME") + } + // We don't need a check for the stream ID here. + // The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream. + if frame.Offset != 0 { + return nil, errors.New("received stream data with non-zero offset") + } + if utils.Debug() { + utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) + hdr.Log() + wire.LogFrame(frame, false) + } + return frame, nil +} + +// packUnencryptedPacket provides a low-overhead way to pack a packet. +// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available. +func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective) ([]byte, error) { + raw := getPacketBuffer() + buffer := bytes.NewBuffer(raw) + if err := hdr.Write(buffer, pers, hdr.Version); err != nil { + return nil, err + } + payloadStartIndex := buffer.Len() + if err := f.Write(buffer, hdr.Version); err != nil { + return nil, err + } + raw = raw[0:buffer.Len()] + _ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex]) + raw = raw[0 : buffer.Len()+aead.Overhead()] + if utils.Debug() { + utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted) + hdr.Log() + wire.LogFrame(f, true) + } + return raw, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/mockgen.go new file mode 100644 index 000000000..3802a8633 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mockgen.go @@ -0,0 +1,12 @@ +package quic + +//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI ReceiveStreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI SendStreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter StreamGetter" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource" +//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager" +//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go mock_stream_manager_test.go" +//go:generate sh -c "goimports -w mock*_test.go" diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_number_generator.go b/vendor/github.com/lucas-clemente/quic-go/packet_number_generator.go index 71ca9a3c4..ac6357765 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_number_generator.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_number_generator.go @@ -4,7 +4,7 @@ import ( "crypto/rand" "math" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // The packetNumberGenerator generates the packet number for the next packet @@ -17,9 +17,9 @@ type packetNumberGenerator struct { nextToSkip protocol.PacketNumber } -func newPacketNumberGenerator(averagePeriod protocol.PacketNumber) *packetNumberGenerator { +func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator { return &packetNumberGenerator{ - next: 1, + next: initial, averagePeriod: averagePeriod, } } diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go index 28c29ace0..1a8e688c8 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go @@ -4,20 +4,27 @@ import ( "bytes" "errors" "fmt" + "sync" - "github.com/lucas-clemente/quic-go/ackhandler" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/handshake" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/ackhandler" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) type packedPacket struct { - number protocol.PacketNumber + header *wire.Header raw []byte - frames []frames.Frame + frames []wire.Frame encryptionLevel protocol.EncryptionLevel } +type streamFrameSource interface { + HasCryptoStreamData() bool + PopCryptoStreamFrame(protocol.ByteCount) *wire.StreamFrame + PopStreamFrames(protocol.ByteCount) []*wire.StreamFrame +} + type packetPacker struct { connectionID protocol.ConnectionID perspective protocol.Perspective @@ -25,41 +32,44 @@ type packetPacker struct { cryptoSetup handshake.CryptoSetup packetNumberGenerator *packetNumberGenerator - connectionParameters handshake.ConnectionParametersManager - streamFramer *streamFramer + streams streamFrameSource - controlFrames []frames.Frame - stopWaiting *frames.StopWaitingFrame - ackFrame *frames.AckFrame - leastUnacked protocol.PacketNumber + controlFrameMutex sync.Mutex + controlFrames []wire.Frame + + stopWaiting *wire.StopWaitingFrame + ackFrame *wire.AckFrame + leastUnacked protocol.PacketNumber + omitConnectionID bool + hasSentPacket bool // has the packetPacker already sent a packet + numNonRetransmittableAcks int } func newPacketPacker(connectionID protocol.ConnectionID, + initialPacketNumber protocol.PacketNumber, cryptoSetup handshake.CryptoSetup, - connectionParameters handshake.ConnectionParametersManager, - streamFramer *streamFramer, + streamFramer streamFrameSource, perspective protocol.Perspective, version protocol.VersionNumber, ) *packetPacker { return &packetPacker{ cryptoSetup: cryptoSetup, connectionID: connectionID, - connectionParameters: connectionParameters, perspective: perspective, version: version, - streamFramer: streamFramer, - packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength), + streams: streamFramer, + packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), } } // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame -func (p *packetPacker) PackConnectionClose(ccf *frames.ConnectionCloseFrame) (*packedPacket, error) { - frames := []frames.Frame{ccf} +func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) { + frames := []wire.Frame{ccf} encLevel, sealer := p.cryptoSetup.GetSealer() - ph := p.getPublicHeader(encLevel) - raw, err := p.writeAndSealPacket(ph, frames, sealer) + header := p.getHeader(encLevel) + raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ - number: ph.PacketNumber, + header: header, raw: raw, frames: frames, encryptionLevel: encLevel, @@ -71,18 +81,18 @@ func (p *packetPacker) PackAckPacket() (*packedPacket, error) { return nil, errors.New("packet packer BUG: no ack frame queued") } encLevel, sealer := p.cryptoSetup.GetSealer() - ph := p.getPublicHeader(encLevel) - frames := []frames.Frame{p.ackFrame} - if p.stopWaiting != nil { - p.stopWaiting.PacketNumber = ph.PacketNumber - p.stopWaiting.PacketNumberLen = ph.PacketNumberLen + header := p.getHeader(encLevel) + frames := []wire.Frame{p.ackFrame} + if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC + p.stopWaiting.PacketNumber = header.PacketNumber + p.stopWaiting.PacketNumberLen = header.PacketNumberLen frames = append(frames, p.stopWaiting) p.stopWaiting = nil } p.ackFrame = nil - raw, err := p.writeAndSealPacket(ph, frames, sealer) + raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ - number: ph.PacketNumber, + header: header, raw: raw, frames: frames, encryptionLevel: encLevel, @@ -98,17 +108,23 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (* if err != nil { return nil, err } - if p.stopWaiting == nil { - return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame") + header := p.getHeader(packet.EncryptionLevel) + var frames []wire.Frame + if !p.version.UsesIETFFrameFormat() { // for gQUIC: pack a STOP_WAITING first + if p.stopWaiting == nil { + return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame") + } + swf := p.stopWaiting + swf.PacketNumber = header.PacketNumber + swf.PacketNumberLen = header.PacketNumberLen + p.stopWaiting = nil + frames = append([]wire.Frame{swf}, packet.Frames...) + } else { + frames = packet.Frames } - ph := p.getPublicHeader(packet.EncryptionLevel) - p.stopWaiting.PacketNumber = ph.PacketNumber - p.stopWaiting.PacketNumberLen = ph.PacketNumberLen - frames := append([]frames.Frame{p.stopWaiting}, packet.Frames...) - p.stopWaiting = nil - raw, err := p.writeAndSealPacket(ph, frames, sealer) + raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ - number: ph.PacketNumber, + header: header, raw: raw, frames: frames, encryptionLevel: packet.EncryptionLevel, @@ -118,23 +134,28 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (* // PackPacket packs a new packet // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise func (p *packetPacker) PackPacket() (*packedPacket, error) { - if p.streamFramer.HasCryptoStreamFrame() { + hasCryptoStreamFrame := p.streams.HasCryptoStreamData() + // if this is the first packet to be send, make sure it contains stream data + if !p.hasSentPacket && !hasCryptoStreamFrame { + return nil, nil + } + if hasCryptoStreamFrame { return p.packCryptoPacket() } encLevel, sealer := p.cryptoSetup.GetSealer() - publicHeader := p.getPublicHeader(encLevel) - publicHeaderLength, err := publicHeader.GetLength(p.perspective) + header := p.getHeader(encLevel) + headerLength, err := header.GetLength(p.perspective, p.version) if err != nil { return nil, err } if p.stopWaiting != nil { - p.stopWaiting.PacketNumber = publicHeader.PacketNumber - p.stopWaiting.PacketNumberLen = publicHeader.PacketNumberLen + p.stopWaiting.PacketNumber = header.PacketNumber + p.stopWaiting.PacketNumberLen = header.PacketNumberLen } - maxSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength + maxSize := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength payloadFrames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel)) if err != nil { return nil, err @@ -148,15 +169,28 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { if len(payloadFrames) == 1 && p.stopWaiting != nil { return nil, nil } + if p.ackFrame != nil { + // check if this packet only contains an ACK (and maybe a STOP_WAITING) + if len(payloadFrames) == 1 || (p.stopWaiting != nil && len(payloadFrames) == 2) { + if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks { + payloadFrames = append(payloadFrames, &wire.PingFrame{}) + p.numNonRetransmittableAcks = 0 + } else { + p.numNonRetransmittableAcks++ + } + } else { + p.numNonRetransmittableAcks = 0 + } + } p.stopWaiting = nil p.ackFrame = nil - raw, err := p.writeAndSealPacket(publicHeader, payloadFrames, sealer) + raw, err := p.writeAndSealPacket(header, payloadFrames, sealer) if err != nil { return nil, err } return &packedPacket{ - number: publicHeader.PacketNumber, + header: header, raw: raw, frames: payloadFrames, encryptionLevel: encLevel, @@ -165,19 +199,21 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { encLevel, sealer := p.cryptoSetup.GetSealerForCryptoStream() - publicHeader := p.getPublicHeader(encLevel) - publicHeaderLength, err := publicHeader.GetLength(p.perspective) + header := p.getHeader(encLevel) + headerLength, err := header.GetLength(p.perspective, p.version) if err != nil { return nil, err } - maxLen := protocol.MaxFrameAndPublicHeaderSize - protocol.NonForwardSecurePacketSizeReduction - publicHeaderLength - frames := []frames.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)} - raw, err := p.writeAndSealPacket(publicHeader, frames, sealer) + maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength + sf := p.streams.PopCryptoStreamFrame(maxLen) + sf.DataLenPresent = false + frames := []wire.Frame{sf} + raw, err := p.writeAndSealPacket(header, frames, sealer) if err != nil { return nil, err } return &packedPacket{ - number: publicHeader.PacketNumber, + header: header, raw: raw, frames: frames, encryptionLevel: encLevel, @@ -187,41 +223,33 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { func (p *packetPacker) composeNextPacket( maxFrameSize protocol.ByteCount, canSendStreamFrames bool, -) ([]frames.Frame, error) { +) ([]wire.Frame, error) { var payloadLength protocol.ByteCount - var payloadFrames []frames.Frame + var payloadFrames []wire.Frame // STOP_WAITING and ACK will always fit - if p.stopWaiting != nil { - payloadFrames = append(payloadFrames, p.stopWaiting) - l, err := p.stopWaiting.MinLength(p.version) - if err != nil { - return nil, err - } + if p.ackFrame != nil { // ACKs need to go first, so that the sentPacketHandler will recognize them + payloadFrames = append(payloadFrames, p.ackFrame) + l := p.ackFrame.Length(p.version) payloadLength += l } - if p.ackFrame != nil { - payloadFrames = append(payloadFrames, p.ackFrame) - l, err := p.ackFrame.MinLength(p.version) - if err != nil { - return nil, err - } - payloadLength += l + if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC + payloadFrames = append(payloadFrames, p.stopWaiting) + payloadLength += p.stopWaiting.Length(p.version) } + p.controlFrameMutex.Lock() for len(p.controlFrames) > 0 { frame := p.controlFrames[len(p.controlFrames)-1] - minLength, err := frame.MinLength(p.version) - if err != nil { - return nil, err - } - if payloadLength+minLength > maxFrameSize { + length := frame.Length(p.version) + if payloadLength+length > maxFrameSize { break } payloadFrames = append(payloadFrames, frame) - payloadLength += minLength + payloadLength += length p.controlFrames = p.controlFrames[:len(p.controlFrames)-1] } + p.controlFrameMutex.Unlock() if payloadLength > maxFrameSize { return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize) @@ -231,91 +259,127 @@ func (p *packetPacker) composeNextPacket( return payloadFrames, nil } - // temporarily increase the maxFrameSize by 2 bytes + // temporarily increase the maxFrameSize by the (minimum) length of the DataLen field // this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set - // however, for the last StreamFrame in the packet, we can omit the DataLen, thus saving 2 bytes and yielding a packet of exactly the correct size - maxFrameSize += 2 + // however, for the last StreamFrame in the packet, we can omit the DataLen, thus yielding a packet of exactly the correct size + // for gQUIC STREAM frames, DataLen is always 2 bytes + // for IETF draft style STREAM frames, the length is encoded to either 1 or 2 bytes + if p.version.UsesIETFFrameFormat() { + maxFrameSize++ + } else { + maxFrameSize += 2 + } - fs := p.streamFramer.PopStreamFrames(maxFrameSize - payloadLength) + fs := p.streams.PopStreamFrames(maxFrameSize - payloadLength) if len(fs) != 0 { fs[len(fs)-1].DataLenPresent = false } - // TODO: Simplify for _, f := range fs { payloadFrames = append(payloadFrames, f) } - - for b := p.streamFramer.PopBlockedFrame(); b != nil; b = p.streamFramer.PopBlockedFrame() { - p.controlFrames = append(p.controlFrames, b) - } - return payloadFrames, nil } -func (p *packetPacker) QueueControlFrame(frame frames.Frame) { +func (p *packetPacker) QueueControlFrame(frame wire.Frame) { switch f := frame.(type) { - case *frames.StopWaitingFrame: + case *wire.StopWaitingFrame: p.stopWaiting = f - case *frames.AckFrame: + case *wire.AckFrame: p.ackFrame = f default: + p.controlFrameMutex.Lock() p.controlFrames = append(p.controlFrames, f) + p.controlFrameMutex.Unlock() } } -func (p *packetPacker) getPublicHeader(encLevel protocol.EncryptionLevel) *PublicHeader { +func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header { pnum := p.packetNumberGenerator.Peek() - packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(pnum, p.leastUnacked) - publicHeader := &PublicHeader{ - ConnectionID: p.connectionID, - PacketNumber: pnum, - PacketNumberLen: packetNumberLen, - TruncateConnectionID: p.connectionParameters.TruncateConnectionID(), + packetNumberLen := protocol.GetPacketNumberLengthForHeader(pnum, p.leastUnacked) + + header := &wire.Header{ + ConnectionID: p.connectionID, + PacketNumber: pnum, + PacketNumberLen: packetNumberLen, } - if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure { - publicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce() - } - if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure { - publicHeader.VersionFlag = true - publicHeader.VersionNumber = p.version + if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure { + header.PacketNumberLen = protocol.PacketNumberLen4 + header.IsLongHeader = true + if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient { + header.Type = protocol.PacketTypeInitial + } else { + header.Type = protocol.PacketTypeHandshake + } } - return publicHeader + if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure { + header.OmitConnectionID = true + } + if !p.version.UsesTLS() { + if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure { + header.DiversificationNonce = p.cryptoSetup.DiversificationNonce() + } + if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure { + header.VersionFlag = true + header.Version = p.version + } + } else { + if encLevel != protocol.EncryptionForwardSecure { + header.Version = p.version + } + } + return header } func (p *packetPacker) writeAndSealPacket( - publicHeader *PublicHeader, - payloadFrames []frames.Frame, + header *wire.Header, + payloadFrames []wire.Frame, sealer handshake.Sealer, ) ([]byte, error) { raw := getPacketBuffer() buffer := bytes.NewBuffer(raw) - if err := publicHeader.Write(buffer, p.version, p.perspective); err != nil { + if err := header.Write(buffer, p.perspective, p.version); err != nil { return nil, err } payloadStartIndex := buffer.Len() + + // the Initial packet needs to be padded, so the last STREAM frame must have the data length present + if header.Type == protocol.PacketTypeInitial { + lastFrame := payloadFrames[len(payloadFrames)-1] + if sf, ok := lastFrame.(*wire.StreamFrame); ok { + sf.DataLenPresent = true + } + } for _, frame := range payloadFrames { - err := frame.Write(buffer, p.version) - if err != nil { + if err := frame.Write(buffer, p.version); err != nil { return nil, err } } - if protocol.ByteCount(buffer.Len()+12) > protocol.MaxPacketSize { - return nil, errors.New("PacketPacker BUG: packet too large") + // if this is an IETF QUIC Initial packet, we need to pad it to fulfill the minimum size requirement + // in gQUIC, padding is handled in the CHLO + if header.Type == protocol.PacketTypeInitial { + paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len() + if paddingLen > 0 { + buffer.Write(bytes.Repeat([]byte{0}, paddingLen)) + } + } + + if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > protocol.MaxPacketSize { + return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, protocol.MaxPacketSize) } raw = raw[0:buffer.Len()] - _ = sealer(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], publicHeader.PacketNumber, raw[:payloadStartIndex]) - raw = raw[0 : buffer.Len()+12] + _ = sealer.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], header.PacketNumber, raw[:payloadStartIndex]) + raw = raw[0 : buffer.Len()+sealer.Overhead()] num := p.packetNumberGenerator.Pop() - if num != publicHeader.PacketNumber { + if num != header.PacketNumber { return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } - + p.hasSentPacket = true return raw, nil } @@ -329,3 +393,7 @@ func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool { func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) { p.leastUnacked = leastUnacked } + +func (p *packetPacker) SetOmitConnectionID() { + p.omitConnectionID = true +} diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go index c92e6a53f..45bdc0fa8 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go @@ -2,17 +2,16 @@ package quic import ( "bytes" - "errors" "fmt" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) type unpackedPacket struct { encryptionLevel protocol.EncryptionLevel - frames []frames.Frame + frames []wire.Frame } type quicAEAD interface { @@ -24,10 +23,10 @@ type packetUnpacker struct { aead quicAEAD } -func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *PublicHeader, data []byte) (*unpackedPacket, error) { +func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) { buf := getPacketBuffer() defer putPacketBuffer(buf) - decrypted, encryptionLevel, err := u.aead.Open(buf, data, hdr.PacketNumber, publicHeaderBinary) + decrypted, encryptionLevel, err := u.aead.Open(buf, data, hdr.PacketNumber, headerBinary) if err != nil { // Wrap err in quicError so that public reset is sent by session return nil, qerr.Error(qerr.DecryptionFailure, err.Error()) @@ -38,7 +37,7 @@ func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *PublicHeader, da return nil, qerr.MissingPayload } - fs := make([]frames.Frame, 0, 2) + fs := make([]wire.Frame, 0, 2) // Read all frames in the packet for r.Len() > 0 { @@ -48,65 +47,15 @@ func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *PublicHeader, da } r.UnreadByte() - var frame frames.Frame - if typeByte&0x80 == 0x80 { - frame, err = frames.ParseStreamFrame(r) - if err != nil { - err = qerr.Error(qerr.InvalidStreamData, err.Error()) - } else { - streamID := frame.(*frames.StreamFrame).StreamID - if streamID != 1 && encryptionLevel <= protocol.EncryptionUnencrypted { - err = qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", streamID)) - } - } - } else if typeByte&0xc0 == 0x40 { - frame, err = frames.ParseAckFrame(r, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidAckData, err.Error()) - } - } else if typeByte&0xe0 == 0x20 { - err = errors.New("unimplemented: CONGESTION_FEEDBACK") - } else { - switch typeByte { - case 0x01: - frame, err = frames.ParseRstStreamFrame(r) - if err != nil { - err = qerr.Error(qerr.InvalidRstStreamData, err.Error()) - } - case 0x02: - frame, err = frames.ParseConnectionCloseFrame(r) - if err != nil { - err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error()) - } - case 0x03: - frame, err = frames.ParseGoawayFrame(r) - if err != nil { - err = qerr.Error(qerr.InvalidGoawayData, err.Error()) - } - case 0x04: - frame, err = frames.ParseWindowUpdateFrame(r) - if err != nil { - err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) - } - case 0x05: - frame, err = frames.ParseBlockedFrame(r) - if err != nil { - err = qerr.Error(qerr.InvalidBlockedData, err.Error()) - } - case 0x06: - frame, err = frames.ParseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidStopWaitingData, err.Error()) - } - case 0x07: - frame, err = frames.ParsePingFrame(r) - default: - err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte)) - } - } + frame, err := u.parseFrame(r, typeByte, hdr) if err != nil { return nil, err } + if sf, ok := frame.(*wire.StreamFrame); ok { + if sf.StreamID != u.version.CryptoStreamID() && encryptionLevel <= protocol.EncryptionUnencrypted { + return nil, qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", sf.StreamID)) + } + } if frame != nil { fs = append(fs, frame) } @@ -117,3 +66,135 @@ func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *PublicHeader, da frames: fs, }, nil } + +func (u *packetUnpacker) parseFrame(r *bytes.Reader, typeByte byte, hdr *wire.Header) (wire.Frame, error) { + if u.version.UsesIETFFrameFormat() { + return u.parseIETFFrame(r, typeByte, hdr) + } + return u.parseGQUICFrame(r, typeByte, hdr) +} + +func (u *packetUnpacker) parseIETFFrame(r *bytes.Reader, typeByte byte, hdr *wire.Header) (wire.Frame, error) { + var frame wire.Frame + var err error + if typeByte&0xf8 == 0x10 { + frame, err = wire.ParseStreamFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidStreamData, err.Error()) + } + return frame, err + } + // TODO: implement all IETF QUIC frame types + switch typeByte { + case 0x1: + frame, err = wire.ParseRstStreamFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidRstStreamData, err.Error()) + } + case 0x2: + frame, err = wire.ParseConnectionCloseFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error()) + } + case 0x4: + frame, err = wire.ParseMaxDataFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) + } + case 0x5: + frame, err = wire.ParseMaxStreamDataFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) + } + case 0x6: + frame, err = wire.ParseMaxStreamIDFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidFrameData, err.Error()) + } + case 0x7: + frame, err = wire.ParsePingFrame(r, u.version) + case 0x8: + frame, err = wire.ParseBlockedFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidBlockedData, err.Error()) + } + case 0x9: + frame, err = wire.ParseStreamBlockedFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidBlockedData, err.Error()) + } + case 0xa: + frame, err = wire.ParseStreamIDBlockedFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidFrameData, err.Error()) + } + case 0xc: + frame, err = wire.ParseStopSendingFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidFrameData, err.Error()) + } + case 0xe: + frame, err = wire.ParseAckFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidAckData, err.Error()) + } + default: + err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte)) + } + return frame, err +} + +func (u *packetUnpacker) parseGQUICFrame(r *bytes.Reader, typeByte byte, hdr *wire.Header) (wire.Frame, error) { + var frame wire.Frame + var err error + if typeByte&0x80 == 0x80 { + frame, err = wire.ParseStreamFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidStreamData, err.Error()) + } + return frame, err + } else if typeByte&0xc0 == 0x40 { + frame, err = wire.ParseAckFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidAckData, err.Error()) + } + return frame, err + } + switch typeByte { + case 0x1: + frame, err = wire.ParseRstStreamFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidRstStreamData, err.Error()) + } + case 0x2: + frame, err = wire.ParseConnectionCloseFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error()) + } + case 0x3: + frame, err = wire.ParseGoawayFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidGoawayData, err.Error()) + } + case 0x4: + frame, err = wire.ParseWindowUpdateFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) + } + case 0x5: + frame, err = wire.ParseBlockedFrameLegacy(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidBlockedData, err.Error()) + } + case 0x6: + frame, err = wire.ParseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidStopWaitingData, err.Error()) + } + case 0x7: + frame, err = wire.ParsePingFrame(r, u.version) + default: + err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte)) + } + return frame, err +} diff --git a/vendor/github.com/lucas-clemente/quic-go/protocol/version.go b/vendor/github.com/lucas-clemente/quic-go/protocol/version.go deleted file mode 100644 index c250cab6a..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/protocol/version.go +++ /dev/null @@ -1,59 +0,0 @@ -package protocol - -// VersionNumber is a version number as int -type VersionNumber int - -// The version numbers, making grepping easier -const ( - Version35 VersionNumber = 35 + iota - Version36 - Version37 - Version38 - VersionWhatever VersionNumber = 0 // for when the version doesn't matter - VersionUnsupported VersionNumber = -1 -) - -// SupportedVersions lists the versions that the server supports -// must be in sorted descending order -var SupportedVersions = []VersionNumber{ - Version38, - Version37, - Version36, - Version35, -} - -// VersionNumberToTag maps version numbers ('32') to tags ('Q032') -func VersionNumberToTag(vn VersionNumber) uint32 { - v := uint32(vn) - return 'Q' + ((v/100%10)+'0')<<8 + ((v/10%10)+'0')<<16 + ((v%10)+'0')<<24 -} - -// VersionTagToNumber is built from VersionNumberToTag in init() -func VersionTagToNumber(v uint32) VersionNumber { - return VersionNumber(((v>>8)&0xff-'0')*100 + ((v>>16)&0xff-'0')*10 + ((v>>24)&0xff - '0')) -} - -// IsSupportedVersion returns true if the server supports this version -func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { - for _, t := range supported { - if t == v { - return true - } - } - return false -} - -// ChooseSupportedVersion finds the best version in the overlap of ours and theirs -// ours is a slice of versions that we support, sorted by our preference (descending) -// theirs is a slice of versions offered by the peer. The order does not matter -// if no suitable version is found, it returns VersionUnsupported -func ChooseSupportedVersion(ours, theirs []VersionNumber) VersionNumber { - for _, ourVer := range ours { - for _, theirVer := range theirs { - if ourVer == theirVer { - return ourVer - } - } - } - return VersionUnsupported -} diff --git a/vendor/github.com/lucas-clemente/quic-go/public_reset.go b/vendor/github.com/lucas-clemente/quic-go/public_reset.go deleted file mode 100644 index 958db9cc4..000000000 --- a/vendor/github.com/lucas-clemente/quic-go/public_reset.go +++ /dev/null @@ -1,62 +0,0 @@ -package quic - -import ( - "bytes" - "encoding/binary" - "errors" - - "github.com/lucas-clemente/quic-go/handshake" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" -) - -type publicReset struct { - rejectedPacketNumber protocol.PacketNumber - nonce uint64 -} - -func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte { - b := &bytes.Buffer{} - b.WriteByte(0x0a) - utils.WriteUint64(b, uint64(connectionID)) - utils.WriteUint32(b, uint32(handshake.TagPRST)) - utils.WriteUint32(b, 2) - utils.WriteUint32(b, uint32(handshake.TagRNON)) - utils.WriteUint32(b, 8) - utils.WriteUint32(b, uint32(handshake.TagRSEQ)) - utils.WriteUint32(b, 16) - utils.WriteUint64(b, nonceProof) - utils.WriteUint64(b, uint64(rejectedPacketNumber)) - return b.Bytes() -} - -func parsePublicReset(r *bytes.Reader) (*publicReset, error) { - pr := publicReset{} - msg, err := handshake.ParseHandshakeMessage(r) - if err != nil { - return nil, err - } - if msg.Tag != handshake.TagPRST { - return nil, errors.New("wrong public reset tag") - } - - rseq, ok := msg.Data[handshake.TagRSEQ] - if !ok { - return nil, errors.New("RSEQ missing") - } - if len(rseq) != 8 { - return nil, errors.New("invalid RSEQ tag") - } - pr.rejectedPacketNumber = protocol.PacketNumber(binary.LittleEndian.Uint64(rseq)) - - rnon, ok := msg.Data[handshake.TagRNON] - if !ok { - return nil, errors.New("RNON missing") - } - if len(rnon) != 8 { - return nil, errors.New("invalid RNON tag") - } - pr.nonce = binary.LittleEndian.Uint64(rnon) - - return &pr, nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/qerr/errorcode_string.go b/vendor/github.com/lucas-clemente/quic-go/qerr/errorcode_string.go index 5a8e0240e..22d0c85a7 100644 --- a/vendor/github.com/lucas-clemente/quic-go/qerr/errorcode_string.go +++ b/vendor/github.com/lucas-clemente/quic-go/qerr/errorcode_string.go @@ -1,8 +1,8 @@ -// Code generated by "stringer -type=ErrorCode"; DO NOT EDIT +// Code generated by "stringer -type=ErrorCode"; DO NOT EDIT. package qerr -import "fmt" +import "strconv" const ( _ErrorCode_name_0 = "InternalErrorStreamDataAfterTerminationInvalidPacketHeaderInvalidFrameDataInvalidFecDataInvalidRstStreamDataInvalidConnectionCloseDataInvalidGoawayDataInvalidAckDataInvalidVersionNegotiationPacketInvalidPublicRstPacketDecryptionFailureEncryptionFailurePacketTooLarge" @@ -19,7 +19,6 @@ var ( _ErrorCode_index_2 = [...]uint16{0, 15, 37, 57, 75, 96, 112, 127, 147, 167, 191, 226, 250, 279, 309, 340, 366, 385, 410, 425, 445, 457, 475, 505, 530, 547} _ErrorCode_index_3 = [...]uint16{0, 14, 29, 50, 65, 90, 119, 158, 184, 208, 231, 249, 279, 301, 322, 340, 366, 390, 425} _ErrorCode_index_4 = [...]uint16{0, 16, 45, 78, 97, 114, 144, 169, 192, 215, 238, 256, 276, 292, 308, 346, 379, 410, 448, 459, 477, 498, 532} - _ErrorCode_index_5 = [...]uint8{0, 34} ) func (i ErrorCode) String() string { @@ -42,6 +41,6 @@ func (i ErrorCode) String() string { case i == 97: return _ErrorCode_name_5 default: - return fmt.Sprintf("ErrorCode(%d)", i) + return "ErrorCode(" + strconv.FormatInt(int64(i), 10) + ")" } } diff --git a/vendor/github.com/lucas-clemente/quic-go/receive_stream.go b/vendor/github.com/lucas-clemente/quic-go/receive_stream.go new file mode 100644 index 000000000..9ae216f6e --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/receive_stream.go @@ -0,0 +1,286 @@ +package quic + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type receiveStreamI interface { + ReceiveStream + + handleStreamFrame(*wire.StreamFrame) error + handleRstStreamFrame(*wire.RstStreamFrame) error + closeForShutdown(error) + getWindowUpdate() protocol.ByteCount +} + +type receiveStream struct { + mutex sync.Mutex + + streamID protocol.StreamID + + sender streamSender + + frameQueue *streamFrameSorter + readPosInFrame int + readOffset protocol.ByteCount + + closeForShutdownErr error + cancelReadErr error + resetRemotelyErr StreamError + + closedForShutdown bool // set when CloseForShutdown() is called + finRead bool // set once we read a frame with a FinBit + canceledRead bool // set when CancelRead() is called + resetRemotely bool // set when HandleRstStreamFrame() is called + + readChan chan struct{} + readDeadline time.Time + + flowController flowcontrol.StreamFlowController + version protocol.VersionNumber +} + +var _ ReceiveStream = &receiveStream{} +var _ receiveStreamI = &receiveStream{} + +func newReceiveStream( + streamID protocol.StreamID, + sender streamSender, + flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, +) *receiveStream { + return &receiveStream{ + streamID: streamID, + sender: sender, + flowController: flowController, + frameQueue: newStreamFrameSorter(), + readChan: make(chan struct{}, 1), + version: version, + } +} + +func (s *receiveStream) StreamID() protocol.StreamID { + return s.streamID +} + +// Read implements io.Reader. It is not thread safe! +func (s *receiveStream) Read(p []byte) (int, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.finRead { + return 0, io.EOF + } + if s.canceledRead { + return 0, s.cancelReadErr + } + if s.resetRemotely { + return 0, s.resetRemotelyErr + } + if s.closedForShutdown { + return 0, s.closeForShutdownErr + } + + bytesRead := 0 + for bytesRead < len(p) { + frame := s.frameQueue.Head() + if frame == nil && bytesRead > 0 { + return bytesRead, s.closeForShutdownErr + } + + for { + // Stop waiting on errors + if s.closedForShutdown { + return bytesRead, s.closeForShutdownErr + } + if s.canceledRead { + return bytesRead, s.cancelReadErr + } + if s.resetRemotely { + return bytesRead, s.resetRemotelyErr + } + + deadline := s.readDeadline + if !deadline.IsZero() && !time.Now().Before(deadline) { + return bytesRead, errDeadline + } + + if frame != nil { + s.readPosInFrame = int(s.readOffset - frame.Offset) + break + } + + s.mutex.Unlock() + if deadline.IsZero() { + <-s.readChan + } else { + select { + case <-s.readChan: + case <-time.After(deadline.Sub(time.Now())): + } + } + s.mutex.Lock() + frame = s.frameQueue.Head() + } + + if bytesRead > len(p) { + return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) + } + if s.readPosInFrame > int(frame.DataLen()) { + return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen()) + } + + s.mutex.Unlock() + + copy(p[bytesRead:], frame.Data[s.readPosInFrame:]) + m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame) + s.readPosInFrame += m + bytesRead += m + s.readOffset += protocol.ByteCount(m) + + s.mutex.Lock() + // when a RST_STREAM was received, the was already informed about the final byteOffset for this stream + if !s.resetRemotely { + s.flowController.AddBytesRead(protocol.ByteCount(m)) + } + // this call triggers the flow controller to increase the flow control window, if necessary + if s.flowController.HasWindowUpdate() { + s.sender.onHasWindowUpdate(s.streamID) + } + + if s.readPosInFrame >= int(frame.DataLen()) { + s.frameQueue.Pop() + s.finRead = frame.FinBit + if frame.FinBit { + s.sender.onStreamCompleted(s.streamID) + return bytesRead, io.EOF + } + } + } + return bytesRead, nil +} + +func (s *receiveStream) CancelRead(errorCode protocol.ApplicationErrorCode) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.finRead { + return nil + } + if s.canceledRead { + return nil + } + s.canceledRead = true + s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode) + s.signalRead() + if s.version.UsesIETFFrameFormat() { + s.sender.queueControlFrame(&wire.StopSendingFrame{ + StreamID: s.streamID, + ErrorCode: errorCode, + }) + } + return nil +} + +func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { + maxOffset := frame.Offset + frame.DataLen() + if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil { + return err + } + + s.mutex.Lock() + defer s.mutex.Unlock() + if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData { + return err + } + s.signalRead() + return nil +} + +func (s *receiveStream) handleRstStreamFrame(frame *wire.RstStreamFrame) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.closedForShutdown { + return nil + } + if err := s.flowController.UpdateHighestReceived(frame.ByteOffset, true); err != nil { + return err + } + // In gQUIC, error code 0 has a special meaning. + // The peer will reliably continue transmitting, but is not interested in reading from the stream. + // We should therefore just continue reading from the stream, until we encounter the FIN bit. + if !s.version.UsesIETFFrameFormat() && frame.ErrorCode == 0 { + return nil + } + + // ignore duplicate RST_STREAM frames for this stream (after checking their final offset) + if s.resetRemotely { + return nil + } + s.resetRemotely = true + s.resetRemotelyErr = streamCanceledError{ + errorCode: frame.ErrorCode, + error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode), + } + s.signalRead() + s.sender.onStreamCompleted(s.streamID) + return nil +} + +func (s *receiveStream) CloseRemote(offset protocol.ByteCount) { + s.handleStreamFrame(&wire.StreamFrame{FinBit: true, Offset: offset}) +} + +func (s *receiveStream) onClose(offset protocol.ByteCount) { + if s.canceledRead && !s.version.UsesIETFFrameFormat() { + s.sender.queueControlFrame(&wire.RstStreamFrame{ + StreamID: s.streamID, + ByteOffset: offset, + ErrorCode: 0, + }) + } +} + +func (s *receiveStream) SetReadDeadline(t time.Time) error { + s.mutex.Lock() + oldDeadline := s.readDeadline + s.readDeadline = t + s.mutex.Unlock() + // if the new deadline is before the currently set deadline, wake up Read() + if t.Before(oldDeadline) { + s.signalRead() + } + return nil +} + +// CloseForShutdown closes a stream abruptly. +// It makes Read unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. +func (s *receiveStream) closeForShutdown(err error) { + s.mutex.Lock() + s.closedForShutdown = true + s.closeForShutdownErr = err + s.mutex.Unlock() + s.signalRead() +} + +func (s *receiveStream) getWindowUpdate() protocol.ByteCount { + return s.flowController.GetWindowUpdate() +} + +// signalRead performs a non-blocking send on the readChan +func (s *receiveStream) signalRead() { + select { + case s.readChan <- struct{}{}: + default: + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/send_stream.go b/vendor/github.com/lucas-clemente/quic-go/send_stream.go new file mode 100644 index 000000000..86aed1534 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/send_stream.go @@ -0,0 +1,313 @@ +package quic + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type sendStreamI interface { + SendStream + handleStopSendingFrame(*wire.StopSendingFrame) + popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool) + closeForShutdown(error) + handleMaxStreamDataFrame(*wire.MaxStreamDataFrame) +} + +type sendStream struct { + mutex sync.Mutex + + ctx context.Context + ctxCancel context.CancelFunc + + streamID protocol.StreamID + sender streamSender + + writeOffset protocol.ByteCount + + cancelWriteErr error + closeForShutdownErr error + + closedForShutdown bool // set when CloseForShutdown() is called + finishedWriting bool // set once Close() is called + canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received + finSent bool // set when a STREAM_FRAME with FIN bit has b + + dataForWriting []byte + writeChan chan struct{} + writeDeadline time.Time + + flowController flowcontrol.StreamFlowController + + version protocol.VersionNumber +} + +var _ SendStream = &sendStream{} +var _ sendStreamI = &sendStream{} + +func newSendStream( + streamID protocol.StreamID, + sender streamSender, + flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, +) *sendStream { + s := &sendStream{ + streamID: streamID, + sender: sender, + flowController: flowController, + writeChan: make(chan struct{}, 1), + version: version, + } + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + return s +} + +func (s *sendStream) StreamID() protocol.StreamID { + return s.streamID // same for receiveStream and sendStream +} + +func (s *sendStream) Write(p []byte) (int, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.finishedWriting { + return 0, fmt.Errorf("write on closed stream %d", s.streamID) + } + if s.canceledWrite { + return 0, s.cancelWriteErr + } + if s.closeForShutdownErr != nil { + return 0, s.closeForShutdownErr + } + if !s.writeDeadline.IsZero() && !time.Now().Before(s.writeDeadline) { + return 0, errDeadline + } + if len(p) == 0 { + return 0, nil + } + + s.dataForWriting = make([]byte, len(p)) + copy(s.dataForWriting, p) + s.sender.onHasStreamData(s.streamID) + + var bytesWritten int + var err error + for { + bytesWritten = len(p) - len(s.dataForWriting) + deadline := s.writeDeadline + if !deadline.IsZero() && !time.Now().Before(deadline) { + s.dataForWriting = nil + err = errDeadline + break + } + if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown { + break + } + + s.mutex.Unlock() + if deadline.IsZero() { + <-s.writeChan + } else { + select { + case <-s.writeChan: + case <-time.After(deadline.Sub(time.Now())): + } + } + s.mutex.Lock() + } + + if s.closeForShutdownErr != nil { + err = s.closeForShutdownErr + } else if s.cancelWriteErr != nil { + err = s.cancelWriteErr + } + return bytesWritten, err +} + +// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream +// maxBytes is the maximum length this frame (including frame header) will have. +func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.closeForShutdownErr != nil { + return nil, false + } + + frame := &wire.StreamFrame{ + StreamID: s.streamID, + Offset: s.writeOffset, + DataLenPresent: true, + } + maxDataLen := frame.MaxDataLen(maxBytes, s.version) + if maxDataLen == 0 { // a STREAM frame must have at least one byte of data + return nil, s.dataForWriting != nil + } + frame.Data, frame.FinBit = s.getDataForWriting(maxDataLen) + if len(frame.Data) == 0 && !frame.FinBit { + // this can happen if: + // - popStreamFrame is called but there's no data for writing + // - there's data for writing, but the stream is stream-level flow control blocked + // - there's data for writing, but the stream is connection-level flow control blocked + if s.dataForWriting == nil { + return nil, false + } + isBlocked, _ := s.flowController.IsBlocked() + return nil, !isBlocked + } + if frame.FinBit { + s.finSent = true + s.sender.onStreamCompleted(s.streamID) + } else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream + if isBlocked, offset := s.flowController.IsBlocked(); isBlocked { + s.sender.queueControlFrame(&wire.StreamBlockedFrame{ + StreamID: s.streamID, + Offset: offset, + }) + return frame, false + } + } + return frame, s.dataForWriting != nil +} + +func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { + if s.dataForWriting == nil { + return nil, s.finishedWriting && !s.finSent + } + + // TODO(#657): Flow control for the crypto stream + if s.streamID != s.version.CryptoStreamID() { + maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize()) + } + if maxBytes == 0 { + return nil, false + } + + var ret []byte + if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { + ret = s.dataForWriting[:maxBytes] + s.dataForWriting = s.dataForWriting[maxBytes:] + } else { + ret = s.dataForWriting + s.dataForWriting = nil + s.signalWrite() + } + s.writeOffset += protocol.ByteCount(len(ret)) + s.flowController.AddBytesSent(protocol.ByteCount(len(ret))) + return ret, s.finishedWriting && s.dataForWriting == nil && !s.finSent +} + +func (s *sendStream) Close() error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.canceledWrite { + return fmt.Errorf("Close called for canceled stream %d", s.streamID) + } + s.finishedWriting = true + s.sender.onHasStreamData(s.streamID) // need to send the FIN + s.ctxCancel() + return nil +} + +func (s *sendStream) CancelWrite(errorCode protocol.ApplicationErrorCode) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) +} + +// must be called after locking the mutex +func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, writeErr error) error { + if s.canceledWrite { + return nil + } + if s.finishedWriting { + return fmt.Errorf("CancelWrite for closed stream %d", s.streamID) + } + s.canceledWrite = true + s.cancelWriteErr = writeErr + s.signalWrite() + s.sender.queueControlFrame(&wire.RstStreamFrame{ + StreamID: s.streamID, + ByteOffset: s.writeOffset, + ErrorCode: errorCode, + }) + // TODO(#991): cancel retransmissions for this stream + s.ctxCancel() + s.sender.onStreamCompleted(s.streamID) + return nil +} + +func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { + s.mutex.Lock() + defer s.mutex.Unlock() + s.handleStopSendingFrameImpl(frame) +} + +func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) { + s.flowController.UpdateSendWindow(frame.ByteOffset) + s.mutex.Lock() + if s.dataForWriting != nil { + s.sender.onHasStreamData(s.streamID) + } + s.mutex.Unlock() +} + +// must be called after locking the mutex +func (s *sendStream) handleStopSendingFrameImpl(frame *wire.StopSendingFrame) { + writeErr := streamCanceledError{ + errorCode: frame.ErrorCode, + error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode), + } + errorCode := errorCodeStopping + if !s.version.UsesIETFFrameFormat() { + errorCode = errorCodeStoppingGQUIC + } + s.cancelWriteImpl(errorCode, writeErr) +} + +func (s *sendStream) Context() context.Context { + return s.ctx +} + +func (s *sendStream) SetWriteDeadline(t time.Time) error { + s.mutex.Lock() + oldDeadline := s.writeDeadline + s.writeDeadline = t + s.mutex.Unlock() + if t.Before(oldDeadline) { + s.signalWrite() + } + return nil +} + +// CloseForShutdown closes a stream abruptly. +// It makes Write unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. +func (s *sendStream) closeForShutdown(err error) { + s.mutex.Lock() + s.closedForShutdown = true + s.closeForShutdownErr = err + s.mutex.Unlock() + s.signalWrite() + s.ctxCancel() +} + +func (s *sendStream) getWriteOffset() protocol.ByteCount { + return s.writeOffset +} + +// signalWrite performs a non-blocking send on the writeChan +func (s *sendStream) signalWrite() { + select { + case s.writeChan <- struct{}{}: + default: + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/server.go b/vendor/github.com/lucas-clemente/quic-go/server.go index 76f07bab3..dc23b8687 100644 --- a/vendor/github.com/lucas-clemente/quic-go/server.go +++ b/vendor/github.com/lucas-clemente/quic-go/server.go @@ -8,17 +8,21 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go/crypto" - "github.com/lucas-clemente/quic-go/handshake" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) // packetHandler handles packets type packetHandler interface { Session + getCryptoStream() cryptoStreamI + handshakeStatus() <-chan error handlePacket(*receivedPacket) + GetVersion() protocol.VersionNumber run() error closeRemote(error) } @@ -30,18 +34,23 @@ type server struct { conn net.PacketConn + supportsTLS bool + serverTLS *serverTLS + certChain crypto.CertChain scfg *handshake.ServerConfig - sessions map[protocol.ConnectionID]packetHandler - sessionsMutex sync.RWMutex - deleteClosedSessionsAfter time.Duration + sessionsMutex sync.RWMutex + sessions map[protocol.ConnectionID]packetHandler + closed bool serverError error sessionQueue chan Session errorChan chan struct{} - newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, <-chan handshakeEvent, error) + // set as members, so they can be set in the tests + newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error) + deleteClosedSessionsAfter time.Duration } var _ Listener = &server{} @@ -74,11 +83,21 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, if err != nil { return nil, err } + config = populateServerConfig(config) + + // check if any of the supported versions supports TLS + var supportsTLS bool + for _, v := range config.Versions { + if v.UsesTLS() { + supportsTLS = true + break + } + } s := &server{ conn: conn, tlsConf: tlsConf, - config: populateServerConfig(config), + config: config, certChain: certChain, scfg: scfg, sessions: map[protocol.ConnectionID]packetHandler{}, @@ -86,16 +105,55 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, sessionQueue: make(chan Session, 5), errorChan: make(chan struct{}), + supportsTLS: supportsTLS, + } + if supportsTLS { + if err := s.setupTLS(); err != nil { + return nil, err + } } go s.serve() + utils.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil } -var defaultAcceptSTK = func(clientAddr net.Addr, stk *STK) bool { - if stk == nil { +func (s *server) setupTLS() error { + cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie) + if err != nil { + return err + } + serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf) + if err != nil { + return err + } + s.serverTLS = serverTLS + // handle TLS connection establishment statelessly + go func() { + for { + select { + case <-s.errorChan: + return + case tlsSession := <-sessionChan: + connID := tlsSession.connID + sess := tlsSession.sess + if _, ok := s.sessions[connID]; ok { // drop this session if it already exists + return + } + s.sessionsMutex.Lock() + s.sessions[connID] = sess + s.sessionsMutex.Unlock() + s.runHandshakeAndSession(sess, connID) + } + } + }() + return nil +} + +var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool { + if cookie == nil { return false } - if time.Now().After(stk.sentTime.Add(protocol.STKExpiryTime)) { + if time.Now().After(cookie.SentTime.Add(protocol.CookieExpiryTime)) { return false } var sourceAddr string @@ -104,7 +162,7 @@ var defaultAcceptSTK = func(clientAddr net.Addr, stk *STK) bool { } else { sourceAddr = clientAddr.String() } - return sourceAddr == stk.remoteAddr + return sourceAddr == cookie.RemoteAddr } // populateServerConfig populates fields in the quic.Config with their default values, if none are set @@ -118,15 +176,19 @@ func populateServerConfig(config *Config) *Config { versions = protocol.SupportedVersions } - vsa := defaultAcceptSTK - if config.AcceptSTK != nil { - vsa = config.AcceptSTK + vsa := defaultAcceptCookie + if config.AcceptCookie != nil { + vsa = config.AcceptCookie } handshakeTimeout := protocol.DefaultHandshakeTimeout if config.HandshakeTimeout != 0 { handshakeTimeout = config.HandshakeTimeout } + idleTimeout := protocol.DefaultIdleTimeout + if config.IdleTimeout != 0 { + idleTimeout = config.IdleTimeout + } maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow if maxReceiveStreamFlowControlWindow == 0 { @@ -140,7 +202,9 @@ func populateServerConfig(config *Config) *Config { return &Config{ Versions: versions, HandshakeTimeout: handshakeTimeout, - AcceptSTK: vsa, + IdleTimeout: idleTimeout, + AcceptCookie: vsa, + KeepAlive: config.KeepAlive, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, } @@ -181,19 +245,29 @@ func (s *server) Accept() (Session, error) { // Close the server func (s *server) Close() error { s.sessionsMutex.Lock() + if s.closed { + s.sessionsMutex.Unlock() + return nil + } + s.closed = true + + var wg sync.WaitGroup for _, session := range s.sessions { if session != nil { - s.sessionsMutex.Unlock() - _ = session.Close(nil) - s.sessionsMutex.Lock() + wg.Add(1) + go func(sess packetHandler) { + // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped + _ = sess.Close(nil) + wg.Done() + }(session) } } s.sessionsMutex.Unlock() + wg.Wait() - if s.conn == nil { - return nil - } - return s.conn.Close() + err := s.conn.Close() + <-s.errorChan // wait for serve() to return + return err } // Addr returns the server's network address @@ -205,25 +279,39 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet rcvTime := time.Now() r := bytes.NewReader(packet) - hdr, err := ParsePublicHeader(r, protocol.PerspectiveClient) + hdr, err := wire.ParseHeaderSentByClient(r) if err != nil { return qerr.Error(qerr.InvalidPacketHeader, err.Error()) } hdr.Raw = packet[:len(packet)-r.Len()] + packetData := packet[len(packet)-r.Len():] + connID := hdr.ConnectionID + + if hdr.Type == protocol.PacketTypeInitial { + if s.supportsTLS { + go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData) + } + return nil + } s.sessionsMutex.RLock() - session, ok := s.sessions[hdr.ConnectionID] + session, sessionKnown := s.sessions[connID] s.sessionsMutex.RUnlock() + if sessionKnown && session == nil { + // Late packet for closed session + return nil + } + // ignore all Public Reset packets if hdr.ResetFlag { - if ok { - var pr *publicReset - pr, err = parsePublicReset(r) + if sessionKnown { + var pr *wire.PublicReset + pr, err = wire.ParsePublicReset(r) if err != nil { utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.") } else { - utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.rejectedPacketNumber) + utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.RejectedPacketNumber) } } else { utils.Infof("Received Public Reset for unknown connection %x.", hdr.ConnectionID) @@ -231,37 +319,47 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return nil } - // a session is only created once the client sent a supported version - // if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated - // it is safe to drop it - if ok && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.VersionNumber) { - return nil - } - - // Send Version Negotiation Packet if the client is speaking a different protocol version - if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.VersionNumber) { - // drop packets that are too small to be valid first packets - if len(packet) < protocol.ClientHelloMinimumSize+len(hdr.Raw) { - return errors.New("dropping small packet with unknown version") - } - utils.Infof("Client offered version %d, sending VersionNegotiationPacket", hdr.VersionNumber) - _, err = pconn.WriteTo(composeVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr) + // If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset + // This should only happen after a server restart, when we still receive packets for connections that we lost the state for. + // TODO(#943): implement sending of IETF draft style stateless resets + if !sessionKnown && (!hdr.VersionFlag && hdr.Type != protocol.PacketTypeInitial) { + _, err = pconn.WriteTo(wire.WritePublicReset(connID, 0, 0), remoteAddr) return err } - if !ok { - if !hdr.VersionFlag { - _, err = pconn.WriteTo(writePublicReset(hdr.ConnectionID, hdr.PacketNumber, 0), remoteAddr) - return err + // a session is only created once the client sent a supported version + // if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated + // it is safe to drop it + if sessionKnown && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { + return nil + } + + // send a Version Negotiation Packet if the client is speaking a different protocol version + // since the client send a Public Header (only gQUIC has a Version Flag), we need to send a gQUIC Version Negotiation Packet + if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { + // drop packets that are too small to be valid first packets + if len(packet) < protocol.MinClientHelloSize+len(hdr.Raw) { + return errors.New("dropping small packet with unknown version") } - version := hdr.VersionNumber + utils.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version) + _, err := pconn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr) + return err + } + + // This is (potentially) a Client Hello. + // Make sure it has the minimum required size before spending any more ressources on it. + if !sessionKnown && len(packet) < protocol.MinClientHelloSize+len(hdr.Raw) { + return errors.New("dropping small packet for unknown connection") + } + + if !sessionKnown { + version := hdr.Version if !protocol.IsSupportedVersion(s.config.Versions, version) { return errors.New("Server BUG: negotiated version not supported") } - utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, version, remoteAddr) - var handshakeChan <-chan handshakeEvent - session, handshakeChan, err = s.newSession( + utils.Infof("Serving new connection: %x, version %s from %v", hdr.ConnectionID, version, remoteAddr) + session, err = s.newSession( &conn{pconn: pconn, currentAddr: remoteAddr}, version, hdr.ConnectionID, @@ -273,41 +371,35 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return err } s.sessionsMutex.Lock() - s.sessions[hdr.ConnectionID] = session + s.sessions[connID] = session s.sessionsMutex.Unlock() - go func() { - // session.run() returns as soon as the session is closed - _ = session.run() - s.removeConnection(hdr.ConnectionID) - }() - - go func() { - for { - ev := <-handshakeChan - if ev.err != nil { - return - } - if ev.encLevel == protocol.EncryptionForwardSecure { - break - } - } - s.sessionQueue <- session - }() - } - if session == nil { - // Late packet for closed session - return nil + s.runHandshakeAndSession(session, connID) } session.handlePacket(&receivedPacket{ - remoteAddr: remoteAddr, - publicHeader: hdr, - data: packet[len(packet)-r.Len():], - rcvTime: rcvTime, + remoteAddr: remoteAddr, + header: hdr, + data: packetData, + rcvTime: rcvTime, }) return nil } +func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.ConnectionID) { + go func() { + _ = session.run() + // session.run() returns as soon as the session is closed + s.removeConnection(connID) + }() + + go func() { + if err := <-session.handshakeStatus(); err != nil { + return + } + s.sessionQueue <- session + }() +} + func (s *server) removeConnection(id protocol.ConnectionID) { s.sessionsMutex.Lock() s.sessions[id] = nil @@ -319,20 +411,3 @@ func (s *server) removeConnection(id protocol.ConnectionID) { s.sessionsMutex.Unlock() }) } - -func composeVersionNegotiation(connectionID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { - fullReply := &bytes.Buffer{} - responsePublicHeader := PublicHeader{ - ConnectionID: connectionID, - PacketNumber: 1, - VersionFlag: true, - } - err := responsePublicHeader.Write(fullReply, protocol.VersionWhatever, protocol.PerspectiveServer) - if err != nil { - utils.Errorf("error composing version negotiation packet: %s", err.Error()) - } - for _, v := range versions { - utils.WriteUint32(fullReply, protocol.VersionNumberToTag(v)) - } - return fullReply.Bytes() -} diff --git a/vendor/github.com/lucas-clemente/quic-go/server_tls.go b/vendor/github.com/lucas-clemente/quic-go/server_tls.go new file mode 100644 index 000000000..5f270e349 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/server_tls.go @@ -0,0 +1,220 @@ +package quic + +import ( + "crypto/tls" + "errors" + "fmt" + "net" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" +) + +type nullAEAD struct { + aead crypto.AEAD +} + +var _ quicAEAD = &nullAEAD{} + +func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { + data, err := n.aead.Open(dst, src, packetNumber, associatedData) + return data, protocol.EncryptionUnencrypted, err +} + +type tlsSession struct { + connID protocol.ConnectionID + sess packetHandler +} + +type serverTLS struct { + conn net.PacketConn + config *Config + supportedVersions []protocol.VersionNumber + mintConf *mint.Config + cookieProtector mint.CookieProtector + params *handshake.TransportParameters + newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) + + sessionChan chan<- tlsSession +} + +func newServerTLS( + conn net.PacketConn, + config *Config, + cookieHandler *handshake.CookieHandler, + tlsConf *tls.Config, +) (*serverTLS, <-chan tlsSession, error) { + mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer) + if err != nil { + return nil, nil, err + } + mconf.RequireCookie = true + cs, err := mint.NewDefaultCookieProtector() + if err != nil { + return nil, nil, err + } + mconf.CookieProtector = cs + mconf.CookieHandler = cookieHandler + + sessionChan := make(chan tlsSession) + s := &serverTLS{ + conn: conn, + config: config, + supportedVersions: config.Versions, + mintConf: mconf, + sessionChan: sessionChan, + params: &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + IdleTimeout: config.IdleTimeout, + // TODO(#523): make these values configurable + MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveServer), + MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveServer), + }, + } + s.newMintConn = s.newMintConnImpl + return s, sessionChan, nil +} + +func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) { + utils.Debugf("Received a Packet. Handling it statelessly.") + sess, err := s.handleInitialImpl(remoteAddr, hdr, data) + if err != nil { + utils.Errorf("Error occured handling initial packet: %s", err) + return + } + if sess == nil { // a stateless reset was done + return + } + s.sessionChan <- tlsSession{ + connID: hdr.ConnectionID, + sess: sess, + } +} + +// will be set to s.newMintConn by the constructor +func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) { + extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v) + conf := s.mintConf.Clone() + conf.ExtensionHandler = extHandler + return newMintController(bc, conf, protocol.PerspectiveServer), extHandler.GetPeerParams(), nil +} + +func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Header, aead crypto.AEAD, closeErr error) error { + ccf := &wire.ConnectionCloseFrame{ + ErrorCode: qerr.HandshakeFailed, + ReasonPhrase: closeErr.Error(), + } + replyHdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + ConnectionID: clientHdr.ConnectionID, // echo the client's connection ID + PacketNumber: 1, // random packet number + Version: clientHdr.Version, + } + data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer) + if err != nil { + return err + } + _, err = s.conn.WriteTo(data, remoteAddr) + return err +} + +func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) { + if len(hdr.Raw)+len(data) < protocol.MinInitialPacketSize { + return nil, errors.New("dropping too small Initial packet") + } + // check version, if not matching send VNP + if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) { + utils.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) + _, err := s.conn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.supportedVersions), remoteAddr) + return nil, err + } + + // unpack packet and check stream frame contents + aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, hdr.Version) + if err != nil { + return nil, err + } + frame, err := unpackInitialPacket(aead, hdr, data, hdr.Version) + if err != nil { + utils.Debugf("Error unpacking initial packet: %s", err) + return nil, nil + } + sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead) + if err != nil { + if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil { + utils.Debugf("Error sending CONNECTION_CLOSE: ", ccerr) + } + return nil, err + } + return sess, nil +} + +func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, error) { + version := hdr.Version + bc := handshake.NewCryptoStreamConn(remoteAddr) + bc.AddDataForReading(frame.Data) + tls, paramsChan, err := s.newMintConn(bc, version) + if err != nil { + return nil, err + } + alert := tls.Handshake() + if alert == mint.AlertStatelessRetry { + // the HelloRetryRequest was written to the bufferConn + // Take that data and write send a Retry packet + replyHdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + ConnectionID: hdr.ConnectionID, // echo the client's connection ID + PacketNumber: hdr.PacketNumber, // echo the client's packet number + Version: version, + } + f := &wire.StreamFrame{ + StreamID: version.CryptoStreamID(), + Data: bc.GetDataForWriting(), + } + data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer) + if err != nil { + return nil, err + } + _, err = s.conn.WriteTo(data, remoteAddr) + return nil, err + } + if alert != mint.AlertNoAlert { + return nil, alert + } + if tls.State() != mint.StateServerNegotiated { + return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State()) + } + if alert := tls.Handshake(); alert != mint.AlertNoAlert { + return nil, alert + } + if tls.State() != mint.StateServerWaitFlight2 { + return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State()) + } + params := <-paramsChan + sess, err := newTLSServerSession( + &conn{pconn: s.conn, currentAddr: remoteAddr}, + hdr.ConnectionID, // TODO: we can use a server-chosen connection ID here + protocol.PacketNumber(1), // TODO: use a random packet number here + s.config, + tls, + bc, + aead, + ¶ms, + version, + ) + if err != nil { + return nil, err + } + cs := sess.getCryptoStream() + cs.setReadOffset(frame.DataLen()) + bc.SetStream(cs) + return sess, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/session.go b/vendor/github.com/lucas-clemente/quic-go/session.go index 88d8ba565..49c616511 100644 --- a/vendor/github.com/lucas-clemente/quic-go/session.go +++ b/vendor/github.com/lucas-clemente/quic-go/session.go @@ -9,42 +9,50 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go/ackhandler" - "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/flowcontrol" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/handshake" + "github.com/lucas-clemente/quic-go/internal/ackhandler" + "github.com/lucas-clemente/quic-go/internal/congestion" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) type unpacker interface { - Unpack(publicHeaderBinary []byte, hdr *PublicHeader, data []byte) (*unpackedPacket, error) + Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) +} + +type streamGetter interface { + GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) + GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) +} + +type streamManager interface { + GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) + GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) + OpenStream() (Stream, error) + OpenStreamSync() (Stream, error) + AcceptStream() (Stream, error) + DeleteStream(protocol.StreamID) error + UpdateLimits(*handshake.TransportParameters) + HandleMaxStreamIDFrame(*wire.MaxStreamIDFrame) error + CloseWithError(error) } type receivedPacket struct { - remoteAddr net.Addr - publicHeader *PublicHeader - data []byte - rcvTime time.Time + remoteAddr net.Addr + header *wire.Header + data []byte + rcvTime time.Time } -var ( - errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream") - errWindowUpdateOnClosedStream = errors.New("WINDOW_UPDATE received for an already closed stream") -) - var ( newCryptoSetup = handshake.NewCryptoSetup newCryptoSetupClient = handshake.NewCryptoSetupClient ) -type handshakeEvent struct { - encLevel protocol.EncryptionLevel - err error -} - type closeError struct { err error remote bool @@ -55,20 +63,20 @@ type session struct { connectionID protocol.ConnectionID perspective protocol.Perspective version protocol.VersionNumber - tlsConf *tls.Config config *Config conn connection - streamsMap *streamsMap + streamsMap streamManager + cryptoStream cryptoStreamI rttStats *congestion.RTTStats sentPacketHandler ackhandler.SentPacketHandler receivedPacketHandler ackhandler.ReceivedPacketHandler streamFramer *streamFramer - - flowControlManager flowcontrol.FlowControlManager + windowUpdateQueue *windowUpdateQueue + connFlowController flowcontrol.ConnectionFlowController unpacker unpacker packer *packetPacker @@ -89,19 +97,16 @@ type session struct { undecryptablePackets []*receivedPacket receivedTooManyUndecrytablePacketsTime time.Time - // this channel is passed to the CryptoSetup and receives the current encryption level - // it is closed as soon as the handshake is complete - aeadChanged <-chan protocol.EncryptionLevel + // this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them + paramsChan <-chan handshake.TransportParameters + // the handshakeEvent channel is passed to the CryptoSetup. + // It receives when it makes sense to try decrypting undecryptable packets. + handshakeEvent <-chan struct{} + // handshakeChan is returned by handshakeStatus. + // It receives any error that might occur during the handshake. + // It is closed when the handshake is complete. + handshakeChan chan error handshakeComplete bool - // will be closed as soon as the handshake completes, and receive any error that might occur until then - // it is used to block WaitUntilHandshakeComplete() - handshakeCompleteChan chan error - // handshakeChan receives handshake events and is closed as soon the handshake completes - // the receiving end of this channel is passed to the creator of the session - // it receives at most 3 handshake events: 2 when the encryption level changes, and one error - handshakeChan chan<- handshakeEvent - - connectionParameters handshake.ConnectionParametersManager lastRcvdPacketNumber protocol.PacketNumber // Used to calculate the next packet number from the truncated wire @@ -110,6 +115,10 @@ type session struct { sessionCreationTime time.Time lastNetworkActivityTime time.Time + // pacingDeadline is the time when the next packet should be sent + pacingDeadline time.Time + + peerParams *handshake.TransportParameters timer *utils.Timer // keepAlivePingSent stores whether a Ping frame was sent to the peer or not @@ -118,27 +127,55 @@ type session struct { } var _ Session = &session{} +var _ streamSender = &session{} // newSession makes a new session func newSession( conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, - sCfg *handshake.ServerConfig, + scfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, -) (packetHandler, <-chan handshakeEvent, error) { +) (packetHandler, error) { + paramsChan := make(chan handshake.TransportParameters) + handshakeEvent := make(chan struct{}, 1) s := &session{ - conn: conn, - connectionID: connectionID, - perspective: protocol.PerspectiveServer, - version: v, - config: config, + conn: conn, + connectionID: connectionID, + perspective: protocol.PerspectiveServer, + version: v, + config: config, + handshakeEvent: handshakeEvent, + paramsChan: paramsChan, } - return s.setup(sCfg, "", nil) + s.preSetup() + transportParams := &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + MaxStreams: protocol.MaxIncomingStreams, + IdleTimeout: s.config.IdleTimeout, + } + cs, err := newCryptoSetup( + s.cryptoStream, + s.connectionID, + s.conn.RemoteAddr(), + s.version, + scfg, + transportParams, + s.config.Versions, + s.config.AcceptCookie, + paramsChan, + handshakeEvent, + ) + if err != nil { + return nil, err + } + s.cryptoSetup = cs + return s, s.postSetup(1) } -// declare this as a variable, such that we can it mock it in the tests +// declare this as a variable, so that we can it mock it in the tests var newClientSession = func( conn connection, hostname string, @@ -146,29 +183,134 @@ var newClientSession = func( connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, - negotiatedVersions []protocol.VersionNumber, -) (packetHandler, <-chan handshakeEvent, error) { + initialVersion protocol.VersionNumber, + negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton +) (packetHandler, error) { + paramsChan := make(chan handshake.TransportParameters) + handshakeEvent := make(chan struct{}, 1) s := &session{ - conn: conn, - connectionID: connectionID, - perspective: protocol.PerspectiveClient, - version: v, - tlsConf: tlsConf, - config: config, + conn: conn, + connectionID: connectionID, + perspective: protocol.PerspectiveClient, + version: v, + config: config, + handshakeEvent: handshakeEvent, + paramsChan: paramsChan, } - return s.setup(nil, hostname, negotiatedVersions) + s.preSetup() + transportParams := &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + MaxStreams: protocol.MaxIncomingStreams, + IdleTimeout: s.config.IdleTimeout, + OmitConnectionID: s.config.RequestConnectionIDOmission, + } + cs, err := newCryptoSetupClient( + s.cryptoStream, + hostname, + s.connectionID, + s.version, + tlsConf, + transportParams, + paramsChan, + handshakeEvent, + initialVersion, + negotiatedVersions, + ) + if err != nil { + return nil, err + } + s.cryptoSetup = cs + return s, s.postSetup(1) } -func (s *session) setup( - scfg *handshake.ServerConfig, +func newTLSServerSession( + conn connection, + connectionID protocol.ConnectionID, + initialPacketNumber protocol.PacketNumber, + config *Config, + tls handshake.MintTLS, + cryptoStreamConn *handshake.CryptoStreamConn, + nullAEAD crypto.AEAD, + peerParams *handshake.TransportParameters, + v protocol.VersionNumber, +) (packetHandler, error) { + handshakeEvent := make(chan struct{}, 1) + s := &session{ + conn: conn, + config: config, + connectionID: connectionID, + perspective: protocol.PerspectiveServer, + version: v, + handshakeEvent: handshakeEvent, + } + s.preSetup() + s.cryptoSetup = handshake.NewCryptoSetupTLSServer( + tls, + cryptoStreamConn, + nullAEAD, + handshakeEvent, + v, + ) + if err := s.postSetup(initialPacketNumber); err != nil { + return nil, err + } + s.peerParams = peerParams + s.processTransportParameters(peerParams) + s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} + return s, nil +} + +// declare this as a variable, such that we can it mock it in the tests +var newTLSClientSession = func( + conn connection, hostname string, - negotiatedVersions []protocol.VersionNumber, -) (packetHandler, <-chan handshakeEvent, error) { - aeadChanged := make(chan protocol.EncryptionLevel, 2) - s.aeadChanged = aeadChanged - handshakeChan := make(chan handshakeEvent, 3) - s.handshakeChan = handshakeChan - s.handshakeCompleteChan = make(chan error, 1) + v protocol.VersionNumber, + connectionID protocol.ConnectionID, + config *Config, + tls handshake.MintTLS, + paramsChan <-chan handshake.TransportParameters, + initialPacketNumber protocol.PacketNumber, +) (packetHandler, error) { + handshakeEvent := make(chan struct{}, 1) + s := &session{ + conn: conn, + config: config, + connectionID: connectionID, + perspective: protocol.PerspectiveClient, + version: v, + handshakeEvent: handshakeEvent, + paramsChan: paramsChan, + } + s.preSetup() + tls.SetCryptoStream(s.cryptoStream) + cs, err := handshake.NewCryptoSetupTLSClient( + s.cryptoStream, + s.connectionID, + hostname, + handshakeEvent, + tls, + v, + ) + if err != nil { + return nil, err + } + s.cryptoSetup = cs + return s, s.postSetup(initialPacketNumber) +} + +func (s *session) preSetup() { + s.rttStats = &congestion.RTTStats{} + s.connFlowController = flowcontrol.NewConnectionFlowController( + protocol.ReceiveConnectionFlowControlWindow, + protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), + s.rttStats, + ) + s.cryptoStream = s.newCryptoStream() +} + +func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { + s.handshakeChan = make(chan error, 1) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) @@ -180,70 +322,31 @@ func (s *session) setup( s.lastNetworkActivityTime = now s.sessionCreationTime = now - s.rttStats = &congestion.RTTStats{} - s.connectionParameters = handshake.NewConnectionParamatersManager(s.perspective, s.version, - s.config.MaxReceiveStreamFlowControlWindow, s.config.MaxReceiveConnectionFlowControlWindow) s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) - s.flowControlManager = flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats) - s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler() - s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters) - s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) + s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) - var err error - if s.perspective == protocol.PerspectiveServer { - cryptoStream, _ := s.GetOrOpenStream(1) - _, _ = s.AcceptStream() // don't expose the crypto stream - verifySourceAddr := func(clientAddr net.Addr, hstk *handshake.STK) bool { - var stk *STK - if hstk != nil { - stk = &STK{remoteAddr: hstk.RemoteAddr, sentTime: hstk.SentTime} - } - return s.config.AcceptSTK(clientAddr, stk) - } - s.cryptoSetup, err = newCryptoSetup( - s.connectionID, - s.conn.RemoteAddr(), - s.version, - scfg, - cryptoStream, - s.connectionParameters, - s.config.Versions, - verifySourceAddr, - aeadChanged, - ) + if s.version.UsesTLS() { + s.streamsMap = newStreamsMap(s, s.newFlowController, s.perspective, s.version) } else { - cryptoStream, _ := s.OpenStream() - s.cryptoSetup, err = newCryptoSetupClient( - hostname, - s.connectionID, - s.version, - cryptoStream, - s.tlsConf, - s.connectionParameters, - aeadChanged, - &handshake.TransportParameters{RequestConnectionIDTruncation: s.config.RequestConnectionIDTruncation}, - negotiatedVersions, - ) + s.streamsMap = newStreamsMapLegacy(s.newStream, s.perspective) } - if err != nil { - return nil, nil, err - } - + s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) s.packer = newPacketPacker(s.connectionID, + initialPacketNumber, s.cryptoSetup, - s.connectionParameters, s.streamFramer, s.perspective, s.version, ) + s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.packer.QueueControlFrame) s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} - - return s, handshakeChan, nil + return nil } // run the session main loop func (s *session) run() error { - // Start the crypto stream handler + defer s.ctxCancel() + go func() { if err := s.cryptoSetup.HandleCryptoStream(); err != nil { s.Close(err) @@ -251,10 +354,11 @@ func (s *session) run() error { }() var closeErr closeError - aeadChanged := s.aeadChanged + handshakeEvent := s.handshakeEvent runLoop: for { + // Close immediately if requested select { case closeErr = <-s.closeChan: @@ -286,16 +390,23 @@ runLoop: } // This is a bit unclean, but works properly, since the packet always // begins with the public header and we never copy it. - putPacketBuffer(p.publicHeader.Raw) - case l, ok := <-aeadChanged: + putPacketBuffer(p.header.Raw) + case p := <-s.paramsChan: + s.processTransportParameters(&p) + case _, ok := <-handshakeEvent: if !ok { // the aeadChanged chan was closed. This means that the handshake is completed. s.handshakeComplete = true - aeadChanged = nil // prevent this case from ever being selected again + handshakeEvent = nil // prevent this case from ever being selected again + s.sentPacketHandler.SetHandshakeComplete() + if !s.version.UsesTLS() && s.perspective == protocol.PerspectiveClient { + // In gQUIC, there's no equivalent to the Finished message in TLS + // The server knows that the handshake is complete when it receives the first forward-secure packet sent by the client. + // We need to make sure that the client actually sends such a packet. + s.packer.QueueControlFrame(&wire.PingFrame{}) + } close(s.handshakeChan) - close(s.handshakeCompleteChan) } else { s.tryDecryptingQueuedPackets() - s.handshakeChan <- handshakeEvent{encLevel: l} } } @@ -306,35 +417,43 @@ runLoop: s.sentPacketHandler.OnAlarm() } - if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.idleTimeout()/2 { + var pacingDeadline time.Time + if s.pacingDeadline.IsZero() { // the timer didn't have a pacing deadline set + pacingDeadline = s.sentPacketHandler.TimeUntilSend() + } + if s.config.KeepAlive && !s.keepAlivePingSent && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.peerParams.IdleTimeout/2 { // send the PING frame since there is no activity in the session - s.packer.QueueControlFrame(&frames.PingFrame{}) + s.packer.QueueControlFrame(&wire.PingFrame{}) s.keepAlivePingSent = true + } else if !pacingDeadline.IsZero() && now.Before(pacingDeadline) { + // If we get to this point before the pacing deadline, we should wait until that deadline. + // This can happen when scheduleSending is called, or a packet is received. + // Set the timer and restart the run loop. + s.pacingDeadline = pacingDeadline + continue } - if err := s.sendPacket(); err != nil { + if err := s.sendPackets(); err != nil { s.closeLocal(err) } + if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 { s.closeLocal(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) } - if now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() { - s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) - } if !s.handshakeComplete && now.Sub(s.sessionCreationTime) >= s.config.HandshakeTimeout { s.closeLocal(qerr.Error(qerr.HandshakeTimeout, "Crypto handshake did not complete in time.")) } - s.garbageCollectStreams() + if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout { + s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) + } } // only send the error the handshakeChan when the handshake is not completed yet // otherwise this chan will already be closed if !s.handshakeComplete { - s.handshakeCompleteChan <- closeErr.err - s.handshakeChan <- handshakeEvent{err: closeErr.err} + s.handshakeChan <- closeErr.err } s.handleCloseError(closeErr) - defer s.ctxCancel() return closeErr.err } @@ -342,12 +461,16 @@ func (s *session) Context() context.Context { return s.ctx } +func (s *session) ConnectionState() ConnectionState { + return s.cryptoSetup.ConnectionState() +} + func (s *session) maybeResetTimer() { var deadline time.Time if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent { - deadline = s.lastNetworkActivityTime.Add(s.idleTimeout() / 2) + deadline = s.lastNetworkActivityTime.Add(s.peerParams.IdleTimeout / 2) } else { - deadline = s.lastNetworkActivityTime.Add(s.idleTimeout()) + deadline = s.lastNetworkActivityTime.Add(s.config.IdleTimeout) } if ackAlarm := s.receivedPacketHandler.GetAlarmTimeout(); !ackAlarm.IsZero() { @@ -363,20 +486,16 @@ func (s *session) maybeResetTimer() { if !s.receivedTooManyUndecrytablePacketsTime.IsZero() { deadline = utils.MinTime(deadline, s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout)) } + if !s.pacingDeadline.IsZero() { + deadline = utils.MinTime(deadline, s.pacingDeadline) + } s.timer.Reset(deadline) } -func (s *session) idleTimeout() time.Duration { - if s.handshakeComplete { - return s.connectionParameters.GetIdleConnectionStateLifetime() - } - return protocol.InitialIdleTimeout -} - func (s *session) handlePacketImpl(p *receivedPacket) error { if s.perspective == protocol.PerspectiveClient { - diversificationNonce := p.publicHeader.DiversificationNonce + diversificationNonce := p.header.DiversificationNonce if len(diversificationNonce) > 0 { s.cryptoSetup.SetDiversificationNonce(diversificationNonce) } @@ -389,7 +508,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { s.lastNetworkActivityTime = p.rcvTime s.keepAlivePingSent = false - hdr := p.publicHeader + hdr := p.header data := p.data // Calculate packet number @@ -406,16 +525,9 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { } else { utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID, packet.encryptionLevel) } + hdr.Log() } // if the decryption failed, this might be a packet sent by an attacker - // don't update the remote address - if quicErr, ok := err.(*qerr.QuicError); ok && quicErr.ErrorCode == qerr.DecryptionFailure { - return err - } - if s.perspective == protocol.PerspectiveServer { - // update the remote address, even if unpacking failed for any other reason than a decryption error - s.conn.SetCurrentRemoteAddr(p.remoteAddr) - } if err != nil { return err } @@ -425,36 +537,41 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber) isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames) - if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, isRetransmittable); err != nil { + if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, p.rcvTime, isRetransmittable); err != nil { return err } - return s.handleFrames(packet.frames) + return s.handleFrames(packet.frames, packet.encryptionLevel) } -func (s *session) handleFrames(fs []frames.Frame) error { +func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLevel) error { for _, ff := range fs { var err error - frames.LogFrame(ff, false) + wire.LogFrame(ff, false) switch frame := ff.(type) { - case *frames.StreamFrame: + case *wire.StreamFrame: err = s.handleStreamFrame(frame) - case *frames.AckFrame: - err = s.handleAckFrame(frame) - case *frames.ConnectionCloseFrame: + case *wire.AckFrame: + err = s.handleAckFrame(frame, encLevel) + case *wire.ConnectionCloseFrame: s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase)) - case *frames.GoawayFrame: + case *wire.GoawayFrame: err = errors.New("unimplemented: handling GOAWAY frames") - case *frames.StopWaitingFrame: - // LeastUnacked is guaranteed to have LeastUnacked > 0 - // therefore this will never underflow - s.receivedPacketHandler.SetLowerLimit(frame.LeastUnacked - 1) - case *frames.RstStreamFrame: + case *wire.StopWaitingFrame: // ignore STOP_WAITINGs + case *wire.RstStreamFrame: err = s.handleRstStreamFrame(frame) - case *frames.WindowUpdateFrame: - err = s.handleWindowUpdateFrame(frame) - case *frames.BlockedFrame: - case *frames.PingFrame: + case *wire.MaxDataFrame: + s.handleMaxDataFrame(frame) + case *wire.MaxStreamDataFrame: + err = s.handleMaxStreamDataFrame(frame) + case *wire.MaxStreamIDFrame: + err = s.handleMaxStreamIDFrame(frame) + case *wire.BlockedFrame: + case *wire.StreamBlockedFrame: + case *wire.StreamIDBlockedFrame: + case *wire.StopSendingFrame: + err = s.handleStopSendingFrame(frame) + case *wire.PingFrame: default: return errors.New("Session BUG: unexpected frame type") } @@ -463,11 +580,6 @@ func (s *session) handleFrames(fs []frames.Frame) error { switch err { case ackhandler.ErrDuplicateOrOutOfOrderAck: // Can happen e.g. when packets thought missing arrive late - case errRstStreamOnInvalidStream: - // Can happen when RST_STREAMs arrive early or late (?) - utils.Errorf("Ignoring error in session: %s", err.Error()) - case errWindowUpdateOnClosedStream: - // Can happen when we already sent the last StreamFrame with the FinBit, but the client already sent a WindowUpdate for this Stream default: return err } @@ -486,8 +598,14 @@ func (s *session) handlePacket(p *receivedPacket) { } } -func (s *session) handleStreamFrame(frame *frames.StreamFrame) error { - str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) +func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { + if frame.StreamID == s.version.CryptoStreamID() { + if frame.FinBit { + return errors.New("Received STREAM frame with FIN bit for the crypto stream") + } + return s.cryptoStream.handleStreamFrame(frame) + } + str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) if err != nil { return err } @@ -496,38 +614,71 @@ func (s *session) handleStreamFrame(frame *frames.StreamFrame) error { // ignore this StreamFrame return nil } - return str.AddStreamFrame(frame) + return str.handleStreamFrame(frame) } -func (s *session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error { - if frame.StreamID != 0 { - str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) - if err != nil { - return err - } - if str == nil { - return errWindowUpdateOnClosedStream - } +func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) { + s.connFlowController.UpdateSendWindow(frame.ByteOffset) +} + +func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { + if frame.StreamID == s.version.CryptoStreamID() { + s.cryptoStream.handleMaxStreamDataFrame(frame) + return nil } - _, err := s.flowControlManager.UpdateWindow(frame.StreamID, frame.ByteOffset) - return err -} - -func (s *session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { - str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) + str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) if err != nil { return err } if str == nil { - return errRstStreamOnInvalidStream + // stream is closed and already garbage collected + return nil } - - str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode)) - return s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset) + str.handleMaxStreamDataFrame(frame) + return nil } -func (s *session) handleAckFrame(frame *frames.AckFrame) error { - return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime) +func (s *session) handleMaxStreamIDFrame(frame *wire.MaxStreamIDFrame) error { + return s.streamsMap.HandleMaxStreamIDFrame(frame) +} + +func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { + if frame.StreamID == s.version.CryptoStreamID() { + return errors.New("Received RST_STREAM frame for the crypto stream") + } + str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + // stream is closed and already garbage collected + return nil + } + return str.handleRstStreamFrame(frame) +} + +func (s *session) handleStopSendingFrame(frame *wire.StopSendingFrame) error { + if frame.StreamID == s.version.CryptoStreamID() { + return errors.New("Received a STOP_SENDING frame for the crypto stream") + } + str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + // stream is closed and already garbage collected + return nil + } + str.handleStopSendingFrame(frame) + return nil +} + +func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { + if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil { + return err + } + s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked()) + return nil } func (s *session) closeLocal(e error) { @@ -567,9 +718,10 @@ func (s *session) handleCloseError(closeErr closeError) error { utils.Errorf("Closing session with error: %s", closeErr.err.Error()) } + s.cryptoStream.closeForShutdown(quicErr) s.streamsMap.CloseWithError(quicErr) - if closeErr.err == errCloseSessionForNewVersion { + if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry { return nil } @@ -586,108 +738,131 @@ func (s *session) handleCloseError(closeErr closeError) error { return s.sendConnectionClose(quicErr) } -func (s *session) sendPacket() error { +func (s *session) processTransportParameters(params *handshake.TransportParameters) { + s.peerParams = params + s.streamsMap.UpdateLimits(params) + if params.OmitConnectionID { + s.packer.SetOmitConnectionID() + } + s.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow) + // the crypto stream is the only open stream at this moment + // so we don't need to update stream flow control windows +} + +func (s *session) sendPackets() error { + s.pacingDeadline = time.Time{} + if !s.sentPacketHandler.SendingAllowed() { // if congestion limited, at least try sending an ACK frame + return s.maybeSendAckOnlyPacket() + } + numPackets := s.sentPacketHandler.ShouldSendNumPackets() + for i := 0; i < numPackets; i++ { + sentPacket, err := s.sendPacket() + if err != nil { + return err + } + // If no packet was sent, or we're congestion limit, we're done here. + if !sentPacket || !s.sentPacketHandler.SendingAllowed() { + return nil + } + } + // Only start the pacing timer if we sent as many packets as we were allowed. + // There will probably be more to send when calling sendPacket again. + s.pacingDeadline = s.sentPacketHandler.TimeUntilSend() + return nil +} + +func (s *session) maybeSendAckOnlyPacket() error { + ack := s.receivedPacketHandler.GetAckFrame() + if ack == nil { + return nil + } + s.packer.QueueControlFrame(ack) + + if !s.version.UsesIETFFrameFormat() { // for gQUIC, maybe add a STOP_WAITING + if swf := s.sentPacketHandler.GetStopWaitingFrame(false); swf != nil { + s.packer.QueueControlFrame(swf) + } + } + packet, err := s.packer.PackAckPacket() + if err != nil { + return err + } + return s.sendPackedPacket(packet) +} + +func (s *session) sendPacket() (bool, error) { s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked()) - // Get WindowUpdate frames - // this call triggers the flow controller to increase the flow control windows, if necessary - windowUpdateFrames := s.getWindowUpdateFrames() - for _, wuf := range windowUpdateFrames { - s.packer.QueueControlFrame(wuf) + if offset := s.connFlowController.GetWindowUpdate(); offset != 0 { + s.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: offset}) } + if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { + s.packer.QueueControlFrame(&wire.BlockedFrame{Offset: offset}) + } + s.windowUpdateQueue.QueueAll() ack := s.receivedPacketHandler.GetAckFrame() if ack != nil { s.packer.QueueControlFrame(ack) } - // Repeatedly try sending until we don't have any more data, or run out of the congestion window + // check for retransmissions first for { - if !s.sentPacketHandler.SendingAllowed() { - if ack == nil { - return nil - } - // If we aren't allowed to send, at least try sending an ACK frame - swf := s.sentPacketHandler.GetStopWaitingFrame(false) - if swf != nil { - s.packer.QueueControlFrame(swf) - } - packet, err := s.packer.PackAckPacket() - if err != nil { - return err - } - return s.sendPackedPacket(packet) + retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission() + if retransmitPacket == nil { + break } - // check for retransmissions first - for { - retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission() - if retransmitPacket == nil { - break - } - - if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { - if s.handshakeComplete { - // Don't retransmit handshake packets when the handshake is complete - continue - } - utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) + // retransmit handshake packets + if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { + utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) + if !s.version.UsesIETFFrameFormat() { s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) - packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket) - if err != nil { - return err - } - if err = s.sendPackedPacket(packet); err != nil { - return err - } - } else { - utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) - // resend the frames that were in the packet - for _, frame := range retransmitPacket.GetFramesForRetransmission() { - switch f := frame.(type) { - case *frames.StreamFrame: - s.streamFramer.AddFrameForRetransmission(f) - case *frames.WindowUpdateFrame: - // only retransmit WindowUpdates if the stream is not yet closed and the we haven't sent another WindowUpdate with a higher ByteOffset for the stream - currentOffset, err := s.flowControlManager.GetReceiveWindow(f.StreamID) - if err == nil && f.ByteOffset >= currentOffset { - s.packer.QueueControlFrame(f) - } - default: - s.packer.QueueControlFrame(frame) - } - } } + packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket) + if err != nil { + return false, err + } + if err := s.sendPackedPacket(packet); err != nil { + return false, err + } + return true, nil } - hasRetransmission := s.streamFramer.HasFramesForRetransmission() - if ack != nil || hasRetransmission { - swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission) - if swf != nil { - s.packer.QueueControlFrame(swf) + // queue all retransmittable frames sent in forward-secure packets + utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) + // resend the frames that were in the packet + for _, frame := range retransmitPacket.GetFramesForRetransmission() { + // TODO: only retransmit WINDOW_UPDATEs if they actually enlarge the window + switch f := frame.(type) { + case *wire.StreamFrame: + s.streamFramer.AddFrameForRetransmission(f) + default: + s.packer.QueueControlFrame(frame) } } - packet, err := s.packer.PackPacket() - if err != nil || packet == nil { - return err - } - if err = s.sendPackedPacket(packet); err != nil { - return err - } - - // send every window update twice - for _, f := range windowUpdateFrames { - s.packer.QueueControlFrame(f) - } - windowUpdateFrames = nil - ack = nil } + + hasRetransmission := s.streamFramer.HasFramesForRetransmission() + if !s.version.UsesIETFFrameFormat() && (ack != nil || hasRetransmission) { + if swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission); swf != nil { + s.packer.QueueControlFrame(swf) + } + } + packet, err := s.packer.PackPacket() + if err != nil || packet == nil { + return false, err + } + if err := s.sendPackedPacket(packet); err != nil { + return false, err + } + return true, nil } func (s *session) sendPackedPacket(packet *packedPacket) error { defer putPacketBuffer(packet.raw) err := s.sentPacketHandler.SentPacket(&ackhandler.Packet{ - PacketNumber: packet.number, + PacketNumber: packet.header.PacketNumber, Frames: packet.frames, Length: protocol.ByteCount(len(packet.raw)), EncryptionLevel: packet.encryptionLevel, @@ -701,7 +876,7 @@ func (s *session) sendPackedPacket(packet *packedPacket) error { func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error { s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked()) - packet, err := s.packer.PackConnectionClose(&frames.ConnectionCloseFrame{ + packet, err := s.packer.PackConnectionClose(&wire.ConnectionCloseFrame{ ErrorCode: quicErr.ErrorCode, ReasonPhrase: quicErr.ErrorMessage, }) @@ -717,18 +892,23 @@ func (s *session) logPacket(packet *packedPacket) { // We don't need to allocate the slices for calling the format functions return } - utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", packet.number, len(packet.raw), s.connectionID, packet.encryptionLevel) + utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", packet.header.PacketNumber, len(packet.raw), s.connectionID, packet.encryptionLevel) + packet.header.Log() for _, frame := range packet.frames { - frames.LogFrame(frame, true) + wire.LogFrame(frame, true) } } // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. -// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. +// It is *only* needed for gQUIC's H2. +// It will be removed as soon as gQUIC moves towards the IETF H2/QUIC stream mapping. func (s *session) GetOrOpenStream(id protocol.StreamID) (Stream, error) { - str, err := s.streamsMap.GetOrOpenStream(id) + str, err := s.streamsMap.GetOrOpenSendStream(id) if str != nil { - return str, err + if bstr, ok := str.(Stream); ok { + return bstr, err + } + return nil, fmt.Errorf("Stream %d is not a bidirectional stream", id) } // make sure to return an actual nil value here, not an Stream with value nil return nil, err @@ -748,47 +928,44 @@ func (s *session) OpenStreamSync() (Stream, error) { return s.streamsMap.OpenStreamSync() } -func (s *session) WaitUntilHandshakeComplete() error { - return <-s.handshakeCompleteChan +func (s *session) newStream(id protocol.StreamID) streamI { + flowController := s.newFlowController(id) + return newStream(id, s, flowController, s.version) } -func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) { - s.packer.QueueControlFrame(&frames.RstStreamFrame{ - StreamID: id, - ByteOffset: offset, - }) - s.scheduleSending() -} - -func (s *session) newStream(id protocol.StreamID) *stream { - // TODO: find a better solution for determining which streams contribute to connection level flow control - if id == 1 || id == 3 { - s.flowControlManager.NewStream(id, false) - } else { - s.flowControlManager.NewStream(id, true) +func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController { + var initialSendWindow protocol.ByteCount + if s.peerParams != nil { + initialSendWindow = s.peerParams.StreamFlowControlWindow } - return newStream(id, s.scheduleSending, s.queueResetStreamFrame, s.flowControlManager) + return flowcontrol.NewStreamFlowController( + id, + s.version.StreamContributesToConnectionFlowControl(id), + s.connFlowController, + protocol.ReceiveStreamFlowControlWindow, + protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), + initialSendWindow, + s.rttStats, + ) } -// garbageCollectStreams goes through all streams and removes EOF'ed streams -// from the streams map. -func (s *session) garbageCollectStreams() { - s.streamsMap.Iterate(func(str *stream) (bool, error) { - id := str.StreamID() - if str.finished() { - err := s.streamsMap.RemoveStream(id) - if err != nil { - return false, err - } - s.flowControlManager.RemoveStream(id) - } - return true, nil - }) +func (s *session) newCryptoStream() cryptoStreamI { + id := s.version.CryptoStreamID() + flowController := flowcontrol.NewStreamFlowController( + id, + s.version.StreamContributesToConnectionFlowControl(id), + s.connFlowController, + protocol.ReceiveStreamFlowControlWindow, + protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), + 0, + s.rttStats, + ) + return newCryptoStream(s, flowController, s.version) } func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { utils.Infof("Sending public reset for connection %x, packet number %d", s.connectionID, rejectedPacketNumber) - return s.conn.Write(writePublicReset(s.connectionID, rejectedPacketNumber, 0)) + return s.conn.Write(wire.WritePublicReset(s.connectionID, rejectedPacketNumber, 0)) } // scheduleSending signals that we have data for sending @@ -801,7 +978,7 @@ func (s *session) scheduleSending() { func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { if s.handshakeComplete { - utils.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.publicHeader, len(p.data)) + utils.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data)) return } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { @@ -810,10 +987,10 @@ func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { s.receivedTooManyUndecrytablePacketsTime = time.Now() s.maybeResetTimer() } - utils.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.publicHeader.PacketNumber) + utils.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber) return } - utils.Infof("Queueing packet 0x%x for later decryption", p.publicHeader.PacketNumber) + utils.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber) s.undecryptablePackets = append(s.undecryptablePackets, p) } @@ -824,13 +1001,25 @@ func (s *session) tryDecryptingQueuedPackets() { s.undecryptablePackets = s.undecryptablePackets[:0] } -func (s *session) getWindowUpdateFrames() []*frames.WindowUpdateFrame { - updates := s.flowControlManager.GetWindowUpdates() - res := make([]*frames.WindowUpdateFrame, len(updates)) - for i, u := range updates { - res[i] = &frames.WindowUpdateFrame{StreamID: u.StreamID, ByteOffset: u.Offset} +func (s *session) queueControlFrame(f wire.Frame) { + s.packer.QueueControlFrame(f) + s.scheduleSending() +} + +func (s *session) onHasWindowUpdate(id protocol.StreamID) { + s.windowUpdateQueue.Add(id) + s.scheduleSending() +} + +func (s *session) onHasStreamData(id protocol.StreamID) { + s.streamFramer.AddActiveStream(id) + s.scheduleSending() +} + +func (s *session) onStreamCompleted(id protocol.StreamID) { + if err := s.streamsMap.DeleteStream(id); err != nil { + s.Close(err) } - return res } func (s *session) LocalAddr() net.Addr { @@ -841,3 +1030,15 @@ func (s *session) LocalAddr() net.Addr { func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } + +func (s *session) handshakeStatus() <-chan error { + return s.handshakeChan +} + +func (s *session) getCryptoStream() cryptoStreamI { + return s.cryptoStream +} + +func (s *session) GetVersion() protocol.VersionNumber { + return s.version +} diff --git a/vendor/github.com/lucas-clemente/quic-go/stream.go b/vendor/github.com/lucas-clemente/quic-go/stream.go index ffadc183f..831234934 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream.go @@ -1,62 +1,82 @@ package quic import ( - "context" - "fmt" - "io" "net" "sync" "time" - "github.com/lucas-clemente/quic-go/flowcontrol" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) +const ( + errorCodeStopping protocol.ApplicationErrorCode = 0 + errorCodeStoppingGQUIC protocol.ApplicationErrorCode = 7 +) + +// The streamSender is notified by the stream about various events. +type streamSender interface { + queueControlFrame(wire.Frame) + onHasWindowUpdate(protocol.StreamID) + onHasStreamData(protocol.StreamID) + onStreamCompleted(protocol.StreamID) +} + +// Each of the both stream halves gets its own uniStreamSender. +// This is necessary in order to keep track when both halves have been completed. +type uniStreamSender struct { + streamSender + onStreamCompletedImpl func() +} + +func (s *uniStreamSender) queueControlFrame(f wire.Frame) { + s.streamSender.queueControlFrame(f) +} + +func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) { + s.streamSender.onHasWindowUpdate(id) +} + +func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) { + s.streamSender.onHasStreamData(id) +} + +func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { + s.onStreamCompletedImpl() +} + +var _ streamSender = &uniStreamSender{} + +type streamI interface { + Stream + closeForShutdown(error) + // for receiving + handleStreamFrame(*wire.StreamFrame) error + handleRstStreamFrame(*wire.RstStreamFrame) error + getWindowUpdate() protocol.ByteCount + // for sending + handleStopSendingFrame(*wire.StopSendingFrame) + popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool) + handleMaxStreamDataFrame(*wire.MaxStreamDataFrame) +} + +var _ receiveStreamI = (streamI)(nil) +var _ sendStreamI = (streamI)(nil) + // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // // Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. type stream struct { - mutex sync.Mutex + receiveStream + sendStream - ctx context.Context - ctxCancel context.CancelFunc + completedMutex sync.Mutex + sender streamSender + receiveStreamCompleted bool + sendStreamCompleted bool - streamID protocol.StreamID - onData func() - // onReset is a callback that should send a RST_STREAM - onReset func(protocol.StreamID, protocol.ByteCount) - - readPosInFrame int - writeOffset protocol.ByteCount - readOffset protocol.ByteCount - - // Once set, the errors must not be changed! - err error - - // cancelled is set when Cancel() is called - cancelled utils.AtomicBool - // finishedReading is set once we read a frame with a FinBit - finishedReading utils.AtomicBool - // finisedWriting is set once Close() is called - finishedWriting utils.AtomicBool - // resetLocally is set if Reset() is called - resetLocally utils.AtomicBool - // resetRemotely is set if RegisterRemoteError() is called - resetRemotely utils.AtomicBool - - frameQueue *streamFrameSorter - readChan chan struct{} - readDeadline time.Time - - dataForWriting []byte - finSent utils.AtomicBool - rstSent utils.AtomicBool - writeChan chan struct{} - writeDeadline time.Time - - flowControlManager flowcontrol.FlowControlManager + version protocol.VersionNumber } var _ Stream = &stream{} @@ -69,279 +89,58 @@ func (deadlineError) Timeout() bool { return true } var errDeadline net.Error = &deadlineError{} +type streamCanceledError struct { + error + errorCode protocol.ApplicationErrorCode +} + +func (streamCanceledError) Canceled() bool { return true } +func (e streamCanceledError) ErrorCode() protocol.ApplicationErrorCode { return e.errorCode } + +var _ StreamError = &streamCanceledError{} + // newStream creates a new Stream -func newStream(StreamID protocol.StreamID, - onData func(), - onReset func(protocol.StreamID, protocol.ByteCount), - flowControlManager flowcontrol.FlowControlManager) *stream { - s := &stream{ - onData: onData, - onReset: onReset, - streamID: StreamID, - flowControlManager: flowControlManager, - frameQueue: newStreamFrameSorter(), - readChan: make(chan struct{}, 1), - writeChan: make(chan struct{}, 1), +func newStream(streamID protocol.StreamID, + sender streamSender, + flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, +) *stream { + s := &stream{sender: sender} + senderForSendStream := &uniStreamSender{ + streamSender: sender, + onStreamCompletedImpl: func() { + s.completedMutex.Lock() + s.sendStreamCompleted = true + s.checkIfCompleted() + s.completedMutex.Unlock() + }, } - s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version) + senderForReceiveStream := &uniStreamSender{ + streamSender: sender, + onStreamCompletedImpl: func() { + s.completedMutex.Lock() + s.receiveStreamCompleted = true + s.checkIfCompleted() + s.completedMutex.Unlock() + }, + } + s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController, version) return s } -// Read implements io.Reader. It is not thread safe! -func (s *stream) Read(p []byte) (int, error) { - s.mutex.Lock() - err := s.err - s.mutex.Unlock() - if s.cancelled.Get() || s.resetLocally.Get() { - return 0, err - } - if s.finishedReading.Get() { - return 0, io.EOF - } - - bytesRead := 0 - for bytesRead < len(p) { - s.mutex.Lock() - frame := s.frameQueue.Head() - if frame == nil && bytesRead > 0 { - err = s.err - s.mutex.Unlock() - return bytesRead, err - } - - var err error - for { - // Stop waiting on errors - if s.resetLocally.Get() || s.cancelled.Get() { - err = s.err - break - } - - deadline := s.readDeadline - if !deadline.IsZero() && !time.Now().Before(deadline) { - err = errDeadline - break - } - - if frame != nil { - s.readPosInFrame = int(s.readOffset - frame.Offset) - break - } - - s.mutex.Unlock() - if deadline.IsZero() { - <-s.readChan - } else { - select { - case <-s.readChan: - case <-time.After(deadline.Sub(time.Now())): - } - } - s.mutex.Lock() - frame = s.frameQueue.Head() - } - s.mutex.Unlock() - - if err != nil { - return bytesRead, err - } - - m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame) - - if bytesRead > len(p) { - return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) - } - if s.readPosInFrame > int(frame.DataLen()) { - return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen()) - } - copy(p[bytesRead:], frame.Data[s.readPosInFrame:]) - - s.readPosInFrame += m - bytesRead += m - s.readOffset += protocol.ByteCount(m) - - // when a RST_STREAM was received, the was already informed about the final byteOffset for this stream - if !s.resetRemotely.Get() { - s.flowControlManager.AddBytesRead(s.streamID, protocol.ByteCount(m)) - } - s.onData() // so that a possible WINDOW_UPDATE is sent - - if s.readPosInFrame >= int(frame.DataLen()) { - fin := frame.FinBit - s.mutex.Lock() - s.frameQueue.Pop() - s.mutex.Unlock() - if fin { - s.finishedReading.Set(true) - return bytesRead, io.EOF - } - } - } - - return bytesRead, nil +// need to define StreamID() here, since both receiveStream and readStream have a StreamID() +func (s *stream) StreamID() protocol.StreamID { + // the result is same for receiveStream and sendStream + return s.sendStream.StreamID() } -func (s *stream) Write(p []byte) (int, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.resetLocally.Get() || s.err != nil { - return 0, s.err - } - if s.finishedWriting.Get() { - return 0, fmt.Errorf("write on closed stream %d", s.streamID) - } - if len(p) == 0 { - return 0, nil - } - - s.dataForWriting = make([]byte, len(p)) - copy(s.dataForWriting, p) - s.onData() - - var err error - for { - deadline := s.writeDeadline - if !deadline.IsZero() && !time.Now().Before(deadline) { - err = errDeadline - break - } - if s.dataForWriting == nil || s.err != nil { - break - } - - s.mutex.Unlock() - if deadline.IsZero() { - <-s.writeChan - } else { - select { - case <-s.writeChan: - case <-time.After(deadline.Sub(time.Now())): - } - } - s.mutex.Lock() - } - - if err != nil { - return 0, err - } - if s.err != nil { - return len(p) - len(s.dataForWriting), s.err - } - return len(p), nil -} - -func (s *stream) lenOfDataForWriting() protocol.ByteCount { - s.mutex.Lock() - var l protocol.ByteCount - if s.err == nil { - l = protocol.ByteCount(len(s.dataForWriting)) - } - s.mutex.Unlock() - return l -} - -func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.err != nil || s.dataForWriting == nil { - return nil - } - - var ret []byte - if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { - ret = s.dataForWriting[:maxBytes] - s.dataForWriting = s.dataForWriting[maxBytes:] - } else { - ret = s.dataForWriting - s.dataForWriting = nil - s.signalWrite() - } - s.writeOffset += protocol.ByteCount(len(ret)) - return ret -} - -// Close implements io.Closer func (s *stream) Close() error { - s.finishedWriting.Set(true) - s.ctxCancel() - s.onData() - return nil -} - -func (s *stream) shouldSendReset() bool { - if s.rstSent.Get() { - return false - } - return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin() -} - -func (s *stream) shouldSendFin() bool { - s.mutex.Lock() - res := s.finishedWriting.Get() && !s.finSent.Get() && s.err == nil && s.dataForWriting == nil - s.mutex.Unlock() - return res -} - -func (s *stream) sentFin() { - s.finSent.Set(true) -} - -// AddStreamFrame adds a new stream frame -func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { - maxOffset := frame.Offset + frame.DataLen() - err := s.flowControlManager.UpdateHighestReceived(s.streamID, maxOffset) - if err != nil { + if err := s.sendStream.Close(); err != nil { return err } - - s.mutex.Lock() - defer s.mutex.Unlock() - err = s.frameQueue.Push(frame) - if err != nil && err != errDuplicateStreamData { - return err - } - s.signalRead() - return nil -} - -// signalRead performs a non-blocking send on the readChan -func (s *stream) signalRead() { - select { - case s.readChan <- struct{}{}: - default: - } -} - -// signalRead performs a non-blocking send on the writeChan -func (s *stream) signalWrite() { - select { - case s.writeChan <- struct{}{}: - default: - } -} - -func (s *stream) SetReadDeadline(t time.Time) error { - s.mutex.Lock() - oldDeadline := s.readDeadline - s.readDeadline = t - s.mutex.Unlock() - // if the new deadline is before the currently set deadline, wake up Read() - if t.Before(oldDeadline) { - s.signalRead() - } - return nil -} - -func (s *stream) SetWriteDeadline(t time.Time) error { - s.mutex.Lock() - oldDeadline := s.writeDeadline - s.writeDeadline = t - s.mutex.Unlock() - if t.Before(oldDeadline) { - s.signalWrite() - } + // in gQUIC, we need to send a RST_STREAM with the final offset if CancelRead() was called + s.receiveStream.onClose(s.sendStream.getWriteOffset()) return nil } @@ -351,83 +150,31 @@ func (s *stream) SetDeadline(t time.Time) error { return nil } -// CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset -func (s *stream) CloseRemote(offset protocol.ByteCount) { - s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset}) +// CloseForShutdown closes a stream abruptly. +// It makes Read and Write unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. +func (s *stream) closeForShutdown(err error) { + s.sendStream.closeForShutdown(err) + s.receiveStream.closeForShutdown(err) } -// Cancel is called by session to indicate that an error occurred -// The stream should will be closed immediately -func (s *stream) Cancel(err error) { - s.mutex.Lock() - s.cancelled.Set(true) - s.ctxCancel() - // errors must not be changed! - if s.err == nil { - s.err = err - s.signalRead() - s.signalWrite() +func (s *stream) handleRstStreamFrame(frame *wire.RstStreamFrame) error { + if err := s.receiveStream.handleRstStreamFrame(frame); err != nil { + return err } - s.mutex.Unlock() + if !s.version.UsesIETFFrameFormat() { + s.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: s.StreamID(), + ErrorCode: frame.ErrorCode, + }) + } + return nil } -// resets the stream locally -func (s *stream) Reset(err error) { - if s.resetLocally.Get() { - return +// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed. +// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed. +func (s *stream) checkIfCompleted() { + if s.sendStreamCompleted && s.receiveStreamCompleted { + s.sender.onStreamCompleted(s.StreamID()) } - s.mutex.Lock() - s.resetLocally.Set(true) - s.ctxCancel() - // errors must not be changed! - if s.err == nil { - s.err = err - s.signalRead() - s.signalWrite() - } - if s.shouldSendReset() { - s.onReset(s.streamID, s.writeOffset) - s.rstSent.Set(true) - } - s.mutex.Unlock() -} - -// resets the stream remotely -func (s *stream) RegisterRemoteError(err error) { - if s.resetRemotely.Get() { - return - } - s.mutex.Lock() - s.resetRemotely.Set(true) - s.ctxCancel() - // errors must not be changed! - if s.err == nil { - s.err = err - s.signalWrite() - } - if s.shouldSendReset() { - s.onReset(s.streamID, s.writeOffset) - s.rstSent.Set(true) - } - s.mutex.Unlock() -} - -func (s *stream) finishedWriteAndSentFin() bool { - return s.finishedWriting.Get() && s.finSent.Get() -} - -func (s *stream) finished() bool { - return s.cancelled.Get() || - (s.finishedReading.Get() && s.finishedWriteAndSentFin()) || - (s.resetRemotely.Get() && s.rstSent.Get()) || - (s.finishedReading.Get() && s.rstSent.Get()) || - (s.finishedWriteAndSentFin() && s.resetRemotely.Get()) -} - -func (s *stream) Context() context.Context { - return s.ctx -} - -func (s *stream) StreamID() protocol.StreamID { - return s.streamID } diff --git a/vendor/github.com/lucas-clemente/quic-go/stream_frame_sorter.go b/vendor/github.com/lucas-clemente/quic-go/stream_frame_sorter.go index 4a50150e2..e3a3a807a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream_frame_sorter.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream_frame_sorter.go @@ -3,13 +3,13 @@ package quic import ( "errors" - "github.com/lucas-clemente/quic-go/frames" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) type streamFrameSorter struct { - queuedFrames map[protocol.ByteCount]*frames.StreamFrame + queuedFrames map[protocol.ByteCount]*wire.StreamFrame readPosition protocol.ByteCount gaps *utils.ByteIntervalList } @@ -23,13 +23,13 @@ var ( func newStreamFrameSorter() *streamFrameSorter { s := streamFrameSorter{ gaps: utils.NewByteIntervalList(), - queuedFrames: make(map[protocol.ByteCount]*frames.StreamFrame), + queuedFrames: make(map[protocol.ByteCount]*wire.StreamFrame), } s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount}) return &s } -func (s *streamFrameSorter) Push(frame *frames.StreamFrame) error { +func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error { if frame.DataLen() == 0 { if frame.FinBit { s.queuedFrames[frame.Offset] = frame @@ -143,7 +143,7 @@ func (s *streamFrameSorter) Push(frame *frames.StreamFrame) error { return nil } -func (s *streamFrameSorter) Pop() *frames.StreamFrame { +func (s *streamFrameSorter) Pop() *wire.StreamFrame { frame := s.Head() if frame != nil { s.readPosition += frame.DataLen() @@ -152,7 +152,7 @@ func (s *streamFrameSorter) Pop() *frames.StreamFrame { return frame } -func (s *streamFrameSorter) Head() *frames.StreamFrame { +func (s *streamFrameSorter) Head() *wire.StreamFrame { frame, ok := s.queuedFrames[s.readPosition] if ok { return frame diff --git a/vendor/github.com/lucas-clemente/quic-go/stream_framer.go b/vendor/github.com/lucas-clemente/quic-go/stream_framer.go index 20f82e3e2..933d642bd 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream_framer.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream_framer.go @@ -1,189 +1,140 @@ package quic import ( - "github.com/lucas-clemente/quic-go/flowcontrol" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) type streamFramer struct { - streamsMap *streamsMap + streamGetter streamGetter + cryptoStream cryptoStreamI + version protocol.VersionNumber - flowControlManager flowcontrol.FlowControlManager + retransmissionQueue []*wire.StreamFrame - retransmissionQueue []*frames.StreamFrame - blockedFrameQueue []*frames.BlockedFrame + streamQueueMutex sync.Mutex + activeStreams map[protocol.StreamID]struct{} + streamQueue []protocol.StreamID + hasCryptoStreamData bool } -func newStreamFramer(streamsMap *streamsMap, flowControlManager flowcontrol.FlowControlManager) *streamFramer { +func newStreamFramer( + cryptoStream cryptoStreamI, + streamGetter streamGetter, + v protocol.VersionNumber, +) *streamFramer { return &streamFramer{ - streamsMap: streamsMap, - flowControlManager: flowControlManager, + streamGetter: streamGetter, + cryptoStream: cryptoStream, + activeStreams: make(map[protocol.StreamID]struct{}), + version: v, } } -func (f *streamFramer) AddFrameForRetransmission(frame *frames.StreamFrame) { +func (f *streamFramer) AddFrameForRetransmission(frame *wire.StreamFrame) { f.retransmissionQueue = append(f.retransmissionQueue, frame) } -func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*frames.StreamFrame { - fs, currentLen := f.maybePopFramesForRetransmission(maxLen) - return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...) +func (f *streamFramer) AddActiveStream(id protocol.StreamID) { + if id == f.version.CryptoStreamID() { // the crypto stream is handled separately + f.streamQueueMutex.Lock() + f.hasCryptoStreamData = true + f.streamQueueMutex.Unlock() + return + } + f.streamQueueMutex.Lock() + if _, ok := f.activeStreams[id]; !ok { + f.streamQueue = append(f.streamQueue, id) + f.activeStreams[id] = struct{}{} + } + f.streamQueueMutex.Unlock() } -func (f *streamFramer) PopBlockedFrame() *frames.BlockedFrame { - if len(f.blockedFrameQueue) == 0 { - return nil - } - frame := f.blockedFrameQueue[0] - f.blockedFrameQueue = f.blockedFrameQueue[1:] - return frame +func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*wire.StreamFrame { + fs, currentLen := f.maybePopFramesForRetransmission(maxLen) + return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...) } func (f *streamFramer) HasFramesForRetransmission() bool { return len(f.retransmissionQueue) > 0 } -func (f *streamFramer) HasCryptoStreamFrame() bool { - // TODO(#657): Flow control - cs, _ := f.streamsMap.GetOrOpenStream(1) - return cs.lenOfDataForWriting() > 0 +func (f *streamFramer) HasCryptoStreamData() bool { + f.streamQueueMutex.Lock() + hasCryptoStreamData := f.hasCryptoStreamData + f.streamQueueMutex.Unlock() + return hasCryptoStreamData } -// TODO(lclemente): This is somewhat duplicate with the normal path for generating frames. -// TODO(#657): Flow control -func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *frames.StreamFrame { - if !f.HasCryptoStreamFrame() { - return nil - } - cs, _ := f.streamsMap.GetOrOpenStream(1) - frame := &frames.StreamFrame{ - StreamID: 1, - Offset: cs.writeOffset, - } - frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error - frame.Data = cs.getDataForWriting(maxLen - frameHeaderBytes) +func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { + f.streamQueueMutex.Lock() + frame, hasMoreData := f.cryptoStream.popStreamFrame(maxLen) + f.hasCryptoStreamData = hasMoreData + f.streamQueueMutex.Unlock() return frame } -func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount) (res []*frames.StreamFrame, currentLen protocol.ByteCount) { +func (f *streamFramer) maybePopFramesForRetransmission(maxTotalLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) { for len(f.retransmissionQueue) > 0 { frame := f.retransmissionQueue[0] frame.DataLenPresent = true - frameHeaderLen, _ := frame.MinLength(protocol.VersionWhatever) // can never error - if currentLen+frameHeaderLen >= maxLen { + maxLen := maxTotalLen - currentLen + if frame.Length(f.version) > maxLen && maxLen < protocol.MinStreamFrameSize { break } - currentLen += frameHeaderLen - - splitFrame := maybeSplitOffFrame(frame, maxLen-currentLen) - if splitFrame != nil { // StreamFrame was split + splitFrame, err := frame.MaybeSplitOffFrame(maxLen, f.version) + if err != nil { // maxLen is too small. Can't split frame + break + } + if splitFrame != nil { // frame was split res = append(res, splitFrame) - currentLen += splitFrame.DataLen() + currentLen += splitFrame.Length(f.version) break } f.retransmissionQueue = f.retransmissionQueue[1:] res = append(res, frame) - currentLen += frame.DataLen() + currentLen += frame.Length(f.version) } return } -func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []*frames.StreamFrame) { - frame := &frames.StreamFrame{DataLenPresent: true} +func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) []*wire.StreamFrame { var currentLen protocol.ByteCount - - fn := func(s *stream) (bool, error) { - if s == nil || s.streamID == 1 /* crypto stream is handled separately */ { - return true, nil + var frames []*wire.StreamFrame + f.streamQueueMutex.Lock() + // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet + numActiveStreams := len(f.streamQueue) + for i := 0; i < numActiveStreams; i++ { + if maxTotalLen-currentLen < protocol.MinStreamFrameSize { + break } - - frame.StreamID = s.streamID - // not perfect, but thread-safe since writeOffset is only written when getting data - frame.Offset = s.writeOffset - frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error - if currentLen+frameHeaderBytes > maxBytes { - return false, nil // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here + id := f.streamQueue[0] + f.streamQueue = f.streamQueue[1:] + // This should never return an error. Better check it anyway. + // The stream will only be in the streamQueue, if it enqueued itself there. + str, err := f.streamGetter.GetOrOpenSendStream(id) + // The stream can be nil if it completed after it said it had data. + if str == nil || err != nil { + delete(f.activeStreams, id) + continue } - maxLen := maxBytes - currentLen - frameHeaderBytes - - var sendWindowSize protocol.ByteCount - lenStreamData := s.lenOfDataForWriting() - if lenStreamData != 0 { - sendWindowSize, _ = f.flowControlManager.SendWindowSize(s.streamID) - maxLen = utils.MinByteCount(maxLen, sendWindowSize) + frame, hasMoreData := str.popStreamFrame(maxTotalLen - currentLen) + if hasMoreData { // put the stream back in the queue (at the end) + f.streamQueue = append(f.streamQueue, id) + } else { // no more data to send. Stream is not active any more + delete(f.activeStreams, id) } - - if maxLen == 0 { - return true, nil + if frame == nil { // can happen if the receiveStream was canceled after it said it had data + continue } - - var data []byte - if lenStreamData != 0 { - // Only getDataForWriting() if we didn't have data earlier, so that we - // don't send without FC approval (if a Write() raced). - data = s.getDataForWriting(maxLen) - } - - // This is unlikely, but check it nonetheless, the scheduler might have jumped in. Seems to happen in ~20% of cases in the tests. - shouldSendFin := s.shouldSendFin() - if data == nil && !shouldSendFin { - return true, nil - } - - if shouldSendFin { - frame.FinBit = true - s.sentFin() - } - - frame.Data = data - f.flowControlManager.AddBytesSent(s.streamID, protocol.ByteCount(len(data))) - - // Finally, check if we are now FC blocked and should queue a BLOCKED frame - if f.flowControlManager.RemainingConnectionWindowSize() == 0 { - // We are now connection-level FC blocked - f.blockedFrameQueue = append(f.blockedFrameQueue, &frames.BlockedFrame{StreamID: 0}) - } else if !frame.FinBit && sendWindowSize-frame.DataLen() == 0 { - // We are now stream-level FC blocked - f.blockedFrameQueue = append(f.blockedFrameQueue, &frames.BlockedFrame{StreamID: s.StreamID()}) - } - - res = append(res, frame) - currentLen += frameHeaderBytes + frame.DataLen() - - if currentLen == maxBytes { - return false, nil - } - - frame = &frames.StreamFrame{DataLenPresent: true} - return true, nil - } - - f.streamsMap.RoundRobinIterate(fn) - - return -} - -// maybeSplitOffFrame removes the first n bytes and returns them as a separate frame. If n >= len(frame), nil is returned and nothing is modified. -func maybeSplitOffFrame(frame *frames.StreamFrame, n protocol.ByteCount) *frames.StreamFrame { - if n >= frame.DataLen() { - return nil - } - - defer func() { - frame.Data = frame.Data[n:] - frame.Offset += n - }() - - return &frames.StreamFrame{ - FinBit: false, - StreamID: frame.StreamID, - Offset: frame.Offset, - Data: frame.Data[:n], - DataLenPresent: frame.DataLenPresent, + frames = append(frames, frame) + currentLen += frame.Length(f.version) } + f.streamQueueMutex.Unlock() + return frames } diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map.go b/vendor/github.com/lucas-clemente/quic-go/streams_map.go index 74be17e08..c3ce2ef4d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/streams_map.go +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map.go @@ -1,333 +1,218 @@ package quic import ( - "errors" "fmt" - "sync" - "github.com/lucas-clemente/quic-go/handshake" - "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/qerr" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type streamType int + +const ( + streamTypeOutgoingBidi streamType = iota + streamTypeIncomingBidi + streamTypeOutgoingUni + streamTypeIncomingUni ) type streamsMap struct { - mutex sync.RWMutex + perspective protocol.Perspective - perspective protocol.Perspective - connectionParameters handshake.ConnectionParametersManager + sender streamSender + newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController - streams map[protocol.StreamID]*stream - // needed for round-robin scheduling - openStreams []protocol.StreamID - roundRobinIndex uint32 - - nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() - highestStreamOpenedByPeer protocol.StreamID - nextStreamOrErrCond sync.Cond - openStreamOrErrCond sync.Cond - - closeErr error - nextStreamToAccept protocol.StreamID - - newStream newStreamLambda - - numOutgoingStreams uint32 - numIncomingStreams uint32 + outgoingBidiStreams *outgoingBidiStreamsMap + outgoingUniStreams *outgoingUniStreamsMap + incomingBidiStreams *incomingBidiStreamsMap + incomingUniStreams *incomingUniStreamsMap } -type streamLambda func(*stream) (bool, error) -type newStreamLambda func(protocol.StreamID) *stream +var _ streamManager = &streamsMap{} -var ( - errMapAccess = errors.New("streamsMap: Error accessing the streams map") -) - -func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap { - sm := streamsMap{ - perspective: pers, - streams: map[protocol.StreamID]*stream{}, - openStreams: make([]protocol.StreamID, 0), - newStream: newStream, - connectionParameters: connectionParameters, +func newStreamsMap( + sender streamSender, + newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController, + perspective protocol.Perspective, + version protocol.VersionNumber, +) streamManager { + m := &streamsMap{ + perspective: perspective, + newFlowController: newFlowController, + sender: sender, } - sm.nextStreamOrErrCond.L = &sm.mutex - sm.openStreamOrErrCond.L = &sm.mutex - - if pers == protocol.PerspectiveClient { - sm.nextStream = 1 - sm.nextStreamToAccept = 2 + var firstOutgoingBidiStream, firstOutgoingUniStream, firstIncomingBidiStream, firstIncomingUniStream protocol.StreamID + if perspective == protocol.PerspectiveServer { + firstOutgoingBidiStream = 1 + firstIncomingBidiStream = 4 // the crypto stream is handled separatedly + firstOutgoingUniStream = 3 + firstIncomingUniStream = 2 } else { - sm.nextStream = 2 - sm.nextStreamToAccept = 1 + firstOutgoingBidiStream = 4 // the crypto stream is handled separately + firstIncomingBidiStream = 1 + firstOutgoingUniStream = 2 + firstIncomingUniStream = 3 } - - return &sm + newBidiStream := func(id protocol.StreamID) streamI { + return newStream(id, m.sender, m.newFlowController(id), version) + } + newUniSendStream := func(id protocol.StreamID) sendStreamI { + return newSendStream(id, m.sender, m.newFlowController(id), version) + } + newUniReceiveStream := func(id protocol.StreamID) receiveStreamI { + return newReceiveStream(id, m.sender, m.newFlowController(id), version) + } + m.outgoingBidiStreams = newOutgoingBidiStreamsMap( + firstOutgoingBidiStream, + newBidiStream, + sender.queueControlFrame, + ) + // TODO(#523): make these values configurable + m.incomingBidiStreams = newIncomingBidiStreamsMap( + firstIncomingBidiStream, + protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, perspective), + protocol.MaxIncomingStreams, + sender.queueControlFrame, + newBidiStream, + ) + m.outgoingUniStreams = newOutgoingUniStreamsMap( + firstOutgoingUniStream, + newUniSendStream, + sender.queueControlFrame, + ) + // TODO(#523): make these values configurable + m.incomingUniStreams = newIncomingUniStreamsMap( + firstIncomingUniStream, + protocol.MaxUniStreamID(protocol.MaxIncomingStreams, perspective), + protocol.MaxIncomingStreams, + sender.queueControlFrame, + newUniReceiveStream, + ) + return m } -// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. -// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. -func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { - m.mutex.RLock() - s, ok := m.streams[id] - m.mutex.RUnlock() - if ok { - return s, nil // s may be nil - } - - // ... we don't have an existing stream - m.mutex.Lock() - defer m.mutex.Unlock() - // We need to check whether another invocation has already created a stream (between RUnlock() and Lock()). - s, ok = m.streams[id] - if ok { - return s, nil - } - +func (m *streamsMap) getStreamType(id protocol.StreamID) streamType { if m.perspective == protocol.PerspectiveServer { - if id%2 == 0 { - if id <= m.nextStream { // this is a server-side stream that we already opened. Must have been closed already - return nil, nil - } - return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) + switch id % 4 { + case 0: + return streamTypeIncomingBidi + case 1: + return streamTypeOutgoingBidi + case 2: + return streamTypeIncomingUni + case 3: + return streamTypeOutgoingUni } - if id <= m.highestStreamOpenedByPeer { // this is a client-side stream that doesn't exist anymore. Must have been closed already - return nil, nil - } - } - if m.perspective == protocol.PerspectiveClient { - if id%2 == 1 { - if id <= m.nextStream { // this is a client-side stream that we already opened. - return nil, nil - } - return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) - } - if id <= m.highestStreamOpenedByPeer { // this is a server-side stream that doesn't exist anymore. Must have been closed already - return nil, nil - } - } - - // sid is the next stream that will be opened - sid := m.highestStreamOpenedByPeer + 2 - // if there is no stream opened yet, and this is the server, stream 1 should be openend - if sid == 2 && m.perspective == protocol.PerspectiveServer { - sid = 1 - } - - for ; sid <= id; sid += 2 { - _, err := m.openRemoteStream(sid) - if err != nil { - return nil, err - } - } - - m.nextStreamOrErrCond.Broadcast() - return m.streams[id], nil -} - -func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) { - if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() { - return nil, qerr.TooManyOpenStreams - } - if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer { - return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer)) - } - - if m.perspective == protocol.PerspectiveServer { - m.numIncomingStreams++ } else { - m.numOutgoingStreams++ + switch id % 4 { + case 0: + return streamTypeOutgoingBidi + case 1: + return streamTypeIncomingBidi + case 2: + return streamTypeOutgoingUni + case 3: + return streamTypeIncomingUni + } } - - if id > m.highestStreamOpenedByPeer { - m.highestStreamOpenedByPeer = id - } - - s := m.newStream(id) - m.putStream(s) - return s, nil + panic("") } -func (m *streamsMap) openStreamImpl() (*stream, error) { - id := m.nextStream - if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() { - return nil, qerr.TooManyOpenStreams - } - - if m.perspective == protocol.PerspectiveServer { - m.numOutgoingStreams++ - } else { - m.numIncomingStreams++ - } - - m.nextStream += 2 - s := m.newStream(id) - m.putStream(s) - return s, nil +func (m *streamsMap) OpenStream() (Stream, error) { + return m.outgoingBidiStreams.OpenStream() } -// OpenStream opens the next available stream -func (m *streamsMap) OpenStream() (*stream, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - return m.openStreamImpl() +func (m *streamsMap) OpenStreamSync() (Stream, error) { + return m.outgoingBidiStreams.OpenStreamSync() } -func (m *streamsMap) OpenStreamSync() (*stream, error) { - m.mutex.Lock() - defer m.mutex.Unlock() +func (m *streamsMap) OpenUniStream() (SendStream, error) { + return m.outgoingUniStreams.OpenStream() +} - for { - if m.closeErr != nil { - return nil, m.closeErr - } - str, err := m.openStreamImpl() - if err == nil { - return str, err - } - if err != nil && err != qerr.TooManyOpenStreams { - return nil, err - } - m.openStreamOrErrCond.Wait() +func (m *streamsMap) OpenUniStreamSync() (SendStream, error) { + return m.outgoingUniStreams.OpenStreamSync() +} + +func (m *streamsMap) AcceptStream() (Stream, error) { + return m.incomingBidiStreams.AcceptStream() +} + +func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) { + return m.incomingUniStreams.AcceptStream() +} + +func (m *streamsMap) DeleteStream(id protocol.StreamID) error { + switch m.getStreamType(id) { + case streamTypeIncomingBidi: + return m.incomingBidiStreams.DeleteStream(id) + case streamTypeOutgoingBidi: + return m.outgoingBidiStreams.DeleteStream(id) + case streamTypeIncomingUni: + return m.incomingUniStreams.DeleteStream(id) + case streamTypeOutgoingUni: + return m.outgoingUniStreams.DeleteStream(id) + default: + panic("invalid stream type") } } -// AcceptStream returns the next stream opened by the peer -// it blocks until a new stream is opened -func (m *streamsMap) AcceptStream() (*stream, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - var str *stream - for { - var ok bool - if m.closeErr != nil { - return nil, m.closeErr - } - str, ok = m.streams[m.nextStreamToAccept] - if ok { - break - } - m.nextStreamOrErrCond.Wait() +func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { + switch m.getStreamType(id) { + case streamTypeOutgoingBidi: + return m.outgoingBidiStreams.GetStream(id) + case streamTypeIncomingBidi: + return m.incomingBidiStreams.GetOrOpenStream(id) + case streamTypeIncomingUni: + return m.incomingUniStreams.GetOrOpenStream(id) + case streamTypeOutgoingUni: + // an outgoing unidirectional stream is a send stream, not a receive stream + return nil, fmt.Errorf("peer attempted to open receive stream %d", id) + default: + panic("invalid stream type") } - m.nextStreamToAccept += 2 - return str, nil } -func (m *streamsMap) Iterate(fn streamLambda) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - openStreams := append([]protocol.StreamID{}, m.openStreams...) - - for _, streamID := range openStreams { - cont, err := m.iterateFunc(streamID, fn) - if err != nil { - return err - } - if !cont { - break - } +func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { + switch m.getStreamType(id) { + case streamTypeOutgoingBidi: + return m.outgoingBidiStreams.GetStream(id) + case streamTypeIncomingBidi: + return m.incomingBidiStreams.GetOrOpenStream(id) + case streamTypeOutgoingUni: + return m.outgoingUniStreams.GetStream(id) + case streamTypeIncomingUni: + // an incoming unidirectional stream is a receive stream, not a send stream + return nil, fmt.Errorf("peer attempted to open send stream %d", id) + default: + panic("invalid stream type") } - return nil } -// RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false -// It uses a round-robin-like scheduling to ensure that every stream is considered fairly -// It prioritizes the crypto- and the header-stream (StreamIDs 1 and 3) -func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - numStreams := uint32(len(m.streams)) - startIndex := m.roundRobinIndex - - for _, i := range []protocol.StreamID{1, 3} { - cont, err := m.iterateFunc(i, fn) - if err != nil && err != errMapAccess { - return err - } - if !cont { - return nil - } +func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error { + id := f.StreamID + switch m.getStreamType(id) { + case streamTypeOutgoingBidi: + m.outgoingBidiStreams.SetMaxStream(id) + return nil + case streamTypeOutgoingUni: + m.outgoingUniStreams.SetMaxStream(id) + return nil + default: + return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id) } - - for i := uint32(0); i < numStreams; i++ { - streamID := m.openStreams[(i+startIndex)%numStreams] - if streamID == 1 || streamID == 3 { - continue - } - - cont, err := m.iterateFunc(streamID, fn) - if err != nil { - return err - } - m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams - if !cont { - break - } - } - return nil } -func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) { - str, ok := m.streams[streamID] - if !ok { - return true, errMapAccess - } - return fn(str) -} - -func (m *streamsMap) putStream(s *stream) error { - id := s.StreamID() - if _, ok := m.streams[id]; ok { - return fmt.Errorf("a stream with ID %d already exists", id) - } - - m.streams[id] = s - m.openStreams = append(m.openStreams, id) - return nil -} - -// Attention: this function must only be called if a mutex has been acquired previously -func (m *streamsMap) RemoveStream(id protocol.StreamID) error { - s, ok := m.streams[id] - if !ok || s == nil { - return fmt.Errorf("attempted to remove non-existing stream: %d", id) - } - - if id%2 == 0 { - m.numOutgoingStreams-- - } else { - m.numIncomingStreams-- - } - - for i, s := range m.openStreams { - if s == id { - // delete the streamID from the openStreams slice - m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])] - // adjust round-robin index, if necessary - if uint32(i) < m.roundRobinIndex { - m.roundRobinIndex-- - } - break - } - } - - delete(m.streams, id) - m.openStreamOrErrCond.Signal() - return nil +func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) { + m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamID) + m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamID) } func (m *streamsMap) CloseWithError(err error) { - m.mutex.Lock() - defer m.mutex.Unlock() - m.closeErr = err - m.nextStreamOrErrCond.Broadcast() - m.openStreamOrErrCond.Broadcast() - for _, s := range m.openStreams { - m.streams[s].Cancel(err) - } + m.outgoingBidiStreams.CloseWithError(err) + m.outgoingUniStreams.CloseWithError(err) + m.incomingBidiStreams.CloseWithError(err) + m.incomingUniStreams.CloseWithError(err) } diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go new file mode 100644 index 000000000..8a35f044f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go @@ -0,0 +1,123 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type incomingBidiStreamsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]streamI + + nextStream protocol.StreamID // the next stream that will be returned by AcceptStream() + highestStream protocol.StreamID // the highest stream that the peer openend + maxStream protocol.StreamID // the highest stream that the peer is allowed to open + maxNumStreams int // maximum number of streams + + newStream func(protocol.StreamID) streamI + queueMaxStreamID func(*wire.MaxStreamIDFrame) + + closeErr error +} + +func newIncomingBidiStreamsMap( + nextStream protocol.StreamID, + initialMaxStreamID protocol.StreamID, + maxNumStreams int, + queueControlFrame func(wire.Frame), + newStream func(protocol.StreamID) streamI, +) *incomingBidiStreamsMap { + m := &incomingBidiStreamsMap{ + streams: make(map[protocol.StreamID]streamI), + nextStream: nextStream, + maxStream: initialMaxStreamID, + maxNumStreams: maxNumStreams, + newStream: newStream, + queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) }, + } + m.cond.L = &m.mutex + return m +} + +func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + var str streamI + for { + var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } + str, ok = m.streams[m.nextStream] + if ok { + break + } + m.cond.Wait() + } + m.nextStream += 4 + return str, nil +} + +func (m *incomingBidiStreamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { + if id > m.maxStream { + return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + } + // if the id is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if id <= m.highestStream { + m.mutex.RLock() + s := m.streams[id] + m.mutex.RUnlock() + return s, nil + } + + m.mutex.Lock() + var start protocol.StreamID + if m.highestStream == 0 { + start = m.nextStream + } else { + start = m.highestStream + 4 + } + for newID := start; newID <= id; newID += 4 { + m.streams[newID] = m.newStream(newID) + m.cond.Signal() + } + m.highestStream = id + s := m.streams[id] + m.mutex.Unlock() + return s, nil +} + +func (m *incomingBidiStreamsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if numNewStreams := m.maxNumStreams - len(m.streams); numNewStreams > 0 { + m.maxStream = m.highestStream + protocol.StreamID(numNewStreams*4) + m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream}) + } + return nil +} + +func (m *incomingBidiStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + m.mutex.Unlock() + m.cond.Broadcast() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go new file mode 100644 index 000000000..830b690d9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go @@ -0,0 +1,121 @@ +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream" +//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream" +type incomingItemsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]item + + nextStream protocol.StreamID // the next stream that will be returned by AcceptStream() + highestStream protocol.StreamID // the highest stream that the peer openend + maxStream protocol.StreamID // the highest stream that the peer is allowed to open + maxNumStreams int // maximum number of streams + + newStream func(protocol.StreamID) item + queueMaxStreamID func(*wire.MaxStreamIDFrame) + + closeErr error +} + +func newIncomingItemsMap( + nextStream protocol.StreamID, + initialMaxStreamID protocol.StreamID, + maxNumStreams int, + queueControlFrame func(wire.Frame), + newStream func(protocol.StreamID) item, +) *incomingItemsMap { + m := &incomingItemsMap{ + streams: make(map[protocol.StreamID]item), + nextStream: nextStream, + maxStream: initialMaxStreamID, + maxNumStreams: maxNumStreams, + newStream: newStream, + queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) }, + } + m.cond.L = &m.mutex + return m +} + +func (m *incomingItemsMap) AcceptStream() (item, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + var str item + for { + var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } + str, ok = m.streams[m.nextStream] + if ok { + break + } + m.cond.Wait() + } + m.nextStream += 4 + return str, nil +} + +func (m *incomingItemsMap) GetOrOpenStream(id protocol.StreamID) (item, error) { + if id > m.maxStream { + return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + } + // if the id is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if id <= m.highestStream { + m.mutex.RLock() + s := m.streams[id] + m.mutex.RUnlock() + return s, nil + } + + m.mutex.Lock() + var start protocol.StreamID + if m.highestStream == 0 { + start = m.nextStream + } else { + start = m.highestStream + 4 + } + for newID := start; newID <= id; newID += 4 { + m.streams[newID] = m.newStream(newID) + m.cond.Signal() + } + m.highestStream = id + s := m.streams[id] + m.mutex.Unlock() + return s, nil +} + +func (m *incomingItemsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if numNewStreams := m.maxNumStreams - len(m.streams); numNewStreams > 0 { + m.maxStream = m.highestStream + protocol.StreamID(numNewStreams*4) + m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream}) + } + return nil +} + +func (m *incomingItemsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + m.mutex.Unlock() + m.cond.Broadcast() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go new file mode 100644 index 000000000..9091d6357 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go @@ -0,0 +1,123 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type incomingUniStreamsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]receiveStreamI + + nextStream protocol.StreamID // the next stream that will be returned by AcceptStream() + highestStream protocol.StreamID // the highest stream that the peer openend + maxStream protocol.StreamID // the highest stream that the peer is allowed to open + maxNumStreams int // maximum number of streams + + newStream func(protocol.StreamID) receiveStreamI + queueMaxStreamID func(*wire.MaxStreamIDFrame) + + closeErr error +} + +func newIncomingUniStreamsMap( + nextStream protocol.StreamID, + initialMaxStreamID protocol.StreamID, + maxNumStreams int, + queueControlFrame func(wire.Frame), + newStream func(protocol.StreamID) receiveStreamI, +) *incomingUniStreamsMap { + m := &incomingUniStreamsMap{ + streams: make(map[protocol.StreamID]receiveStreamI), + nextStream: nextStream, + maxStream: initialMaxStreamID, + maxNumStreams: maxNumStreams, + newStream: newStream, + queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) }, + } + m.cond.L = &m.mutex + return m +} + +func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + var str receiveStreamI + for { + var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } + str, ok = m.streams[m.nextStream] + if ok { + break + } + m.cond.Wait() + } + m.nextStream += 4 + return str, nil +} + +func (m *incomingUniStreamsMap) GetOrOpenStream(id protocol.StreamID) (receiveStreamI, error) { + if id > m.maxStream { + return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + } + // if the id is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if id <= m.highestStream { + m.mutex.RLock() + s := m.streams[id] + m.mutex.RUnlock() + return s, nil + } + + m.mutex.Lock() + var start protocol.StreamID + if m.highestStream == 0 { + start = m.nextStream + } else { + start = m.highestStream + 4 + } + for newID := start; newID <= id; newID += 4 { + m.streams[newID] = m.newStream(newID) + m.cond.Signal() + } + m.highestStream = id + s := m.streams[id] + m.mutex.Unlock() + return s, nil +} + +func (m *incomingUniStreamsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if numNewStreams := m.maxNumStreams - len(m.streams); numNewStreams > 0 { + m.maxStream = m.highestStream + protocol.StreamID(numNewStreams*4) + m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream}) + } + return nil +} + +func (m *incomingUniStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + m.mutex.Unlock() + m.cond.Broadcast() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_legacy.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_legacy.go new file mode 100644 index 000000000..152f20e5c --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_legacy.go @@ -0,0 +1,263 @@ +package quic + +import ( + "errors" + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" +) + +type streamsMapLegacy struct { + mutex sync.RWMutex + + perspective protocol.Perspective + + streams map[protocol.StreamID]streamI + + nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() + highestStreamOpenedByPeer protocol.StreamID + nextStreamOrErrCond sync.Cond + openStreamOrErrCond sync.Cond + + closeErr error + nextStreamToAccept protocol.StreamID + + newStream func(protocol.StreamID) streamI + + numOutgoingStreams uint32 + numIncomingStreams uint32 + maxIncomingStreams uint32 + maxOutgoingStreams uint32 +} + +var _ streamManager = &streamsMapLegacy{} + +var errMapAccess = errors.New("streamsMap: Error accessing the streams map") + +func newStreamsMapLegacy(newStream func(protocol.StreamID) streamI, pers protocol.Perspective) streamManager { + // add some tolerance to the maximum incoming streams value + maxStreams := uint32(protocol.MaxIncomingStreams) + maxIncomingStreams := utils.MaxUint32( + maxStreams+protocol.MaxStreamsMinimumIncrement, + uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)), + ) + sm := streamsMapLegacy{ + perspective: pers, + streams: make(map[protocol.StreamID]streamI), + newStream: newStream, + maxIncomingStreams: maxIncomingStreams, + } + sm.nextStreamOrErrCond.L = &sm.mutex + sm.openStreamOrErrCond.L = &sm.mutex + + nextServerInitiatedStream := protocol.StreamID(2) + nextClientInitiatedStream := protocol.StreamID(3) + if pers == protocol.PerspectiveServer { + sm.highestStreamOpenedByPeer = 1 + } + if pers == protocol.PerspectiveServer { + sm.nextStreamToOpen = nextServerInitiatedStream + sm.nextStreamToAccept = nextClientInitiatedStream + } else { + sm.nextStreamToOpen = nextClientInitiatedStream + sm.nextStreamToAccept = nextServerInitiatedStream + } + return &sm +} + +// getStreamPerspective says which side should initiate a stream +func (m *streamsMapLegacy) streamInitiatedBy(id protocol.StreamID) protocol.Perspective { + if id%2 == 0 { + return protocol.PerspectiveServer + } + return protocol.PerspectiveClient +} + +func (m *streamsMapLegacy) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { + // every bidirectional stream is also a receive stream + return m.getOrOpenStream(id) +} + +func (m *streamsMapLegacy) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { + // every bidirectional stream is also a send stream + return m.getOrOpenStream(id) +} + +// getOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. +// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. +func (m *streamsMapLegacy) getOrOpenStream(id protocol.StreamID) (streamI, error) { + m.mutex.RLock() + s, ok := m.streams[id] + m.mutex.RUnlock() + if ok { + return s, nil + } + + // ... we don't have an existing stream + m.mutex.Lock() + defer m.mutex.Unlock() + // We need to check whether another invocation has already created a stream (between RUnlock() and Lock()). + s, ok = m.streams[id] + if ok { + return s, nil + } + + if m.perspective == m.streamInitiatedBy(id) { + if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already + return nil, nil + } + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) + } + if id <= m.highestStreamOpenedByPeer { // this is a peer-initiated stream that doesn't exist anymore. Must have been closed already + return nil, nil + } + + for sid := m.highestStreamOpenedByPeer + 2; sid <= id; sid += 2 { + if _, err := m.openRemoteStream(sid); err != nil { + return nil, err + } + } + + m.nextStreamOrErrCond.Broadcast() + return m.streams[id], nil +} + +func (m *streamsMapLegacy) openRemoteStream(id protocol.StreamID) (streamI, error) { + if m.numIncomingStreams >= m.maxIncomingStreams { + return nil, qerr.TooManyOpenStreams + } + if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer)) + } + + m.numIncomingStreams++ + if id > m.highestStreamOpenedByPeer { + m.highestStreamOpenedByPeer = id + } + + s := m.newStream(id) + return s, m.putStream(s) +} + +func (m *streamsMapLegacy) openStreamImpl() (streamI, error) { + if m.numOutgoingStreams >= m.maxOutgoingStreams { + return nil, qerr.TooManyOpenStreams + } + + m.numOutgoingStreams++ + s := m.newStream(m.nextStreamToOpen) + m.nextStreamToOpen += 2 + return s, m.putStream(s) +} + +// OpenStream opens the next available stream +func (m *streamsMapLegacy) OpenStream() (Stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + return m.openStreamImpl() +} + +func (m *streamsMapLegacy) OpenStreamSync() (Stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + if m.closeErr != nil { + return nil, m.closeErr + } + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.openStreamOrErrCond.Wait() + } +} + +// AcceptStream returns the next stream opened by the peer +// it blocks until a new stream is opened +func (m *streamsMapLegacy) AcceptStream() (Stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + var str streamI + for { + var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } + str, ok = m.streams[m.nextStreamToAccept] + if ok { + break + } + m.nextStreamOrErrCond.Wait() + } + m.nextStreamToAccept += 2 + return str, nil +} + +func (m *streamsMapLegacy) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + _, ok := m.streams[id] + if !ok { + return errMapAccess + } + delete(m.streams, id) + if m.streamInitiatedBy(id) == m.perspective { + m.numOutgoingStreams-- + } else { + m.numIncomingStreams-- + } + m.openStreamOrErrCond.Signal() + return nil +} + +func (m *streamsMapLegacy) putStream(s streamI) error { + id := s.StreamID() + if _, ok := m.streams[id]; ok { + return fmt.Errorf("a stream with ID %d already exists", id) + } + m.streams[id] = s + return nil +} + +func (m *streamsMapLegacy) CloseWithError(err error) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.closeErr = err + m.nextStreamOrErrCond.Broadcast() + m.openStreamOrErrCond.Broadcast() + for _, s := range m.streams { + s.closeForShutdown(err) + } +} + +// TODO(#952): this won't be needed when gQUIC supports stateless handshakes +func (m *streamsMapLegacy) UpdateLimits(params *handshake.TransportParameters) { + m.mutex.Lock() + m.maxOutgoingStreams = params.MaxStreams + for id, str := range m.streams { + str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: id, + ByteOffset: params.StreamFlowControlWindow, + }) + } + m.mutex.Unlock() + m.openStreamOrErrCond.Broadcast() +} + +// should never be called, since MAX_STREAM_ID frames can only be unpacked for IETF QUIC +func (m *streamsMapLegacy) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error { + return errors.New("gQUIC doesn't have MAX_STREAM_ID frames") +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go new file mode 100644 index 000000000..d2c92dec2 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go @@ -0,0 +1,122 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" +) + +type outgoingBidiStreamsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]streamI + + nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamID // the maximum stream ID we're allowed to open + highestBlocked protocol.StreamID // the highest stream ID that we queued a STREAM_ID_BLOCKED frame for + + newStream func(protocol.StreamID) streamI + queueStreamIDBlocked func(*wire.StreamIDBlockedFrame) + + closeErr error +} + +func newOutgoingBidiStreamsMap( + nextStream protocol.StreamID, + newStream func(protocol.StreamID) streamI, + queueControlFrame func(wire.Frame), +) *outgoingBidiStreamsMap { + m := &outgoingBidiStreamsMap{ + streams: make(map[protocol.StreamID]streamI), + nextStream: nextStream, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamIDBlockedFrame) { queueControlFrame(f) }, + } + m.cond.L = &m.mutex + return m +} + +func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.openStreamImpl() +} + +func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.cond.Wait() + } +} + +func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) { + if m.closeErr != nil { + return nil, m.closeErr + } + if m.nextStream > m.maxStream { + if m.maxStream == 0 || m.highestBlocked < m.maxStream { + m.queueStreamIDBlocked(&wire.StreamIDBlockedFrame{StreamID: m.maxStream}) + m.highestBlocked = m.maxStream + } + return nil, qerr.TooManyOpenStreams + } + s := m.newStream(m.nextStream) + m.streams[m.nextStream] = s + m.nextStream += 4 + return s, nil +} + +func (m *outgoingBidiStreamsMap) GetStream(id protocol.StreamID) (streamI, error) { + if id >= m.nextStream { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) + } + m.mutex.RLock() + s := m.streams[id] + m.mutex.RUnlock() + return s, nil +} + +func (m *outgoingBidiStreamsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + return nil +} + +func (m *outgoingBidiStreamsMap) SetMaxStream(id protocol.StreamID) { + m.mutex.Lock() + if id > m.maxStream { + m.maxStream = id + m.cond.Broadcast() + } + m.mutex.Unlock() +} + +func (m *outgoingBidiStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + m.cond.Broadcast() + m.mutex.Unlock() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go new file mode 100644 index 000000000..5a2836026 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go @@ -0,0 +1,123 @@ +package quic + +import ( + "fmt" + "sync" + + "github.com/cheekybits/genny/generic" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" +) + +type item generic.Type + +//go:generate genny -in $GOFILE -out streams_map_outgoing_bidi.go gen "item=streamI Item=BidiStream" +//go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream" +type outgoingItemsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]item + + nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamID // the maximum stream ID we're allowed to open + highestBlocked protocol.StreamID // the highest stream ID that we queued a STREAM_ID_BLOCKED frame for + + newStream func(protocol.StreamID) item + queueStreamIDBlocked func(*wire.StreamIDBlockedFrame) + + closeErr error +} + +func newOutgoingItemsMap( + nextStream protocol.StreamID, + newStream func(protocol.StreamID) item, + queueControlFrame func(wire.Frame), +) *outgoingItemsMap { + m := &outgoingItemsMap{ + streams: make(map[protocol.StreamID]item), + nextStream: nextStream, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamIDBlockedFrame) { queueControlFrame(f) }, + } + m.cond.L = &m.mutex + return m +} + +func (m *outgoingItemsMap) OpenStream() (item, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.openStreamImpl() +} + +func (m *outgoingItemsMap) OpenStreamSync() (item, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.cond.Wait() + } +} + +func (m *outgoingItemsMap) openStreamImpl() (item, error) { + if m.closeErr != nil { + return nil, m.closeErr + } + if m.nextStream > m.maxStream { + if m.maxStream == 0 || m.highestBlocked < m.maxStream { + m.queueStreamIDBlocked(&wire.StreamIDBlockedFrame{StreamID: m.maxStream}) + m.highestBlocked = m.maxStream + } + return nil, qerr.TooManyOpenStreams + } + s := m.newStream(m.nextStream) + m.streams[m.nextStream] = s + m.nextStream += 4 + return s, nil +} + +func (m *outgoingItemsMap) GetStream(id protocol.StreamID) (item, error) { + if id >= m.nextStream { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) + } + m.mutex.RLock() + s := m.streams[id] + m.mutex.RUnlock() + return s, nil +} + +func (m *outgoingItemsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + return nil +} + +func (m *outgoingItemsMap) SetMaxStream(id protocol.StreamID) { + m.mutex.Lock() + if id > m.maxStream { + m.maxStream = id + m.cond.Broadcast() + } + m.mutex.Unlock() +} + +func (m *outgoingItemsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + m.cond.Broadcast() + m.mutex.Unlock() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go new file mode 100644 index 000000000..77511b780 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go @@ -0,0 +1,122 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" +) + +type outgoingUniStreamsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]sendStreamI + + nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamID // the maximum stream ID we're allowed to open + highestBlocked protocol.StreamID // the highest stream ID that we queued a STREAM_ID_BLOCKED frame for + + newStream func(protocol.StreamID) sendStreamI + queueStreamIDBlocked func(*wire.StreamIDBlockedFrame) + + closeErr error +} + +func newOutgoingUniStreamsMap( + nextStream protocol.StreamID, + newStream func(protocol.StreamID) sendStreamI, + queueControlFrame func(wire.Frame), +) *outgoingUniStreamsMap { + m := &outgoingUniStreamsMap{ + streams: make(map[protocol.StreamID]sendStreamI), + nextStream: nextStream, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamIDBlockedFrame) { queueControlFrame(f) }, + } + m.cond.L = &m.mutex + return m +} + +func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.openStreamImpl() +} + +func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.cond.Wait() + } +} + +func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) { + if m.closeErr != nil { + return nil, m.closeErr + } + if m.nextStream > m.maxStream { + if m.maxStream == 0 || m.highestBlocked < m.maxStream { + m.queueStreamIDBlocked(&wire.StreamIDBlockedFrame{StreamID: m.maxStream}) + m.highestBlocked = m.maxStream + } + return nil, qerr.TooManyOpenStreams + } + s := m.newStream(m.nextStream) + m.streams[m.nextStream] = s + m.nextStream += 4 + return s, nil +} + +func (m *outgoingUniStreamsMap) GetStream(id protocol.StreamID) (sendStreamI, error) { + if id >= m.nextStream { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) + } + m.mutex.RLock() + s := m.streams[id] + m.mutex.RUnlock() + return s, nil +} + +func (m *outgoingUniStreamsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + return nil +} + +func (m *outgoingUniStreamsMap) SetMaxStream(id protocol.StreamID) { + m.mutex.Lock() + if id > m.maxStream { + m.maxStream = id + m.cond.Broadcast() + } + m.mutex.Unlock() +} + +func (m *outgoingUniStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + m.cond.Broadcast() + m.mutex.Unlock() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/alert.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/alert.go new file mode 100644 index 000000000..430e45542 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/alert.go @@ -0,0 +1,101 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mint + +import "strconv" + +type Alert uint8 + +const ( + // alert level + AlertLevelWarning = 1 + AlertLevelError = 2 +) + +const ( + AlertCloseNotify Alert = 0 + AlertUnexpectedMessage Alert = 10 + AlertBadRecordMAC Alert = 20 + AlertDecryptionFailed Alert = 21 + AlertRecordOverflow Alert = 22 + AlertDecompressionFailure Alert = 30 + AlertHandshakeFailure Alert = 40 + AlertBadCertificate Alert = 42 + AlertUnsupportedCertificate Alert = 43 + AlertCertificateRevoked Alert = 44 + AlertCertificateExpired Alert = 45 + AlertCertificateUnknown Alert = 46 + AlertIllegalParameter Alert = 47 + AlertUnknownCA Alert = 48 + AlertAccessDenied Alert = 49 + AlertDecodeError Alert = 50 + AlertDecryptError Alert = 51 + AlertProtocolVersion Alert = 70 + AlertInsufficientSecurity Alert = 71 + AlertInternalError Alert = 80 + AlertInappropriateFallback Alert = 86 + AlertUserCanceled Alert = 90 + AlertNoRenegotiation Alert = 100 + AlertMissingExtension Alert = 109 + AlertUnsupportedExtension Alert = 110 + AlertCertificateUnobtainable Alert = 111 + AlertUnrecognizedName Alert = 112 + AlertBadCertificateStatsResponse Alert = 113 + AlertBadCertificateHashValue Alert = 114 + AlertUnknownPSKIdentity Alert = 115 + AlertNoApplicationProtocol Alert = 120 + AlertStatelessRetry Alert = 253 + AlertWouldBlock Alert = 254 + AlertNoAlert Alert = 255 +) + +var alertText = map[Alert]string{ + AlertCloseNotify: "close notify", + AlertUnexpectedMessage: "unexpected message", + AlertBadRecordMAC: "bad record MAC", + AlertDecryptionFailed: "decryption failed", + AlertRecordOverflow: "record overflow", + AlertDecompressionFailure: "decompression failure", + AlertHandshakeFailure: "handshake failure", + AlertBadCertificate: "bad certificate", + AlertUnsupportedCertificate: "unsupported certificate", + AlertCertificateRevoked: "revoked certificate", + AlertCertificateExpired: "expired certificate", + AlertCertificateUnknown: "unknown certificate", + AlertIllegalParameter: "illegal parameter", + AlertUnknownCA: "unknown certificate authority", + AlertAccessDenied: "access denied", + AlertDecodeError: "error decoding message", + AlertDecryptError: "error decrypting message", + AlertProtocolVersion: "protocol version not supported", + AlertInsufficientSecurity: "insufficient security level", + AlertInternalError: "internal error", + AlertInappropriateFallback: "inappropriate fallback", + AlertUserCanceled: "user canceled", + AlertMissingExtension: "missing extension", + AlertUnsupportedExtension: "unsupported extension", + AlertCertificateUnobtainable: "certificate unobtainable", + AlertUnrecognizedName: "unrecognized name", + AlertBadCertificateStatsResponse: "bad certificate status response", + AlertBadCertificateHashValue: "bad certificate hash value", + AlertUnknownPSKIdentity: "unknown PSK identity", + AlertNoApplicationProtocol: "no application protocol", + AlertNoRenegotiation: "no renegotiation", + AlertStatelessRetry: "stateless retry", + AlertWouldBlock: "would have blocked", + AlertNoAlert: "no alert", +} + +func (e Alert) String() string { + s, ok := alertText[e] + if ok { + return s + } + return "alert(" + strconv.Itoa(int(e)) + ")" +} + +func (e Alert) Error() string { + return e.String() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/client-state-machine.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/client-state-machine.go new file mode 100644 index 000000000..ddd021816 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/client-state-machine.go @@ -0,0 +1,1062 @@ +package mint + +import ( + "bytes" + "crypto" + "crypto/x509" + "hash" + "time" +) + +// Client State Machine +// +// START <----+ +// Send ClientHello | | Recv HelloRetryRequest +// / v | +// | WAIT_SH ---+ +// Can | | Recv ServerHello +// send | V +// early | WAIT_EE +// data | | Recv EncryptedExtensions +// | +--------+--------+ +// | Using | | Using certificate +// | PSK | v +// | | WAIT_CERT_CR +// | | Recv | | Recv CertificateRequest +// | | Certificate | v +// | | | WAIT_CERT +// | | | | Recv Certificate +// | | v v +// | | WAIT_CV +// | | | Recv CertificateVerify +// | +> WAIT_FINISHED <+ +// | | Recv Finished +// \ | +// | [Send EndOfEarlyData] +// | [Send Certificate [+ CertificateVerify]] +// | Send Finished +// Can send v +// app data --> CONNECTED +// after +// here +// +// State Instructions +// START Send(CH); [RekeyOut; SendEarlyData] +// WAIT_SH Send(CH) || RekeyIn +// WAIT_EE {} +// WAIT_CERT_CR {} +// WAIT_CERT {} +// WAIT_CV {} +// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut; +// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) + +type ClientStateStart struct { + Config *Config + Opts ConnectionOptions + Params ConnectionParameters + + cookie []byte + firstClientHello *HandshakeMessage + helloRetryRequest *HandshakeMessage + hsCtx HandshakeContext +} + +var _ HandshakeState = &ClientStateStart{} + +func (state ClientStateStart) State() State { + return StateClientStart +} + +func (state ClientStateStart) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + // key_shares + offeredDH := map[NamedGroup][]byte{} + ks := KeyShareExtension{ + HandshakeType: HandshakeTypeClientHello, + Shares: make([]KeyShareEntry, len(state.Config.Groups)), + } + for i, group := range state.Config.Groups { + pub, priv, err := newKeyShare(group) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err) + return nil, nil, AlertInternalError + } + + ks.Shares[i].Group = group + ks.Shares[i].KeyExchange = pub + offeredDH[group] = priv + } + + logf(logTypeHandshake, "opts: %+v", state.Opts) + + // supported_versions, supported_groups, signature_algorithms, server_name + sv := SupportedVersionsExtension{HandshakeType: HandshakeTypeClientHello, Versions: []uint16{supportedVersion}} + sni := ServerNameExtension(state.Opts.ServerName) + sg := SupportedGroupsExtension{Groups: state.Config.Groups} + sa := SignatureAlgorithmsExtension{Algorithms: state.Config.SignatureSchemes} + + state.Params.ServerName = state.Opts.ServerName + + // Application Layer Protocol Negotiation + var alpn *ALPNExtension + if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) { + alpn = &ALPNExtension{Protocols: state.Opts.NextProtos} + } + + // Construct base ClientHello + ch := &ClientHelloBody{ + LegacyVersion: wireVersion(state.hsCtx.hIn), + CipherSuites: state.Config.CipherSuites, + } + _, err := prng.Read(ch.Random[:]) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err) + return nil, nil, AlertInternalError + } + for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} { + err := ch.Extensions.Add(ext) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err) + return nil, nil, AlertInternalError + } + } + // XXX: These optional extensions can't be folded into the above because Go + // interface-typed values are never reported as nil + if alpn != nil { + err := ch.Extensions.Add(alpn) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err) + return nil, nil, AlertInternalError + } + } + if state.cookie != nil { + err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie}) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Run the external extension handler. + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Handle PSK and EarlyData just before transmitting, so that we can + // calculate the PSK binder value + var psk *PreSharedKeyExtension + var ed *EarlyDataExtension + var offeredPSK PreSharedKey + var earlyHash crypto.Hash + var earlySecret []byte + var clientEarlyTrafficKeys keySet + var clientHello *HandshakeMessage + if key, ok := state.Config.PSKs.Get(state.Opts.ServerName); ok { + offeredPSK = key + + // Narrow ciphersuites to ones that match PSK hash + params, ok := cipherSuiteMap[key.CipherSuite] + if !ok { + logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite") + return nil, nil, AlertInternalError + } + + compatibleSuites := []CipherSuite{} + for _, suite := range ch.CipherSuites { + if cipherSuiteMap[suite].Hash == params.Hash { + compatibleSuites = append(compatibleSuites, suite) + } + } + ch.CipherSuites = compatibleSuites + + // Signal early data if we're going to do it + if len(state.Opts.EarlyData) > 0 { + state.Params.ClientSendingEarlyData = true + ed = &EarlyDataExtension{} + err = ch.Extensions.Add(ed) + if err != nil { + logf(logTypeHandshake, "Error adding early data extension: %v", err) + return nil, nil, AlertInternalError + } + } + + // Signal supported PSK key exchange modes + if len(state.Config.PSKModes) == 0 { + logf(logTypeHandshake, "PSK selected, but no PSKModes") + return nil, nil, AlertInternalError + } + kem := &PSKKeyExchangeModesExtension{KEModes: state.Config.PSKModes} + err = ch.Extensions.Add(kem) + if err != nil { + logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err) + return nil, nil, AlertInternalError + } + + // Add the shim PSK extension to the ClientHello + logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity) + psk = &PreSharedKeyExtension{ + HandshakeType: HandshakeTypeClientHello, + Identities: []PSKIdentity{ + { + Identity: key.Identity, + ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd, + }, + }, + Binders: []PSKBinderEntry{ + // Note: Stub to get the length fields right + {Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())}, + }, + } + ch.Extensions.Add(psk) + + // Compute the binder key + h0 := params.Hash.New().Sum(nil) + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + + earlyHash = params.Hash + earlySecret = HkdfExtract(params.Hash, zero, key.Key) + logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) + + binderLabel := labelExternalBinder + if key.IsResumption { + binderLabel = labelResumptionBinder + } + binderKey := deriveSecret(params, earlySecret, binderLabel, h0) + logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey) + + // Compute the binder value + trunc, err := ch.Truncated() + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err) + return nil, nil, AlertInternalError + } + + truncHash := params.Hash.New() + truncHash.Write(trunc) + + binder := computeFinishedData(params, binderKey, truncHash.Sum(nil)) + + // Replace the PSK extension + psk.Binders[0].Binder = binder + ch.Extensions.Add(psk) + + // If we got here, the earlier marshal succeeded (in ch.Truncated()), so + // this one should too. + clientHello, _ = state.hsCtx.hOut.HandshakeMessageFromBody(ch) + + // Compute early traffic keys + h := params.Hash.New() + h.Write(clientHello.Marshal()) + chHash := h.Sum(nil) + + earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) + logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret) + clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret) + } else if len(state.Opts.EarlyData) > 0 { + logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK") + return nil, nil, AlertInternalError + } else { + clientHello, err = state.hsCtx.hOut.HandshakeMessageFromBody(ch) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err) + return nil, nil, AlertInternalError + } + } + + logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]") + state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2. + nextState := ClientStateWaitSH{ + Config: state.Config, + Opts: state.Opts, + Params: state.Params, + hsCtx: state.hsCtx, + OfferedDH: offeredDH, + OfferedPSK: offeredPSK, + + earlySecret: earlySecret, + earlyHash: earlyHash, + + firstClientHello: state.firstClientHello, + helloRetryRequest: state.helloRetryRequest, + clientHello: clientHello, + } + + toSend := []HandshakeAction{ + QueueHandshakeMessage{clientHello}, + SendQueuedHandshake{}, + } + if state.Params.ClientSendingEarlyData { + toSend = append(toSend, []HandshakeAction{ + RekeyOut{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys}, + SendEarlyData{}, + }...) + } + + return nextState, toSend, AlertNoAlert +} + +type ClientStateWaitSH struct { + Config *Config + Opts ConnectionOptions + Params ConnectionParameters + hsCtx HandshakeContext + OfferedDH map[NamedGroup][]byte + OfferedPSK PreSharedKey + PSK []byte + + earlySecret []byte + earlyHash crypto.Hash + + firstClientHello *HandshakeMessage + helloRetryRequest *HandshakeMessage + clientHello *HandshakeMessage +} + +var _ HandshakeState = &ClientStateWaitSH{} + +func (state ClientStateWaitSH) State() State { + return StateClientWaitSH +} + +func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } + + if hm == nil || hm.msgType != HandshakeTypeServerHello { + logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + sh := &ServerHelloBody{} + if _, err := sh.Unmarshal(hm.body); err != nil { + logf(logTypeHandshake, "[ClientStateWaitSH] unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + // Common SH/HRR processing first. + // 1. Check that sh.version is TLS 1.2 + if sh.Version != tls12Version { + logf(logTypeHandshake, "[ClientStateWaitSH] illegal legacy version [%v]", sh.Version) + return nil, nil, AlertIllegalParameter + } + + // 2. Check that it responded with a valid version. + supportedVersions := SupportedVersionsExtension{HandshakeType: HandshakeTypeServerHello} + foundSupportedVersions, err := sh.Extensions.Find(&supportedVersions) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitSH] invalid supported_versions extension [%v]", err) + return nil, nil, AlertDecodeError + } + if !foundSupportedVersions { + logf(logTypeHandshake, "[ClientStateWaitSH] no supported_versions extension") + return nil, nil, AlertMissingExtension + } + if supportedVersions.Versions[0] != supportedVersion { + logf(logTypeHandshake, "[ClientStateWaitSH] unsupported version [%x]", supportedVersions.Versions[0]) + return nil, nil, AlertProtocolVersion + } + // 3. Check that the server provided a supported ciphersuite + supportedCipherSuite := false + for _, suite := range state.Config.CipherSuites { + supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite) + } + if !supportedCipherSuite { + logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite) + return nil, nil, AlertHandshakeFailure + } + + // Now check for the sentinel. + + if sh.Random == hrrRandomSentinel { + // This is actually HRR. + hrr := sh + + // Narrow the supported ciphersuites to the server-provided one + state.Config.CipherSuites = []CipherSuite{hrr.CipherSuite} + + // Handle external extensions. + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + // The only thing we know how to respond to in an HRR is the Cookie + // extension, so if there is either no Cookie extension or anything other + // than a Cookie extension and SupportedVersions we have to fail. + serverCookie := new(CookieExtension) + foundCookie, err := hrr.Extensions.Find(serverCookie) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitSH] Invalid server cookie extension [%v]", err) + return nil, nil, AlertDecodeError + } + if !foundCookie || len(hrr.Extensions) != 2 { + logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions)) + return nil, nil, AlertIllegalParameter + } + + // Hash the body into a pseudo-message + // XXX: Ignoring some errors here + params := cipherSuiteMap[hrr.CipherSuite] + h := params.Hash.New() + h.Write(state.clientHello.Marshal()) + firstClientHello := &HandshakeMessage{ + msgType: HandshakeTypeMessageHash, + body: h.Sum(nil), + } + + logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]") + return ClientStateStart{ + Config: state.Config, + Opts: state.Opts, + hsCtx: state.hsCtx, + cookie: serverCookie.Cookie, + firstClientHello: firstClientHello, + helloRetryRequest: hm, + }, nil, AlertNoAlert + } + + // This is SH. + // Handle external extensions. + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Do PSK or key agreement depending on extensions + serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello} + serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello} + + foundExts, err := sh.Extensions.Parse( + []ExtensionBody{ + &serverPSK, + &serverKeyShare, + }) + if err != nil { + logf(logTypeHandshake, "[ClientWaitSH] Error processing extensions [%v]", err) + return nil, nil, AlertDecodeError + } + + if foundExts[ExtensionTypePreSharedKey] && (serverPSK.SelectedIdentity == 0) { + state.Params.UsingPSK = true + } + + var dhSecret []byte + if foundExts[ExtensionTypeKeyShare] { + sks := serverKeyShare.Shares[0] + priv, ok := state.OfferedDH[sks.Group] + if !ok { + logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group") + return nil, nil, AlertIllegalParameter + } + + state.Params.UsingDH = true + dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv) + } + + suite := sh.CipherSuite + state.Params.CipherSuite = suite + + params, ok := cipherSuiteMap[suite] + if !ok { + logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite) + return nil, nil, AlertHandshakeFailure + } + + // Start up the handshake hash + handshakeHash := params.Hash.New() + handshakeHash.Write(state.firstClientHello.Marshal()) + handshakeHash.Write(state.helloRetryRequest.Marshal()) + handshakeHash.Write(state.clientHello.Marshal()) + handshakeHash.Write(hm.Marshal()) + + // Compute handshake secrets + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + + var earlySecret []byte + if state.Params.UsingPSK { + if params.Hash != state.earlyHash { + logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]", + state.earlyHash, suite, params.Hash) + } + + earlySecret = state.earlySecret + } else { + earlySecret = HkdfExtract(params.Hash, zero, zero) + } + + if dhSecret == nil { + dhSecret = zero + } + + h0 := params.Hash.New().Sum(nil) + h2 := handshakeHash.Sum(nil) + preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) + handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret) + clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) + serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) + preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) + masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) + + logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) + logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) + logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) + logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) + logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) + + serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) + + logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") + nextState := ClientStateWaitEE{ + Config: state.Config, + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: params, + handshakeHash: handshakeHash, + masterSecret: masterSecret, + clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: serverHandshakeTrafficSecret, + } + toSend := []HandshakeAction{ + RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys}, + } + return nextState, toSend, AlertNoAlert +} + +type ClientStateWaitEE struct { + Config *Config + Params ConnectionParameters + hsCtx HandshakeContext + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +var _ HandshakeState = &ClientStateWaitEE{} + +func (state ClientStateWaitEE) State() State { + return StateClientWaitEE +} + +func (state ClientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } + if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions { + logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + ee := EncryptedExtensionsBody{} + if err := safeUnmarshal(&ee, hm.body); err != nil { + logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + // Handle external extensions. + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + serverALPN := &ALPNExtension{} + serverEarlyData := &EarlyDataExtension{} + + foundExts, err := ee.Extensions.Parse( + []ExtensionBody{ + serverALPN, + serverEarlyData, + }) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding extensions: %v", err) + return nil, nil, AlertDecodeError + } + + state.Params.UsingEarlyData = foundExts[ExtensionTypeEarlyData] + + if foundExts[ExtensionTypeALPN] && len(serverALPN.Protocols) > 0 { + state.Params.NextProto = serverALPN.Protocols[0] + } + + state.handshakeHash.Write(hm.Marshal()) + + if state.Params.UsingPSK { + logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]") + nextState := ClientStateWaitFinished{ + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.Config.Certificates, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert + } + + logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]") + nextState := ClientStateWaitCertCR{ + Config: state.Config, + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert +} + +type ClientStateWaitCertCR struct { + Config *Config + Params ConnectionParameters + hsCtx HandshakeContext + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +var _ HandshakeState = &ClientStateWaitCertCR{} + +func (state ClientStateWaitCertCR) State() State { + return StateClientWaitCertCR +} + +func (state ClientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } + if hm == nil { + logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + bodyGeneric, err := hm.ToBody() + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + state.handshakeHash.Write(hm.Marshal()) + + switch body := bodyGeneric.(type) { + case *CertificateBody: + logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]") + nextState := ClientStateWaitCV{ + Config: state.Config, + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + serverCertificate: body, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert + + case *CertificateRequestBody: + // A certificate request in the handshake should have a zero-length context + if len(body.CertificateRequestContext) > 0 { + logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err) + return nil, nil, AlertIllegalParameter + } + + state.Params.UsingClientAuth = true + + logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]") + nextState := ClientStateWaitCert{ + Config: state.Config, + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + serverCertificateRequest: body, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert + } + + return nil, nil, AlertUnexpectedMessage +} + +type ClientStateWaitCert struct { + Config *Config + Params ConnectionParameters + hsCtx HandshakeContext + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + + serverCertificateRequest *CertificateRequestBody + + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +var _ HandshakeState = &ClientStateWaitCert{} + +func (state ClientStateWaitCert) State() State { + return StateClientWaitCert +} + +func (state ClientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } + if hm == nil || hm.msgType != HandshakeTypeCertificate { + logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + cert := &CertificateBody{} + if err := safeUnmarshal(cert, hm.body); err != nil { + logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + state.handshakeHash.Write(hm.Marshal()) + + logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]") + nextState := ClientStateWaitCV{ + Config: state.Config, + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + serverCertificate: cert, + serverCertificateRequest: state.serverCertificateRequest, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert +} + +type ClientStateWaitCV struct { + Config *Config + Params ConnectionParameters + hsCtx HandshakeContext + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + + serverCertificate *CertificateBody + serverCertificateRequest *CertificateRequestBody + + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +var _ HandshakeState = &ClientStateWaitCV{} + +func (state ClientStateWaitCV) State() State { + return StateClientWaitCV +} + +func (state ClientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } + if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { + logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + certVerify := CertificateVerifyBody{} + if err := safeUnmarshal(&certVerify, hm.body); err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + hcv := state.handshakeHash.Sum(nil) + logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) + + serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey + if err := certVerify.Verify(serverPublicKey, hcv); err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify") + return nil, nil, AlertHandshakeFailure + } + + certs := make([]*x509.Certificate, len(state.serverCertificate.CertificateList)) + rawCerts := make([][]byte, len(state.serverCertificate.CertificateList)) + for i, certEntry := range state.serverCertificate.CertificateList { + certs[i] = certEntry.CertData + rawCerts[i] = certEntry.CertData.Raw + } + + var verifiedChains [][]*x509.Certificate + if !state.Config.InsecureSkipVerify { + opts := x509.VerifyOptions{ + Roots: state.Config.RootCAs, + CurrentTime: state.Config.time(), + DNSName: state.Config.ServerName, + Intermediates: x509.NewCertPool(), + } + + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + var err error + verifiedChains, err = certs[0].Verify(opts) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Certificate verification failed: %s", err) + return nil, nil, AlertBadCertificate + } + } + + if state.Config.VerifyPeerCertificate != nil { + if err := state.Config.VerifyPeerCertificate(rawCerts, verifiedChains); err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate: %s", err) + return nil, nil, AlertBadCertificate + } + } + + state.handshakeHash.Write(hm.Marshal()) + + logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]") + nextState := ClientStateWaitFinished{ + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.Config.Certificates, + serverCertificateRequest: state.serverCertificateRequest, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + peerCertificates: certs, + verifiedChains: verifiedChains, + } + return nextState, nil, AlertNoAlert +} + +type ClientStateWaitFinished struct { + Params ConnectionParameters + hsCtx HandshakeContext + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + + certificates []*Certificate + serverCertificateRequest *CertificateRequestBody + peerCertificates []*x509.Certificate + verifiedChains [][]*x509.Certificate + + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +var _ HandshakeState = &ClientStateWaitFinished{} + +func (state ClientStateWaitFinished) State() State { + return StateClientWaitFinished +} + +func (state ClientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } + if hm == nil || hm.msgType != HandshakeTypeFinished { + logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + // Verify server's Finished + h3 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3) + logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3) + + serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3) + logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) + + fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)} + if err := safeUnmarshal(fin, hm.body); err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + if !bytes.Equal(fin.VerifyData, serverFinishedData) { + logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]", + fin.VerifyData, serverFinishedData) + return nil, nil, AlertHandshakeFailure + } + + // Update the handshake hash with the Finished + state.handshakeHash.Write(hm.Marshal()) + logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal()) + h4 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4) + + // Compute traffic secrets and keys + clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4) + serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4) + logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret) + logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret) + + clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret) + serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret) + + exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4) + logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret) + + // Assemble client's second flight + toSend := []HandshakeAction{} + + if state.Params.UsingEarlyData { + // Note: We only send EOED if the server is actually going to use the early + // data. Otherwise, it will never see it, and the transcripts will + // mismatch. + // EOED marshal is infallible + eoedm, _ := state.hsCtx.hOut.HandshakeMessageFromBody(&EndOfEarlyDataBody{}) + toSend = append(toSend, QueueHandshakeMessage{eoedm}) + + state.handshakeHash.Write(eoedm.Marshal()) + logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal()) + } + + clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) + toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys}) + + if state.Params.UsingClientAuth { + // Extract constraints from certicateRequest + schemes := SignatureAlgorithmsExtension{} + gotSchemes, err := state.serverCertificateRequest.Extensions.Find(&schemes) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING invalid signature_schemes extension [%v]", err) + return nil, nil, AlertDecodeError + } + if !gotSchemes { + logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found") + return nil, nil, AlertIllegalParameter + } + + // Select a certificate + cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates) + if err != nil { + // XXX: Signal this to the application layer? + logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) + + certificate := &CertificateBody{} + certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, QueueHandshakeMessage{certm}) + state.handshakeHash.Write(certm.Marshal()) + } else { + // Create and send Certificate, CertificateVerify + certificate := &CertificateBody{ + CertificateList: make([]CertificateEntry, len(cert.Chain)), + } + for i, entry := range cert.Chain { + certificate.CertificateList[i] = CertificateEntry{CertData: entry} + } + certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, QueueHandshakeMessage{certm}) + state.handshakeHash.Write(certm.Marshal()) + + hcv := state.handshakeHash.Sum(nil) + logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) + + certificateVerify := &CertificateVerifyBody{Algorithm: certScheme} + logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash) + + err = certificateVerify.Sign(cert.PrivateKey, hcv) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err) + return nil, nil, AlertInternalError + } + certvm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificateVerify) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, QueueHandshakeMessage{certvm}) + state.handshakeHash.Write(certvm.Marshal()) + } + } + + // Compute the client's Finished message + h5 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5) + + clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5) + logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData) + + fin = &FinishedBody{ + VerifyDataLen: len(clientFinishedData), + VerifyData: clientFinishedData, + } + finm, err := state.hsCtx.hOut.HandshakeMessageFromBody(fin) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err) + return nil, nil, AlertInternalError + } + + // Compute the resumption secret + state.handshakeHash.Write(finm.Marshal()) + h6 := state.handshakeHash.Sum(nil) + + resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6) + logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) + + toSend = append(toSend, []HandshakeAction{ + QueueHandshakeMessage{finm}, + SendQueuedHandshake{}, + RekeyIn{epoch: EpochApplicationData, KeySet: serverTrafficKeys}, + RekeyOut{epoch: EpochApplicationData, KeySet: clientTrafficKeys}, + }...) + + logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]") + nextState := StateConnected{ + Params: state.Params, + hsCtx: state.hsCtx, + isClient: true, + cryptoParams: state.cryptoParams, + resumptionSecret: resumptionSecret, + clientTrafficSecret: clientTrafficSecret, + serverTrafficSecret: serverTrafficSecret, + exporterSecret: exporterSecret, + peerCertificates: state.peerCertificates, + verifiedChains: state.verifiedChains, + } + return nextState, toSend, AlertNoAlert +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/common.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/common.go new file mode 100644 index 000000000..565d15e32 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/common.go @@ -0,0 +1,252 @@ +package mint + +import ( + "fmt" + "strconv" +) + +const ( + supportedVersion uint16 = 0x7f16 // draft-22 + tls12Version uint16 = 0x0303 + tls10Version uint16 = 0x0301 + dtls12WireVersion uint16 = 0xfefd +) + +var ( + // Flags for some minor compat issues + allowWrongVersionNumber = true + allowPKCS1 = true +) + +// enum {...} ContentType; +type RecordType byte + +const ( + RecordTypeAlert RecordType = 21 + RecordTypeHandshake RecordType = 22 + RecordTypeApplicationData RecordType = 23 +) + +// enum {...} HandshakeType; +type HandshakeType byte + +const ( + // Omitted: *_RESERVED + HandshakeTypeClientHello HandshakeType = 1 + HandshakeTypeServerHello HandshakeType = 2 + HandshakeTypeNewSessionTicket HandshakeType = 4 + HandshakeTypeEndOfEarlyData HandshakeType = 5 + HandshakeTypeHelloRetryRequest HandshakeType = 6 + HandshakeTypeEncryptedExtensions HandshakeType = 8 + HandshakeTypeCertificate HandshakeType = 11 + HandshakeTypeCertificateRequest HandshakeType = 13 + HandshakeTypeCertificateVerify HandshakeType = 15 + HandshakeTypeServerConfiguration HandshakeType = 17 + HandshakeTypeFinished HandshakeType = 20 + HandshakeTypeKeyUpdate HandshakeType = 24 + HandshakeTypeMessageHash HandshakeType = 254 +) + +var hrrRandomSentinel = [32]byte{ + 0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, + 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91, + 0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, + 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c, +} + +// uint8 CipherSuite[2]; +type CipherSuite uint16 + +const ( + // XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero + // value for this type so that we can detect when a field is set. + CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000 + TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301 + TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302 + TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303 + TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304 + TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305 +) + +func (c CipherSuite) String() string { + switch c { + case CIPHER_SUITE_UNKNOWN: + return "unknown" + case TLS_AES_128_GCM_SHA256: + return "TLS_AES_128_GCM_SHA256" + case TLS_AES_256_GCM_SHA384: + return "TLS_AES_256_GCM_SHA384" + case TLS_CHACHA20_POLY1305_SHA256: + return "TLS_CHACHA20_POLY1305_SHA256" + case TLS_AES_128_CCM_SHA256: + return "TLS_AES_128_CCM_SHA256" + case TLS_AES_256_CCM_8_SHA256: + return "TLS_AES_256_CCM_8_SHA256" + } + // cannot use %x here, since it calls String(), leading to infinite recursion + return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16)) +} + +// enum {...} SignatureScheme +type SignatureScheme uint16 + +const ( + // RSASSA-PKCS1-v1_5 algorithms + RSA_PKCS1_SHA1 SignatureScheme = 0x0201 + RSA_PKCS1_SHA256 SignatureScheme = 0x0401 + RSA_PKCS1_SHA384 SignatureScheme = 0x0501 + RSA_PKCS1_SHA512 SignatureScheme = 0x0601 + // ECDSA algorithms + ECDSA_P256_SHA256 SignatureScheme = 0x0403 + ECDSA_P384_SHA384 SignatureScheme = 0x0503 + ECDSA_P521_SHA512 SignatureScheme = 0x0603 + // RSASSA-PSS algorithms + RSA_PSS_SHA256 SignatureScheme = 0x0804 + RSA_PSS_SHA384 SignatureScheme = 0x0805 + RSA_PSS_SHA512 SignatureScheme = 0x0806 + // EdDSA algorithms + Ed25519 SignatureScheme = 0x0807 + Ed448 SignatureScheme = 0x0808 +) + +// enum {...} ExtensionType +type ExtensionType uint16 + +const ( + ExtensionTypeServerName ExtensionType = 0 + ExtensionTypeSupportedGroups ExtensionType = 10 + ExtensionTypeSignatureAlgorithms ExtensionType = 13 + ExtensionTypeALPN ExtensionType = 16 + ExtensionTypeKeyShare ExtensionType = 40 + ExtensionTypePreSharedKey ExtensionType = 41 + ExtensionTypeEarlyData ExtensionType = 42 + ExtensionTypeSupportedVersions ExtensionType = 43 + ExtensionTypeCookie ExtensionType = 44 + ExtensionTypePSKKeyExchangeModes ExtensionType = 45 + ExtensionTypeTicketEarlyDataInfo ExtensionType = 46 +) + +// enum {...} NamedGroup +type NamedGroup uint16 + +const ( + // Elliptic Curve Groups. + P256 NamedGroup = 23 + P384 NamedGroup = 24 + P521 NamedGroup = 25 + // ECDH functions. + X25519 NamedGroup = 29 + X448 NamedGroup = 30 + // Finite field groups. + FFDHE2048 NamedGroup = 256 + FFDHE3072 NamedGroup = 257 + FFDHE4096 NamedGroup = 258 + FFDHE6144 NamedGroup = 259 + FFDHE8192 NamedGroup = 260 +) + +// enum {...} PskKeyExchangeMode; +type PSKKeyExchangeMode uint8 + +const ( + PSKModeKE PSKKeyExchangeMode = 0 + PSKModeDHEKE PSKKeyExchangeMode = 1 +) + +// enum { +// update_not_requested(0), update_requested(1), (255) +// } KeyUpdateRequest; +type KeyUpdateRequest uint8 + +const ( + KeyUpdateNotRequested KeyUpdateRequest = 0 + KeyUpdateRequested KeyUpdateRequest = 1 +) + +type State uint8 + +const ( + // states valid for the client + StateClientStart State = iota + StateClientWaitSH + StateClientWaitEE + StateClientWaitCert + StateClientWaitCV + StateClientWaitFinished + StateClientWaitCertCR + StateClientConnected + // states valid for the server + StateServerStart State = iota + StateServerRecvdCH + StateServerNegotiated + StateServerWaitEOED + StateServerWaitFlight2 + StateServerWaitCert + StateServerWaitCV + StateServerWaitFinished + StateServerConnected +) + +func (s State) String() string { + switch s { + case StateClientStart: + return "Client START" + case StateClientWaitSH: + return "Client WAIT_SH" + case StateClientWaitEE: + return "Client WAIT_EE" + case StateClientWaitCert: + return "Client WAIT_CERT" + case StateClientWaitCV: + return "Client WAIT_CV" + case StateClientWaitFinished: + return "Client WAIT_FINISHED" + case StateClientConnected: + return "Client CONNECTED" + case StateServerStart: + return "Server START" + case StateServerRecvdCH: + return "Server RECVD_CH" + case StateServerNegotiated: + return "Server NEGOTIATED" + case StateServerWaitEOED: + return "Server WAIT_EOED" + case StateServerWaitFlight2: + return "Server WAIT_FLIGHT2" + case StateServerWaitCert: + return "Server WAIT_CERT" + case StateServerWaitCV: + return "Server WAIT_CV" + case StateServerWaitFinished: + return "Server WAIT_FINISHED" + case StateServerConnected: + return "Server CONNECTED" + default: + return fmt.Sprintf("unknown state: %d", s) + } +} + +// Epochs for DTLS (also used for key phase labelling) +type Epoch uint16 + +const ( + EpochClear Epoch = 0 + EpochEarlyData Epoch = 1 + EpochHandshakeData Epoch = 2 + EpochApplicationData Epoch = 3 + EpochUpdate Epoch = 4 +) + +func (e Epoch) label() string { + switch e { + case EpochClear: + return "clear" + case EpochEarlyData: + return "early data" + case EpochHandshakeData: + return "handshake" + case EpochApplicationData: + return "application data" + } + return "Application data (updated)" +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/conn.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/conn.go new file mode 100644 index 000000000..ffbead3fc --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/conn.go @@ -0,0 +1,884 @@ +package mint + +import ( + "crypto" + "crypto/x509" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "reflect" + "sync" + "time" +) + +var WouldBlock = fmt.Errorf("Would have blocked") + +type Certificate struct { + Chain []*x509.Certificate + PrivateKey crypto.Signer +} + +type PreSharedKey struct { + CipherSuite CipherSuite + IsResumption bool + Identity []byte + Key []byte + NextProto string + ReceivedAt time.Time + ExpiresAt time.Time + TicketAgeAdd uint32 +} + +type PreSharedKeyCache interface { + Get(string) (PreSharedKey, bool) + Put(string, PreSharedKey) + Size() int +} + +// A CookieHandler can be used to give the application more fine-grained control over Cookies. +// Generate receives the Conn as an argument, so the CookieHandler can decide when to send the cookie based on that, and offload state to the client by encoding that into the Cookie. +// When the client echoes the Cookie, Validate is called. The application can then recover the state from the cookie. +type CookieHandler interface { + // Generate a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest + // If Generate returns nil, mint will not send a HelloRetryRequest. + Generate(*Conn) ([]byte, error) + // Validate is called when receiving a ClientHello containing a Cookie. + // If validation failed, the handshake is aborted. + Validate(*Conn, []byte) bool +} + +type PSKMapCache map[string]PreSharedKey + +func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) { + psk, ok = cache[key] + return +} + +func (cache *PSKMapCache) Put(key string, psk PreSharedKey) { + (*cache)[key] = psk +} + +func (cache PSKMapCache) Size() int { + return len(cache) +} + +// Config is the struct used to pass configuration settings to a TLS client or +// server instance. The settings for client and server are pretty different, +// but we just throw them all in here. +type Config struct { + // Client fields + ServerName string + + // Server fields + SendSessionTickets bool + TicketLifetime uint32 + TicketLen int + EarlyDataLifetime uint32 + AllowEarlyData bool + // Require the client to echo a cookie. + RequireCookie bool + // A CookieHandler can be used to set and validate a cookie. + // The cookie returned by the CookieHandler will be part of the cookie sent on the wire, and encoded using the CookieProtector. + // If no CookieHandler is set, mint will always send a cookie. + // The CookieHandler can be used to decide on a per-connection basis, if a cookie should be sent. + CookieHandler CookieHandler + // The CookieProtector is used to encrypt / decrypt cookies. + // It should make sure that the Cookie cannot be read and tampered with by the client. + // If non-blocking mode is used, and cookies are required, this field has to be set. + // In blocking mode, a default cookie protector is used, if this is unused. + CookieProtector CookieProtector + // The ExtensionHandler is used to add custom extensions. + ExtensionHandler AppExtensionHandler + RequireClientAuth bool + + // Time returns the current time as the number of seconds since the epoch. + // If Time is nil, TLS uses time.Now. + Time func() time.Time + // RootCAs defines the set of root certificate authorities + // that clients use when verifying server certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *x509.CertPool + // InsecureSkipVerify controls whether a client verifies the + // server's certificate chain and host name. + // If InsecureSkipVerify is true, TLS accepts any certificate + // presented by the server and any host name in that certificate. + // In this mode, TLS is susceptible to man-in-the-middle attacks. + // This should be used only for testing. + InsecureSkipVerify bool + + // Shared fields + Certificates []*Certificate + // VerifyPeerCertificate, if not nil, is called after normal + // certificate verification by either a TLS client or server. It + // receives the raw ASN.1 certificates provided by the peer and also + // any verified chains that normal processing found. If it returns a + // non-nil error, the handshake is aborted and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. If normal verification is disabled by + // setting InsecureSkipVerify then this callback will be considered but + // the verifiedChains argument will always be nil. + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + + CipherSuites []CipherSuite + Groups []NamedGroup + SignatureSchemes []SignatureScheme + NextProtos []string + PSKs PreSharedKeyCache + PSKModes []PSKKeyExchangeMode + NonBlocking bool + UseDTLS bool + + // The same config object can be shared among different connections, so it + // needs its own mutex + mutex sync.RWMutex +} + +// Clone returns a shallow clone of c. It is safe to clone a Config that is +// being used concurrently by a TLS client or server. +func (c *Config) Clone() *Config { + c.mutex.Lock() + defer c.mutex.Unlock() + + return &Config{ + ServerName: c.ServerName, + + SendSessionTickets: c.SendSessionTickets, + TicketLifetime: c.TicketLifetime, + TicketLen: c.TicketLen, + EarlyDataLifetime: c.EarlyDataLifetime, + AllowEarlyData: c.AllowEarlyData, + RequireCookie: c.RequireCookie, + CookieHandler: c.CookieHandler, + CookieProtector: c.CookieProtector, + ExtensionHandler: c.ExtensionHandler, + RequireClientAuth: c.RequireClientAuth, + Time: c.Time, + RootCAs: c.RootCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + + Certificates: c.Certificates, + VerifyPeerCertificate: c.VerifyPeerCertificate, + CipherSuites: c.CipherSuites, + Groups: c.Groups, + SignatureSchemes: c.SignatureSchemes, + NextProtos: c.NextProtos, + PSKs: c.PSKs, + PSKModes: c.PSKModes, + NonBlocking: c.NonBlocking, + UseDTLS: c.UseDTLS, + } +} + +func (c *Config) Init(isClient bool) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + // Set defaults + if len(c.CipherSuites) == 0 { + c.CipherSuites = defaultSupportedCipherSuites + } + if len(c.Groups) == 0 { + c.Groups = defaultSupportedGroups + } + if len(c.SignatureSchemes) == 0 { + c.SignatureSchemes = defaultSignatureSchemes + } + if c.TicketLen == 0 { + c.TicketLen = defaultTicketLen + } + if !reflect.ValueOf(c.PSKs).IsValid() { + c.PSKs = &PSKMapCache{} + } + if len(c.PSKModes) == 0 { + c.PSKModes = defaultPSKModes + } + return nil +} + +func (c *Config) ValidForServer() bool { + return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) || + (len(c.Certificates) > 0 && + len(c.Certificates[0].Chain) > 0 && + c.Certificates[0].PrivateKey != nil) +} + +func (c *Config) ValidForClient() bool { + return len(c.ServerName) > 0 +} + +func (c *Config) time() time.Time { + t := c.Time + if t == nil { + t = time.Now + } + return t() +} + +var ( + defaultSupportedCipherSuites = []CipherSuite{ + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + } + + defaultSupportedGroups = []NamedGroup{ + P256, + P384, + FFDHE2048, + X25519, + } + + defaultSignatureSchemes = []SignatureScheme{ + RSA_PSS_SHA256, + RSA_PSS_SHA384, + RSA_PSS_SHA512, + ECDSA_P256_SHA256, + ECDSA_P384_SHA384, + ECDSA_P521_SHA512, + } + + defaultTicketLen = 16 + + defaultPSKModes = []PSKKeyExchangeMode{ + PSKModeKE, + PSKModeDHEKE, + } +) + +type ConnectionState struct { + HandshakeState State + CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...) + PeerCertificates []*x509.Certificate // certificate chain presented by remote peer + VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates + NextProto string // Selected ALPN proto +} + +// Conn implements the net.Conn interface, as with "crypto/tls" +// * Read, Write, and Close are provided locally +// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn +type Conn struct { + config *Config + conn net.Conn + isClient bool + + EarlyData []byte + + state StateConnected + hState HandshakeState + handshakeMutex sync.Mutex + handshakeAlert Alert + handshakeComplete bool + + readBuffer []byte + in, out *RecordLayer + hsCtx HandshakeContext +} + +func NewConn(conn net.Conn, config *Config, isClient bool) *Conn { + c := &Conn{conn: conn, config: config, isClient: isClient} + if !config.UseDTLS { + c.in = NewRecordLayerTLS(c.conn) + c.out = NewRecordLayerTLS(c.conn) + c.hsCtx.hIn = NewHandshakeLayerTLS(c.in) + c.hsCtx.hOut = NewHandshakeLayerTLS(c.out) + } else { + c.in = NewRecordLayerDTLS(c.conn) + c.out = NewRecordLayerDTLS(c.conn) + c.hsCtx.hIn = NewHandshakeLayerDTLS(c.in) + c.hsCtx.hOut = NewHandshakeLayerDTLS(c.out) + } + c.hsCtx.hIn.nonblocking = c.config.NonBlocking + return c +} + +// Read up +func (c *Conn) consumeRecord() error { + pt, err := c.in.ReadRecord() + if pt == nil { + logf(logTypeIO, "extendBuffer returns error %v", err) + return err + } + + switch pt.contentType { + case RecordTypeHandshake: + logf(logTypeHandshake, "Received post-handshake message") + // We do not support fragmentation of post-handshake handshake messages. + // TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage() + start := 0 + headerLen := handshakeHeaderLenTLS + if c.config.UseDTLS { + headerLen = handshakeHeaderLenDTLS + } + for start < len(pt.fragment) { + if len(pt.fragment[start:]) < headerLen { + return fmt.Errorf("Post-handshake handshake message too short for header") + } + + hm := &HandshakeMessage{} + hm.msgType = HandshakeType(pt.fragment[start]) + hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3]) + + if len(pt.fragment[start+headerLen:]) < hmLen { + return fmt.Errorf("Post-handshake handshake message too short for body") + } + hm.body = pt.fragment[start+headerLen : start+headerLen+hmLen] + + // XXX: If we want to support more advanced cases, e.g., post-handshake + // authentication, we'll need to allow transitions other than + // Connected -> Connected + state, actions, alert := c.state.ProcessMessage(hm) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error in state transition: %v", alert) + c.sendAlert(alert) + return io.EOF + } + + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + c.sendAlert(alert) + return io.EOF + } + } + + var connected bool + c.state, connected = state.(StateConnected) + if !connected { + logf(logTypeHandshake, "Disconnected after state transition: %v", alert) + c.sendAlert(alert) + return io.EOF + } + + start += headerLen + hmLen + } + case RecordTypeAlert: + logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer) + if len(pt.fragment) != 2 { + c.sendAlert(AlertUnexpectedMessage) + return io.EOF + } + if Alert(pt.fragment[1]) == AlertCloseNotify { + return io.EOF + } + + switch pt.fragment[0] { + case AlertLevelWarning: + // drop on the floor + case AlertLevelError: + return Alert(pt.fragment[1]) + default: + c.sendAlert(AlertUnexpectedMessage) + return io.EOF + } + + case RecordTypeApplicationData: + c.readBuffer = append(c.readBuffer, pt.fragment...) + logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer) + } + + return err +} + +// Read application data up to the size of buffer. Handshake and alert records +// are consumed by the Conn object directly. +func (c *Conn) Read(buffer []byte) (int, error) { + if _, connected := c.hState.(StateConnected); !connected && c.config.NonBlocking { + return 0, errors.New("Read called before the handshake completed") + } + logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer)) + if alert := c.Handshake(); alert != AlertNoAlert { + return 0, alert + } + + if len(buffer) == 0 { + return 0, nil + } + + // Lock the input channel + c.in.Lock() + defer c.in.Unlock() + for len(c.readBuffer) == 0 { + err := c.consumeRecord() + + // err can be nil if consumeRecord processed a non app-data + // record. + if err != nil { + if c.config.NonBlocking || err != WouldBlock { + logf(logTypeIO, "conn.Read returns err=%v", err) + return 0, err + } + } + } + + var read int + n := len(buffer) + logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer)) + if len(c.readBuffer) <= n { + buffer = buffer[:len(c.readBuffer)] + copy(buffer, c.readBuffer) + read = len(c.readBuffer) + c.readBuffer = c.readBuffer[:0] + } else { + logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n) + copy(buffer[:n], c.readBuffer[:n]) + c.readBuffer = c.readBuffer[n:] + read = n + } + + logf(logTypeVerbose, "Returning %v", string(buffer)) + return read, nil +} + +// Write application data +func (c *Conn) Write(buffer []byte) (int, error) { + // Lock the output channel + c.out.Lock() + defer c.out.Unlock() + + // Send full-size fragments + var start int + sent := 0 + for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen { + err := c.out.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeApplicationData, + fragment: buffer[start : start+maxFragmentLen], + }) + + if err != nil { + return sent, err + } + sent += maxFragmentLen + } + + // Send a final partial fragment if necessary + if start < len(buffer) { + err := c.out.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeApplicationData, + fragment: buffer[start:], + }) + + if err != nil { + return sent, err + } + sent += len(buffer[start:]) + } + return sent, nil +} + +// sendAlert sends a TLS alert message. +// c.out.Mutex <= L. +func (c *Conn) sendAlert(err Alert) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + var level int + switch err { + case AlertNoRenegotiation, AlertCloseNotify: + level = AlertLevelWarning + default: + level = AlertLevelError + } + + buf := []byte{byte(err), byte(level)} + c.out.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeAlert, + fragment: buf, + }) + + // close_notify and end_of_early_data are not actually errors + if level == AlertLevelWarning { + return &net.OpError{Op: "local error", Err: err} + } + + return c.Close() +} + +// Close closes the connection. +func (c *Conn) Close() error { + // XXX crypto/tls has an interlock with Write here. Do we need that? + + return c.conn.Close() +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated with the connection. +// A zero value for t means Read and Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying connection. +// A zero value for t means Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { + label := "[server]" + if c.isClient { + label = "[client]" + } + + switch action := actionGeneric.(type) { + case QueueHandshakeMessage: + logf(logTypeHandshake, "%s queuing handshake message type=%v", label, action.Message.msgType) + err := c.hsCtx.hOut.QueueMessage(action.Message) + if err != nil { + logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) + return AlertInternalError + } + + case SendQueuedHandshake: + err := c.hsCtx.hOut.SendQueuedMessages() + if err != nil { + logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) + return AlertInternalError + } + case RekeyIn: + logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.epoch.label(), action.KeySet) + err := c.in.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) + if err != nil { + logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err) + return AlertInternalError + } + + case RekeyOut: + logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.epoch.label(), action.KeySet) + err := c.out.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) + if err != nil { + logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err) + return AlertInternalError + } + + case SendEarlyData: + logf(logTypeHandshake, "%s Sending early data...", label) + _, err := c.Write(c.EarlyData) + if err != nil { + logf(logTypeHandshake, "%s Error writing early data: %v", label, err) + return AlertInternalError + } + + case ReadPastEarlyData: + logf(logTypeHandshake, "%s Reading past early data...", label) + // Scan past all records that fail to decrypt + _, err := c.in.PeekRecordType(!c.config.NonBlocking) + if err == nil { + break + } + _, ok := err.(DecryptError) + + for ok { + _, err = c.in.PeekRecordType(!c.config.NonBlocking) + if err == nil { + break + } + _, ok = err.(DecryptError) + } + + case ReadEarlyData: + logf(logTypeHandshake, "%s Reading early data...", label) + t, err := c.in.PeekRecordType(!c.config.NonBlocking) + if err != nil { + logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err) + return AlertInternalError + } + logf(logTypeHandshake, "%s Got record type(1): %v", label, t) + + for t == RecordTypeApplicationData { + // Read a record into the buffer. Note that this is safe + // in blocking mode because we read the record in in + // PeekRecordType. + pt, err := c.in.ReadRecord() + if err != nil { + logf(logTypeHandshake, "%s Error reading early data record: %v", label, err) + return AlertInternalError + } + + logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment) + c.EarlyData = append(c.EarlyData, pt.fragment...) + + t, err = c.in.PeekRecordType(!c.config.NonBlocking) + if err != nil { + logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err) + return AlertInternalError + } + logf(logTypeHandshake, "%s Got record type (2): %v", label, t) + } + logf(logTypeHandshake, "%s Done reading early data", label) + + case StorePSK: + logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity) + if c.isClient { + // Clients look up PSKs based on server name + c.config.PSKs.Put(c.config.ServerName, action.PSK) + } else { + // Servers look them up based on the identity in the extension + c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK) + } + + default: + logf(logTypeHandshake, "%s Unknown actionuction type", label) + return AlertInternalError + } + + return AlertNoAlert +} + +func (c *Conn) HandshakeSetup() Alert { + var state HandshakeState + var actions []HandshakeAction + var alert Alert + + if err := c.config.Init(c.isClient); err != nil { + logf(logTypeHandshake, "Error initializing config: %v", err) + return AlertInternalError + } + + opts := ConnectionOptions{ + ServerName: c.config.ServerName, + NextProtos: c.config.NextProtos, + EarlyData: c.EarlyData, + } + + if c.isClient { + state, actions, alert = ClientStateStart{Config: c.config, Opts: opts, hsCtx: c.hsCtx}.Next(nil) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error initializing client state: %v", alert) + return alert + } + + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + return alert + } + } + } else { + if c.config.RequireCookie && c.config.CookieProtector == nil { + logf(logTypeHandshake, "RequireCookie set, but no CookieProtector provided. Using default cookie protector. Stateless Retry not possible.") + if c.config.NonBlocking { + logf(logTypeHandshake, "Not possible in non-blocking mode.") + return AlertInternalError + } + var err error + c.config.CookieProtector, err = NewDefaultCookieProtector() + if err != nil { + logf(logTypeHandshake, "Error initializing cookie source: %v", alert) + return AlertInternalError + } + } + state = ServerStateStart{Config: c.config, conn: c, hsCtx: c.hsCtx} + } + + c.hState = state + return AlertNoAlert +} + +type handshakeMessageReader interface { + ReadMessage() (*HandshakeMessage, Alert) +} + +type handshakeMessageReaderImpl struct { + hsCtx *HandshakeContext +} + +var _ handshakeMessageReader = &handshakeMessageReaderImpl{} + +func (r *handshakeMessageReaderImpl) ReadMessage() (*HandshakeMessage, Alert) { + hm, err := r.hsCtx.hIn.ReadMessage() + if err == WouldBlock { + return nil, AlertWouldBlock + } + if err != nil { + logf(logTypeHandshake, "[client] Error reading message: %v", err) + return nil, AlertCloseNotify + } + + // Once you have read a message, you no longer need the outgoing queue + // for DTLS. + r.hsCtx.hOut.ClearQueuedMessages() + + return hm, AlertNoAlert +} + +// Handshake causes a TLS handshake on the connection. The `isClient` member +// determines whether a client or server handshake is performed. If a +// handshake has already been performed, then its result will be returned. +func (c *Conn) Handshake() Alert { + label := "[server]" + if c.isClient { + label = "[client]" + } + + // TODO Lock handshakeMutex + // TODO Remove CloseNotify hack + if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify { + logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert) + return c.handshakeAlert + } + if c.handshakeComplete { + return AlertNoAlert + } + + if c.hState == nil { + logf(logTypeHandshake, "%s First time through handshake (or after stateless retry), setting up", label) + alert := c.HandshakeSetup() + if alert != AlertNoAlert || (c.isClient && c.config.NonBlocking) { + return alert + } + } + + logf(logTypeHandshake, "(Re-)entering handshake, state=%v", c.hState) + state := c.hState + _, connected := state.(StateConnected) + + hmr := &handshakeMessageReaderImpl{hsCtx: &c.hsCtx} + for !connected { + var alert Alert + var actions []HandshakeAction + // Advance the state machine + state, actions, alert = state.Next(hmr) + if alert == WouldBlock { + logf(logTypeHandshake, "%s Would block reading message: %s", label, alert) + return AlertWouldBlock + } + if alert == AlertCloseNotify { + logf(logTypeHandshake, "%s Error reading message: %s", label, alert) + c.sendAlert(AlertCloseNotify) + return AlertCloseNotify + } + if alert != AlertNoAlert && alert != AlertStatelessRetry { + logf(logTypeHandshake, "Error in state transition: %v", alert) + return alert + } + + for index, action := range actions { + logf(logTypeHandshake, "%s taking next action (%d)", label, index) + if alert := c.takeAction(action); alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + c.sendAlert(alert) + return alert + } + } + + c.hState = state + logf(logTypeHandshake, "state is now %s", c.GetHsState()) + _, connected = state.(StateConnected) + if connected { + c.state = state.(StateConnected) + c.handshakeComplete = true + } + + if c.config.NonBlocking { + if alert == AlertStatelessRetry { + return AlertStatelessRetry + } + return AlertNoAlert + } + } + + // Send NewSessionTicket if acting as server + if !c.isClient && c.config.SendSessionTickets { + actions, alert := c.state.NewSessionTicket( + c.config.TicketLen, + c.config.TicketLifetime, + c.config.EarlyDataLifetime) + + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + c.sendAlert(alert) + return alert + } + } + } + + return AlertNoAlert +} + +func (c *Conn) SendKeyUpdate(requestUpdate bool) error { + if !c.handshakeComplete { + return fmt.Errorf("Cannot update keys until after handshake") + } + + request := KeyUpdateNotRequested + if requestUpdate { + request = KeyUpdateRequested + } + + // Create the key update and update state + actions, alert := c.state.KeyUpdate(request) + if alert != AlertNoAlert { + c.sendAlert(alert) + return fmt.Errorf("Alert while generating key update: %v", alert) + } + + // Take actions (send key update and rekey) + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + c.sendAlert(alert) + return fmt.Errorf("Alert during key update actions: %v", alert) + } + } + + return nil +} + +func (c *Conn) GetHsState() State { + return c.hState.State() +} + +func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { + _, connected := c.hState.(StateConnected) + if !connected { + return nil, fmt.Errorf("Cannot compute exporter when state is not connected") + } + + if c.state.exporterSecret == nil { + return nil, fmt.Errorf("Internal error: no exporter secret") + } + + h0 := c.state.cryptoParams.Hash.New().Sum(nil) + tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0) + + hc := c.state.cryptoParams.Hash.New().Sum(context) + return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil +} + +func (c *Conn) ConnectionState() ConnectionState { + state := ConnectionState{ + HandshakeState: c.GetHsState(), + } + + if c.handshakeComplete { + state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite] + state.NextProto = c.state.Params.NextProto + state.VerifiedChains = c.state.verifiedChains + state.PeerCertificates = c.state.peerCertificates + } + + return state +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/cookie-protector.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/cookie-protector.go new file mode 100644 index 000000000..73dd80bae --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/cookie-protector.go @@ -0,0 +1,86 @@ +package mint + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "fmt" + "io" + + "golang.org/x/crypto/hkdf" +) + +// CookieProtector is used to create and verify a cookie +type CookieProtector interface { + // NewToken creates a new token + NewToken([]byte) ([]byte, error) + // DecodeToken decodes a token + DecodeToken([]byte) ([]byte, error) +} + +const cookieSecretSize = 32 +const cookieNonceSize = 32 + +// The DefaultCookieProtector is a simple implementation for the CookieProtector. +type DefaultCookieProtector struct { + secret []byte +} + +var _ CookieProtector = &DefaultCookieProtector{} + +// NewDefaultCookieProtector creates a source for source address tokens +func NewDefaultCookieProtector() (CookieProtector, error) { + secret := make([]byte, cookieSecretSize) + if _, err := rand.Read(secret); err != nil { + return nil, err + } + return &DefaultCookieProtector{secret: secret}, nil +} + +// NewToken encodes data into a new token. +func (s *DefaultCookieProtector) NewToken(data []byte) ([]byte, error) { + nonce := make([]byte, cookieNonceSize) + if _, err := rand.Read(nonce); err != nil { + return nil, err + } + aead, aeadNonce, err := s.createAEAD(nonce) + if err != nil { + return nil, err + } + return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil +} + +// DecodeToken decodes a token. +func (s *DefaultCookieProtector) DecodeToken(p []byte) ([]byte, error) { + if len(p) < cookieNonceSize { + return nil, fmt.Errorf("Token too short: %d", len(p)) + } + nonce := p[:cookieNonceSize] + aead, aeadNonce, err := s.createAEAD(nonce) + if err != nil { + return nil, err + } + return aead.Open(nil, aeadNonce, p[cookieNonceSize:], nil) +} + +func (s *DefaultCookieProtector) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { + h := hkdf.New(sha256.New, s.secret, nonce, []byte("mint cookie source")) + key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 + if _, err := io.ReadFull(h, key); err != nil { + return nil, nil, err + } + aeadNonce := make([]byte, 12) + if _, err := io.ReadFull(h, aeadNonce); err != nil { + return nil, nil, err + } + c, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, nil, err + } + return aead, aeadNonce, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/crypto.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/crypto.go new file mode 100644 index 000000000..5fa70d4f4 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/crypto.go @@ -0,0 +1,618 @@ +package mint + +import ( + "bytes" + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/asn1" + "fmt" + "math/big" + + "golang.org/x/crypto/curve25519" + + // Blank includes to ensure hash support + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" +) + +var prng = rand.Reader + +type aeadFactory func(key []byte) (cipher.AEAD, error) + +type CipherSuiteParams struct { + Suite CipherSuite + Cipher aeadFactory // Cipher factory + Hash crypto.Hash // Hash function + KeyLen int // Key length in octets + IvLen int // IV length in octets +} + +type signatureAlgorithm uint8 + +const ( + signatureAlgorithmUnknown = iota + signatureAlgorithmRSA_PKCS1 + signatureAlgorithmRSA_PSS + signatureAlgorithmECDSA +) + +var ( + hashMap = map[SignatureScheme]crypto.Hash{ + RSA_PKCS1_SHA1: crypto.SHA1, + RSA_PKCS1_SHA256: crypto.SHA256, + RSA_PKCS1_SHA384: crypto.SHA384, + RSA_PKCS1_SHA512: crypto.SHA512, + ECDSA_P256_SHA256: crypto.SHA256, + ECDSA_P384_SHA384: crypto.SHA384, + ECDSA_P521_SHA512: crypto.SHA512, + RSA_PSS_SHA256: crypto.SHA256, + RSA_PSS_SHA384: crypto.SHA384, + RSA_PSS_SHA512: crypto.SHA512, + } + + sigMap = map[SignatureScheme]signatureAlgorithm{ + RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1, + RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1, + RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1, + RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1, + ECDSA_P256_SHA256: signatureAlgorithmECDSA, + ECDSA_P384_SHA384: signatureAlgorithmECDSA, + ECDSA_P521_SHA512: signatureAlgorithmECDSA, + RSA_PSS_SHA256: signatureAlgorithmRSA_PSS, + RSA_PSS_SHA384: signatureAlgorithmRSA_PSS, + RSA_PSS_SHA512: signatureAlgorithmRSA_PSS, + } + + curveMap = map[SignatureScheme]NamedGroup{ + ECDSA_P256_SHA256: P256, + ECDSA_P384_SHA384: P384, + ECDSA_P521_SHA512: P521, + } + + newAESGCM = func(key []byte) (cipher.AEAD, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + // TLS always uses 12-byte nonces + return cipher.NewGCMWithNonceSize(block, 12) + } + + cipherSuiteMap = map[CipherSuite]CipherSuiteParams{ + TLS_AES_128_GCM_SHA256: { + Suite: TLS_AES_128_GCM_SHA256, + Cipher: newAESGCM, + Hash: crypto.SHA256, + KeyLen: 16, + IvLen: 12, + }, + TLS_AES_256_GCM_SHA384: { + Suite: TLS_AES_256_GCM_SHA384, + Cipher: newAESGCM, + Hash: crypto.SHA384, + KeyLen: 32, + IvLen: 12, + }, + } + + x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{ + RSA_PKCS1_SHA1: x509.SHA1WithRSA, + RSA_PKCS1_SHA256: x509.SHA256WithRSA, + RSA_PKCS1_SHA384: x509.SHA384WithRSA, + RSA_PKCS1_SHA512: x509.SHA512WithRSA, + ECDSA_P256_SHA256: x509.ECDSAWithSHA256, + ECDSA_P384_SHA384: x509.ECDSAWithSHA384, + ECDSA_P521_SHA512: x509.ECDSAWithSHA512, + } + + defaultRSAKeySize = 2048 +) + +func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) { + switch group { + case P256: + crv = elliptic.P256() + case P384: + crv = elliptic.P384() + case P521: + crv = elliptic.P521() + } + return +} + +func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) { + switch key.Curve.Params().Name { + case elliptic.P256().Params().Name: + g = P256 + case elliptic.P384().Params().Name: + g = P384 + case elliptic.P521().Params().Name: + g = P521 + } + return +} + +func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) { + size = 0 + switch group { + case X25519: + size = 32 + case P256: + size = 65 + case P384: + size = 97 + case P521: + size = 133 + case FFDHE2048: + size = 256 + case FFDHE3072: + size = 384 + case FFDHE4096: + size = 512 + case FFDHE6144: + size = 768 + case FFDHE8192: + size = 1024 + } + return +} + +func primeFromNamedGroup(group NamedGroup) (p *big.Int) { + switch group { + case FFDHE2048: + p = finiteFieldPrime2048 + case FFDHE3072: + p = finiteFieldPrime3072 + case FFDHE4096: + p = finiteFieldPrime4096 + case FFDHE6144: + p = finiteFieldPrime6144 + case FFDHE8192: + p = finiteFieldPrime8192 + } + return +} + +func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool { + sigType := sigMap[alg] + switch key.(type) { + case *rsa.PrivateKey: + return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS + case *ecdsa.PrivateKey: + return sigType == signatureAlgorithmECDSA + default: + return false + } +} + +func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) { + primeLen := len(p.Bytes()) + for { + // g = 2 for all ffdhe groups + priv, err = rand.Int(prng, p) + if err != nil { + return + } + + pub = big.NewInt(0) + pub.Exp(big.NewInt(2), priv, p) + + if len(pub.Bytes()) == primeLen { + return + } + } +} + +func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) { + switch group { + case P256, P384, P521: + var x, y *big.Int + crv := curveFromNamedGroup(group) + priv, x, y, err = elliptic.GenerateKey(crv, prng) + if err != nil { + return + } + + pub = elliptic.Marshal(crv, x, y) + return + + case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192: + p := primeFromNamedGroup(group) + x, X, err2 := ffdheKeyShareFromPrime(p) + if err2 != nil { + err = err2 + return + } + + priv = x.Bytes() + pubBytes := X.Bytes() + + numBytes := keyExchangeSizeFromNamedGroup(group) + + pub = make([]byte, numBytes) + copy(pub[numBytes-len(pubBytes):], pubBytes) + + return + + case X25519: + var private, public [32]byte + _, err = prng.Read(private[:]) + if err != nil { + return + } + + curve25519.ScalarBaseMult(&public, &private) + priv = private[:] + pub = public[:] + return + + default: + return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group) + } +} + +func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) { + switch group { + case P256, P384, P521: + if len(pub) != keyExchangeSizeFromNamedGroup(group) { + return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") + } + + crv := curveFromNamedGroup(group) + pubX, pubY := elliptic.Unmarshal(crv, pub) + x, _ := crv.Params().ScalarMult(pubX, pubY, priv) + xBytes := x.Bytes() + + numBytes := len(crv.Params().P.Bytes()) + + ret := make([]byte, numBytes) + copy(ret[numBytes-len(xBytes):], xBytes) + + return ret, nil + + case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192: + numBytes := keyExchangeSizeFromNamedGroup(group) + if len(pub) != numBytes { + return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") + } + p := primeFromNamedGroup(group) + x := big.NewInt(0).SetBytes(priv) + Y := big.NewInt(0).SetBytes(pub) + ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes() + + ret := make([]byte, numBytes) + copy(ret[numBytes-len(ZBytes):], ZBytes) + + return ret, nil + + case X25519: + if len(pub) != keyExchangeSizeFromNamedGroup(group) { + return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") + } + + var private, public, ret [32]byte + copy(private[:], priv) + copy(public[:], pub) + curve25519.ScalarMult(&ret, &private, &public) + + return ret[:], nil + + default: + return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group) + } +} + +func newSigningKey(sig SignatureScheme) (crypto.Signer, error) { + switch sig { + case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256, + RSA_PKCS1_SHA384, RSA_PKCS1_SHA512, + RSA_PSS_SHA256, RSA_PSS_SHA384, + RSA_PSS_SHA512: + return rsa.GenerateKey(prng, defaultRSAKeySize) + case ECDSA_P256_SHA256: + return ecdsa.GenerateKey(elliptic.P256(), prng) + case ECDSA_P384_SHA384: + return ecdsa.GenerateKey(elliptic.P384(), prng) + case ECDSA_P521_SHA512: + return ecdsa.GenerateKey(elliptic.P521(), prng) + default: + return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig) + } +} + +// XXX(rlb): Copied from crypto/x509 +type ecdsaSignature struct { + R, S *big.Int +} + +func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) { + var opts crypto.SignerOpts + + hash := hashMap[alg] + if hash == crypto.SHA1 { + return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden") + } + + sigType := sigMap[alg] + var realInput []byte + switch key := privateKey.(type) { + case *rsa.PrivateKey: + switch { + case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: + logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size()) + opts = hash + case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: + fallthrough + case sigType == signatureAlgorithmRSA_PSS: + logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size()) + opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash} + default: + return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key") + } + + h := hash.New() + h.Write(sigInput) + realInput = h.Sum(nil) + case *ecdsa.PrivateKey: + if sigType != signatureAlgorithmECDSA { + return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key") + } + + algGroup := curveMap[alg] + keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey)) + if algGroup != keyGroup { + return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination") + } + + h := hash.New() + h.Write(sigInput) + realInput = h.Sum(nil) + default: + return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type") + } + + sig, err := privateKey.Sign(prng, realInput, opts) + logf(logTypeCrypto, "signature: %x", sig) + return sig, err +} + +func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error { + hash := hashMap[alg] + + if hash == crypto.SHA1 { + return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden") + } + + sigType := sigMap[alg] + switch pub := publicKey.(type) { + case *rsa.PublicKey: + switch { + case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: + logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size()) + + h := hash.New() + h.Write(sigInput) + realInput := h.Sum(nil) + return rsa.VerifyPKCS1v15(pub, hash, realInput, sig) + case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: + fallthrough + case sigType == signatureAlgorithmRSA_PSS: + logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size()) + opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash} + + h := hash.New() + h.Write(sigInput) + realInput := h.Sum(nil) + return rsa.VerifyPSS(pub, hash, realInput, sig, opts) + default: + return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key") + } + + case *ecdsa.PublicKey: + if sigType != signatureAlgorithmECDSA { + return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key") + } + + if curveMap[alg] != namedGroupFromECDSAKey(pub) { + return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key") + } + + ecdsaSig := new(ecdsaSignature) + if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil { + return err + } else if len(rest) != 0 { + return fmt.Errorf("tls.verify: trailing data after ECDSA signature") + } + if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { + return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values") + } + + h := hash.New() + h.Write(sigInput) + realInput := h.Sum(nil) + if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) { + return fmt.Errorf("tls.verify: ECDSA verification failure") + } + return nil + default: + return fmt.Errorf("tls.verify: Unsupported key type") + } +} + +// 0 +// | +// v +// PSK -> HKDF-Extract = Early Secret +// | +// +-----> Derive-Secret(., +// | "ext binder" | +// | "res binder", +// | "") +// | = binder_key +// | +// +-----> Derive-Secret(., "c e traffic", +// | ClientHello) +// | = client_early_traffic_secret +// | +// +-----> Derive-Secret(., "e exp master", +// | ClientHello) +// | = early_exporter_master_secret +// v +// Derive-Secret(., "derived", "") +// | +// v +// (EC)DHE -> HKDF-Extract = Handshake Secret +// | +// +-----> Derive-Secret(., "c hs traffic", +// | ClientHello...ServerHello) +// | = client_handshake_traffic_secret +// | +// +-----> Derive-Secret(., "s hs traffic", +// | ClientHello...ServerHello) +// | = server_handshake_traffic_secret +// v +// Derive-Secret(., "derived", "") +// | +// v +// 0 -> HKDF-Extract = Master Secret +// | +// +-----> Derive-Secret(., "c ap traffic", +// | ClientHello...server Finished) +// | = client_application_traffic_secret_0 +// | +// +-----> Derive-Secret(., "s ap traffic", +// | ClientHello...server Finished) +// | = server_application_traffic_secret_0 +// | +// +-----> Derive-Secret(., "exp master", +// | ClientHello...server Finished) +// | = exporter_master_secret +// | +// +-----> Derive-Secret(., "res master", +// ClientHello...client Finished) +// = resumption_master_secret + +// From RFC 5869 +// PRK = HMAC-Hash(salt, IKM) +func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte { + salt := saltIn + + // if [salt is] not provided, it is set to a string of HashLen zeros + if salt == nil { + salt = bytes.Repeat([]byte{0}, hash.Size()) + } + + h := hmac.New(hash.New, salt) + h.Write(input) + out := h.Sum(nil) + + logf(logTypeCrypto, "HKDF Extract:\n") + logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt) + logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input) + logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out) + + return out +} + +const ( + labelExternalBinder = "ext binder" + labelResumptionBinder = "res binder" + labelEarlyTrafficSecret = "c e traffic" + labelEarlyExporterSecret = "e exp master" + labelClientHandshakeTrafficSecret = "c hs traffic" + labelServerHandshakeTrafficSecret = "s hs traffic" + labelClientApplicationTrafficSecret = "c ap traffic" + labelServerApplicationTrafficSecret = "s ap traffic" + labelExporterSecret = "exp master" + labelResumptionSecret = "res master" + labelDerived = "derived" + labelFinished = "finished" + labelResumption = "resumption" +) + +// struct HkdfLabel { +// uint16 length; +// opaque label<9..255>; +// opaque hash_value<0..255>; +// }; +func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte { + label := "tls13 " + labelIn + + labelLen := len(label) + hashLen := len(hashValue) + hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen) + hkdfLabel[0] = byte(outLen >> 8) + hkdfLabel[1] = byte(outLen) + hkdfLabel[2] = byte(labelLen) + copy(hkdfLabel[3:3+labelLen], []byte(label)) + hkdfLabel[3+labelLen] = byte(hashLen) + copy(hkdfLabel[3+labelLen+1:], hashValue) + + return hkdfLabel +} + +func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte { + out := []byte{} + T := []byte{} + i := byte(1) + for len(out) < outLen { + block := append(T, info...) + block = append(block, i) + + h := hmac.New(hash.New, prk) + h.Write(block) + + T = h.Sum(nil) + out = append(out, T...) + i++ + } + return out[:outLen] +} + +func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte { + info := hkdfEncodeLabel(label, hashValue, outLen) + derived := HkdfExpand(hash, secret, info, outLen) + + logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen) + logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret) + logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue) + logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info) + logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived) + + return derived +} + +func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte { + return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size()) +} + +func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte { + macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size()) + mac := hmac.New(params.Hash.New, macKey) + mac.Write(input) + return mac.Sum(nil) +} + +type keySet struct { + cipher aeadFactory + key []byte + iv []byte +} + +func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet { + logf(logTypeCrypto, "making traffic keys: secret=%x", secret) + return keySet{ + cipher: params.Cipher, + key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen), + iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen), + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/dtls.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/dtls.go new file mode 100644 index 000000000..df4f1aa11 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/dtls.go @@ -0,0 +1,28 @@ +package mint + +import ( + "fmt" +) + +// This file is a placeholder. DTLS-specific stuff (timer management, +// ACKs, retransmits, etc. will eventually go here. +const ( + initialMtu = 1200 +) + +func wireVersion(h *HandshakeLayer) uint16 { + if h.datagram { + return dtls12WireVersion + } + return tls12Version +} + +func dtlsConvertVersion(version uint16) uint16 { + if version == tls12Version { + return dtls12WireVersion + } + if version == tls10Version { + return 0xfeff + } + panic(fmt.Sprintf("Internal error, unexpected version=%d", version)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/extensions.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/extensions.go new file mode 100644 index 000000000..07cb16c62 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/extensions.go @@ -0,0 +1,626 @@ +package mint + +import ( + "bytes" + "fmt" + "github.com/bifurcation/mint/syntax" +) + +type ExtensionBody interface { + Type() ExtensionType + Marshal() ([]byte, error) + Unmarshal(data []byte) (int, error) +} + +// struct { +// ExtensionType extension_type; +// opaque extension_data<0..2^16-1>; +// } Extension; +type Extension struct { + ExtensionType ExtensionType + ExtensionData []byte `tls:"head=2"` +} + +func (ext Extension) Marshal() ([]byte, error) { + return syntax.Marshal(ext) +} + +func (ext *Extension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, ext) +} + +type ExtensionList []Extension + +type extensionListInner struct { + List []Extension `tls:"head=2"` +} + +func (el ExtensionList) Marshal() ([]byte, error) { + return syntax.Marshal(extensionListInner{el}) +} + +func (el *ExtensionList) Unmarshal(data []byte) (int, error) { + var list extensionListInner + read, err := syntax.Unmarshal(data, &list) + if err != nil { + return 0, err + } + + *el = list.List + return read, nil +} + +func (el *ExtensionList) Add(src ExtensionBody) error { + data, err := src.Marshal() + if err != nil { + return err + } + + if el == nil { + el = new(ExtensionList) + } + + // If one already exists with this type, replace it + for i := range *el { + if (*el)[i].ExtensionType == src.Type() { + (*el)[i].ExtensionData = data + return nil + } + } + + // Otherwise append + *el = append(*el, Extension{ + ExtensionType: src.Type(), + ExtensionData: data, + }) + return nil +} + +func (el ExtensionList) Parse(dsts []ExtensionBody) (map[ExtensionType]bool, error) { + found := make(map[ExtensionType]bool) + + for _, dst := range dsts { + for _, ext := range el { + if ext.ExtensionType == dst.Type() { + if found[dst.Type()] { + return nil, fmt.Errorf("Duplicate extension of type [%v]", dst.Type()) + } + + err := safeUnmarshal(dst, ext.ExtensionData) + if err != nil { + return nil, err + } + + found[dst.Type()] = true + } + } + } + + return found, nil +} + +func (el ExtensionList) Find(dst ExtensionBody) (bool, error) { + for _, ext := range el { + if ext.ExtensionType == dst.Type() { + err := safeUnmarshal(dst, ext.ExtensionData) + if err != nil { + return true, err + } + return true, nil + } + } + return false, nil +} + +// struct { +// NameType name_type; +// select (name_type) { +// case host_name: HostName; +// } name; +// } ServerName; +// +// enum { +// host_name(0), (255) +// } NameType; +// +// opaque HostName<1..2^16-1>; +// +// struct { +// ServerName server_name_list<1..2^16-1> +// } ServerNameList; +// +// But we only care about the case where there's a single DNS hostname. We +// will never create anything else, and throw if we receive something else +// +// 2 1 2 +// | listLen | NameType | nameLen | name | +type ServerNameExtension string + +type serverNameInner struct { + NameType uint8 + HostName []byte `tls:"head=2,min=1"` +} + +type serverNameListInner struct { + ServerNameList []serverNameInner `tls:"head=2,min=1"` +} + +func (sni ServerNameExtension) Type() ExtensionType { + return ExtensionTypeServerName +} + +func (sni ServerNameExtension) Marshal() ([]byte, error) { + list := serverNameListInner{ + ServerNameList: []serverNameInner{{ + NameType: 0x00, // host_name + HostName: []byte(sni), + }}, + } + + return syntax.Marshal(list) +} + +func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) { + var list serverNameListInner + read, err := syntax.Unmarshal(data, &list) + if err != nil { + return 0, err + } + + // Syntax requires at least one entry + // Entries beyond the first are ignored + if nameType := list.ServerNameList[0].NameType; nameType != 0x00 { + return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType) + } + + *sni = ServerNameExtension(list.ServerNameList[0].HostName) + return read, nil +} + +// struct { +// NamedGroup group; +// opaque key_exchange<1..2^16-1>; +// } KeyShareEntry; +// +// struct { +// select (Handshake.msg_type) { +// case client_hello: +// KeyShareEntry client_shares<0..2^16-1>; +// +// case hello_retry_request: +// NamedGroup selected_group; +// +// case server_hello: +// KeyShareEntry server_share; +// }; +// } KeyShare; +type KeyShareEntry struct { + Group NamedGroup + KeyExchange []byte `tls:"head=2,min=1"` +} + +func (kse KeyShareEntry) SizeValid() bool { + return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group) +} + +type KeyShareExtension struct { + HandshakeType HandshakeType + SelectedGroup NamedGroup + Shares []KeyShareEntry +} + +type KeyShareClientHelloInner struct { + ClientShares []KeyShareEntry `tls:"head=2,min=0"` +} +type KeyShareHelloRetryInner struct { + SelectedGroup NamedGroup +} +type KeyShareServerHelloInner struct { + ServerShare KeyShareEntry +} + +func (ks KeyShareExtension) Type() ExtensionType { + return ExtensionTypeKeyShare +} + +func (ks KeyShareExtension) Marshal() ([]byte, error) { + switch ks.HandshakeType { + case HandshakeTypeClientHello: + for _, share := range ks.Shares { + if !share.SizeValid() { + return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") + } + } + return syntax.Marshal(KeyShareClientHelloInner{ks.Shares}) + + case HandshakeTypeHelloRetryRequest: + if len(ks.Shares) > 0 { + return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest") + } + + return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup}) + + case HandshakeTypeServerHello: + if len(ks.Shares) != 1 { + return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share") + } + + if !ks.Shares[0].SizeValid() { + return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") + } + + return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]}) + + default: + return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed") + } +} + +func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) { + switch ks.HandshakeType { + case HandshakeTypeClientHello: + var inner KeyShareClientHelloInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + for _, share := range inner.ClientShares { + if !share.SizeValid() { + return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") + } + } + + ks.Shares = inner.ClientShares + return read, nil + + case HandshakeTypeHelloRetryRequest: + var inner KeyShareHelloRetryInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + ks.SelectedGroup = inner.SelectedGroup + return read, nil + + case HandshakeTypeServerHello: + var inner KeyShareServerHelloInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + if !inner.ServerShare.SizeValid() { + return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") + } + + ks.Shares = []KeyShareEntry{inner.ServerShare} + return read, nil + + default: + return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed") + } +} + +// struct { +// NamedGroup named_group_list<2..2^16-1>; +// } NamedGroupList; +type SupportedGroupsExtension struct { + Groups []NamedGroup `tls:"head=2,min=2"` +} + +func (sg SupportedGroupsExtension) Type() ExtensionType { + return ExtensionTypeSupportedGroups +} + +func (sg SupportedGroupsExtension) Marshal() ([]byte, error) { + return syntax.Marshal(sg) +} + +func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, sg) +} + +// struct { +// SignatureScheme supported_signature_algorithms<2..2^16-2>; +// } SignatureSchemeList +type SignatureAlgorithmsExtension struct { + Algorithms []SignatureScheme `tls:"head=2,min=2"` +} + +func (sa SignatureAlgorithmsExtension) Type() ExtensionType { + return ExtensionTypeSignatureAlgorithms +} + +func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) { + return syntax.Marshal(sa) +} + +func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, sa) +} + +// struct { +// opaque identity<1..2^16-1>; +// uint32 obfuscated_ticket_age; +// } PskIdentity; +// +// opaque PskBinderEntry<32..255>; +// +// struct { +// select (Handshake.msg_type) { +// case client_hello: +// PskIdentity identities<7..2^16-1>; +// PskBinderEntry binders<33..2^16-1>; +// +// case server_hello: +// uint16 selected_identity; +// }; +// +// } PreSharedKeyExtension; +type PSKIdentity struct { + Identity []byte `tls:"head=2,min=1"` + ObfuscatedTicketAge uint32 +} + +type PSKBinderEntry struct { + Binder []byte `tls:"head=1,min=32"` +} + +type PreSharedKeyExtension struct { + HandshakeType HandshakeType + Identities []PSKIdentity + Binders []PSKBinderEntry + SelectedIdentity uint16 +} + +type preSharedKeyClientInner struct { + Identities []PSKIdentity `tls:"head=2,min=7"` + Binders []PSKBinderEntry `tls:"head=2,min=33"` +} + +type preSharedKeyServerInner struct { + SelectedIdentity uint16 +} + +func (psk PreSharedKeyExtension) Type() ExtensionType { + return ExtensionTypePreSharedKey +} + +func (psk PreSharedKeyExtension) Marshal() ([]byte, error) { + switch psk.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Marshal(preSharedKeyClientInner{ + Identities: psk.Identities, + Binders: psk.Binders, + }) + + case HandshakeTypeServerHello: + if len(psk.Identities) > 0 || len(psk.Binders) > 0 { + return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index") + } + return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity}) + + default: + return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported") + } +} + +func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) { + switch psk.HandshakeType { + case HandshakeTypeClientHello: + var inner preSharedKeyClientInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + if len(inner.Identities) != len(inner.Binders) { + return 0, fmt.Errorf("Lengths of identities and binders not equal") + } + + psk.Identities = inner.Identities + psk.Binders = inner.Binders + return read, nil + + case HandshakeTypeServerHello: + var inner preSharedKeyServerInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + psk.SelectedIdentity = inner.SelectedIdentity + return read, nil + + default: + return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported") + } +} + +func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) { + for i, localID := range psk.Identities { + if bytes.Equal(localID.Identity, id) { + return psk.Binders[i].Binder, true + } + } + return nil, false +} + +// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode; +// +// struct { +// PskKeyExchangeMode ke_modes<1..255>; +// } PskKeyExchangeModes; +type PSKKeyExchangeModesExtension struct { + KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"` +} + +func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType { + return ExtensionTypePSKKeyExchangeModes +} + +func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) { + return syntax.Marshal(pkem) +} + +func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, pkem) +} + +// struct { +// } EarlyDataIndication; + +type EarlyDataExtension struct{} + +func (ed EarlyDataExtension) Type() ExtensionType { + return ExtensionTypeEarlyData +} + +func (ed EarlyDataExtension) Marshal() ([]byte, error) { + return []byte{}, nil +} + +func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) { + return 0, nil +} + +// struct { +// uint32 max_early_data_size; +// } TicketEarlyDataInfo; + +type TicketEarlyDataInfoExtension struct { + MaxEarlyDataSize uint32 +} + +func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType { + return ExtensionTypeTicketEarlyDataInfo +} + +func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) { + return syntax.Marshal(tedi) +} + +func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, tedi) +} + +// opaque ProtocolName<1..2^8-1>; +// +// struct { +// ProtocolName protocol_name_list<2..2^16-1> +// } ProtocolNameList; +type ALPNExtension struct { + Protocols []string +} + +type protocolNameInner struct { + Name []byte `tls:"head=1,min=1"` +} + +type alpnExtensionInner struct { + Protocols []protocolNameInner `tls:"head=2,min=2"` +} + +func (alpn ALPNExtension) Type() ExtensionType { + return ExtensionTypeALPN +} + +func (alpn ALPNExtension) Marshal() ([]byte, error) { + protocols := make([]protocolNameInner, len(alpn.Protocols)) + for i, protocol := range alpn.Protocols { + protocols[i] = protocolNameInner{[]byte(protocol)} + } + return syntax.Marshal(alpnExtensionInner{protocols}) +} + +func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) { + var inner alpnExtensionInner + read, err := syntax.Unmarshal(data, &inner) + + if err != nil { + return 0, err + } + + alpn.Protocols = make([]string, len(inner.Protocols)) + for i, protocol := range inner.Protocols { + alpn.Protocols[i] = string(protocol.Name) + } + return read, nil +} + +// struct { +// ProtocolVersion versions<2..254>; +// } SupportedVersions; +type SupportedVersionsExtension struct { + HandshakeType HandshakeType + Versions []uint16 +} + +type SupportedVersionsClientHelloInner struct { + Versions []uint16 `tls:"head=1,min=2,max=254"` +} + +type SupportedVersionsServerHelloInner struct { + Version uint16 +} + +func (sv SupportedVersionsExtension) Type() ExtensionType { + return ExtensionTypeSupportedVersions +} + +func (sv SupportedVersionsExtension) Marshal() ([]byte, error) { + switch sv.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Marshal(SupportedVersionsClientHelloInner{sv.Versions}) + case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest: + return syntax.Marshal(SupportedVersionsServerHelloInner{sv.Versions[0]}) + default: + return nil, fmt.Errorf("tls.supported_versions: Handshake type not allowed") + } +} + +func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) { + switch sv.HandshakeType { + case HandshakeTypeClientHello: + var inner SupportedVersionsClientHelloInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + sv.Versions = inner.Versions + return read, nil + + case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest: + var inner SupportedVersionsServerHelloInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + sv.Versions = []uint16{inner.Version} + return read, nil + + default: + return 0, fmt.Errorf("tls.supported_versions: Handshake type not allowed") + } +} + +// struct { +// opaque cookie<1..2^16-1>; +// } Cookie; +type CookieExtension struct { + Cookie []byte `tls:"head=2,min=1"` +} + +func (c CookieExtension) Type() ExtensionType { + return ExtensionTypeCookie +} + +func (c CookieExtension) Marshal() ([]byte, error) { + return syntax.Marshal(c) +} + +func (c *CookieExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, c) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/ffdhe.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/ffdhe.go new file mode 100644 index 000000000..59d1f7f9d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/ffdhe.go @@ -0,0 +1,147 @@ +package mint + +import ( + "encoding/hex" + "math/big" +) + +var ( + finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B423861285C97FFFFFFFFFFFFFFFF" + finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex) + finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes) + + finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + + "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + + "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + + "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + + "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + + "3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF" + finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex) + finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes) + + finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + + "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + + "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + + "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + + "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + + "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + + "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + + "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + + "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + + "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + + "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" + + "FFFFFFFFFFFFFFFF" + finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex) + finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes) + + finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + + "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + + "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + + "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + + "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + + "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + + "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + + "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + + "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + + "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + + "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" + + "0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" + + "3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" + + "CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" + + "A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" + + "0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" + + "763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" + + "B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" + + "D72B03746AE77F5E62292C311562A846505DC82DB854338A" + + "E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" + + "5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" + + "A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF" + finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex) + finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes) + + finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + + "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + + "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + + "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + + "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + + "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + + "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + + "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + + "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + + "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + + "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" + + "0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" + + "3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" + + "CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" + + "A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" + + "0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" + + "763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" + + "B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" + + "D72B03746AE77F5E62292C311562A846505DC82DB854338A" + + "E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" + + "5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" + + "A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" + + "1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" + + "0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" + + "CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" + + "2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" + + "BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" + + "51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" + + "D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" + + "1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" + + "FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" + + "97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" + + "D68C8BB7C5C6424CFFFFFFFFFFFFFFFF" + finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex) + finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes) +) diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/frame-reader.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/frame-reader.go new file mode 100644 index 000000000..54f40ce2c --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/frame-reader.go @@ -0,0 +1,98 @@ +// Read a generic "framed" packet consisting of a header and a +// This is used for both TLS Records and TLS Handshake Messages +package mint + +type framing interface { + headerLen() int + defaultReadLen() int + frameLen(hdr []byte) (int, error) +} + +const ( + kFrameReaderHdr = 0 + kFrameReaderBody = 1 +) + +type frameNextAction func(f *frameReader) error + +type frameReader struct { + details framing + state uint8 + header []byte + body []byte + working []byte + writeOffset int + remainder []byte +} + +func newFrameReader(d framing) *frameReader { + hdr := make([]byte, d.headerLen()) + return &frameReader{ + d, + kFrameReaderHdr, + hdr, + nil, + hdr, + 0, + nil, + } +} + +func dup(a []byte) []byte { + r := make([]byte, len(a)) + copy(r, a) + return r +} + +func (f *frameReader) needed() int { + tmp := (len(f.working) - f.writeOffset) - len(f.remainder) + if tmp < 0 { + return 0 + } + return tmp +} + +func (f *frameReader) addChunk(in []byte) { + // Append to the buffer. + logf(logTypeFrameReader, "Appending %v", len(in)) + f.remainder = append(f.remainder, in...) +} + +func (f *frameReader) process() (hdr []byte, body []byte, err error) { + for f.needed() == 0 { + logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset) + // Fill out our working block + copied := copy(f.working[f.writeOffset:], f.remainder) + f.remainder = f.remainder[copied:] + f.writeOffset += copied + if f.writeOffset < len(f.working) { + logf(logTypeVerbose, "Read would have blocked 1") + return nil, nil, WouldBlock + } + // Reset the write offset, because we are now full. + f.writeOffset = 0 + + // We have read a full frame + if f.state == kFrameReaderBody { + logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder)) + f.state = kFrameReaderHdr + f.working = f.header + return dup(f.header), dup(f.body), nil + } + + // We have read the header + bodyLen, err := f.details.frameLen(f.header) + if err != nil { + return nil, nil, err + } + logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen) + + f.body = make([]byte, bodyLen) + f.working = f.body + f.writeOffset = 0 + f.state = kFrameReaderBody + } + + logf(logTypeVerbose, "Read would have blocked 2") + return nil, nil, WouldBlock +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/handshake-layer.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/handshake-layer.go new file mode 100644 index 000000000..888c5f364 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/handshake-layer.go @@ -0,0 +1,495 @@ +package mint + +import ( + "fmt" + "io" + "net" +) + +const ( + handshakeHeaderLenTLS = 4 // handshake message header length + handshakeHeaderLenDTLS = 12 // handshake message header length + maxHandshakeMessageLen = 1 << 24 // max handshake message length +) + +// struct { +// HandshakeType msg_type; /* handshake type */ +// uint24 length; /* bytes in message */ +// select (HandshakeType) { +// ... +// } body; +// } Handshake; +// +// We do the select{...} part in a different layer, so we treat the +// actual message body as opaque: +// +// struct { +// HandshakeType msg_type; +// opaque msg<0..2^24-1> +// } Handshake; +// +type HandshakeMessage struct { + msgType HandshakeType + seq uint32 + body []byte + datagram bool + offset uint32 // Used for DTLS + length uint32 + records []uint64 // Used for DTLS + cipher *cipherState +} + +// Note: This could be done with the `syntax` module, using the simplified +// syntax as discussed above. However, since this is so simple, there's not +// much benefit to doing so. +// When datagram is set, we marshal this as a whole DTLS record. +func (hm *HandshakeMessage) Marshal() []byte { + if hm == nil { + return []byte{} + } + + fragLen := len(hm.body) + var data []byte + + if hm.datagram { + data = make([]byte, handshakeHeaderLenDTLS+fragLen) + } else { + data = make([]byte, handshakeHeaderLenTLS+fragLen) + } + tmp := data + tmp = encodeUint(uint64(hm.msgType), 1, tmp) + tmp = encodeUint(uint64(hm.length), 3, tmp) + if hm.datagram { + tmp = encodeUint(uint64(hm.seq), 2, tmp) + tmp = encodeUint(uint64(hm.offset), 3, tmp) + tmp = encodeUint(uint64(fragLen), 3, tmp) + } + copy(tmp, hm.body) + return data +} + +func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) { + logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body) + + var body HandshakeMessageBody + switch hm.msgType { + case HandshakeTypeClientHello: + body = new(ClientHelloBody) + case HandshakeTypeServerHello: + body = new(ServerHelloBody) + case HandshakeTypeEncryptedExtensions: + body = new(EncryptedExtensionsBody) + case HandshakeTypeCertificate: + body = new(CertificateBody) + case HandshakeTypeCertificateRequest: + body = new(CertificateRequestBody) + case HandshakeTypeCertificateVerify: + body = new(CertificateVerifyBody) + case HandshakeTypeFinished: + body = &FinishedBody{VerifyDataLen: len(hm.body)} + case HandshakeTypeNewSessionTicket: + body = new(NewSessionTicketBody) + case HandshakeTypeKeyUpdate: + body = new(KeyUpdateBody) + case HandshakeTypeEndOfEarlyData: + body = new(EndOfEarlyDataBody) + default: + return body, fmt.Errorf("tls.handshakemessage: Unsupported body type") + } + + err := safeUnmarshal(body, hm.body) + return body, err +} + +func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) { + data, err := body.Marshal() + if err != nil { + return nil, err + } + + m := &HandshakeMessage{ + msgType: body.Type(), + body: data, + seq: h.msgSeq, + datagram: h.datagram, + length: uint32(len(data)), + } + h.msgSeq++ + return m, nil +} + +type HandshakeLayer struct { + nonblocking bool // Should we operate in nonblocking mode + conn *RecordLayer // Used for reading/writing records + frame *frameReader // The buffered frame reader + datagram bool // Is this DTLS? + msgSeq uint32 // The DTLS message sequence number + queued []*HandshakeMessage // In/out queue + sent []*HandshakeMessage // Sent messages for DTLS + maxFragmentLen int +} + +type handshakeLayerFrameDetails struct { + datagram bool +} + +func (d handshakeLayerFrameDetails) headerLen() int { + if d.datagram { + return handshakeHeaderLenDTLS + } + return handshakeHeaderLenTLS +} + +func (d handshakeLayerFrameDetails) defaultReadLen() int { + return d.headerLen() + maxFragmentLen +} + +func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { + logf(logTypeIO, "Header=%x", hdr) + // The length of this fragment (as opposed to the message) + // is always the last three bytes for both TLS and DTLS + val, _ := decodeUint(hdr[len(hdr)-3:], 3) + return int(val), nil +} + +func NewHandshakeLayerTLS(r *RecordLayer) *HandshakeLayer { + h := HandshakeLayer{} + h.conn = r + h.datagram = false + h.frame = newFrameReader(&handshakeLayerFrameDetails{false}) + h.maxFragmentLen = maxFragmentLen + return &h +} + +func NewHandshakeLayerDTLS(r *RecordLayer) *HandshakeLayer { + h := HandshakeLayer{} + h.conn = r + h.datagram = true + h.frame = newFrameReader(&handshakeLayerFrameDetails{true}) + h.maxFragmentLen = initialMtu // Not quite right + return &h +} + +func (h *HandshakeLayer) readRecord() error { + logf(logTypeVerbose, "Trying to read record") + pt, err := h.conn.ReadRecord() + if err != nil { + return err + } + + if pt.contentType != RecordTypeHandshake && + pt.contentType != RecordTypeAlert { + return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType) + } + + if pt.contentType == RecordTypeAlert { + logf(logTypeIO, "read alert %v", pt.fragment[1]) + if len(pt.fragment) < 2 { + h.sendAlert(AlertUnexpectedMessage) + return io.EOF + } + return Alert(pt.fragment[1]) + } + + h.frame.addChunk(pt.fragment) + + return nil +} + +// sendAlert sends a TLS alert message. +func (h *HandshakeLayer) sendAlert(err Alert) error { + tmp := make([]byte, 2) + tmp[0] = AlertLevelError + tmp[1] = byte(err) + h.conn.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeAlert, + fragment: tmp}, + ) + + // closeNotify is a special case in that it isn't an error: + if err != AlertCloseNotify { + return &net.OpError{Op: "local error", Err: err} + } + return nil +} + +func (h *HandshakeLayer) noteMessageDelivered(seq uint32) { + h.msgSeq = seq + 1 + var i int + var m *HandshakeMessage + for i, m = range h.queued { + if m.seq > seq { + break + } + } + h.queued = h.queued[i:] +} + +func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) { + if hm.seq < h.msgSeq { + return nil, WouldBlock + } + + if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) { + // TODO(ekr@rtfm.com): Check the length? + // This is complete. + h.noteMessageDelivered(hm.seq) + return hm, nil + } + + // Now insert sorted. + var i int + for i = 0; i < len(h.queued); i++ { + f := h.queued[i] + if hm.seq < f.seq { + break + } + if hm.offset < f.offset { + break + } + } + tmp := make([]*HandshakeMessage, 0, len(h.queued)+1) + tmp = append(tmp, h.queued[:i]...) + tmp = append(tmp, hm) + tmp = append(tmp, h.queued[i:]...) + h.queued = tmp + + return h.checkMessageAvailable() +} + +func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) { + if len(h.queued) == 0 { + return nil, WouldBlock + } + + hm := h.queued[0] + if hm.seq != h.msgSeq { + return nil, WouldBlock + } + + if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) { + // TODO(ekr@rtfm.com): Check the length? + // This is complete. + h.noteMessageDelivered(hm.seq) + return hm, nil + } + + // OK, this at least might complete the message. + end := uint32(0) + buf := make([]byte, hm.length) + + for _, f := range h.queued { + // Out of fragments + if f.seq > hm.seq { + break + } + + if f.length != uint32(len(buf)) { + return nil, fmt.Errorf("Mismatched DTLS length") + } + + if f.offset > end { + break + } + + if f.offset+uint32(len(f.body)) > end { + // OK, this is adding something we don't know about + copy(buf[f.offset:], f.body) + end = f.offset + uint32(len(f.body)) + if end == hm.length { + h2 := *hm + h2.offset = 0 + h2.body = buf + h.noteMessageDelivered(hm.seq) + return &h2, nil + } + } + + } + + return nil, WouldBlock +} + +func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { + var hdr, body []byte + var err error + + hm, err := h.checkMessageAvailable() + if err == nil { + return hm, err + } + if err != WouldBlock { + return nil, err + } + for { + logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder)) + if h.frame.needed() > 0 { + logf(logTypeVerbose, "Trying to read a new record") + err = h.readRecord() + + if err != nil && (h.nonblocking || err != WouldBlock) { + return nil, err + } + } + + hdr, body, err = h.frame.process() + if err == nil { + break + } + if err != nil && (h.nonblocking || err != WouldBlock) { + return nil, err + } + } + + logf(logTypeHandshake, "read handshake message") + + hm = &HandshakeMessage{} + hm.msgType = HandshakeType(hdr[0]) + hm.datagram = h.datagram + hm.body = make([]byte, len(body)) + copy(hm.body, body) + logf(logTypeHandshake, "Read message with type: %v", hm.msgType) + if h.datagram { + tmp, hdr := decodeUint(hdr[1:], 3) + hm.length = uint32(tmp) + tmp, hdr = decodeUint(hdr, 2) + hm.seq = uint32(tmp) + tmp, hdr = decodeUint(hdr, 3) + hm.offset = uint32(tmp) + + return h.newFragmentReceived(hm) + } + + hm.length = uint32(len(body)) + return hm, nil +} + +func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error { + hm.cipher = h.conn.cipher + h.queued = append(h.queued, hm) + return nil +} + +func (h *HandshakeLayer) SendQueuedMessages() error { + logf(logTypeHandshake, "Sending outgoing messages") + err := h.WriteMessages(h.queued) + h.ClearQueuedMessages() // This isn't going to work for DTLS, but we'll + // get there. + return err +} + +func (h *HandshakeLayer) ClearQueuedMessages() { + logf(logTypeHandshake, "Clearing outgoing hs message queue") + h.queued = nil +} + +func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (int, error) { + var buf []byte + + // Figure out if we're going to want the full header or just + // the body + hdrlen := 0 + if hm.datagram { + hdrlen = handshakeHeaderLenDTLS + } else if start == 0 { + hdrlen = handshakeHeaderLenTLS + } + + // Compute the amount of body we can fit in + room -= hdrlen + if room == 0 { + // This works because we are doing one record per + // message + panic("Too short max fragment len") + } + bodylen := len(hm.body) - start + if bodylen > room { + bodylen = room + } + body := hm.body[start : start+bodylen] + + // Encode the data. + if hdrlen > 0 { + hm2 := *hm + hm2.offset = uint32(start) + hm2.body = body + buf = hm2.Marshal() + } else { + buf = body + } + + return start + bodylen, h.conn.writeRecordWithPadding( + &TLSPlaintext{ + contentType: RecordTypeHandshake, + fragment: buf, + }, + hm.cipher, 0) +} + +func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error { + start := int(0) + + if len(hm.body) > maxHandshakeMessageLen { + return fmt.Errorf("Tried to write a handshake message that's too long") + } + + // Always make one pass through to allow EOED (which is empty). + for { + var err error + start, err = h.writeFragment(hm, start, h.maxFragmentLen) + if err != nil { + return err + } + if start >= len(hm.body) { + break + } + } + + return nil +} + +func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error { + for _, hm := range hms { + logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body) + + err := h.WriteMessage(hm) + if err != nil { + return err + } + } + return nil +} + +func encodeUint(v uint64, size int, out []byte) []byte { + for i := size - 1; i >= 0; i-- { + out[i] = byte(v & 0xff) + v >>= 8 + } + return out[size:] +} + +func decodeUint(in []byte, size int) (uint64, []byte) { + val := uint64(0) + + for i := 0; i < size; i++ { + val <<= 8 + val += uint64(in[i]) + } + return val, in[size:] +} + +type marshalledPDU interface { + Marshal() ([]byte, error) + Unmarshal(data []byte) (int, error) +} + +func safeUnmarshal(pdu marshalledPDU, data []byte) error { + read, err := pdu.Unmarshal(data) + if err != nil { + return err + } + if len(data) != read { + return fmt.Errorf("Invalid encoding: Extra data not consumed") + } + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/handshake-messages.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/handshake-messages.go new file mode 100644 index 000000000..5a229f1d0 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/handshake-messages.go @@ -0,0 +1,481 @@ +package mint + +import ( + "bytes" + "crypto" + "crypto/x509" + "encoding/binary" + "fmt" + + "github.com/bifurcation/mint/syntax" +) + +type HandshakeMessageBody interface { + Type() HandshakeType + Marshal() ([]byte, error) + Unmarshal(data []byte) (int, error) +} + +// struct { +// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */ +// Random random; +// opaque legacy_session_id<0..32>; +// CipherSuite cipher_suites<2..2^16-2>; +// opaque legacy_compression_methods<1..2^8-1>; +// Extension extensions<0..2^16-1>; +// } ClientHello; +type ClientHelloBody struct { + LegacyVersion uint16 + Random [32]byte + LegacySessionID []byte + CipherSuites []CipherSuite + Extensions ExtensionList +} + +type clientHelloBodyInnerTLS struct { + LegacyVersion uint16 + Random [32]byte + LegacySessionID []byte `tls:"head=1,max=32"` + CipherSuites []CipherSuite `tls:"head=2,min=2"` + LegacyCompressionMethods []byte `tls:"head=1,min=1"` + Extensions []Extension `tls:"head=2"` +} + +type clientHelloBodyInnerDTLS struct { + LegacyVersion uint16 + Random [32]byte + LegacySessionID []byte `tls:"head=1,max=32"` + EmptyCookie uint8 + CipherSuites []CipherSuite `tls:"head=2,min=2"` + LegacyCompressionMethods []byte `tls:"head=1,min=1"` + Extensions []Extension `tls:"head=2"` +} + +func (ch ClientHelloBody) Type() HandshakeType { + return HandshakeTypeClientHello +} + +func (ch ClientHelloBody) Marshal() ([]byte, error) { + if ch.LegacyVersion == tls12Version { + return syntax.Marshal(clientHelloBodyInnerTLS{ + LegacyVersion: ch.LegacyVersion, + Random: ch.Random, + LegacySessionID: []byte{}, + CipherSuites: ch.CipherSuites, + LegacyCompressionMethods: []byte{0}, + Extensions: ch.Extensions, + }) + } else { + return syntax.Marshal(clientHelloBodyInnerDTLS{ + LegacyVersion: ch.LegacyVersion, + Random: ch.Random, + LegacySessionID: []byte{}, + CipherSuites: ch.CipherSuites, + LegacyCompressionMethods: []byte{0}, + Extensions: ch.Extensions, + }) + } + +} + +func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) { + var read int + var err error + + // Note that this might be 0, in which case we do TLS. That + // makes the tests easier. + if ch.LegacyVersion != dtls12WireVersion { + var inner clientHelloBodyInnerTLS + read, err = syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 { + return 0, fmt.Errorf("tls.clienthello: Invalid compression method") + } + + ch.LegacyVersion = inner.LegacyVersion + ch.Random = inner.Random + ch.LegacySessionID = inner.LegacySessionID + ch.CipherSuites = inner.CipherSuites + ch.Extensions = inner.Extensions + } else { + var inner clientHelloBodyInnerDTLS + read, err = syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + if inner.EmptyCookie != 0 { + return 0, fmt.Errorf("tls.clienthello: Invalid cookie") + } + + if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 { + return 0, fmt.Errorf("tls.clienthello: Invalid compression method") + } + + ch.LegacyVersion = inner.LegacyVersion + ch.Random = inner.Random + ch.LegacySessionID = inner.LegacySessionID + ch.CipherSuites = inner.CipherSuites + ch.Extensions = inner.Extensions + } + return read, nil +} + +// TODO: File a spec bug to clarify this +func (ch ClientHelloBody) Truncated() ([]byte, error) { + if len(ch.Extensions) == 0 { + return nil, fmt.Errorf("tls.clienthello.truncate: No extensions") + } + + pskExt := ch.Extensions[len(ch.Extensions)-1] + if pskExt.ExtensionType != ExtensionTypePreSharedKey { + return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK") + } + + body, err := ch.Marshal() + if err != nil { + return nil, err + } + chm := &HandshakeMessage{ + msgType: ch.Type(), + body: body, + length: uint32(len(body)), + } + chData := chm.Marshal() + + psk := PreSharedKeyExtension{ + HandshakeType: HandshakeTypeClientHello, + } + _, err = psk.Unmarshal(pskExt.ExtensionData) + if err != nil { + return nil, err + } + + // Marshal just the binders so that we know how much to truncate + binders := struct { + Binders []PSKBinderEntry `tls:"head=2,min=33"` + }{Binders: psk.Binders} + binderData, _ := syntax.Marshal(binders) + binderLen := len(binderData) + + chLen := len(chData) + return chData[:chLen-binderLen], nil +} + +// struct { +// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */ +// Random random; +// opaque legacy_session_id_echo<0..32>; +// CipherSuite cipher_suite; +// uint8 legacy_compression_method = 0; +// Extension extensions<6..2^16-1>; +// } ServerHello; +type ServerHelloBody struct { + Version uint16 + Random [32]byte + LegacySessionID []byte `tls:"head=1,max=32"` + CipherSuite CipherSuite + LegacyCompressionMethod uint8 + Extensions ExtensionList `tls:"head=2"` +} + +func (sh ServerHelloBody) Type() HandshakeType { + return HandshakeTypeServerHello +} + +func (sh ServerHelloBody) Marshal() ([]byte, error) { + return syntax.Marshal(sh) +} + +func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, sh) +} + +// struct { +// opaque verify_data[verify_data_length]; +// } Finished; +// +// verifyDataLen is not a field in the TLS struct, but we add it here so +// that calling code can tell us how much data to expect when we marshal / +// unmarshal. (We could add this to the marshal/unmarshal methods, but let's +// try to keep the signature consistent for now.) +// +// For similar reasons, we don't use the `syntax` module here, because this +// struct doesn't map well to standard TLS presentation language concepts. +// +// TODO: File a spec bug +type FinishedBody struct { + VerifyDataLen int + VerifyData []byte +} + +func (fin FinishedBody) Type() HandshakeType { + return HandshakeTypeFinished +} + +func (fin FinishedBody) Marshal() ([]byte, error) { + if len(fin.VerifyData) != fin.VerifyDataLen { + return nil, fmt.Errorf("tls.finished: data length mismatch") + } + + body := make([]byte, len(fin.VerifyData)) + copy(body, fin.VerifyData) + return body, nil +} + +func (fin *FinishedBody) Unmarshal(data []byte) (int, error) { + if len(data) < fin.VerifyDataLen { + return 0, fmt.Errorf("tls.finished: Malformed finished; too short") + } + + fin.VerifyData = make([]byte, fin.VerifyDataLen) + copy(fin.VerifyData, data[:fin.VerifyDataLen]) + return fin.VerifyDataLen, nil +} + +// struct { +// Extension extensions<0..2^16-1>; +// } EncryptedExtensions; +// +// Marshal() and Unmarshal() are handled by ExtensionList +type EncryptedExtensionsBody struct { + Extensions ExtensionList `tls:"head=2"` +} + +func (ee EncryptedExtensionsBody) Type() HandshakeType { + return HandshakeTypeEncryptedExtensions +} + +func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) { + return syntax.Marshal(ee) +} + +func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, ee) +} + +// opaque ASN1Cert<1..2^24-1>; +// +// struct { +// ASN1Cert cert_data; +// Extension extensions<0..2^16-1> +// } CertificateEntry; +// +// struct { +// opaque certificate_request_context<0..2^8-1>; +// CertificateEntry certificate_list<0..2^24-1>; +// } Certificate; +type CertificateEntry struct { + CertData *x509.Certificate + Extensions ExtensionList +} + +type CertificateBody struct { + CertificateRequestContext []byte + CertificateList []CertificateEntry +} + +type certificateEntryInner struct { + CertData []byte `tls:"head=3,min=1"` + Extensions ExtensionList `tls:"head=2"` +} + +type certificateBodyInner struct { + CertificateRequestContext []byte `tls:"head=1"` + CertificateList []certificateEntryInner `tls:"head=3"` +} + +func (c CertificateBody) Type() HandshakeType { + return HandshakeTypeCertificate +} + +func (c CertificateBody) Marshal() ([]byte, error) { + inner := certificateBodyInner{ + CertificateRequestContext: c.CertificateRequestContext, + CertificateList: make([]certificateEntryInner, len(c.CertificateList)), + } + + for i, entry := range c.CertificateList { + inner.CertificateList[i] = certificateEntryInner{ + CertData: entry.CertData.Raw, + Extensions: entry.Extensions, + } + } + + return syntax.Marshal(inner) +} + +func (c *CertificateBody) Unmarshal(data []byte) (int, error) { + inner := certificateBodyInner{} + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return read, err + } + + c.CertificateRequestContext = inner.CertificateRequestContext + c.CertificateList = make([]CertificateEntry, len(inner.CertificateList)) + + for i, entry := range inner.CertificateList { + c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData) + if err != nil { + return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err) + } + + c.CertificateList[i].Extensions = entry.Extensions + } + + return read, nil +} + +// struct { +// SignatureScheme algorithm; +// opaque signature<0..2^16-1>; +// } CertificateVerify; +type CertificateVerifyBody struct { + Algorithm SignatureScheme + Signature []byte `tls:"head=2"` +} + +func (cv CertificateVerifyBody) Type() HandshakeType { + return HandshakeTypeCertificateVerify +} + +func (cv CertificateVerifyBody) Marshal() ([]byte, error) { + return syntax.Marshal(cv) +} + +func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, cv) +} + +func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte { + // TODO: Change context for client auth + // TODO: Put this in a const + const context = "TLS 1.3, server CertificateVerify" + sigInput := bytes.Repeat([]byte{0x20}, 64) + sigInput = append(sigInput, []byte(context)...) + sigInput = append(sigInput, []byte{0}...) + sigInput = append(sigInput, data...) + return sigInput +} + +func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) { + sigInput := cv.EncodeSignatureInput(handshakeHash) + cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput) + logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature) + return +} + +func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error { + sigInput := cv.EncodeSignatureInput(handshakeHash) + logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature) + return verify(cv.Algorithm, publicKey, sigInput, cv.Signature) +} + +// struct { +// opaque certificate_request_context<0..2^8-1>; +// Extension extensions<2..2^16-1>; +// } CertificateRequest; +type CertificateRequestBody struct { + CertificateRequestContext []byte `tls:"head=1"` + Extensions ExtensionList `tls:"head=2"` +} + +func (cr CertificateRequestBody) Type() HandshakeType { + return HandshakeTypeCertificateRequest +} + +func (cr CertificateRequestBody) Marshal() ([]byte, error) { + return syntax.Marshal(cr) +} + +func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, cr) +} + +// struct { +// uint32 ticket_lifetime; +// uint32 ticket_age_add; +// opaque ticket_nonce<1..255>; +// opaque ticket<1..2^16-1>; +// Extension extensions<0..2^16-2>; +// } NewSessionTicket; +type NewSessionTicketBody struct { + TicketLifetime uint32 + TicketAgeAdd uint32 + TicketNonce []byte `tls:"head=1,min=1"` + Ticket []byte `tls:"head=2,min=1"` + Extensions ExtensionList `tls:"head=2"` +} + +const ticketNonceLen = 16 + +func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) { + buf := make([]byte, 4+ticketNonceLen+ticketLen) + _, err := prng.Read(buf) + if err != nil { + return nil, err + } + + tkt := &NewSessionTicketBody{ + TicketLifetime: ticketLifetime, + TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]), + TicketNonce: buf[4 : 4+ticketNonceLen], + Ticket: buf[4+ticketNonceLen:], + } + + return tkt, err +} + +func (tkt NewSessionTicketBody) Type() HandshakeType { + return HandshakeTypeNewSessionTicket +} + +func (tkt NewSessionTicketBody) Marshal() ([]byte, error) { + return syntax.Marshal(tkt) +} + +func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, tkt) +} + +// enum { +// update_not_requested(0), update_requested(1), (255) +// } KeyUpdateRequest; +// +// struct { +// KeyUpdateRequest request_update; +// } KeyUpdate; +type KeyUpdateBody struct { + KeyUpdateRequest KeyUpdateRequest +} + +func (ku KeyUpdateBody) Type() HandshakeType { + return HandshakeTypeKeyUpdate +} + +func (ku KeyUpdateBody) Marshal() ([]byte, error) { + return syntax.Marshal(ku) +} + +func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, ku) +} + +// struct {} EndOfEarlyData; +type EndOfEarlyDataBody struct{} + +func (eoed EndOfEarlyDataBody) Type() HandshakeType { + return HandshakeTypeEndOfEarlyData +} + +func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) { + return []byte{}, nil +} + +func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) { + return 0, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/log.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/log.go new file mode 100644 index 000000000..2fba90de7 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/log.go @@ -0,0 +1,55 @@ +package mint + +import ( + "fmt" + "log" + "os" + "strings" +) + +// We use this environment variable to control logging. It should be a +// comma-separated list of log tags (see below) or "*" to enable all logging. +const logConfigVar = "MINT_LOG" + +// Pre-defined log types +const ( + logTypeCrypto = "crypto" + logTypeHandshake = "handshake" + logTypeNegotiation = "negotiation" + logTypeIO = "io" + logTypeFrameReader = "frame" + logTypeVerbose = "verbose" +) + +var ( + logFunction = log.Printf + logAll = false + logSettings = map[string]bool{} +) + +func init() { + parseLogEnv(os.Environ()) +} + +func parseLogEnv(env []string) { + for _, stmt := range env { + if strings.HasPrefix(stmt, logConfigVar+"=") { + val := stmt[len(logConfigVar)+1:] + + if val == "*" { + logAll = true + } else { + for _, t := range strings.Split(val, ",") { + logSettings[t] = true + } + } + } + } +} + +func logf(tag string, format string, args ...interface{}) { + if logAll || logSettings[tag] { + fullFormat := fmt.Sprintf("[%s] %s", tag, format) + logFunction(fullFormat, args...) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/negotiation.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/negotiation.go new file mode 100644 index 000000000..4697bbc80 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/negotiation.go @@ -0,0 +1,217 @@ +package mint + +import ( + "bytes" + "encoding/hex" + "fmt" + "time" +) + +func VersionNegotiation(offered, supported []uint16) (bool, uint16) { + for _, offeredVersion := range offered { + for _, supportedVersion := range supported { + logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion) + if offeredVersion == supportedVersion { + // XXX: Should probably be highest supported version, but for now, we + // only support one version, so it doesn't really matter. + return true, offeredVersion + } + } + } + + return false, 0 +} + +func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) { + for _, share := range keyShares { + for _, group := range groups { + if group != share.Group { + continue + } + + pub, priv, err := newKeyShare(share.Group) + if err != nil { + // If we encounter an error, just keep looking + continue + } + + dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv) + if err != nil { + // If we encounter an error, just keep looking + continue + } + + return true, group, pub, dhSecret + } + } + + return false, 0, nil, nil +} + +const ( + ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds +) + +func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) { + logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size()) + for i, id := range identities { + identityHex := hex.EncodeToString(id.Identity) + + psk, ok := psks.Get(identityHex) + if !ok { + logf(logTypeNegotiation, "No PSK for identity %x", identityHex) + continue + } + + // For resumption, make sure the ticket age is correct + if psk.IsResumption { + extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd + knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond) + ticketAgeDelta := knownTicketAge - extTicketAge + if knownTicketAge < extTicketAge { + ticketAgeDelta = extTicketAge - knownTicketAge + } + if ticketAgeDelta > ticketAgeTolerance { + logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity) + logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]", + extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance) + return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity) + } + } + + params, ok := cipherSuiteMap[psk.CipherSuite] + if !ok { + err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite) + return false, 0, nil, CipherSuiteParams{}, err + } + + // Compute binder + binderLabel := labelExternalBinder + if psk.IsResumption { + binderLabel = labelResumptionBinder + } + + h0 := params.Hash.New().Sum(nil) + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + earlySecret := HkdfExtract(params.Hash, zero, psk.Key) + binderKey := deriveSecret(params, earlySecret, binderLabel, h0) + + // context = ClientHello[truncated] + // context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated] + ctxHash := params.Hash.New() + ctxHash.Write(context) + + binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil)) + if !bytes.Equal(binder, binders[i].Binder) { + logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder) + return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity) + } + + logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity) + return true, i, &psk, params, nil + } + + logf(logTypeNegotiation, "Failed to find a usable PSK") + return false, 0, nil, CipherSuiteParams{}, nil +} + +func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) { + logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes) + dhAllowed := false + dhRequired := true + for _, mode := range modes { + dhAllowed = dhAllowed || (mode == PSKModeDHEKE) + dhRequired = dhRequired && (mode == PSKModeDHEKE) + } + + // Use PSK if we can meet DH requirement and modes were provided + usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0) + + // Use DH if allowed + usingDH := canDoDH && (dhAllowed || !usingPSK) + + logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK) + return usingDH, usingPSK +} + +func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) { + // Select for server name if provided + candidates := certs + if serverName != nil { + candidatesByName := []*Certificate{} + for _, cert := range certs { + for _, name := range cert.Chain[0].DNSNames { + if len(*serverName) > 0 && name == *serverName { + candidatesByName = append(candidatesByName, cert) + } + } + } + + if len(candidatesByName) == 0 { + return nil, 0, fmt.Errorf("No certificates available for server name: %s", *serverName) + } + + candidates = candidatesByName + } + + // Select for signature scheme + for _, cert := range candidates { + for _, scheme := range signatureSchemes { + if !schemeValidForKey(scheme, cert.PrivateKey) { + continue + } + + return cert, scheme, nil + } + } + + return nil, 0, fmt.Errorf("No certificates compatible with signature schemes") +} + +func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool { + usingEarlyData := gotEarlyData && usingPSK && allowEarlyData + logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData) + return usingEarlyData +} + +func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) { + for _, s1 := range offered { + if psk != nil { + if s1 == psk.CipherSuite { + return s1, nil + } + continue + } + + for _, s2 := range supported { + if s1 == s2 { + return s1, nil + } + } + } + + return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil) +} + +func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) { + for _, p1 := range offered { + if psk != nil { + if p1 != psk.NextProto { + continue + } + } + + for _, p2 := range supported { + if p1 == p2 { + return p1, nil + } + } + } + + // If the client offers ALPN on resumption, it must match the earlier one + var err error + if psk != nil && psk.IsResumption && (len(offered) > 0) { + err = fmt.Errorf("ALPN for PSK not provided") + } + return "", err +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/record-layer.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/record-layer.go new file mode 100644 index 000000000..761a868da --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/record-layer.go @@ -0,0 +1,393 @@ +package mint + +import ( + "bytes" + "crypto/cipher" + "fmt" + "io" + "sync" +) + +const ( + sequenceNumberLen = 8 // sequence number length + recordHeaderLenTLS = 5 // record header length (TLS) + recordHeaderLenDTLS = 13 // record header length (DTLS) + maxFragmentLen = 1 << 14 // max number of bytes in a record +) + +type DecryptError string + +func (err DecryptError) Error() string { + return string(err) +} + +// struct { +// ContentType type; +// ProtocolVersion record_version [0301 for CH, 0303 for others] +// uint16 length; +// opaque fragment[TLSPlaintext.length]; +// } TLSPlaintext; +type TLSPlaintext struct { + // Omitted: record_version (static) + // Omitted: length (computed from fragment) + contentType RecordType + fragment []byte +} + +type cipherState struct { + epoch Epoch // DTLS epoch + ivLength int // Length of the seq and nonce fields + seq []byte // Zero-padded sequence number + iv []byte // Buffer for the IV + cipher cipher.AEAD // AEAD cipher +} + +type RecordLayer struct { + sync.Mutex + + version uint16 // The current version number + conn io.ReadWriter // The underlying connection + frame *frameReader // The buffered frame reader + nextData []byte // The next record to send + cachedRecord *TLSPlaintext // Last record read, cached to enable "peek" + cachedError error // Error on the last record read + + cipher *cipherState + datagram bool +} + +type recordLayerFrameDetails struct { + datagram bool +} + +func (d recordLayerFrameDetails) headerLen() int { + if d.datagram { + return recordHeaderLenDTLS + } + return recordHeaderLenTLS +} + +func (d recordLayerFrameDetails) defaultReadLen() int { + return d.headerLen() + maxFragmentLen +} + +func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { + return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil +} + +func newCipherStateNull() *cipherState { + return &cipherState{EpochClear, 0, bytes.Repeat([]byte{0}, sequenceNumberLen), nil, nil} +} + +func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) { + cipher, err := factory(key) + if err != nil { + return nil, err + } + + return &cipherState{epoch, len(iv), bytes.Repeat([]byte{0}, sequenceNumberLen), iv, cipher}, nil +} + +func NewRecordLayerTLS(conn io.ReadWriter) *RecordLayer { + r := RecordLayer{} + r.conn = conn + r.frame = newFrameReader(recordLayerFrameDetails{false}) + r.cipher = newCipherStateNull() + r.version = tls10Version + return &r +} + +func NewRecordLayerDTLS(conn io.ReadWriter) *RecordLayer { + r := RecordLayer{} + r.conn = conn + r.frame = newFrameReader(recordLayerFrameDetails{true}) + r.cipher = newCipherStateNull() + r.datagram = true + return &r +} + +func (r *RecordLayer) SetVersion(v uint16) { + r.version = v +} + +func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error { + cipher, err := newCipherStateAead(epoch, factory, key, iv) + if err != nil { + return err + } + r.cipher = cipher + return nil +} + +func (c *cipherState) formatSeq(datagram bool) []byte { + seq := append([]byte{}, c.seq...) + if datagram { + seq[0] = byte(c.epoch >> 8) + seq[1] = byte(c.epoch & 0xff) + } + return seq +} + +func (c *cipherState) computeNonce(seq []byte) []byte { + nonce := make([]byte, len(c.iv)) + copy(nonce, c.iv) + + offset := len(c.iv) - len(seq) + for i, b := range seq { + nonce[i+offset] ^= b + } + + return nonce +} + +func (c *cipherState) incrementSequenceNumber() { + var i int + for i = len(c.seq) - 1; i >= 0; i-- { + c.seq[i]++ + if c.seq[i] != 0 { + break + } + } + + if i < 0 { + // Not allowed to let sequence number wrap. + // Instead, must renegotiate before it does. + // Not likely enough to bother. + // TODO(ekr@rtfm.com): Check for DTLS here + // because the limit is sooner. + panic("TLS: sequence number wraparound") + } +} + +func (c *cipherState) overhead() int { + if c.cipher == nil { + return 0 + } + return c.cipher.Overhead() +} + +func (r *RecordLayer) encrypt(cipher *cipherState, seq []byte, pt *TLSPlaintext, padLen int) *TLSPlaintext { + logf(logTypeIO, "Encrypt seq=[%x]", seq) + // Expand the fragment to hold contentType, padding, and overhead + originalLen := len(pt.fragment) + plaintextLen := originalLen + 1 + padLen + ciphertextLen := plaintextLen + cipher.overhead() + + // Assemble the revised plaintext + out := &TLSPlaintext{ + + contentType: RecordTypeApplicationData, + fragment: make([]byte, ciphertextLen), + } + copy(out.fragment, pt.fragment) + out.fragment[originalLen] = byte(pt.contentType) + for i := 1; i <= padLen; i++ { + out.fragment[originalLen+i] = 0 + } + + // Encrypt the fragment + payload := out.fragment[:plaintextLen] + cipher.cipher.Seal(payload[:0], cipher.computeNonce(seq), payload, nil) + return out +} + +func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int, error) { + logf(logTypeIO, "Decrypt seq=[%x]", seq) + if len(pt.fragment) < r.cipher.overhead() { + msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead()) + return nil, 0, DecryptError(msg) + } + + decryptLen := len(pt.fragment) - r.cipher.overhead() + out := &TLSPlaintext{ + contentType: pt.contentType, + fragment: make([]byte, decryptLen), + } + + // Decrypt + _, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil) + if err != nil { + logf(logTypeIO, "AEAD decryption failure [%x]", pt) + return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed") + } + + // Find the padding boundary + padLen := 0 + for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ { + } + + // Transfer the content type + newLen := decryptLen - padLen - 1 + out.contentType = RecordType(out.fragment[newLen]) + + // Truncate the message to remove contentType, padding, overhead + out.fragment = out.fragment[:newLen] + return out, padLen, nil +} + +func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { + var pt *TLSPlaintext + var err error + + for { + pt, err = r.nextRecord() + if err == nil { + break + } + if !block || err != WouldBlock { + return 0, err + } + } + return pt.contentType, nil +} + +func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { + pt, err := r.nextRecord() + + // Consume the cached record if there was one + r.cachedRecord = nil + r.cachedError = nil + + return pt, err +} + +func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { + cipher := r.cipher + if r.cachedRecord != nil { + logf(logTypeIO, "Returning cached record") + return r.cachedRecord, r.cachedError + } + + // Loop until one of three things happens: + // + // 1. We get a frame + // 2. We try to read off the socket and get nothing, in which case + // return WouldBlock + // 3. We get an error. + err := WouldBlock + var header, body []byte + + for err != nil { + if r.frame.needed() > 0 { + buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen) + n, err := r.conn.Read(buf) + if err != nil { + logf(logTypeIO, "Error reading, %v", err) + return nil, err + } + + if n == 0 { + return nil, WouldBlock + } + + logf(logTypeIO, "Read %v bytes", n) + + buf = buf[:n] + r.frame.addChunk(buf) + } + + header, body, err = r.frame.process() + // Loop around on WouldBlock to see if some + // data is now available. + if err != nil && err != WouldBlock { + return nil, err + } + } + + pt := &TLSPlaintext{} + // Validate content type + switch RecordType(header[0]) { + default: + return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0]) + case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData: + pt.contentType = RecordType(header[0]) + } + + // Validate version + if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) { + return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2]) + } + + // Validate size < max + size := (int(header[len(header)-2]) << 8) + int(header[len(header)-1]) + + if size > maxFragmentLen+256 { + return nil, fmt.Errorf("tls.record: Ciphertext size too big") + } + + pt.fragment = make([]byte, size) + copy(pt.fragment, body) + + // Attempt to decrypt fragment + if cipher.cipher != nil { + seq := cipher.seq + if r.datagram { + seq = header[3:11] + } + // TODO(ekr@rtfm.com): Handle the wrong epoch. + // TODO(ekr@rtfm.com): Handle duplicates. + logf(logTypeIO, "RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", cipher.epoch.label(), seq, pt.contentType, pt.fragment) + pt, _, err = r.decrypt(pt, seq) + if err != nil { + logf(logTypeIO, "Decryption failed") + return nil, err + } + } + + // Check that plaintext length is not too long + if len(pt.fragment) > maxFragmentLen { + return nil, fmt.Errorf("tls.record: Plaintext size too big") + } + + logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment) + + r.cachedRecord = pt + cipher.incrementSequenceNumber() + return pt, nil +} + +func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error { + return r.writeRecordWithPadding(pt, r.cipher, 0) +} + +func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error { + return r.writeRecordWithPadding(pt, r.cipher, padLen) +} + +func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error { + seq := cipher.formatSeq(r.datagram) + + if cipher.cipher != nil { + logf(logTypeIO, "RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) + pt = r.encrypt(cipher, seq, pt, padLen) + } else if padLen > 0 { + return fmt.Errorf("tls.record: Padding can only be done on encrypted records") + } + + if len(pt.fragment) > maxFragmentLen { + return fmt.Errorf("tls.record: Record size too big") + } + + length := len(pt.fragment) + var header []byte + + if !r.datagram { + header = []byte{byte(pt.contentType), + byte(r.version >> 8), byte(r.version & 0xff), + byte(length >> 8), byte(length)} + } else { + version := dtlsConvertVersion(r.version) + header = []byte{byte(pt.contentType), + byte(version >> 8), byte(version & 0xff), + seq[0], seq[1], seq[2], seq[3], + seq[4], seq[5], seq[6], seq[7], + byte(length >> 8), byte(length)} + } + record := append(header, pt.fragment...) + + logf(logTypeIO, "RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) + + cipher.incrementSequenceNumber() + _, err := r.conn.Write(record) + return err +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/server-state-machine.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/server-state-machine.go new file mode 100644 index 000000000..c3c916101 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/server-state-machine.go @@ -0,0 +1,1102 @@ +package mint + +import ( + "bytes" + "crypto/x509" + "fmt" + "hash" + "reflect" + + "github.com/bifurcation/mint/syntax" +) + +// Server State Machine +// +// START <-----+ +// Recv ClientHello | | Send HelloRetryRequest +// v | +// RECVD_CH ----+ +// | Select parameters +// | Send ServerHello +// v +// NEGOTIATED +// | Send EncryptedExtensions +// | [Send CertificateRequest] +// Can send | [Send Certificate + CertificateVerify] +// app data --> | Send Finished +// after +--------+--------+ +// here No 0-RTT | | 0-RTT +// | v +// | WAIT_EOED <---+ +// | Recv | | | Recv +// | EndOfEarlyData | | | early data +// | | +-----+ +// +> WAIT_FLIGHT2 <-+ +// | +// +--------+--------+ +// No auth | | Client auth +// | | +// | v +// | WAIT_CERT +// | Recv | | Recv Certificate +// | empty | v +// | Certificate | WAIT_CV +// | | | Recv +// | v | CertificateVerify +// +-> WAIT_FINISHED <---+ +// | Recv Finished +// v +// CONNECTED +// +// NB: Not using state RECVD_CH +// +// State Instructions +// START {} +// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)] +// WAIT_EOED RekeyIn; +// WAIT_FLIGHT2 {} +// WAIT_CERT_CR {} +// WAIT_CERT {} +// WAIT_CV {} +// WAIT_FINISHED RekeyIn; RekeyOut; +// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) + +// A cookie can be sent to the client in a HRR. +type cookie struct { + // The CipherSuite that was selected when the client sent the first ClientHello + CipherSuite CipherSuite + ClientHelloHash []byte `tls:"head=2"` + + // The ApplicationCookie can be provided by the application (by setting a Config.CookieHandler) + ApplicationCookie []byte `tls:"head=2"` +} + +type ServerStateStart struct { + Config *Config + conn *Conn + hsCtx HandshakeContext +} + +var _ HandshakeState = &ServerStateStart{} + +func (state ServerStateStart) State() State { + return StateServerStart +} + +func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } + if hm == nil || hm.msgType != HandshakeTypeClientHello { + logf(logTypeHandshake, "[ServerStateStart] unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + ch := &ClientHelloBody{LegacyVersion: wireVersion(state.hsCtx.hIn)} + if err := safeUnmarshal(ch, hm.body); err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + // We are strict about these things because we only support 1.3 + if ch.LegacyVersion != wireVersion(state.hsCtx.hIn) { + logf(logTypeHandshake, "[ServerStateStart] Invalid version number: %v", ch.LegacyVersion) + return nil, nil, AlertDecodeError + } + + clientHello := hm + connParams := ConnectionParameters{} + + supportedVersions := &SupportedVersionsExtension{HandshakeType: HandshakeTypeClientHello} + serverName := new(ServerNameExtension) + supportedGroups := new(SupportedGroupsExtension) + signatureAlgorithms := new(SignatureAlgorithmsExtension) + clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello} + clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello} + clientEarlyData := &EarlyDataExtension{} + clientALPN := new(ALPNExtension) + clientPSKModes := new(PSKKeyExchangeModesExtension) + clientCookie := new(CookieExtension) + + // Handle external extensions. + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + foundExts, err := ch.Extensions.Parse( + []ExtensionBody{ + supportedVersions, + serverName, + supportedGroups, + signatureAlgorithms, + clientEarlyData, + clientKeyShares, + clientPSK, + clientALPN, + clientPSKModes, + clientCookie, + }) + + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error parsing extensions [%v]", err) + return nil, nil, AlertDecodeError + } + + clientSentCookie := len(clientCookie.Cookie) > 0 + + if foundExts[ExtensionTypeServerName] { + connParams.ServerName = string(*serverName) + } + + // If the client didn't send supportedVersions or doesn't support 1.3, + // then we're done here. + if !foundExts[ExtensionTypeSupportedVersions] { + logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions") + return nil, nil, AlertProtocolVersion + } + versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion}) + if !versionOK { + logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version") + return nil, nil, AlertProtocolVersion + } + + // The client sent a cookie. So this is probably the second ClientHello (sent as a response to a HRR) + var firstClientHello *HandshakeMessage + var initialCipherSuite CipherSuiteParams // the cipher suite that was negotiated when sending the HelloRetryRequest + if clientSentCookie { + plainCookie, err := state.Config.CookieProtector.DecodeToken(clientCookie.Cookie) + if err != nil { + logf(logTypeHandshake, fmt.Sprintf("[ServerStateStart] Error decoding token [%v]", err)) + return nil, nil, AlertDecryptError + } + cookie := &cookie{} + if rb, err := syntax.Unmarshal(plainCookie, cookie); err != nil && rb != len(plainCookie) { // this should never happen + logf(logTypeHandshake, fmt.Sprintf("[ServerStateStart] Error unmarshaling cookie [%v]", err)) + return nil, nil, AlertInternalError + } + // restore the hash of initial ClientHello from the cookie + firstClientHello = &HandshakeMessage{ + msgType: HandshakeTypeMessageHash, + body: cookie.ClientHelloHash, + } + // have the application validate its part of the cookie + if state.Config.CookieHandler != nil && !state.Config.CookieHandler.Validate(state.conn, cookie.ApplicationCookie) { + logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch") + return nil, nil, AlertAccessDenied + } + var ok bool + initialCipherSuite, ok = cipherSuiteMap[cookie.CipherSuite] + if !ok { + logf(logTypeHandshake, fmt.Sprintf("[ServerStateStart] Cookie contained invalid cipher suite: %#x", cookie.CipherSuite)) + return nil, nil, AlertInternalError + } + } + + if len(ch.LegacySessionID) != 0 && len(ch.LegacySessionID) != 32 { + logf(logTypeHandshake, "[ServerStateStart] invalid session ID") + return nil, nil, AlertIllegalParameter + } + + // Figure out if we can do DH + canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Config.Groups) + + // Figure out if we can do PSK + var canDoPSK bool + var selectedPSK int + var params CipherSuiteParams + var psk *PreSharedKey + if len(clientPSK.Identities) > 0 { + contextBase := []byte{} + if clientSentCookie { + contextBase = append(contextBase, firstClientHello.Marshal()...) + // fill in the cookie sent by the client. Needed to calculate the correct hash + cookieExt := &CookieExtension{Cookie: clientCookie.Cookie} + hrr, err := state.generateHRR(params.Suite, + ch.LegacySessionID, cookieExt) + if err != nil { + return nil, nil, AlertInternalError + } + contextBase = append(contextBase, hrr.Marshal()...) + } + chTrunc, err := ch.Truncated() + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err) + return nil, nil, AlertDecodeError + } + context := append(contextBase, chTrunc...) + + canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Config.PSKs) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err) + return nil, nil, AlertInternalError + } + if clientSentCookie && initialCipherSuite.Suite != params.Suite { + logf(logTypeHandshake, "[ServerStateStart] Would have selected a different CipherSuite after receiving the client's Cookie") + return nil, nil, AlertInternalError + } + } + + // Figure out if we actually should do DH / PSK + connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes) + + // Select a ciphersuite + connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Config.CipherSuites) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err) + return nil, nil, AlertHandshakeFailure + } + if clientSentCookie && initialCipherSuite.Suite != connParams.CipherSuite { + logf(logTypeHandshake, "[ServerStateStart] Would have selected a different CipherSuite after receiving the client's Cookie") + return nil, nil, AlertInternalError + } + + var helloRetryRequest *HandshakeMessage + if state.Config.RequireCookie { + // Send a cookie if required + // NB: Need to do this here because it's after ciphersuite selection, which + // has to be after PSK selection. + var shouldSendHRR bool + var cookieExt *CookieExtension + if !clientSentCookie { // this is the first ClientHello that we receive + var appCookie []byte + if state.Config.CookieHandler == nil { // if Config.RequireCookie is set, but no CookieHandler was provided, we definitely need to send a cookie + shouldSendHRR = true + } else { // if the CookieHandler was set, we just send a cookie when the application provides one + var err error + appCookie, err = state.Config.CookieHandler.Generate(state.conn) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err) + return nil, nil, AlertInternalError + } + shouldSendHRR = appCookie != nil + } + if shouldSendHRR { + params := cipherSuiteMap[connParams.CipherSuite] + h := params.Hash.New() + h.Write(clientHello.Marshal()) + plainCookie, err := syntax.Marshal(cookie{ + CipherSuite: connParams.CipherSuite, + ClientHelloHash: h.Sum(nil), + ApplicationCookie: appCookie, + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error marshalling cookie [%v]", err) + return nil, nil, AlertInternalError + } + cookieData, err := state.Config.CookieProtector.NewToken(plainCookie) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error encoding cookie [%v]", err) + return nil, nil, AlertInternalError + } + cookieExt = &CookieExtension{Cookie: cookieData} + } + } else { + cookieExt = &CookieExtension{Cookie: clientCookie.Cookie} + } + + // Generate a HRR. We will need it in both of the two cases: + // 1. We need to send a Cookie. Then this HRR will be sent on the wire + // 2. We need to validate a cookie. Then we need its hash + // Ignoring errors because everything here is newly constructed, so there + // shouldn't be marshal errors + if shouldSendHRR || clientSentCookie { + helloRetryRequest, err = state.generateHRR(connParams.CipherSuite, + ch.LegacySessionID, cookieExt) + if err != nil { + return nil, nil, AlertInternalError + } + } + + if shouldSendHRR { + toSend := []HandshakeAction{ + QueueHandshakeMessage{helloRetryRequest}, + SendQueuedHandshake{}, + } + logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]") + return state, toSend, AlertStatelessRetry + } + } + + // If we've got no entropy to make keys from, fail + if !connParams.UsingDH && !connParams.UsingPSK { + logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated") + return nil, nil, AlertHandshakeFailure + } + + var pskSecret []byte + var cert *Certificate + var certScheme SignatureScheme + if connParams.UsingPSK { + pskSecret = psk.Key + } else { + psk = nil + + // If we're not using a PSK mode, then we need to have certain extensions + if !(foundExts[ExtensionTypeServerName] && + foundExts[ExtensionTypeSupportedGroups] && + foundExts[ExtensionTypeSignatureAlgorithms]) { + logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v)", foundExts) + return nil, nil, AlertMissingExtension + } + + // Select a certificate + name := string(*serverName) + var err error + cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Config.Certificates) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err) + return nil, nil, AlertAccessDenied + } + } + + if !connParams.UsingDH { + dhSecret = nil + } + + // Figure out if we're going to do early data + var clientEarlyTrafficSecret []byte + connParams.ClientSendingEarlyData = foundExts[ExtensionTypeEarlyData] + connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Config.AllowEarlyData) + if connParams.UsingEarlyData { + h := params.Hash.New() + h.Write(clientHello.Marshal()) + chHash := h.Sum(nil) + + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + earlySecret := HkdfExtract(params.Hash, zero, pskSecret) + clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) + } + + // Select a next protocol + connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Config.NextProtos) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err) + return nil, nil, AlertNoApplicationProtocol + } + + logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]") + state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2. + return ServerStateNegotiated{ + Config: state.Config, + Params: connParams, + hsCtx: state.hsCtx, + dhGroup: dhGroup, + dhPublic: dhPublic, + dhSecret: dhSecret, + pskSecret: pskSecret, + selectedPSK: selectedPSK, + cert: cert, + certScheme: certScheme, + legacySessionId: ch.LegacySessionID, + clientEarlyTrafficSecret: clientEarlyTrafficSecret, + + firstClientHello: firstClientHello, + helloRetryRequest: helloRetryRequest, + clientHello: clientHello, + }, nil, AlertNoAlert +} + +func (state *ServerStateStart) generateHRR(cs CipherSuite, legacySessionId []byte, + cookieExt *CookieExtension) (*HandshakeMessage, error) { + var helloRetryRequest *HandshakeMessage + hrr := &ServerHelloBody{ + Version: tls12Version, + Random: hrrRandomSentinel, + CipherSuite: cs, + LegacySessionID: legacySessionId, + LegacyCompressionMethod: 0, + } + + sv := &SupportedVersionsExtension{ + HandshakeType: HandshakeTypeServerHello, + Versions: []uint16{supportedVersion}, + } + + if err := hrr.Extensions.Add(sv); err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error adding SupportedVersion [%v]", err) + return nil, err + } + + if err := hrr.Extensions.Add(cookieExt); err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error adding CookieExtension [%v]", err) + return nil, err + } + // Run the external extension handler. + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err) + return nil, err + } + } + helloRetryRequest, err := state.hsCtx.hOut.HandshakeMessageFromBody(hrr) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err) + return nil, err + } + return helloRetryRequest, nil +} + +type ServerStateNegotiated struct { + Config *Config + Params ConnectionParameters + hsCtx HandshakeContext + dhGroup NamedGroup + dhPublic []byte + dhSecret []byte + pskSecret []byte + clientEarlyTrafficSecret []byte + selectedPSK int + cert *Certificate + certScheme SignatureScheme + legacySessionId []byte + firstClientHello *HandshakeMessage + helloRetryRequest *HandshakeMessage + clientHello *HandshakeMessage +} + +var _ HandshakeState = &ServerStateNegotiated{} + +func (state ServerStateNegotiated) State() State { + return StateServerNegotiated +} + +func (state ServerStateNegotiated) Next(_ handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + // Create the ServerHello + sh := &ServerHelloBody{ + Version: tls12Version, + CipherSuite: state.Params.CipherSuite, + LegacySessionID: state.legacySessionId, + LegacyCompressionMethod: 0, + } + if _, err := prng.Read(sh.Random[:]); err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err) + return nil, nil, AlertInternalError + } + + err := sh.Extensions.Add(&SupportedVersionsExtension{ + HandshakeType: HandshakeTypeServerHello, + Versions: []uint16{supportedVersion}, + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported_versions extension [%v]", err) + return nil, nil, AlertInternalError + } + if state.Params.UsingDH { + logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension") + err := sh.Extensions.Add(&KeyShareExtension{ + HandshakeType: HandshakeTypeServerHello, + Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}}, + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err) + return nil, nil, AlertInternalError + } + } + if state.Params.UsingPSK { + logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension") + err := sh.Extensions.Add(&PreSharedKeyExtension{ + HandshakeType: HandshakeTypeServerHello, + SelectedIdentity: uint16(state.selectedPSK), + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Run the external extension handler. + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) + return nil, nil, AlertInternalError + } + } + + serverHello, err := state.hsCtx.hOut.HandshakeMessageFromBody(sh) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err) + return nil, nil, AlertInternalError + } + + // Look up crypto params + params, ok := cipherSuiteMap[sh.CipherSuite] + if !ok { + logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite) + return nil, nil, AlertHandshakeFailure + } + + // Start up the handshake hash + handshakeHash := params.Hash.New() + handshakeHash.Write(state.firstClientHello.Marshal()) + handshakeHash.Write(state.helloRetryRequest.Marshal()) + handshakeHash.Write(state.clientHello.Marshal()) + handshakeHash.Write(serverHello.Marshal()) + + // Compute handshake secrets + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + + var earlySecret []byte + if state.Params.UsingPSK { + earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret) + } else { + earlySecret = HkdfExtract(params.Hash, zero, zero) + } + + if state.dhSecret == nil { + state.dhSecret = zero + } + + h0 := params.Hash.New().Sum(nil) + h2 := handshakeHash.Sum(nil) + preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) + handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret) + clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) + serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) + preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) + masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) + + logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret) + logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) + logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) + logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) + logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) + + clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret) + serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) + + // Send an EncryptedExtensions message (even if it's empty) + eeList := ExtensionList{} + if state.Params.NextProto != "" { + logf(logTypeHandshake, "[server] sending ALPN extension") + err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}}) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err) + return nil, nil, AlertInternalError + } + } + if state.Params.UsingEarlyData { + logf(logTypeHandshake, "[server] sending EDI extension") + err = eeList.Add(&EarlyDataExtension{}) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err) + return nil, nil, AlertInternalError + } + } + ee := &EncryptedExtensionsBody{eeList} + + // Run the external extension handler. + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) + return nil, nil, AlertInternalError + } + } + + eem, err := state.hsCtx.hOut.HandshakeMessageFromBody(ee) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err) + return nil, nil, AlertInternalError + } + + handshakeHash.Write(eem.Marshal()) + + toSend := []HandshakeAction{ + QueueHandshakeMessage{serverHello}, + RekeyOut{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys}, + QueueHandshakeMessage{eem}, + } + + // Authenticate with a certificate if required + if !state.Params.UsingPSK { + // Send a CertificateRequest message if we want client auth + if state.Config.RequireClientAuth { + state.Params.UsingClientAuth = true + + // XXX: We don't support sending any constraints besides a list of + // supported signature algorithms + cr := &CertificateRequestBody{} + schemes := &SignatureAlgorithmsExtension{Algorithms: state.Config.SignatureSchemes} + err := cr.Extensions.Add(schemes) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err) + return nil, nil, AlertInternalError + } + + crm, err := state.hsCtx.hOut.HandshakeMessageFromBody(cr) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err) + return nil, nil, AlertInternalError + } + //TODO state.state.serverCertificateRequest = cr + + toSend = append(toSend, QueueHandshakeMessage{crm}) + handshakeHash.Write(crm.Marshal()) + } + + // Create and send Certificate, CertificateVerify + certificate := &CertificateBody{ + CertificateList: make([]CertificateEntry, len(state.cert.Chain)), + } + for i, entry := range state.cert.Chain { + certificate.CertificateList[i] = CertificateEntry{CertData: entry} + } + certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, QueueHandshakeMessage{certm}) + handshakeHash.Write(certm.Marshal()) + + certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme} + logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash) + + hcv := handshakeHash.Sum(nil) + logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) + + err = certificateVerify.Sign(state.cert.PrivateKey, hcv) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err) + return nil, nil, AlertInternalError + } + certvm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificateVerify) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, QueueHandshakeMessage{certvm}) + handshakeHash.Write(certvm.Marshal()) + } + + // Compute secrets resulting from the server's first flight + h3 := handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3) + logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3) + + serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3) + logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) + + // Assemble the Finished message + fin := &FinishedBody{ + VerifyDataLen: len(serverFinishedData), + VerifyData: serverFinishedData, + } + finm, _ := state.hsCtx.hOut.HandshakeMessageFromBody(fin) + + toSend = append(toSend, QueueHandshakeMessage{finm}) + handshakeHash.Write(finm.Marshal()) + toSend = append(toSend, SendQueuedHandshake{}) + + // Compute traffic secrets + h4 := handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4) + logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4) + + clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4) + serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4) + logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret) + logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret) + + serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret) + toSend = append(toSend, RekeyOut{epoch: EpochApplicationData, KeySet: serverTrafficKeys}) + + exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4) + logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret) + + if state.Params.UsingEarlyData { + clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret) + + logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]") + nextState := ServerStateWaitEOED{ + Config: state.Config, + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: params, + handshakeHash: handshakeHash, + masterSecret: masterSecret, + clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, + clientTrafficSecret: clientTrafficSecret, + serverTrafficSecret: serverTrafficSecret, + exporterSecret: exporterSecret, + } + toSend = append(toSend, []HandshakeAction{ + RekeyIn{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys}, + ReadEarlyData{}, + }...) + return nextState, toSend, AlertNoAlert + } + + logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]") + toSend = append(toSend, []HandshakeAction{ + RekeyIn{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys}, + ReadPastEarlyData{}, + }...) + waitFlight2 := ServerStateWaitFlight2{ + Config: state.Config, + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: params, + handshakeHash: handshakeHash, + masterSecret: masterSecret, + clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, + clientTrafficSecret: clientTrafficSecret, + serverTrafficSecret: serverTrafficSecret, + exporterSecret: exporterSecret, + } + return waitFlight2, toSend, AlertNoAlert +} + +type ServerStateWaitEOED struct { + Config *Config + Params ConnectionParameters + hsCtx HandshakeContext + cryptoParams CipherSuiteParams + masterSecret []byte + clientHandshakeTrafficSecret []byte + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +var _ HandshakeState = &ServerStateWaitEOED{} + +func (state ServerStateWaitEOED) State() State { + return StateServerWaitEOED +} + +func (state ServerStateWaitEOED) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } + if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData { + logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + if len(hm.body) > 0 { + logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]") + return nil, nil, AlertDecodeError + } + + state.handshakeHash.Write(hm.Marshal()) + + clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) + + logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]") + toSend := []HandshakeAction{ + RekeyIn{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys}, + } + waitFlight2 := ServerStateWaitFlight2{ + Config: state.Config, + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + return waitFlight2, toSend, AlertNoAlert +} + +type ServerStateWaitFlight2 struct { + Config *Config + Params ConnectionParameters + hsCtx HandshakeContext + cryptoParams CipherSuiteParams + masterSecret []byte + clientHandshakeTrafficSecret []byte + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +var _ HandshakeState = &ServerStateWaitFlight2{} + +func (state ServerStateWaitFlight2) State() State { + return StateServerWaitFlight2 +} + +func (state ServerStateWaitFlight2) Next(_ handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + if state.Params.UsingClientAuth { + logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]") + nextState := ServerStateWaitCert{ + Config: state.Config, + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert + } + + logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]") + nextState := ServerStateWaitFinished{ + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: state.cryptoParams, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + handshakeHash: state.handshakeHash, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert +} + +type ServerStateWaitCert struct { + Config *Config + Params ConnectionParameters + hsCtx HandshakeContext + cryptoParams CipherSuiteParams + masterSecret []byte + clientHandshakeTrafficSecret []byte + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +var _ HandshakeState = &ServerStateWaitCert{} + +func (state ServerStateWaitCert) State() State { + return StateServerWaitCert +} + +func (state ServerStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } + if hm == nil || hm.msgType != HandshakeTypeCertificate { + logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + cert := &CertificateBody{} + if err := safeUnmarshal(cert, hm.body); err != nil { + logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") + return nil, nil, AlertDecodeError + } + + state.handshakeHash.Write(hm.Marshal()) + + if len(cert.CertificateList) == 0 { + logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate") + + logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]") + nextState := ServerStateWaitFinished{ + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: state.cryptoParams, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + handshakeHash: state.handshakeHash, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert + } + + logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]") + nextState := ServerStateWaitCV{ + Config: state.Config, + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: state.cryptoParams, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + handshakeHash: state.handshakeHash, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + clientCertificate: cert, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert +} + +type ServerStateWaitCV struct { + Config *Config + Params ConnectionParameters + hsCtx HandshakeContext + cryptoParams CipherSuiteParams + + masterSecret []byte + clientHandshakeTrafficSecret []byte + + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte + + clientCertificate *CertificateBody +} + +var _ HandshakeState = &ServerStateWaitCV{} + +func (state ServerStateWaitCV) State() State { + return StateServerWaitCV +} + +func (state ServerStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } + if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { + logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm)) + return nil, nil, AlertUnexpectedMessage + } + + certVerify := &CertificateVerifyBody{} + if err := safeUnmarshal(certVerify, hm.body); err != nil { + logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err) + return nil, nil, AlertDecodeError + } + + rawCerts := make([][]byte, len(state.clientCertificate.CertificateList)) + certs := make([]*x509.Certificate, len(state.clientCertificate.CertificateList)) + for i, certEntry := range state.clientCertificate.CertificateList { + certs[i] = certEntry.CertData + rawCerts[i] = certEntry.CertData.Raw + } + + // Verify client signature over handshake hash + hcv := state.handshakeHash.Sum(nil) + logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) + + clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey + if err := certVerify.Verify(clientPublicKey, hcv); err != nil { + logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err) + return nil, nil, AlertHandshakeFailure + } + + if state.Config.VerifyPeerCertificate != nil { + // TODO(#171): pass in the verified chains, once we support different client auth types + if err := state.Config.VerifyPeerCertificate(rawCerts, nil); err != nil { + logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate: %s", err) + return nil, nil, AlertBadCertificate + } + } + + // If it passes, record the certificateVerify in the transcript hash + state.handshakeHash.Write(hm.Marshal()) + + logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]") + nextState := ServerStateWaitFinished{ + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: state.cryptoParams, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + handshakeHash: state.handshakeHash, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + peerCertificates: certs, + verifiedChains: nil, // TODO(#171): set this value + } + return nextState, nil, AlertNoAlert +} + +type ServerStateWaitFinished struct { + Params ConnectionParameters + hsCtx HandshakeContext + cryptoParams CipherSuiteParams + + masterSecret []byte + clientHandshakeTrafficSecret []byte + peerCertificates []*x509.Certificate + verifiedChains [][]*x509.Certificate + + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +var _ HandshakeState = &ServerStateWaitFinished{} + +func (state ServerStateWaitFinished) State() State { + return StateServerWaitFinished +} + +func (state ServerStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } + if hm == nil || hm.msgType != HandshakeTypeFinished { + logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()} + if err := safeUnmarshal(fin, hm.body); err != nil { + logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err) + return nil, nil, AlertDecodeError + } + + // Verify client Finished data + h5 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5) + + clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5) + logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData) + + if !bytes.Equal(fin.VerifyData, clientFinishedData) { + logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify") + return nil, nil, AlertHandshakeFailure + } + + // Compute the resumption secret + state.handshakeHash.Write(hm.Marshal()) + h6 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6) + + resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6) + logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) + + // Compute client traffic keys + clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) + + logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]") + nextState := StateConnected{ + Params: state.Params, + hsCtx: state.hsCtx, + isClient: false, + cryptoParams: state.cryptoParams, + resumptionSecret: resumptionSecret, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + peerCertificates: state.peerCertificates, + verifiedChains: state.verifiedChains, + } + toSend := []HandshakeAction{ + RekeyIn{epoch: EpochApplicationData, KeySet: clientTrafficKeys}, + } + return nextState, toSend, AlertNoAlert +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/state-machine.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/state-machine.go new file mode 100644 index 000000000..556fc09d7 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/state-machine.go @@ -0,0 +1,241 @@ +package mint + +import ( + "crypto/x509" + "time" +) + +// Marker interface for actions that an implementation should take based on +// state transitions. +type HandshakeAction interface{} + +type QueueHandshakeMessage struct { + Message *HandshakeMessage +} + +type SendQueuedHandshake struct{} + +type SendEarlyData struct{} + +type ReadEarlyData struct{} + +type ReadPastEarlyData struct{} + +type RekeyIn struct { + epoch Epoch + KeySet keySet +} + +type RekeyOut struct { + epoch Epoch + KeySet keySet +} + +type StorePSK struct { + PSK PreSharedKey +} + +type HandshakeState interface { + Next(handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) + State() State +} + +type AppExtensionHandler interface { + Send(hs HandshakeType, el *ExtensionList) error + Receive(hs HandshakeType, el *ExtensionList) error +} + +// ConnectionOptions objects represent per-connection settings for a client +// initiating a connection +type ConnectionOptions struct { + ServerName string + NextProtos []string + EarlyData []byte +} + +// ConnectionParameters objects represent the parameters negotiated for a +// connection. +type ConnectionParameters struct { + UsingPSK bool + UsingDH bool + ClientSendingEarlyData bool + UsingEarlyData bool + UsingClientAuth bool + + CipherSuite CipherSuite + ServerName string + NextProto string +} + +// Working state for the handshake. +type HandshakeContext struct { + hIn, hOut *HandshakeLayer +} + +func (hc *HandshakeContext) SetVersion(version uint16) { + if hc.hIn.conn != nil { + hc.hIn.conn.SetVersion(version) + } + if hc.hOut.conn != nil { + hc.hOut.conn.SetVersion(version) + } +} + +// StateConnected is symmetric between client and server +type StateConnected struct { + Params ConnectionParameters + hsCtx HandshakeContext + isClient bool + cryptoParams CipherSuiteParams + resumptionSecret []byte + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte + peerCertificates []*x509.Certificate + verifiedChains [][]*x509.Certificate +} + +var _ HandshakeState = &StateConnected{} + +func (state StateConnected) State() State { + if state.isClient { + return StateClientConnected + } + return StateServerConnected +} + +func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) { + var trafficKeys keySet + if state.isClient { + state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, + labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) + trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) + } else { + state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret, + labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) + trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) + } + + kum, err := state.hsCtx.hOut.HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request}) + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err) + return nil, AlertInternalError + } + + toSend := []HandshakeAction{ + QueueHandshakeMessage{kum}, + SendQueuedHandshake{}, + RekeyOut{epoch: EpochUpdate, KeySet: trafficKeys}, + } + return toSend, AlertNoAlert +} + +func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) { + tkt, err := NewSessionTicket(length, lifetime) + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err) + return nil, AlertInternalError + } + + err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime}) + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err) + return nil, AlertInternalError + } + + resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, + labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size()) + + newPSK := PreSharedKey{ + CipherSuite: state.cryptoParams.Suite, + IsResumption: true, + Identity: tkt.Ticket, + Key: resumptionKey, + NextProto: state.Params.NextProto, + ReceivedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second), + TicketAgeAdd: tkt.TicketAgeAdd, + } + + tktm, err := state.hsCtx.hOut.HandshakeMessageFromBody(tkt) + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err) + return nil, AlertInternalError + } + + toSend := []HandshakeAction{ + StorePSK{newPSK}, + QueueHandshakeMessage{tktm}, + SendQueuedHandshake{}, + } + return toSend, AlertNoAlert +} + +// Next does nothing for this state. +func (state StateConnected) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + return state, nil, AlertNoAlert +} + +func (state StateConnected) ProcessMessage(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil { + logf(logTypeHandshake, "[StateConnected] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + bodyGeneric, err := hm.ToBody() + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + switch body := bodyGeneric.(type) { + case *KeyUpdateBody: + var trafficKeys keySet + if !state.isClient { + state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, + labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) + trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) + } else { + state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret, + labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) + trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) + } + + toSend := []HandshakeAction{RekeyIn{epoch: EpochUpdate, KeySet: trafficKeys}} + + // If requested, roll outbound keys and send a KeyUpdate + if body.KeyUpdateRequest == KeyUpdateRequested { + logf(logTypeHandshake, "Received key update, update requested", body.KeyUpdateRequest) + moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested) + if alert != AlertNoAlert { + return nil, nil, alert + } + toSend = append(toSend, moreToSend...) + } + return state, toSend, AlertNoAlert + case *NewSessionTicketBody: + // XXX: Allow NewSessionTicket in both directions? + if !state.isClient { + return nil, nil, AlertUnexpectedMessage + } + + resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, + labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size()) + psk := PreSharedKey{ + CipherSuite: state.cryptoParams.Suite, + IsResumption: true, + Identity: body.Ticket, + Key: resumptionKey, + NextProto: state.Params.NextProto, + ReceivedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second), + TicketAgeAdd: body.TicketAgeAdd, + } + + toSend := []HandshakeAction{StorePSK{psk}} + return state, toSend, AlertNoAlert + } + + logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType) + return nil, nil, AlertUnexpectedMessage +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/syntax/decode.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/syntax/decode.go new file mode 100644 index 000000000..92c036fcf --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/syntax/decode.go @@ -0,0 +1,310 @@ +package syntax + +import ( + "bytes" + "fmt" + "reflect" + "runtime" +) + +func Unmarshal(data []byte, v interface{}) (int, error) { + // Check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + d := decodeState{} + d.Write(data) + return d.unmarshal(v) +} + +// Unmarshaler is the interface implemented by types that can +// unmarshal a TLS description of themselves. Note that unlike the +// JSON unmarshaler interface, it is not known a priori how much of +// the input data will be consumed. So the Unmarshaler must state +// how much of the input data it consumed. +type Unmarshaler interface { + UnmarshalTLS([]byte) (int, error) +} + +// These are the options that can be specified in the struct tag. Right now, +// all of them apply to variable-length vectors and nothing else +type decOpts struct { + head uint // length of length in bytes + min uint // minimum size in bytes + max uint // maximum size in bytes + varint bool // whether to decode as a varint +} + +type decodeState struct { + bytes.Buffer +} + +func (d *decodeState) unmarshal(v interface{}) (read int, err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + if s, ok := r.(string); ok { + panic(s) + } + err = r.(error) + } + }() + + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)") + } + + read = d.value(rv) + return read, nil +} + +func (e *decodeState) value(v reflect.Value) int { + return valueDecoder(v)(e, v, decOpts{}) +} + +type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int + +func valueDecoder(v reflect.Value) decoderFunc { + return typeDecoder(v.Type().Elem()) +} + +func typeDecoder(t reflect.Type) decoderFunc { + // Note: Omits the caching / wait-group things that encoding/json uses + return newTypeDecoder(t) +} + +var ( + unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem() +) + +func newTypeDecoder(t reflect.Type) decoderFunc { + if t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(unmarshalerType) { + return unmarshalerDecoder + } + + switch t.Kind() { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uintDecoder + case reflect.Array: + return newArrayDecoder(t) + case reflect.Slice: + return newSliceDecoder(t) + case reflect.Struct: + return newStructDecoder(t) + case reflect.Ptr: + return newPointerDecoder(t) + default: + panic(fmt.Errorf("Unsupported type (%s)", t)) + } +} + +///// Specific decoders below + +func unmarshalerDecoder(d *decodeState, v reflect.Value, opts decOpts) int { + um, ok := v.Interface().(Unmarshaler) + if !ok { + panic(fmt.Errorf("Non-Unmarshaler passed to unmarshalerEncoder")) + } + + read, err := um.UnmarshalTLS(d.Bytes()) + if err != nil { + panic(err) + } + + if read > d.Len() { + panic(fmt.Errorf("Invalid return value from UnmarshalTLS")) + } + + d.Next(read) + return read +} + +////////// + +func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int { + if opts.varint { + return varintDecoder(d, v, opts) + } + + uintLen := int(v.Elem().Type().Size()) + buf := d.Next(uintLen) + if len(buf) != uintLen { + panic(fmt.Errorf("Insufficient data to read uint")) + } + + return setUintFromBuffer(v, buf) +} + +func varintDecoder(d *decodeState, v reflect.Value, opts decOpts) int { + // Read the first octet and decide the size of the presented varint + first := d.Next(1) + if len(first) != 1 { + panic(fmt.Errorf("Insufficient data to read varint length")) + } + + uintLen := int(v.Elem().Type().Size()) + twoBits := uint(first[0] >> 6) + varintLen := 1 << twoBits + + if uintLen < varintLen { + panic(fmt.Errorf("Uint too small to fit varint: %d < %d")) + } + + rest := d.Next(varintLen - 1) + if len(rest) != varintLen-1 { + panic(fmt.Errorf("Insufficient data to read varint")) + } + + buf := append(first, rest...) + buf[0] &= 0x3f + return setUintFromBuffer(v, buf) +} + +func setUintFromBuffer(v reflect.Value, buf []byte) int { + val := uint64(0) + for _, b := range buf { + val = (val << 8) + uint64(b) + } + + v.Elem().SetUint(val) + return len(buf) +} + +////////// + +type arrayDecoder struct { + elemDec decoderFunc +} + +func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { + n := v.Elem().Type().Len() + read := 0 + for i := 0; i < n; i += 1 { + read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts) + } + return read +} + +func newArrayDecoder(t reflect.Type) decoderFunc { + dec := &arrayDecoder{typeDecoder(t.Elem())} + return dec.decode +} + +////////// + +type sliceDecoder struct { + elementType reflect.Type + elementDec decoderFunc +} + +func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { + if opts.head == 0 { + panic(fmt.Errorf("Cannot decode a slice without a header length")) + } + + lengthBytes := d.Next(int(opts.head)) + if len(lengthBytes) != int(opts.head) { + panic(fmt.Errorf("Not enough data to read header")) + } + + length := uint(0) + for _, b := range lengthBytes { + length = (length << 8) + uint(b) + } + + if opts.max > 0 && length > opts.max { + panic(fmt.Errorf("Length of vector exceeds declared max")) + } + if length < opts.min { + panic(fmt.Errorf("Length of vector below declared min")) + } + + data := d.Next(int(length)) + if len(data) != int(length) { + panic(fmt.Errorf("Available data less than declared length [%d < %d]", len(data), length)) + } + + elemBuf := &decodeState{} + elemBuf.Write(data) + elems := []reflect.Value{} + read := int(opts.head) + for elemBuf.Len() > 0 { + elem := reflect.New(sd.elementType) + read += sd.elementDec(elemBuf, elem, opts) + elems = append(elems, elem) + } + + v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems))) + for i := 0; i < len(elems); i += 1 { + v.Elem().Index(i).Set(elems[i].Elem()) + } + return read +} + +func newSliceDecoder(t reflect.Type) decoderFunc { + dec := &sliceDecoder{ + elementType: t.Elem(), + elementDec: typeDecoder(t.Elem()), + } + return dec.decode +} + +////////// + +type structDecoder struct { + fieldOpts []decOpts + fieldDecs []decoderFunc +} + +func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { + read := 0 + for i := range sd.fieldDecs { + read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i]) + } + return read +} + +func newStructDecoder(t reflect.Type) decoderFunc { + n := t.NumField() + sd := structDecoder{ + fieldOpts: make([]decOpts, n), + fieldDecs: make([]decoderFunc, n), + } + + for i := 0; i < n; i += 1 { + f := t.Field(i) + + tag := f.Tag.Get("tls") + tagOpts := parseTag(tag) + + sd.fieldOpts[i] = decOpts{ + head: tagOpts["head"], + max: tagOpts["max"], + min: tagOpts["min"], + varint: tagOpts[varintOption] > 0, + } + + sd.fieldDecs[i] = typeDecoder(f.Type) + } + + return sd.decode +} + +////////// + +type pointerDecoder struct { + base decoderFunc +} + +func (pd *pointerDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { + v.Elem().Set(reflect.New(v.Elem().Type().Elem())) + return pd.base(d, v.Elem(), opts) +} + +func newPointerDecoder(t reflect.Type) decoderFunc { + baseDecoder := typeDecoder(t.Elem()) + pd := pointerDecoder{base: baseDecoder} + return pd.decode +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/syntax/encode.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/syntax/encode.go new file mode 100644 index 000000000..63283936b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/syntax/encode.go @@ -0,0 +1,266 @@ +package syntax + +import ( + "bytes" + "fmt" + "reflect" + "runtime" +) + +func Marshal(v interface{}) ([]byte, error) { + e := &encodeState{} + err := e.marshal(v, encOpts{}) + if err != nil { + return nil, err + } + return e.Bytes(), nil +} + +// Marshaler is the interface implemented by types that +// have a defined TLS encoding. +type Marshaler interface { + MarshalTLS() ([]byte, error) +} + +// These are the options that can be specified in the struct tag. Right now, +// all of them apply to variable-length vectors and nothing else +type encOpts struct { + head uint // length of length in bytes + min uint // minimum size in bytes + max uint // maximum size in bytes + varint bool // whether to encode as a varint +} + +type encodeState struct { + bytes.Buffer +} + +func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + if s, ok := r.(string); ok { + panic(s) + } + err = r.(error) + } + }() + e.reflectValue(reflect.ValueOf(v), opts) + return nil +} + +func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) { + valueEncoder(v)(e, v, opts) +} + +type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts) + +func valueEncoder(v reflect.Value) encoderFunc { + if !v.IsValid() { + panic(fmt.Errorf("Cannot encode an invalid value")) + } + return typeEncoder(v.Type()) +} + +func typeEncoder(t reflect.Type) encoderFunc { + // Note: Omits the caching / wait-group things that encoding/json uses + return newTypeEncoder(t) +} + +var ( + marshalerType = reflect.TypeOf(new(Marshaler)).Elem() +) + +func newTypeEncoder(t reflect.Type) encoderFunc { + if t.Implements(marshalerType) { + return marshalerEncoder + } + + switch t.Kind() { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uintEncoder + case reflect.Array: + return newArrayEncoder(t) + case reflect.Slice: + return newSliceEncoder(t) + case reflect.Struct: + return newStructEncoder(t) + case reflect.Ptr: + return newPointerEncoder(t) + default: + panic(fmt.Errorf("Unsupported type (%s)", t)) + } +} + +///// Specific encoders below + +func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { + if v.Kind() == reflect.Ptr && v.IsNil() { + panic(fmt.Errorf("Cannot encode nil pointer")) + } + + m, ok := v.Interface().(Marshaler) + if !ok { + panic(fmt.Errorf("Non-Marshaler passed to marshalerEncoder")) + } + + b, err := m.MarshalTLS() + if err == nil { + _, err = e.Write(b) + } + + if err != nil { + panic(err) + } +} + +////////// + +func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) { + if opts.varint { + varintEncoder(e, v, opts) + return + } + + writeUint(e, v.Uint(), int(v.Type().Size())) +} + +func varintEncoder(e *encodeState, v reflect.Value, opts encOpts) { + u := v.Uint() + if (u >> 62) > 0 { + panic(fmt.Errorf("uint value is too big for varint")) + } + + var varintLen int + for _, len := range []uint{1, 2, 4, 8} { + if u < (uint64(1) << (8*len - 2)) { + varintLen = int(len) + break + } + } + + twoBits := map[int]uint64{1: 0x00, 2: 0x01, 4: 0x02, 8: 0x03}[varintLen] + shift := uint(8*varintLen - 2) + writeUint(e, u|(twoBits<> uint(8*(len-i-1))) + } + e.Write(data) +} + +////////// + +type arrayEncoder struct { + elemEnc encoderFunc +} + +func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { + n := v.Len() + for i := 0; i < n; i += 1 { + ae.elemEnc(e, v.Index(i), opts) + } +} + +func newArrayEncoder(t reflect.Type) encoderFunc { + enc := &arrayEncoder{typeEncoder(t.Elem())} + return enc.encode +} + +////////// + +type sliceEncoder struct { + ae *arrayEncoder +} + +func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { + if opts.head == 0 { + panic(fmt.Errorf("Cannot encode a slice without a header length")) + } + + arrayState := &encodeState{} + se.ae.encode(arrayState, v, opts) + + n := uint(arrayState.Len()) + if opts.max > 0 && n > opts.max { + panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max)) + } + if n>>(8*opts.head) > 0 { + panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head)) + } + if n < opts.min { + panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min)) + } + + for i := int(opts.head - 1); i >= 0; i -= 1 { + e.WriteByte(byte(n >> (8 * uint(i)))) + } + e.Write(arrayState.Bytes()) +} + +func newSliceEncoder(t reflect.Type) encoderFunc { + enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}} + return enc.encode +} + +////////// + +type structEncoder struct { + fieldOpts []encOpts + fieldEncs []encoderFunc +} + +func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { + for i := range se.fieldEncs { + se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i]) + } +} + +func newStructEncoder(t reflect.Type) encoderFunc { + n := t.NumField() + se := structEncoder{ + fieldOpts: make([]encOpts, n), + fieldEncs: make([]encoderFunc, n), + } + + for i := 0; i < n; i += 1 { + f := t.Field(i) + tag := f.Tag.Get("tls") + tagOpts := parseTag(tag) + + se.fieldOpts[i] = encOpts{ + head: tagOpts["head"], + max: tagOpts["max"], + min: tagOpts["min"], + varint: tagOpts[varintOption] > 0, + } + se.fieldEncs[i] = typeEncoder(f.Type) + } + + return se.encode +} + +////////// + +type pointerEncoder struct { + base encoderFunc +} + +func (pe pointerEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { + if v.IsNil() { + panic(fmt.Errorf("Cannot marshal a struct containing a nil pointer")) + } + + pe.base(e, v.Elem(), opts) +} + +func newPointerEncoder(t reflect.Type) encoderFunc { + baseEncoder := typeEncoder(t.Elem()) + pe := pointerEncoder{base: baseEncoder} + return pe.encode +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/syntax/tags.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/syntax/tags.go new file mode 100644 index 000000000..1bb3718e2 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/syntax/tags.go @@ -0,0 +1,40 @@ +package syntax + +import ( + "strconv" + "strings" +) + +// `tls:"head=2,min=2,max=255,varint"` + +type tagOptions map[string]uint + +var ( + varintOption = "varint" +) + +// parseTag parses a struct field's "tls" tag as a comma-separated list of +// name=value pairs, where the values MUST be unsigned integers +func parseTag(tag string) tagOptions { + opts := tagOptions{} + for _, token := range strings.Split(tag, ",") { + if token == varintOption { + opts[varintOption] = 1 + continue + } + + parts := strings.Split(token, "=") + if len(parts[0]) == 0 { + continue + } + + if len(parts) == 1 { + continue + } + + if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 { + opts[parts[0]] = uint(val) + } + } + return opts +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/tls.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/tls.go new file mode 100644 index 000000000..4d2286922 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/tls.go @@ -0,0 +1,179 @@ +package mint + +// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls + +import ( + "errors" + "net" + "strings" + "time" +) + +// Server returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Server(conn net.Conn, config *Config) *Conn { + return NewConn(conn, config, false) +} + +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. +func Client(conn net.Conn, config *Config) *Conn { + return NewConn(conn, config, true) +} + +// A listener implements a network listener (net.Listener) for TLS connections. +type Listener struct { + net.Listener + config *Config +} + +// Accept waits for and returns the next incoming TLS connection. +// The returned connection c is a *tls.Conn. +func (l *Listener) Accept() (c net.Conn, err error) { + c, err = l.Listener.Accept() + if err != nil { + return + } + server := Server(c, l.config) + err = server.Handshake() + if err == AlertNoAlert { + err = nil + } + c = server + return +} + +// NewListener creates a Listener which accepts connections from an inner +// Listener and wraps each connection with Server. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func NewListener(inner net.Listener, config *Config) (net.Listener, error) { + if config != nil && config.NonBlocking { + return nil, errors.New("listening not possible in non-blocking mode") + } + l := new(Listener) + l.Listener = inner + l.config = config + return l, nil +} + +// Listen creates a TLS listener accepting connections on the +// given network address using net.Listen. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Listen(network, laddr string, config *Config) (net.Listener, error) { + if config == nil || !config.ValidForServer() { + return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config") + } + l, err := net.Listen(network, laddr) + if err != nil { + return nil, err + } + return NewListener(l, config) +} + +type TimeoutError struct{} + +func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (TimeoutError) Timeout() bool { return true } +func (TimeoutError) Temporary() bool { return true } + +// DialWithDialer connects to the given network address using dialer.Dial and +// then initiates a TLS handshake, returning the resulting TLS connection. Any +// timeout or deadline given in the dialer apply to connection and TLS +// handshake as a whole. +// +// DialWithDialer interprets a nil configuration as equivalent to the zero +// configuration; see the documentation of Config for the defaults. +func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { + if config != nil && config.NonBlocking { + return nil, errors.New("dialing not possible in non-blocking mode") + } + + // We want the Timeout and Deadline values from dialer to cover the + // whole process: TCP connection and TLS handshake. This means that we + // also need to start our own timers now. + timeout := dialer.Timeout + + if !dialer.Deadline.IsZero() { + deadlineTimeout := dialer.Deadline.Sub(time.Now()) + if timeout == 0 || deadlineTimeout < timeout { + timeout = deadlineTimeout + } + } + + var errChannel chan error + + if timeout != 0 { + errChannel = make(chan error, 2) + time.AfterFunc(timeout, func() { + errChannel <- TimeoutError{} + }) + } + + rawConn, err := dialer.Dial(network, addr) + if err != nil { + return nil, err + } + + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + + if config == nil { + config = &Config{} + } else { + config = config.Clone() + } + + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + if config.ServerName == "" { + config.ServerName = hostname + + } + + // Set up DTLS as needed. + config.UseDTLS = (network == "udp") + + conn := Client(rawConn, config) + + if timeout == 0 { + err = conn.Handshake() + if err == AlertNoAlert { + err = nil + } + } else { + go func() { + errChannel <- conn.Handshake() + }() + + err = <-errChannel + if err == AlertNoAlert { + err = nil + } + } + + if err != nil { + rawConn.Close() + return nil, err + } + + return conn, nil +} + +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, addr string, config *Config) (*Conn, error) { + return DialWithDialer(new(net.Dialer), network, addr, config) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/cheekybits/genny/generic/doc.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/cheekybits/genny/generic/doc.go new file mode 100644 index 000000000..3bd6c869c --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/cheekybits/genny/generic/doc.go @@ -0,0 +1,2 @@ +// Package generic contains the generic marker types. +package generic diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/cheekybits/genny/generic/generic.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/cheekybits/genny/generic/generic.go new file mode 100644 index 000000000..04a2306cb --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/cheekybits/genny/generic/generic.go @@ -0,0 +1,13 @@ +package generic + +// Type is the placeholder type that indicates a generic value. +// When genny is executed, variables of this type will be replaced with +// references to the specific types. +// var GenericType generic.Type +type Type interface{} + +// Number is the placehoder type that indiccates a generic numerical value. +// When genny is executed, variables of this type will be replaced with +// references to the specific types. +// var GenericType generic.Number +type Number float64 diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/const_amd64.h b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/const_amd64.h new file mode 100644 index 000000000..b3f74162f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/const_amd64.h @@ -0,0 +1,8 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This code was translated into a form compatible with 6a from the public +// domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html + +#define REDMASK51 0x0007FFFFFFFFFFFF diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/const_amd64.s b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/const_amd64.s new file mode 100644 index 000000000..ee7b4bd5f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/const_amd64.s @@ -0,0 +1,20 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This code was translated into a form compatible with 6a from the public +// domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html + +// +build amd64,!gccgo,!appengine + +// These constants cannot be encoded in non-MOVQ immediates. +// We access them directly from memory instead. + +DATA ·_121666_213(SB)/8, $996687872 +GLOBL ·_121666_213(SB), 8, $8 + +DATA ·_2P0(SB)/8, $0xFFFFFFFFFFFDA +GLOBL ·_2P0(SB), 8, $8 + +DATA ·_2P1234(SB)/8, $0xFFFFFFFFFFFFE +GLOBL ·_2P1234(SB), 8, $8 diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/cswap_amd64.s b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/cswap_amd64.s new file mode 100644 index 000000000..cd793a5b5 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/cswap_amd64.s @@ -0,0 +1,65 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build amd64,!gccgo,!appengine + +// func cswap(inout *[4][5]uint64, v uint64) +TEXT ·cswap(SB),7,$0 + MOVQ inout+0(FP),DI + MOVQ v+8(FP),SI + + SUBQ $1, SI + NOTQ SI + MOVQ SI, X15 + PSHUFD $0x44, X15, X15 + + MOVOU 0(DI), X0 + MOVOU 16(DI), X2 + MOVOU 32(DI), X4 + MOVOU 48(DI), X6 + MOVOU 64(DI), X8 + MOVOU 80(DI), X1 + MOVOU 96(DI), X3 + MOVOU 112(DI), X5 + MOVOU 128(DI), X7 + MOVOU 144(DI), X9 + + MOVO X1, X10 + MOVO X3, X11 + MOVO X5, X12 + MOVO X7, X13 + MOVO X9, X14 + + PXOR X0, X10 + PXOR X2, X11 + PXOR X4, X12 + PXOR X6, X13 + PXOR X8, X14 + PAND X15, X10 + PAND X15, X11 + PAND X15, X12 + PAND X15, X13 + PAND X15, X14 + PXOR X10, X0 + PXOR X10, X1 + PXOR X11, X2 + PXOR X11, X3 + PXOR X12, X4 + PXOR X12, X5 + PXOR X13, X6 + PXOR X13, X7 + PXOR X14, X8 + PXOR X14, X9 + + MOVOU X0, 0(DI) + MOVOU X2, 16(DI) + MOVOU X4, 32(DI) + MOVOU X6, 48(DI) + MOVOU X8, 64(DI) + MOVOU X1, 80(DI) + MOVOU X3, 96(DI) + MOVOU X5, 112(DI) + MOVOU X7, 128(DI) + MOVOU X9, 144(DI) + RET diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/curve25519.go b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/curve25519.go new file mode 100644 index 000000000..cb8fbc57b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/curve25519.go @@ -0,0 +1,834 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// We have an implementation in amd64 assembly so this code is only run on +// non-amd64 platforms. The amd64 assembly does not support gccgo. +// +build !amd64 gccgo appengine + +package curve25519 + +import ( + "encoding/binary" +) + +// This code is a port of the public domain, "ref10" implementation of +// curve25519 from SUPERCOP 20130419 by D. J. Bernstein. + +// fieldElement represents an element of the field GF(2^255 - 19). An element +// t, entries t[0]...t[9], represents the integer t[0]+2^26 t[1]+2^51 t[2]+2^77 +// t[3]+2^102 t[4]+...+2^230 t[9]. Bounds on each t[i] vary depending on +// context. +type fieldElement [10]int32 + +func feZero(fe *fieldElement) { + for i := range fe { + fe[i] = 0 + } +} + +func feOne(fe *fieldElement) { + feZero(fe) + fe[0] = 1 +} + +func feAdd(dst, a, b *fieldElement) { + for i := range dst { + dst[i] = a[i] + b[i] + } +} + +func feSub(dst, a, b *fieldElement) { + for i := range dst { + dst[i] = a[i] - b[i] + } +} + +func feCopy(dst, src *fieldElement) { + for i := range dst { + dst[i] = src[i] + } +} + +// feCSwap replaces (f,g) with (g,f) if b == 1; replaces (f,g) with (f,g) if b == 0. +// +// Preconditions: b in {0,1}. +func feCSwap(f, g *fieldElement, b int32) { + b = -b + for i := range f { + t := b & (f[i] ^ g[i]) + f[i] ^= t + g[i] ^= t + } +} + +// load3 reads a 24-bit, little-endian value from in. +func load3(in []byte) int64 { + var r int64 + r = int64(in[0]) + r |= int64(in[1]) << 8 + r |= int64(in[2]) << 16 + return r +} + +// load4 reads a 32-bit, little-endian value from in. +func load4(in []byte) int64 { + return int64(binary.LittleEndian.Uint32(in)) +} + +func feFromBytes(dst *fieldElement, src *[32]byte) { + h0 := load4(src[:]) + h1 := load3(src[4:]) << 6 + h2 := load3(src[7:]) << 5 + h3 := load3(src[10:]) << 3 + h4 := load3(src[13:]) << 2 + h5 := load4(src[16:]) + h6 := load3(src[20:]) << 7 + h7 := load3(src[23:]) << 5 + h8 := load3(src[26:]) << 4 + h9 := load3(src[29:]) << 2 + + var carry [10]int64 + carry[9] = (h9 + 1<<24) >> 25 + h0 += carry[9] * 19 + h9 -= carry[9] << 25 + carry[1] = (h1 + 1<<24) >> 25 + h2 += carry[1] + h1 -= carry[1] << 25 + carry[3] = (h3 + 1<<24) >> 25 + h4 += carry[3] + h3 -= carry[3] << 25 + carry[5] = (h5 + 1<<24) >> 25 + h6 += carry[5] + h5 -= carry[5] << 25 + carry[7] = (h7 + 1<<24) >> 25 + h8 += carry[7] + h7 -= carry[7] << 25 + + carry[0] = (h0 + 1<<25) >> 26 + h1 += carry[0] + h0 -= carry[0] << 26 + carry[2] = (h2 + 1<<25) >> 26 + h3 += carry[2] + h2 -= carry[2] << 26 + carry[4] = (h4 + 1<<25) >> 26 + h5 += carry[4] + h4 -= carry[4] << 26 + carry[6] = (h6 + 1<<25) >> 26 + h7 += carry[6] + h6 -= carry[6] << 26 + carry[8] = (h8 + 1<<25) >> 26 + h9 += carry[8] + h8 -= carry[8] << 26 + + dst[0] = int32(h0) + dst[1] = int32(h1) + dst[2] = int32(h2) + dst[3] = int32(h3) + dst[4] = int32(h4) + dst[5] = int32(h5) + dst[6] = int32(h6) + dst[7] = int32(h7) + dst[8] = int32(h8) + dst[9] = int32(h9) +} + +// feToBytes marshals h to s. +// Preconditions: +// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc. +// +// Write p=2^255-19; q=floor(h/p). +// Basic claim: q = floor(2^(-255)(h + 19 2^(-25)h9 + 2^(-1))). +// +// Proof: +// Have |h|<=p so |q|<=1 so |19^2 2^(-255) q|<1/4. +// Also have |h-2^230 h9|<2^230 so |19 2^(-255)(h-2^230 h9)|<1/4. +// +// Write y=2^(-1)-19^2 2^(-255)q-19 2^(-255)(h-2^230 h9). +// Then 0> 25 + q = (h[0] + q) >> 26 + q = (h[1] + q) >> 25 + q = (h[2] + q) >> 26 + q = (h[3] + q) >> 25 + q = (h[4] + q) >> 26 + q = (h[5] + q) >> 25 + q = (h[6] + q) >> 26 + q = (h[7] + q) >> 25 + q = (h[8] + q) >> 26 + q = (h[9] + q) >> 25 + + // Goal: Output h-(2^255-19)q, which is between 0 and 2^255-20. + h[0] += 19 * q + // Goal: Output h-2^255 q, which is between 0 and 2^255-20. + + carry[0] = h[0] >> 26 + h[1] += carry[0] + h[0] -= carry[0] << 26 + carry[1] = h[1] >> 25 + h[2] += carry[1] + h[1] -= carry[1] << 25 + carry[2] = h[2] >> 26 + h[3] += carry[2] + h[2] -= carry[2] << 26 + carry[3] = h[3] >> 25 + h[4] += carry[3] + h[3] -= carry[3] << 25 + carry[4] = h[4] >> 26 + h[5] += carry[4] + h[4] -= carry[4] << 26 + carry[5] = h[5] >> 25 + h[6] += carry[5] + h[5] -= carry[5] << 25 + carry[6] = h[6] >> 26 + h[7] += carry[6] + h[6] -= carry[6] << 26 + carry[7] = h[7] >> 25 + h[8] += carry[7] + h[7] -= carry[7] << 25 + carry[8] = h[8] >> 26 + h[9] += carry[8] + h[8] -= carry[8] << 26 + carry[9] = h[9] >> 25 + h[9] -= carry[9] << 25 + // h10 = carry9 + + // Goal: Output h[0]+...+2^255 h10-2^255 q, which is between 0 and 2^255-20. + // Have h[0]+...+2^230 h[9] between 0 and 2^255-1; + // evidently 2^255 h10-2^255 q = 0. + // Goal: Output h[0]+...+2^230 h[9]. + + s[0] = byte(h[0] >> 0) + s[1] = byte(h[0] >> 8) + s[2] = byte(h[0] >> 16) + s[3] = byte((h[0] >> 24) | (h[1] << 2)) + s[4] = byte(h[1] >> 6) + s[5] = byte(h[1] >> 14) + s[6] = byte((h[1] >> 22) | (h[2] << 3)) + s[7] = byte(h[2] >> 5) + s[8] = byte(h[2] >> 13) + s[9] = byte((h[2] >> 21) | (h[3] << 5)) + s[10] = byte(h[3] >> 3) + s[11] = byte(h[3] >> 11) + s[12] = byte((h[3] >> 19) | (h[4] << 6)) + s[13] = byte(h[4] >> 2) + s[14] = byte(h[4] >> 10) + s[15] = byte(h[4] >> 18) + s[16] = byte(h[5] >> 0) + s[17] = byte(h[5] >> 8) + s[18] = byte(h[5] >> 16) + s[19] = byte((h[5] >> 24) | (h[6] << 1)) + s[20] = byte(h[6] >> 7) + s[21] = byte(h[6] >> 15) + s[22] = byte((h[6] >> 23) | (h[7] << 3)) + s[23] = byte(h[7] >> 5) + s[24] = byte(h[7] >> 13) + s[25] = byte((h[7] >> 21) | (h[8] << 4)) + s[26] = byte(h[8] >> 4) + s[27] = byte(h[8] >> 12) + s[28] = byte((h[8] >> 20) | (h[9] << 6)) + s[29] = byte(h[9] >> 2) + s[30] = byte(h[9] >> 10) + s[31] = byte(h[9] >> 18) +} + +// feMul calculates h = f * g +// Can overlap h with f or g. +// +// Preconditions: +// |f| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc. +// |g| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc. +// +// Postconditions: +// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc. +// +// Notes on implementation strategy: +// +// Using schoolbook multiplication. +// Karatsuba would save a little in some cost models. +// +// Most multiplications by 2 and 19 are 32-bit precomputations; +// cheaper than 64-bit postcomputations. +// +// There is one remaining multiplication by 19 in the carry chain; +// one *19 precomputation can be merged into this, +// but the resulting data flow is considerably less clean. +// +// There are 12 carries below. +// 10 of them are 2-way parallelizable and vectorizable. +// Can get away with 11 carries, but then data flow is much deeper. +// +// With tighter constraints on inputs can squeeze carries into int32. +func feMul(h, f, g *fieldElement) { + f0 := f[0] + f1 := f[1] + f2 := f[2] + f3 := f[3] + f4 := f[4] + f5 := f[5] + f6 := f[6] + f7 := f[7] + f8 := f[8] + f9 := f[9] + g0 := g[0] + g1 := g[1] + g2 := g[2] + g3 := g[3] + g4 := g[4] + g5 := g[5] + g6 := g[6] + g7 := g[7] + g8 := g[8] + g9 := g[9] + g1_19 := 19 * g1 // 1.4*2^29 + g2_19 := 19 * g2 // 1.4*2^30; still ok + g3_19 := 19 * g3 + g4_19 := 19 * g4 + g5_19 := 19 * g5 + g6_19 := 19 * g6 + g7_19 := 19 * g7 + g8_19 := 19 * g8 + g9_19 := 19 * g9 + f1_2 := 2 * f1 + f3_2 := 2 * f3 + f5_2 := 2 * f5 + f7_2 := 2 * f7 + f9_2 := 2 * f9 + f0g0 := int64(f0) * int64(g0) + f0g1 := int64(f0) * int64(g1) + f0g2 := int64(f0) * int64(g2) + f0g3 := int64(f0) * int64(g3) + f0g4 := int64(f0) * int64(g4) + f0g5 := int64(f0) * int64(g5) + f0g6 := int64(f0) * int64(g6) + f0g7 := int64(f0) * int64(g7) + f0g8 := int64(f0) * int64(g8) + f0g9 := int64(f0) * int64(g9) + f1g0 := int64(f1) * int64(g0) + f1g1_2 := int64(f1_2) * int64(g1) + f1g2 := int64(f1) * int64(g2) + f1g3_2 := int64(f1_2) * int64(g3) + f1g4 := int64(f1) * int64(g4) + f1g5_2 := int64(f1_2) * int64(g5) + f1g6 := int64(f1) * int64(g6) + f1g7_2 := int64(f1_2) * int64(g7) + f1g8 := int64(f1) * int64(g8) + f1g9_38 := int64(f1_2) * int64(g9_19) + f2g0 := int64(f2) * int64(g0) + f2g1 := int64(f2) * int64(g1) + f2g2 := int64(f2) * int64(g2) + f2g3 := int64(f2) * int64(g3) + f2g4 := int64(f2) * int64(g4) + f2g5 := int64(f2) * int64(g5) + f2g6 := int64(f2) * int64(g6) + f2g7 := int64(f2) * int64(g7) + f2g8_19 := int64(f2) * int64(g8_19) + f2g9_19 := int64(f2) * int64(g9_19) + f3g0 := int64(f3) * int64(g0) + f3g1_2 := int64(f3_2) * int64(g1) + f3g2 := int64(f3) * int64(g2) + f3g3_2 := int64(f3_2) * int64(g3) + f3g4 := int64(f3) * int64(g4) + f3g5_2 := int64(f3_2) * int64(g5) + f3g6 := int64(f3) * int64(g6) + f3g7_38 := int64(f3_2) * int64(g7_19) + f3g8_19 := int64(f3) * int64(g8_19) + f3g9_38 := int64(f3_2) * int64(g9_19) + f4g0 := int64(f4) * int64(g0) + f4g1 := int64(f4) * int64(g1) + f4g2 := int64(f4) * int64(g2) + f4g3 := int64(f4) * int64(g3) + f4g4 := int64(f4) * int64(g4) + f4g5 := int64(f4) * int64(g5) + f4g6_19 := int64(f4) * int64(g6_19) + f4g7_19 := int64(f4) * int64(g7_19) + f4g8_19 := int64(f4) * int64(g8_19) + f4g9_19 := int64(f4) * int64(g9_19) + f5g0 := int64(f5) * int64(g0) + f5g1_2 := int64(f5_2) * int64(g1) + f5g2 := int64(f5) * int64(g2) + f5g3_2 := int64(f5_2) * int64(g3) + f5g4 := int64(f5) * int64(g4) + f5g5_38 := int64(f5_2) * int64(g5_19) + f5g6_19 := int64(f5) * int64(g6_19) + f5g7_38 := int64(f5_2) * int64(g7_19) + f5g8_19 := int64(f5) * int64(g8_19) + f5g9_38 := int64(f5_2) * int64(g9_19) + f6g0 := int64(f6) * int64(g0) + f6g1 := int64(f6) * int64(g1) + f6g2 := int64(f6) * int64(g2) + f6g3 := int64(f6) * int64(g3) + f6g4_19 := int64(f6) * int64(g4_19) + f6g5_19 := int64(f6) * int64(g5_19) + f6g6_19 := int64(f6) * int64(g6_19) + f6g7_19 := int64(f6) * int64(g7_19) + f6g8_19 := int64(f6) * int64(g8_19) + f6g9_19 := int64(f6) * int64(g9_19) + f7g0 := int64(f7) * int64(g0) + f7g1_2 := int64(f7_2) * int64(g1) + f7g2 := int64(f7) * int64(g2) + f7g3_38 := int64(f7_2) * int64(g3_19) + f7g4_19 := int64(f7) * int64(g4_19) + f7g5_38 := int64(f7_2) * int64(g5_19) + f7g6_19 := int64(f7) * int64(g6_19) + f7g7_38 := int64(f7_2) * int64(g7_19) + f7g8_19 := int64(f7) * int64(g8_19) + f7g9_38 := int64(f7_2) * int64(g9_19) + f8g0 := int64(f8) * int64(g0) + f8g1 := int64(f8) * int64(g1) + f8g2_19 := int64(f8) * int64(g2_19) + f8g3_19 := int64(f8) * int64(g3_19) + f8g4_19 := int64(f8) * int64(g4_19) + f8g5_19 := int64(f8) * int64(g5_19) + f8g6_19 := int64(f8) * int64(g6_19) + f8g7_19 := int64(f8) * int64(g7_19) + f8g8_19 := int64(f8) * int64(g8_19) + f8g9_19 := int64(f8) * int64(g9_19) + f9g0 := int64(f9) * int64(g0) + f9g1_38 := int64(f9_2) * int64(g1_19) + f9g2_19 := int64(f9) * int64(g2_19) + f9g3_38 := int64(f9_2) * int64(g3_19) + f9g4_19 := int64(f9) * int64(g4_19) + f9g5_38 := int64(f9_2) * int64(g5_19) + f9g6_19 := int64(f9) * int64(g6_19) + f9g7_38 := int64(f9_2) * int64(g7_19) + f9g8_19 := int64(f9) * int64(g8_19) + f9g9_38 := int64(f9_2) * int64(g9_19) + h0 := f0g0 + f1g9_38 + f2g8_19 + f3g7_38 + f4g6_19 + f5g5_38 + f6g4_19 + f7g3_38 + f8g2_19 + f9g1_38 + h1 := f0g1 + f1g0 + f2g9_19 + f3g8_19 + f4g7_19 + f5g6_19 + f6g5_19 + f7g4_19 + f8g3_19 + f9g2_19 + h2 := f0g2 + f1g1_2 + f2g0 + f3g9_38 + f4g8_19 + f5g7_38 + f6g6_19 + f7g5_38 + f8g4_19 + f9g3_38 + h3 := f0g3 + f1g2 + f2g1 + f3g0 + f4g9_19 + f5g8_19 + f6g7_19 + f7g6_19 + f8g5_19 + f9g4_19 + h4 := f0g4 + f1g3_2 + f2g2 + f3g1_2 + f4g0 + f5g9_38 + f6g8_19 + f7g7_38 + f8g6_19 + f9g5_38 + h5 := f0g5 + f1g4 + f2g3 + f3g2 + f4g1 + f5g0 + f6g9_19 + f7g8_19 + f8g7_19 + f9g6_19 + h6 := f0g6 + f1g5_2 + f2g4 + f3g3_2 + f4g2 + f5g1_2 + f6g0 + f7g9_38 + f8g8_19 + f9g7_38 + h7 := f0g7 + f1g6 + f2g5 + f3g4 + f4g3 + f5g2 + f6g1 + f7g0 + f8g9_19 + f9g8_19 + h8 := f0g8 + f1g7_2 + f2g6 + f3g5_2 + f4g4 + f5g3_2 + f6g2 + f7g1_2 + f8g0 + f9g9_38 + h9 := f0g9 + f1g8 + f2g7 + f3g6 + f4g5 + f5g4 + f6g3 + f7g2 + f8g1 + f9g0 + var carry [10]int64 + + // |h0| <= (1.1*1.1*2^52*(1+19+19+19+19)+1.1*1.1*2^50*(38+38+38+38+38)) + // i.e. |h0| <= 1.2*2^59; narrower ranges for h2, h4, h6, h8 + // |h1| <= (1.1*1.1*2^51*(1+1+19+19+19+19+19+19+19+19)) + // i.e. |h1| <= 1.5*2^58; narrower ranges for h3, h5, h7, h9 + + carry[0] = (h0 + (1 << 25)) >> 26 + h1 += carry[0] + h0 -= carry[0] << 26 + carry[4] = (h4 + (1 << 25)) >> 26 + h5 += carry[4] + h4 -= carry[4] << 26 + // |h0| <= 2^25 + // |h4| <= 2^25 + // |h1| <= 1.51*2^58 + // |h5| <= 1.51*2^58 + + carry[1] = (h1 + (1 << 24)) >> 25 + h2 += carry[1] + h1 -= carry[1] << 25 + carry[5] = (h5 + (1 << 24)) >> 25 + h6 += carry[5] + h5 -= carry[5] << 25 + // |h1| <= 2^24; from now on fits into int32 + // |h5| <= 2^24; from now on fits into int32 + // |h2| <= 1.21*2^59 + // |h6| <= 1.21*2^59 + + carry[2] = (h2 + (1 << 25)) >> 26 + h3 += carry[2] + h2 -= carry[2] << 26 + carry[6] = (h6 + (1 << 25)) >> 26 + h7 += carry[6] + h6 -= carry[6] << 26 + // |h2| <= 2^25; from now on fits into int32 unchanged + // |h6| <= 2^25; from now on fits into int32 unchanged + // |h3| <= 1.51*2^58 + // |h7| <= 1.51*2^58 + + carry[3] = (h3 + (1 << 24)) >> 25 + h4 += carry[3] + h3 -= carry[3] << 25 + carry[7] = (h7 + (1 << 24)) >> 25 + h8 += carry[7] + h7 -= carry[7] << 25 + // |h3| <= 2^24; from now on fits into int32 unchanged + // |h7| <= 2^24; from now on fits into int32 unchanged + // |h4| <= 1.52*2^33 + // |h8| <= 1.52*2^33 + + carry[4] = (h4 + (1 << 25)) >> 26 + h5 += carry[4] + h4 -= carry[4] << 26 + carry[8] = (h8 + (1 << 25)) >> 26 + h9 += carry[8] + h8 -= carry[8] << 26 + // |h4| <= 2^25; from now on fits into int32 unchanged + // |h8| <= 2^25; from now on fits into int32 unchanged + // |h5| <= 1.01*2^24 + // |h9| <= 1.51*2^58 + + carry[9] = (h9 + (1 << 24)) >> 25 + h0 += carry[9] * 19 + h9 -= carry[9] << 25 + // |h9| <= 2^24; from now on fits into int32 unchanged + // |h0| <= 1.8*2^37 + + carry[0] = (h0 + (1 << 25)) >> 26 + h1 += carry[0] + h0 -= carry[0] << 26 + // |h0| <= 2^25; from now on fits into int32 unchanged + // |h1| <= 1.01*2^24 + + h[0] = int32(h0) + h[1] = int32(h1) + h[2] = int32(h2) + h[3] = int32(h3) + h[4] = int32(h4) + h[5] = int32(h5) + h[6] = int32(h6) + h[7] = int32(h7) + h[8] = int32(h8) + h[9] = int32(h9) +} + +// feSquare calculates h = f*f. Can overlap h with f. +// +// Preconditions: +// |f| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc. +// +// Postconditions: +// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc. +func feSquare(h, f *fieldElement) { + f0 := f[0] + f1 := f[1] + f2 := f[2] + f3 := f[3] + f4 := f[4] + f5 := f[5] + f6 := f[6] + f7 := f[7] + f8 := f[8] + f9 := f[9] + f0_2 := 2 * f0 + f1_2 := 2 * f1 + f2_2 := 2 * f2 + f3_2 := 2 * f3 + f4_2 := 2 * f4 + f5_2 := 2 * f5 + f6_2 := 2 * f6 + f7_2 := 2 * f7 + f5_38 := 38 * f5 // 1.31*2^30 + f6_19 := 19 * f6 // 1.31*2^30 + f7_38 := 38 * f7 // 1.31*2^30 + f8_19 := 19 * f8 // 1.31*2^30 + f9_38 := 38 * f9 // 1.31*2^30 + f0f0 := int64(f0) * int64(f0) + f0f1_2 := int64(f0_2) * int64(f1) + f0f2_2 := int64(f0_2) * int64(f2) + f0f3_2 := int64(f0_2) * int64(f3) + f0f4_2 := int64(f0_2) * int64(f4) + f0f5_2 := int64(f0_2) * int64(f5) + f0f6_2 := int64(f0_2) * int64(f6) + f0f7_2 := int64(f0_2) * int64(f7) + f0f8_2 := int64(f0_2) * int64(f8) + f0f9_2 := int64(f0_2) * int64(f9) + f1f1_2 := int64(f1_2) * int64(f1) + f1f2_2 := int64(f1_2) * int64(f2) + f1f3_4 := int64(f1_2) * int64(f3_2) + f1f4_2 := int64(f1_2) * int64(f4) + f1f5_4 := int64(f1_2) * int64(f5_2) + f1f6_2 := int64(f1_2) * int64(f6) + f1f7_4 := int64(f1_2) * int64(f7_2) + f1f8_2 := int64(f1_2) * int64(f8) + f1f9_76 := int64(f1_2) * int64(f9_38) + f2f2 := int64(f2) * int64(f2) + f2f3_2 := int64(f2_2) * int64(f3) + f2f4_2 := int64(f2_2) * int64(f4) + f2f5_2 := int64(f2_2) * int64(f5) + f2f6_2 := int64(f2_2) * int64(f6) + f2f7_2 := int64(f2_2) * int64(f7) + f2f8_38 := int64(f2_2) * int64(f8_19) + f2f9_38 := int64(f2) * int64(f9_38) + f3f3_2 := int64(f3_2) * int64(f3) + f3f4_2 := int64(f3_2) * int64(f4) + f3f5_4 := int64(f3_2) * int64(f5_2) + f3f6_2 := int64(f3_2) * int64(f6) + f3f7_76 := int64(f3_2) * int64(f7_38) + f3f8_38 := int64(f3_2) * int64(f8_19) + f3f9_76 := int64(f3_2) * int64(f9_38) + f4f4 := int64(f4) * int64(f4) + f4f5_2 := int64(f4_2) * int64(f5) + f4f6_38 := int64(f4_2) * int64(f6_19) + f4f7_38 := int64(f4) * int64(f7_38) + f4f8_38 := int64(f4_2) * int64(f8_19) + f4f9_38 := int64(f4) * int64(f9_38) + f5f5_38 := int64(f5) * int64(f5_38) + f5f6_38 := int64(f5_2) * int64(f6_19) + f5f7_76 := int64(f5_2) * int64(f7_38) + f5f8_38 := int64(f5_2) * int64(f8_19) + f5f9_76 := int64(f5_2) * int64(f9_38) + f6f6_19 := int64(f6) * int64(f6_19) + f6f7_38 := int64(f6) * int64(f7_38) + f6f8_38 := int64(f6_2) * int64(f8_19) + f6f9_38 := int64(f6) * int64(f9_38) + f7f7_38 := int64(f7) * int64(f7_38) + f7f8_38 := int64(f7_2) * int64(f8_19) + f7f9_76 := int64(f7_2) * int64(f9_38) + f8f8_19 := int64(f8) * int64(f8_19) + f8f9_38 := int64(f8) * int64(f9_38) + f9f9_38 := int64(f9) * int64(f9_38) + h0 := f0f0 + f1f9_76 + f2f8_38 + f3f7_76 + f4f6_38 + f5f5_38 + h1 := f0f1_2 + f2f9_38 + f3f8_38 + f4f7_38 + f5f6_38 + h2 := f0f2_2 + f1f1_2 + f3f9_76 + f4f8_38 + f5f7_76 + f6f6_19 + h3 := f0f3_2 + f1f2_2 + f4f9_38 + f5f8_38 + f6f7_38 + h4 := f0f4_2 + f1f3_4 + f2f2 + f5f9_76 + f6f8_38 + f7f7_38 + h5 := f0f5_2 + f1f4_2 + f2f3_2 + f6f9_38 + f7f8_38 + h6 := f0f6_2 + f1f5_4 + f2f4_2 + f3f3_2 + f7f9_76 + f8f8_19 + h7 := f0f7_2 + f1f6_2 + f2f5_2 + f3f4_2 + f8f9_38 + h8 := f0f8_2 + f1f7_4 + f2f6_2 + f3f5_4 + f4f4 + f9f9_38 + h9 := f0f9_2 + f1f8_2 + f2f7_2 + f3f6_2 + f4f5_2 + var carry [10]int64 + + carry[0] = (h0 + (1 << 25)) >> 26 + h1 += carry[0] + h0 -= carry[0] << 26 + carry[4] = (h4 + (1 << 25)) >> 26 + h5 += carry[4] + h4 -= carry[4] << 26 + + carry[1] = (h1 + (1 << 24)) >> 25 + h2 += carry[1] + h1 -= carry[1] << 25 + carry[5] = (h5 + (1 << 24)) >> 25 + h6 += carry[5] + h5 -= carry[5] << 25 + + carry[2] = (h2 + (1 << 25)) >> 26 + h3 += carry[2] + h2 -= carry[2] << 26 + carry[6] = (h6 + (1 << 25)) >> 26 + h7 += carry[6] + h6 -= carry[6] << 26 + + carry[3] = (h3 + (1 << 24)) >> 25 + h4 += carry[3] + h3 -= carry[3] << 25 + carry[7] = (h7 + (1 << 24)) >> 25 + h8 += carry[7] + h7 -= carry[7] << 25 + + carry[4] = (h4 + (1 << 25)) >> 26 + h5 += carry[4] + h4 -= carry[4] << 26 + carry[8] = (h8 + (1 << 25)) >> 26 + h9 += carry[8] + h8 -= carry[8] << 26 + + carry[9] = (h9 + (1 << 24)) >> 25 + h0 += carry[9] * 19 + h9 -= carry[9] << 25 + + carry[0] = (h0 + (1 << 25)) >> 26 + h1 += carry[0] + h0 -= carry[0] << 26 + + h[0] = int32(h0) + h[1] = int32(h1) + h[2] = int32(h2) + h[3] = int32(h3) + h[4] = int32(h4) + h[5] = int32(h5) + h[6] = int32(h6) + h[7] = int32(h7) + h[8] = int32(h8) + h[9] = int32(h9) +} + +// feMul121666 calculates h = f * 121666. Can overlap h with f. +// +// Preconditions: +// |f| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc. +// +// Postconditions: +// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc. +func feMul121666(h, f *fieldElement) { + h0 := int64(f[0]) * 121666 + h1 := int64(f[1]) * 121666 + h2 := int64(f[2]) * 121666 + h3 := int64(f[3]) * 121666 + h4 := int64(f[4]) * 121666 + h5 := int64(f[5]) * 121666 + h6 := int64(f[6]) * 121666 + h7 := int64(f[7]) * 121666 + h8 := int64(f[8]) * 121666 + h9 := int64(f[9]) * 121666 + var carry [10]int64 + + carry[9] = (h9 + (1 << 24)) >> 25 + h0 += carry[9] * 19 + h9 -= carry[9] << 25 + carry[1] = (h1 + (1 << 24)) >> 25 + h2 += carry[1] + h1 -= carry[1] << 25 + carry[3] = (h3 + (1 << 24)) >> 25 + h4 += carry[3] + h3 -= carry[3] << 25 + carry[5] = (h5 + (1 << 24)) >> 25 + h6 += carry[5] + h5 -= carry[5] << 25 + carry[7] = (h7 + (1 << 24)) >> 25 + h8 += carry[7] + h7 -= carry[7] << 25 + + carry[0] = (h0 + (1 << 25)) >> 26 + h1 += carry[0] + h0 -= carry[0] << 26 + carry[2] = (h2 + (1 << 25)) >> 26 + h3 += carry[2] + h2 -= carry[2] << 26 + carry[4] = (h4 + (1 << 25)) >> 26 + h5 += carry[4] + h4 -= carry[4] << 26 + carry[6] = (h6 + (1 << 25)) >> 26 + h7 += carry[6] + h6 -= carry[6] << 26 + carry[8] = (h8 + (1 << 25)) >> 26 + h9 += carry[8] + h8 -= carry[8] << 26 + + h[0] = int32(h0) + h[1] = int32(h1) + h[2] = int32(h2) + h[3] = int32(h3) + h[4] = int32(h4) + h[5] = int32(h5) + h[6] = int32(h6) + h[7] = int32(h7) + h[8] = int32(h8) + h[9] = int32(h9) +} + +// feInvert sets out = z^-1. +func feInvert(out, z *fieldElement) { + var t0, t1, t2, t3 fieldElement + var i int + + feSquare(&t0, z) + for i = 1; i < 1; i++ { + feSquare(&t0, &t0) + } + feSquare(&t1, &t0) + for i = 1; i < 2; i++ { + feSquare(&t1, &t1) + } + feMul(&t1, z, &t1) + feMul(&t0, &t0, &t1) + feSquare(&t2, &t0) + for i = 1; i < 1; i++ { + feSquare(&t2, &t2) + } + feMul(&t1, &t1, &t2) + feSquare(&t2, &t1) + for i = 1; i < 5; i++ { + feSquare(&t2, &t2) + } + feMul(&t1, &t2, &t1) + feSquare(&t2, &t1) + for i = 1; i < 10; i++ { + feSquare(&t2, &t2) + } + feMul(&t2, &t2, &t1) + feSquare(&t3, &t2) + for i = 1; i < 20; i++ { + feSquare(&t3, &t3) + } + feMul(&t2, &t3, &t2) + feSquare(&t2, &t2) + for i = 1; i < 10; i++ { + feSquare(&t2, &t2) + } + feMul(&t1, &t2, &t1) + feSquare(&t2, &t1) + for i = 1; i < 50; i++ { + feSquare(&t2, &t2) + } + feMul(&t2, &t2, &t1) + feSquare(&t3, &t2) + for i = 1; i < 100; i++ { + feSquare(&t3, &t3) + } + feMul(&t2, &t3, &t2) + feSquare(&t2, &t2) + for i = 1; i < 50; i++ { + feSquare(&t2, &t2) + } + feMul(&t1, &t2, &t1) + feSquare(&t1, &t1) + for i = 1; i < 5; i++ { + feSquare(&t1, &t1) + } + feMul(out, &t1, &t0) +} + +func scalarMult(out, in, base *[32]byte) { + var e [32]byte + + copy(e[:], in[:]) + e[0] &= 248 + e[31] &= 127 + e[31] |= 64 + + var x1, x2, z2, x3, z3, tmp0, tmp1 fieldElement + feFromBytes(&x1, base) + feOne(&x2) + feCopy(&x3, &x1) + feOne(&z3) + + swap := int32(0) + for pos := 254; pos >= 0; pos-- { + b := e[pos/8] >> uint(pos&7) + b &= 1 + swap ^= int32(b) + feCSwap(&x2, &x3, swap) + feCSwap(&z2, &z3, swap) + swap = int32(b) + + feSub(&tmp0, &x3, &z3) + feSub(&tmp1, &x2, &z2) + feAdd(&x2, &x2, &z2) + feAdd(&z2, &x3, &z3) + feMul(&z3, &tmp0, &x2) + feMul(&z2, &z2, &tmp1) + feSquare(&tmp0, &tmp1) + feSquare(&tmp1, &x2) + feAdd(&x3, &z3, &z2) + feSub(&z2, &z3, &z2) + feMul(&x2, &tmp1, &tmp0) + feSub(&tmp1, &tmp1, &tmp0) + feSquare(&z2, &z2) + feMul121666(&z3, &tmp1) + feSquare(&x3, &x3) + feAdd(&tmp0, &tmp0, &z3) + feMul(&z3, &x1, &z2) + feMul(&z2, &tmp1, &tmp0) + } + + feCSwap(&x2, &x3, swap) + feCSwap(&z2, &z3, swap) + + feInvert(&z2, &z2) + feMul(&x2, &x2, &z2) + feToBytes(out, &x2) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/doc.go b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/doc.go new file mode 100644 index 000000000..da9b10d9c --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/doc.go @@ -0,0 +1,23 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package curve25519 provides an implementation of scalar multiplication on +// the elliptic curve known as curve25519. See https://cr.yp.to/ecdh.html +package curve25519 // import "golang.org/x/crypto/curve25519" + +// basePoint is the x coordinate of the generator of the curve. +var basePoint = [32]byte{9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + +// ScalarMult sets dst to the product in*base where dst and base are the x +// coordinates of group points and all values are in little-endian form. +func ScalarMult(dst, in, base *[32]byte) { + scalarMult(dst, in, base) +} + +// ScalarBaseMult sets dst to the product in*base where dst and base are the x +// coordinates of group points, base is the standard generator and all values +// are in little-endian form. +func ScalarBaseMult(dst, in *[32]byte) { + ScalarMult(dst, in, &basePoint) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/freeze_amd64.s b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/freeze_amd64.s new file mode 100644 index 000000000..390816106 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/freeze_amd64.s @@ -0,0 +1,73 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This code was translated into a form compatible with 6a from the public +// domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html + +// +build amd64,!gccgo,!appengine + +#include "const_amd64.h" + +// func freeze(inout *[5]uint64) +TEXT ·freeze(SB),7,$0-8 + MOVQ inout+0(FP), DI + + MOVQ 0(DI),SI + MOVQ 8(DI),DX + MOVQ 16(DI),CX + MOVQ 24(DI),R8 + MOVQ 32(DI),R9 + MOVQ $REDMASK51,AX + MOVQ AX,R10 + SUBQ $18,R10 + MOVQ $3,R11 +REDUCELOOP: + MOVQ SI,R12 + SHRQ $51,R12 + ANDQ AX,SI + ADDQ R12,DX + MOVQ DX,R12 + SHRQ $51,R12 + ANDQ AX,DX + ADDQ R12,CX + MOVQ CX,R12 + SHRQ $51,R12 + ANDQ AX,CX + ADDQ R12,R8 + MOVQ R8,R12 + SHRQ $51,R12 + ANDQ AX,R8 + ADDQ R12,R9 + MOVQ R9,R12 + SHRQ $51,R12 + ANDQ AX,R9 + IMUL3Q $19,R12,R12 + ADDQ R12,SI + SUBQ $1,R11 + JA REDUCELOOP + MOVQ $1,R12 + CMPQ R10,SI + CMOVQLT R11,R12 + CMPQ AX,DX + CMOVQNE R11,R12 + CMPQ AX,CX + CMOVQNE R11,R12 + CMPQ AX,R8 + CMOVQNE R11,R12 + CMPQ AX,R9 + CMOVQNE R11,R12 + NEGQ R12 + ANDQ R12,AX + ANDQ R12,R10 + SUBQ R10,SI + SUBQ AX,DX + SUBQ AX,CX + SUBQ AX,R8 + SUBQ AX,R9 + MOVQ SI,0(DI) + MOVQ DX,8(DI) + MOVQ CX,16(DI) + MOVQ R8,24(DI) + MOVQ R9,32(DI) + RET diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/ladderstep_amd64.s b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/ladderstep_amd64.s new file mode 100644 index 000000000..9e9040b25 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/ladderstep_amd64.s @@ -0,0 +1,1377 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This code was translated into a form compatible with 6a from the public +// domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html + +// +build amd64,!gccgo,!appengine + +#include "const_amd64.h" + +// func ladderstep(inout *[5][5]uint64) +TEXT ·ladderstep(SB),0,$296-8 + MOVQ inout+0(FP),DI + + MOVQ 40(DI),SI + MOVQ 48(DI),DX + MOVQ 56(DI),CX + MOVQ 64(DI),R8 + MOVQ 72(DI),R9 + MOVQ SI,AX + MOVQ DX,R10 + MOVQ CX,R11 + MOVQ R8,R12 + MOVQ R9,R13 + ADDQ ·_2P0(SB),AX + ADDQ ·_2P1234(SB),R10 + ADDQ ·_2P1234(SB),R11 + ADDQ ·_2P1234(SB),R12 + ADDQ ·_2P1234(SB),R13 + ADDQ 80(DI),SI + ADDQ 88(DI),DX + ADDQ 96(DI),CX + ADDQ 104(DI),R8 + ADDQ 112(DI),R9 + SUBQ 80(DI),AX + SUBQ 88(DI),R10 + SUBQ 96(DI),R11 + SUBQ 104(DI),R12 + SUBQ 112(DI),R13 + MOVQ SI,0(SP) + MOVQ DX,8(SP) + MOVQ CX,16(SP) + MOVQ R8,24(SP) + MOVQ R9,32(SP) + MOVQ AX,40(SP) + MOVQ R10,48(SP) + MOVQ R11,56(SP) + MOVQ R12,64(SP) + MOVQ R13,72(SP) + MOVQ 40(SP),AX + MULQ 40(SP) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 40(SP),AX + SHLQ $1,AX + MULQ 48(SP) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 40(SP),AX + SHLQ $1,AX + MULQ 56(SP) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 40(SP),AX + SHLQ $1,AX + MULQ 64(SP) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 40(SP),AX + SHLQ $1,AX + MULQ 72(SP) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 48(SP),AX + MULQ 48(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 48(SP),AX + SHLQ $1,AX + MULQ 56(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 48(SP),AX + SHLQ $1,AX + MULQ 64(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 48(SP),DX + IMUL3Q $38,DX,AX + MULQ 72(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 56(SP),AX + MULQ 56(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 56(SP),DX + IMUL3Q $38,DX,AX + MULQ 64(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 56(SP),DX + IMUL3Q $38,DX,AX + MULQ 72(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 64(SP),DX + IMUL3Q $19,DX,AX + MULQ 64(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 64(SP),DX + IMUL3Q $38,DX,AX + MULQ 72(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 72(SP),DX + IMUL3Q $19,DX,AX + MULQ 72(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,CX:SI + ANDQ DX,SI + SHLQ $13,R9:R8 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R11:R10 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R13:R12 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R15:R14 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + ANDQ DX,SI + MOVQ CX,R8 + SHRQ $51,CX + ADDQ R10,CX + ANDQ DX,R8 + MOVQ CX,R9 + SHRQ $51,CX + ADDQ R12,CX + ANDQ DX,R9 + MOVQ CX,AX + SHRQ $51,CX + ADDQ R14,CX + ANDQ DX,AX + MOVQ CX,R10 + SHRQ $51,CX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,80(SP) + MOVQ R8,88(SP) + MOVQ R9,96(SP) + MOVQ AX,104(SP) + MOVQ R10,112(SP) + MOVQ 0(SP),AX + MULQ 0(SP) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 0(SP),AX + SHLQ $1,AX + MULQ 8(SP) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 0(SP),AX + SHLQ $1,AX + MULQ 16(SP) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 0(SP),AX + SHLQ $1,AX + MULQ 24(SP) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 0(SP),AX + SHLQ $1,AX + MULQ 32(SP) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 8(SP),AX + MULQ 8(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 8(SP),AX + SHLQ $1,AX + MULQ 16(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 8(SP),AX + SHLQ $1,AX + MULQ 24(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 8(SP),DX + IMUL3Q $38,DX,AX + MULQ 32(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 16(SP),AX + MULQ 16(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 16(SP),DX + IMUL3Q $38,DX,AX + MULQ 24(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 16(SP),DX + IMUL3Q $38,DX,AX + MULQ 32(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 24(SP),DX + IMUL3Q $19,DX,AX + MULQ 24(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 24(SP),DX + IMUL3Q $38,DX,AX + MULQ 32(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 32(SP),DX + IMUL3Q $19,DX,AX + MULQ 32(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,CX:SI + ANDQ DX,SI + SHLQ $13,R9:R8 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R11:R10 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R13:R12 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R15:R14 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + ANDQ DX,SI + MOVQ CX,R8 + SHRQ $51,CX + ADDQ R10,CX + ANDQ DX,R8 + MOVQ CX,R9 + SHRQ $51,CX + ADDQ R12,CX + ANDQ DX,R9 + MOVQ CX,AX + SHRQ $51,CX + ADDQ R14,CX + ANDQ DX,AX + MOVQ CX,R10 + SHRQ $51,CX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,120(SP) + MOVQ R8,128(SP) + MOVQ R9,136(SP) + MOVQ AX,144(SP) + MOVQ R10,152(SP) + MOVQ SI,SI + MOVQ R8,DX + MOVQ R9,CX + MOVQ AX,R8 + MOVQ R10,R9 + ADDQ ·_2P0(SB),SI + ADDQ ·_2P1234(SB),DX + ADDQ ·_2P1234(SB),CX + ADDQ ·_2P1234(SB),R8 + ADDQ ·_2P1234(SB),R9 + SUBQ 80(SP),SI + SUBQ 88(SP),DX + SUBQ 96(SP),CX + SUBQ 104(SP),R8 + SUBQ 112(SP),R9 + MOVQ SI,160(SP) + MOVQ DX,168(SP) + MOVQ CX,176(SP) + MOVQ R8,184(SP) + MOVQ R9,192(SP) + MOVQ 120(DI),SI + MOVQ 128(DI),DX + MOVQ 136(DI),CX + MOVQ 144(DI),R8 + MOVQ 152(DI),R9 + MOVQ SI,AX + MOVQ DX,R10 + MOVQ CX,R11 + MOVQ R8,R12 + MOVQ R9,R13 + ADDQ ·_2P0(SB),AX + ADDQ ·_2P1234(SB),R10 + ADDQ ·_2P1234(SB),R11 + ADDQ ·_2P1234(SB),R12 + ADDQ ·_2P1234(SB),R13 + ADDQ 160(DI),SI + ADDQ 168(DI),DX + ADDQ 176(DI),CX + ADDQ 184(DI),R8 + ADDQ 192(DI),R9 + SUBQ 160(DI),AX + SUBQ 168(DI),R10 + SUBQ 176(DI),R11 + SUBQ 184(DI),R12 + SUBQ 192(DI),R13 + MOVQ SI,200(SP) + MOVQ DX,208(SP) + MOVQ CX,216(SP) + MOVQ R8,224(SP) + MOVQ R9,232(SP) + MOVQ AX,240(SP) + MOVQ R10,248(SP) + MOVQ R11,256(SP) + MOVQ R12,264(SP) + MOVQ R13,272(SP) + MOVQ 224(SP),SI + IMUL3Q $19,SI,AX + MOVQ AX,280(SP) + MULQ 56(SP) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 232(SP),DX + IMUL3Q $19,DX,AX + MOVQ AX,288(SP) + MULQ 48(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 200(SP),AX + MULQ 40(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 200(SP),AX + MULQ 48(SP) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 200(SP),AX + MULQ 56(SP) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 200(SP),AX + MULQ 64(SP) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 200(SP),AX + MULQ 72(SP) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 208(SP),AX + MULQ 40(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 208(SP),AX + MULQ 48(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 208(SP),AX + MULQ 56(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 208(SP),AX + MULQ 64(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 208(SP),DX + IMUL3Q $19,DX,AX + MULQ 72(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 216(SP),AX + MULQ 40(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 216(SP),AX + MULQ 48(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 216(SP),AX + MULQ 56(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 216(SP),DX + IMUL3Q $19,DX,AX + MULQ 64(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 216(SP),DX + IMUL3Q $19,DX,AX + MULQ 72(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 224(SP),AX + MULQ 40(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 224(SP),AX + MULQ 48(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 280(SP),AX + MULQ 64(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 280(SP),AX + MULQ 72(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 232(SP),AX + MULQ 40(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 288(SP),AX + MULQ 56(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 288(SP),AX + MULQ 64(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 288(SP),AX + MULQ 72(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,CX:SI + ANDQ DX,SI + SHLQ $13,R9:R8 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R11:R10 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R13:R12 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R15:R14 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + MOVQ CX,R8 + SHRQ $51,CX + ANDQ DX,SI + ADDQ R10,CX + MOVQ CX,R9 + SHRQ $51,CX + ANDQ DX,R8 + ADDQ R12,CX + MOVQ CX,AX + SHRQ $51,CX + ANDQ DX,R9 + ADDQ R14,CX + MOVQ CX,R10 + SHRQ $51,CX + ANDQ DX,AX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,40(SP) + MOVQ R8,48(SP) + MOVQ R9,56(SP) + MOVQ AX,64(SP) + MOVQ R10,72(SP) + MOVQ 264(SP),SI + IMUL3Q $19,SI,AX + MOVQ AX,200(SP) + MULQ 16(SP) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 272(SP),DX + IMUL3Q $19,DX,AX + MOVQ AX,208(SP) + MULQ 8(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 240(SP),AX + MULQ 0(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 240(SP),AX + MULQ 8(SP) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 240(SP),AX + MULQ 16(SP) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 240(SP),AX + MULQ 24(SP) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 240(SP),AX + MULQ 32(SP) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 248(SP),AX + MULQ 0(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 248(SP),AX + MULQ 8(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 248(SP),AX + MULQ 16(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 248(SP),AX + MULQ 24(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 248(SP),DX + IMUL3Q $19,DX,AX + MULQ 32(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 256(SP),AX + MULQ 0(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 256(SP),AX + MULQ 8(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 256(SP),AX + MULQ 16(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 256(SP),DX + IMUL3Q $19,DX,AX + MULQ 24(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 256(SP),DX + IMUL3Q $19,DX,AX + MULQ 32(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 264(SP),AX + MULQ 0(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 264(SP),AX + MULQ 8(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 200(SP),AX + MULQ 24(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 200(SP),AX + MULQ 32(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 272(SP),AX + MULQ 0(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 208(SP),AX + MULQ 16(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 208(SP),AX + MULQ 24(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 208(SP),AX + MULQ 32(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,CX:SI + ANDQ DX,SI + SHLQ $13,R9:R8 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R11:R10 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R13:R12 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R15:R14 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + MOVQ CX,R8 + SHRQ $51,CX + ANDQ DX,SI + ADDQ R10,CX + MOVQ CX,R9 + SHRQ $51,CX + ANDQ DX,R8 + ADDQ R12,CX + MOVQ CX,AX + SHRQ $51,CX + ANDQ DX,R9 + ADDQ R14,CX + MOVQ CX,R10 + SHRQ $51,CX + ANDQ DX,AX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,DX + MOVQ R8,CX + MOVQ R9,R11 + MOVQ AX,R12 + MOVQ R10,R13 + ADDQ ·_2P0(SB),DX + ADDQ ·_2P1234(SB),CX + ADDQ ·_2P1234(SB),R11 + ADDQ ·_2P1234(SB),R12 + ADDQ ·_2P1234(SB),R13 + ADDQ 40(SP),SI + ADDQ 48(SP),R8 + ADDQ 56(SP),R9 + ADDQ 64(SP),AX + ADDQ 72(SP),R10 + SUBQ 40(SP),DX + SUBQ 48(SP),CX + SUBQ 56(SP),R11 + SUBQ 64(SP),R12 + SUBQ 72(SP),R13 + MOVQ SI,120(DI) + MOVQ R8,128(DI) + MOVQ R9,136(DI) + MOVQ AX,144(DI) + MOVQ R10,152(DI) + MOVQ DX,160(DI) + MOVQ CX,168(DI) + MOVQ R11,176(DI) + MOVQ R12,184(DI) + MOVQ R13,192(DI) + MOVQ 120(DI),AX + MULQ 120(DI) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 120(DI),AX + SHLQ $1,AX + MULQ 128(DI) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 120(DI),AX + SHLQ $1,AX + MULQ 136(DI) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 120(DI),AX + SHLQ $1,AX + MULQ 144(DI) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 120(DI),AX + SHLQ $1,AX + MULQ 152(DI) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 128(DI),AX + MULQ 128(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 128(DI),AX + SHLQ $1,AX + MULQ 136(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 128(DI),AX + SHLQ $1,AX + MULQ 144(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 128(DI),DX + IMUL3Q $38,DX,AX + MULQ 152(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 136(DI),AX + MULQ 136(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 136(DI),DX + IMUL3Q $38,DX,AX + MULQ 144(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 136(DI),DX + IMUL3Q $38,DX,AX + MULQ 152(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 144(DI),DX + IMUL3Q $19,DX,AX + MULQ 144(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 144(DI),DX + IMUL3Q $38,DX,AX + MULQ 152(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 152(DI),DX + IMUL3Q $19,DX,AX + MULQ 152(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,CX:SI + ANDQ DX,SI + SHLQ $13,R9:R8 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R11:R10 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R13:R12 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R15:R14 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + ANDQ DX,SI + MOVQ CX,R8 + SHRQ $51,CX + ADDQ R10,CX + ANDQ DX,R8 + MOVQ CX,R9 + SHRQ $51,CX + ADDQ R12,CX + ANDQ DX,R9 + MOVQ CX,AX + SHRQ $51,CX + ADDQ R14,CX + ANDQ DX,AX + MOVQ CX,R10 + SHRQ $51,CX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,120(DI) + MOVQ R8,128(DI) + MOVQ R9,136(DI) + MOVQ AX,144(DI) + MOVQ R10,152(DI) + MOVQ 160(DI),AX + MULQ 160(DI) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 160(DI),AX + SHLQ $1,AX + MULQ 168(DI) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 160(DI),AX + SHLQ $1,AX + MULQ 176(DI) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 160(DI),AX + SHLQ $1,AX + MULQ 184(DI) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 160(DI),AX + SHLQ $1,AX + MULQ 192(DI) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 168(DI),AX + MULQ 168(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 168(DI),AX + SHLQ $1,AX + MULQ 176(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 168(DI),AX + SHLQ $1,AX + MULQ 184(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 168(DI),DX + IMUL3Q $38,DX,AX + MULQ 192(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 176(DI),AX + MULQ 176(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 176(DI),DX + IMUL3Q $38,DX,AX + MULQ 184(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 176(DI),DX + IMUL3Q $38,DX,AX + MULQ 192(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 184(DI),DX + IMUL3Q $19,DX,AX + MULQ 184(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 184(DI),DX + IMUL3Q $38,DX,AX + MULQ 192(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 192(DI),DX + IMUL3Q $19,DX,AX + MULQ 192(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,CX:SI + ANDQ DX,SI + SHLQ $13,R9:R8 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R11:R10 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R13:R12 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R15:R14 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + ANDQ DX,SI + MOVQ CX,R8 + SHRQ $51,CX + ADDQ R10,CX + ANDQ DX,R8 + MOVQ CX,R9 + SHRQ $51,CX + ADDQ R12,CX + ANDQ DX,R9 + MOVQ CX,AX + SHRQ $51,CX + ADDQ R14,CX + ANDQ DX,AX + MOVQ CX,R10 + SHRQ $51,CX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,160(DI) + MOVQ R8,168(DI) + MOVQ R9,176(DI) + MOVQ AX,184(DI) + MOVQ R10,192(DI) + MOVQ 184(DI),SI + IMUL3Q $19,SI,AX + MOVQ AX,0(SP) + MULQ 16(DI) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 192(DI),DX + IMUL3Q $19,DX,AX + MOVQ AX,8(SP) + MULQ 8(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 160(DI),AX + MULQ 0(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 160(DI),AX + MULQ 8(DI) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 160(DI),AX + MULQ 16(DI) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 160(DI),AX + MULQ 24(DI) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 160(DI),AX + MULQ 32(DI) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 168(DI),AX + MULQ 0(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 168(DI),AX + MULQ 8(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 168(DI),AX + MULQ 16(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 168(DI),AX + MULQ 24(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 168(DI),DX + IMUL3Q $19,DX,AX + MULQ 32(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 176(DI),AX + MULQ 0(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 176(DI),AX + MULQ 8(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 176(DI),AX + MULQ 16(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 176(DI),DX + IMUL3Q $19,DX,AX + MULQ 24(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 176(DI),DX + IMUL3Q $19,DX,AX + MULQ 32(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 184(DI),AX + MULQ 0(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 184(DI),AX + MULQ 8(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 0(SP),AX + MULQ 24(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 0(SP),AX + MULQ 32(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 192(DI),AX + MULQ 0(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 8(SP),AX + MULQ 16(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 8(SP),AX + MULQ 24(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 8(SP),AX + MULQ 32(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,CX:SI + ANDQ DX,SI + SHLQ $13,R9:R8 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R11:R10 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R13:R12 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R15:R14 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + MOVQ CX,R8 + SHRQ $51,CX + ANDQ DX,SI + ADDQ R10,CX + MOVQ CX,R9 + SHRQ $51,CX + ANDQ DX,R8 + ADDQ R12,CX + MOVQ CX,AX + SHRQ $51,CX + ANDQ DX,R9 + ADDQ R14,CX + MOVQ CX,R10 + SHRQ $51,CX + ANDQ DX,AX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,160(DI) + MOVQ R8,168(DI) + MOVQ R9,176(DI) + MOVQ AX,184(DI) + MOVQ R10,192(DI) + MOVQ 144(SP),SI + IMUL3Q $19,SI,AX + MOVQ AX,0(SP) + MULQ 96(SP) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 152(SP),DX + IMUL3Q $19,DX,AX + MOVQ AX,8(SP) + MULQ 88(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 120(SP),AX + MULQ 80(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 120(SP),AX + MULQ 88(SP) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 120(SP),AX + MULQ 96(SP) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 120(SP),AX + MULQ 104(SP) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 120(SP),AX + MULQ 112(SP) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 128(SP),AX + MULQ 80(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 128(SP),AX + MULQ 88(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 128(SP),AX + MULQ 96(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 128(SP),AX + MULQ 104(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 128(SP),DX + IMUL3Q $19,DX,AX + MULQ 112(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 136(SP),AX + MULQ 80(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 136(SP),AX + MULQ 88(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 136(SP),AX + MULQ 96(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 136(SP),DX + IMUL3Q $19,DX,AX + MULQ 104(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 136(SP),DX + IMUL3Q $19,DX,AX + MULQ 112(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 144(SP),AX + MULQ 80(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 144(SP),AX + MULQ 88(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 0(SP),AX + MULQ 104(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 0(SP),AX + MULQ 112(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 152(SP),AX + MULQ 80(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 8(SP),AX + MULQ 96(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 8(SP),AX + MULQ 104(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 8(SP),AX + MULQ 112(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,CX:SI + ANDQ DX,SI + SHLQ $13,R9:R8 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R11:R10 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R13:R12 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R15:R14 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + MOVQ CX,R8 + SHRQ $51,CX + ANDQ DX,SI + ADDQ R10,CX + MOVQ CX,R9 + SHRQ $51,CX + ANDQ DX,R8 + ADDQ R12,CX + MOVQ CX,AX + SHRQ $51,CX + ANDQ DX,R9 + ADDQ R14,CX + MOVQ CX,R10 + SHRQ $51,CX + ANDQ DX,AX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,40(DI) + MOVQ R8,48(DI) + MOVQ R9,56(DI) + MOVQ AX,64(DI) + MOVQ R10,72(DI) + MOVQ 160(SP),AX + MULQ ·_121666_213(SB) + SHRQ $13,AX + MOVQ AX,SI + MOVQ DX,CX + MOVQ 168(SP),AX + MULQ ·_121666_213(SB) + SHRQ $13,AX + ADDQ AX,CX + MOVQ DX,R8 + MOVQ 176(SP),AX + MULQ ·_121666_213(SB) + SHRQ $13,AX + ADDQ AX,R8 + MOVQ DX,R9 + MOVQ 184(SP),AX + MULQ ·_121666_213(SB) + SHRQ $13,AX + ADDQ AX,R9 + MOVQ DX,R10 + MOVQ 192(SP),AX + MULQ ·_121666_213(SB) + SHRQ $13,AX + ADDQ AX,R10 + IMUL3Q $19,DX,DX + ADDQ DX,SI + ADDQ 80(SP),SI + ADDQ 88(SP),CX + ADDQ 96(SP),R8 + ADDQ 104(SP),R9 + ADDQ 112(SP),R10 + MOVQ SI,80(DI) + MOVQ CX,88(DI) + MOVQ R8,96(DI) + MOVQ R9,104(DI) + MOVQ R10,112(DI) + MOVQ 104(DI),SI + IMUL3Q $19,SI,AX + MOVQ AX,0(SP) + MULQ 176(SP) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 112(DI),DX + IMUL3Q $19,DX,AX + MOVQ AX,8(SP) + MULQ 168(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 80(DI),AX + MULQ 160(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 80(DI),AX + MULQ 168(SP) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 80(DI),AX + MULQ 176(SP) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 80(DI),AX + MULQ 184(SP) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 80(DI),AX + MULQ 192(SP) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 88(DI),AX + MULQ 160(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 88(DI),AX + MULQ 168(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 88(DI),AX + MULQ 176(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 88(DI),AX + MULQ 184(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 88(DI),DX + IMUL3Q $19,DX,AX + MULQ 192(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 96(DI),AX + MULQ 160(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 96(DI),AX + MULQ 168(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 96(DI),AX + MULQ 176(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 96(DI),DX + IMUL3Q $19,DX,AX + MULQ 184(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 96(DI),DX + IMUL3Q $19,DX,AX + MULQ 192(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 104(DI),AX + MULQ 160(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 104(DI),AX + MULQ 168(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 0(SP),AX + MULQ 184(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 0(SP),AX + MULQ 192(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 112(DI),AX + MULQ 160(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 8(SP),AX + MULQ 176(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 8(SP),AX + MULQ 184(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 8(SP),AX + MULQ 192(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,CX:SI + ANDQ DX,SI + SHLQ $13,R9:R8 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R11:R10 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R13:R12 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R15:R14 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + MOVQ CX,R8 + SHRQ $51,CX + ANDQ DX,SI + ADDQ R10,CX + MOVQ CX,R9 + SHRQ $51,CX + ANDQ DX,R8 + ADDQ R12,CX + MOVQ CX,AX + SHRQ $51,CX + ANDQ DX,R9 + ADDQ R14,CX + MOVQ CX,R10 + SHRQ $51,CX + ANDQ DX,AX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,80(DI) + MOVQ R8,88(DI) + MOVQ R9,96(DI) + MOVQ AX,104(DI) + MOVQ R10,112(DI) + RET diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/mont25519_amd64.go b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/mont25519_amd64.go new file mode 100644 index 000000000..5822bd533 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/mont25519_amd64.go @@ -0,0 +1,240 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build amd64,!gccgo,!appengine + +package curve25519 + +// These functions are implemented in the .s files. The names of the functions +// in the rest of the file are also taken from the SUPERCOP sources to help +// people following along. + +//go:noescape + +func cswap(inout *[5]uint64, v uint64) + +//go:noescape + +func ladderstep(inout *[5][5]uint64) + +//go:noescape + +func freeze(inout *[5]uint64) + +//go:noescape + +func mul(dest, a, b *[5]uint64) + +//go:noescape + +func square(out, in *[5]uint64) + +// mladder uses a Montgomery ladder to calculate (xr/zr) *= s. +func mladder(xr, zr *[5]uint64, s *[32]byte) { + var work [5][5]uint64 + + work[0] = *xr + setint(&work[1], 1) + setint(&work[2], 0) + work[3] = *xr + setint(&work[4], 1) + + j := uint(6) + var prevbit byte + + for i := 31; i >= 0; i-- { + for j < 8 { + bit := ((*s)[i] >> j) & 1 + swap := bit ^ prevbit + prevbit = bit + cswap(&work[1], uint64(swap)) + ladderstep(&work) + j-- + } + j = 7 + } + + *xr = work[1] + *zr = work[2] +} + +func scalarMult(out, in, base *[32]byte) { + var e [32]byte + copy(e[:], (*in)[:]) + e[0] &= 248 + e[31] &= 127 + e[31] |= 64 + + var t, z [5]uint64 + unpack(&t, base) + mladder(&t, &z, &e) + invert(&z, &z) + mul(&t, &t, &z) + pack(out, &t) +} + +func setint(r *[5]uint64, v uint64) { + r[0] = v + r[1] = 0 + r[2] = 0 + r[3] = 0 + r[4] = 0 +} + +// unpack sets r = x where r consists of 5, 51-bit limbs in little-endian +// order. +func unpack(r *[5]uint64, x *[32]byte) { + r[0] = uint64(x[0]) | + uint64(x[1])<<8 | + uint64(x[2])<<16 | + uint64(x[3])<<24 | + uint64(x[4])<<32 | + uint64(x[5])<<40 | + uint64(x[6]&7)<<48 + + r[1] = uint64(x[6])>>3 | + uint64(x[7])<<5 | + uint64(x[8])<<13 | + uint64(x[9])<<21 | + uint64(x[10])<<29 | + uint64(x[11])<<37 | + uint64(x[12]&63)<<45 + + r[2] = uint64(x[12])>>6 | + uint64(x[13])<<2 | + uint64(x[14])<<10 | + uint64(x[15])<<18 | + uint64(x[16])<<26 | + uint64(x[17])<<34 | + uint64(x[18])<<42 | + uint64(x[19]&1)<<50 + + r[3] = uint64(x[19])>>1 | + uint64(x[20])<<7 | + uint64(x[21])<<15 | + uint64(x[22])<<23 | + uint64(x[23])<<31 | + uint64(x[24])<<39 | + uint64(x[25]&15)<<47 + + r[4] = uint64(x[25])>>4 | + uint64(x[26])<<4 | + uint64(x[27])<<12 | + uint64(x[28])<<20 | + uint64(x[29])<<28 | + uint64(x[30])<<36 | + uint64(x[31]&127)<<44 +} + +// pack sets out = x where out is the usual, little-endian form of the 5, +// 51-bit limbs in x. +func pack(out *[32]byte, x *[5]uint64) { + t := *x + freeze(&t) + + out[0] = byte(t[0]) + out[1] = byte(t[0] >> 8) + out[2] = byte(t[0] >> 16) + out[3] = byte(t[0] >> 24) + out[4] = byte(t[0] >> 32) + out[5] = byte(t[0] >> 40) + out[6] = byte(t[0] >> 48) + + out[6] ^= byte(t[1]<<3) & 0xf8 + out[7] = byte(t[1] >> 5) + out[8] = byte(t[1] >> 13) + out[9] = byte(t[1] >> 21) + out[10] = byte(t[1] >> 29) + out[11] = byte(t[1] >> 37) + out[12] = byte(t[1] >> 45) + + out[12] ^= byte(t[2]<<6) & 0xc0 + out[13] = byte(t[2] >> 2) + out[14] = byte(t[2] >> 10) + out[15] = byte(t[2] >> 18) + out[16] = byte(t[2] >> 26) + out[17] = byte(t[2] >> 34) + out[18] = byte(t[2] >> 42) + out[19] = byte(t[2] >> 50) + + out[19] ^= byte(t[3]<<1) & 0xfe + out[20] = byte(t[3] >> 7) + out[21] = byte(t[3] >> 15) + out[22] = byte(t[3] >> 23) + out[23] = byte(t[3] >> 31) + out[24] = byte(t[3] >> 39) + out[25] = byte(t[3] >> 47) + + out[25] ^= byte(t[4]<<4) & 0xf0 + out[26] = byte(t[4] >> 4) + out[27] = byte(t[4] >> 12) + out[28] = byte(t[4] >> 20) + out[29] = byte(t[4] >> 28) + out[30] = byte(t[4] >> 36) + out[31] = byte(t[4] >> 44) +} + +// invert calculates r = x^-1 mod p using Fermat's little theorem. +func invert(r *[5]uint64, x *[5]uint64) { + var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t [5]uint64 + + square(&z2, x) /* 2 */ + square(&t, &z2) /* 4 */ + square(&t, &t) /* 8 */ + mul(&z9, &t, x) /* 9 */ + mul(&z11, &z9, &z2) /* 11 */ + square(&t, &z11) /* 22 */ + mul(&z2_5_0, &t, &z9) /* 2^5 - 2^0 = 31 */ + + square(&t, &z2_5_0) /* 2^6 - 2^1 */ + for i := 1; i < 5; i++ { /* 2^20 - 2^10 */ + square(&t, &t) + } + mul(&z2_10_0, &t, &z2_5_0) /* 2^10 - 2^0 */ + + square(&t, &z2_10_0) /* 2^11 - 2^1 */ + for i := 1; i < 10; i++ { /* 2^20 - 2^10 */ + square(&t, &t) + } + mul(&z2_20_0, &t, &z2_10_0) /* 2^20 - 2^0 */ + + square(&t, &z2_20_0) /* 2^21 - 2^1 */ + for i := 1; i < 20; i++ { /* 2^40 - 2^20 */ + square(&t, &t) + } + mul(&t, &t, &z2_20_0) /* 2^40 - 2^0 */ + + square(&t, &t) /* 2^41 - 2^1 */ + for i := 1; i < 10; i++ { /* 2^50 - 2^10 */ + square(&t, &t) + } + mul(&z2_50_0, &t, &z2_10_0) /* 2^50 - 2^0 */ + + square(&t, &z2_50_0) /* 2^51 - 2^1 */ + for i := 1; i < 50; i++ { /* 2^100 - 2^50 */ + square(&t, &t) + } + mul(&z2_100_0, &t, &z2_50_0) /* 2^100 - 2^0 */ + + square(&t, &z2_100_0) /* 2^101 - 2^1 */ + for i := 1; i < 100; i++ { /* 2^200 - 2^100 */ + square(&t, &t) + } + mul(&t, &t, &z2_100_0) /* 2^200 - 2^0 */ + + square(&t, &t) /* 2^201 - 2^1 */ + for i := 1; i < 50; i++ { /* 2^250 - 2^50 */ + square(&t, &t) + } + mul(&t, &t, &z2_50_0) /* 2^250 - 2^0 */ + + square(&t, &t) /* 2^251 - 2^1 */ + square(&t, &t) /* 2^252 - 2^2 */ + square(&t, &t) /* 2^253 - 2^3 */ + + square(&t, &t) /* 2^254 - 2^4 */ + + square(&t, &t) /* 2^255 - 2^5 */ + mul(r, &t, &z11) /* 2^255 - 21 */ +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/mul_amd64.s b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/mul_amd64.s new file mode 100644 index 000000000..5ce80a2e5 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/mul_amd64.s @@ -0,0 +1,169 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This code was translated into a form compatible with 6a from the public +// domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html + +// +build amd64,!gccgo,!appengine + +#include "const_amd64.h" + +// func mul(dest, a, b *[5]uint64) +TEXT ·mul(SB),0,$16-24 + MOVQ dest+0(FP), DI + MOVQ a+8(FP), SI + MOVQ b+16(FP), DX + + MOVQ DX,CX + MOVQ 24(SI),DX + IMUL3Q $19,DX,AX + MOVQ AX,0(SP) + MULQ 16(CX) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 32(SI),DX + IMUL3Q $19,DX,AX + MOVQ AX,8(SP) + MULQ 8(CX) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 0(SI),AX + MULQ 0(CX) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 0(SI),AX + MULQ 8(CX) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 0(SI),AX + MULQ 16(CX) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 0(SI),AX + MULQ 24(CX) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 0(SI),AX + MULQ 32(CX) + MOVQ AX,BX + MOVQ DX,BP + MOVQ 8(SI),AX + MULQ 0(CX) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 8(SI),AX + MULQ 8(CX) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 8(SI),AX + MULQ 16(CX) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 8(SI),AX + MULQ 24(CX) + ADDQ AX,BX + ADCQ DX,BP + MOVQ 8(SI),DX + IMUL3Q $19,DX,AX + MULQ 32(CX) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 16(SI),AX + MULQ 0(CX) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 16(SI),AX + MULQ 8(CX) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 16(SI),AX + MULQ 16(CX) + ADDQ AX,BX + ADCQ DX,BP + MOVQ 16(SI),DX + IMUL3Q $19,DX,AX + MULQ 24(CX) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 16(SI),DX + IMUL3Q $19,DX,AX + MULQ 32(CX) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 24(SI),AX + MULQ 0(CX) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 24(SI),AX + MULQ 8(CX) + ADDQ AX,BX + ADCQ DX,BP + MOVQ 0(SP),AX + MULQ 24(CX) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 0(SP),AX + MULQ 32(CX) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 32(SI),AX + MULQ 0(CX) + ADDQ AX,BX + ADCQ DX,BP + MOVQ 8(SP),AX + MULQ 16(CX) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 8(SP),AX + MULQ 24(CX) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 8(SP),AX + MULQ 32(CX) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ $REDMASK51,SI + SHLQ $13,R9:R8 + ANDQ SI,R8 + SHLQ $13,R11:R10 + ANDQ SI,R10 + ADDQ R9,R10 + SHLQ $13,R13:R12 + ANDQ SI,R12 + ADDQ R11,R12 + SHLQ $13,R15:R14 + ANDQ SI,R14 + ADDQ R13,R14 + SHLQ $13,BP:BX + ANDQ SI,BX + ADDQ R15,BX + IMUL3Q $19,BP,DX + ADDQ DX,R8 + MOVQ R8,DX + SHRQ $51,DX + ADDQ R10,DX + MOVQ DX,CX + SHRQ $51,DX + ANDQ SI,R8 + ADDQ R12,DX + MOVQ DX,R9 + SHRQ $51,DX + ANDQ SI,CX + ADDQ R14,DX + MOVQ DX,AX + SHRQ $51,DX + ANDQ SI,R9 + ADDQ BX,DX + MOVQ DX,R10 + SHRQ $51,DX + ANDQ SI,AX + IMUL3Q $19,DX,DX + ADDQ DX,R8 + ANDQ SI,R10 + MOVQ R8,0(DI) + MOVQ CX,8(DI) + MOVQ R9,16(DI) + MOVQ AX,24(DI) + MOVQ R10,32(DI) + RET diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/square_amd64.s b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/square_amd64.s new file mode 100644 index 000000000..12f73734f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/curve25519/square_amd64.s @@ -0,0 +1,132 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This code was translated into a form compatible with 6a from the public +// domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html + +// +build amd64,!gccgo,!appengine + +#include "const_amd64.h" + +// func square(out, in *[5]uint64) +TEXT ·square(SB),7,$0-16 + MOVQ out+0(FP), DI + MOVQ in+8(FP), SI + + MOVQ 0(SI),AX + MULQ 0(SI) + MOVQ AX,CX + MOVQ DX,R8 + MOVQ 0(SI),AX + SHLQ $1,AX + MULQ 8(SI) + MOVQ AX,R9 + MOVQ DX,R10 + MOVQ 0(SI),AX + SHLQ $1,AX + MULQ 16(SI) + MOVQ AX,R11 + MOVQ DX,R12 + MOVQ 0(SI),AX + SHLQ $1,AX + MULQ 24(SI) + MOVQ AX,R13 + MOVQ DX,R14 + MOVQ 0(SI),AX + SHLQ $1,AX + MULQ 32(SI) + MOVQ AX,R15 + MOVQ DX,BX + MOVQ 8(SI),AX + MULQ 8(SI) + ADDQ AX,R11 + ADCQ DX,R12 + MOVQ 8(SI),AX + SHLQ $1,AX + MULQ 16(SI) + ADDQ AX,R13 + ADCQ DX,R14 + MOVQ 8(SI),AX + SHLQ $1,AX + MULQ 24(SI) + ADDQ AX,R15 + ADCQ DX,BX + MOVQ 8(SI),DX + IMUL3Q $38,DX,AX + MULQ 32(SI) + ADDQ AX,CX + ADCQ DX,R8 + MOVQ 16(SI),AX + MULQ 16(SI) + ADDQ AX,R15 + ADCQ DX,BX + MOVQ 16(SI),DX + IMUL3Q $38,DX,AX + MULQ 24(SI) + ADDQ AX,CX + ADCQ DX,R8 + MOVQ 16(SI),DX + IMUL3Q $38,DX,AX + MULQ 32(SI) + ADDQ AX,R9 + ADCQ DX,R10 + MOVQ 24(SI),DX + IMUL3Q $19,DX,AX + MULQ 24(SI) + ADDQ AX,R9 + ADCQ DX,R10 + MOVQ 24(SI),DX + IMUL3Q $38,DX,AX + MULQ 32(SI) + ADDQ AX,R11 + ADCQ DX,R12 + MOVQ 32(SI),DX + IMUL3Q $19,DX,AX + MULQ 32(SI) + ADDQ AX,R13 + ADCQ DX,R14 + MOVQ $REDMASK51,SI + SHLQ $13,R8:CX + ANDQ SI,CX + SHLQ $13,R10:R9 + ANDQ SI,R9 + ADDQ R8,R9 + SHLQ $13,R12:R11 + ANDQ SI,R11 + ADDQ R10,R11 + SHLQ $13,R14:R13 + ANDQ SI,R13 + ADDQ R12,R13 + SHLQ $13,BX:R15 + ANDQ SI,R15 + ADDQ R14,R15 + IMUL3Q $19,BX,DX + ADDQ DX,CX + MOVQ CX,DX + SHRQ $51,DX + ADDQ R9,DX + ANDQ SI,CX + MOVQ DX,R8 + SHRQ $51,DX + ADDQ R11,DX + ANDQ SI,R8 + MOVQ DX,R9 + SHRQ $51,DX + ADDQ R13,DX + ANDQ SI,R9 + MOVQ DX,AX + SHRQ $51,DX + ADDQ R15,DX + ANDQ SI,AX + MOVQ DX,R10 + SHRQ $51,DX + IMUL3Q $19,DX,DX + ADDQ DX,CX + ANDQ SI,R10 + MOVQ CX,0(DI) + MOVQ R8,8(DI) + MOVQ R9,16(DI) + MOVQ AX,24(DI) + MOVQ R10,32(DI) + RET diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/hkdf/hkdf.go b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/hkdf/hkdf.go new file mode 100644 index 000000000..5bc246355 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/golang.org/x/crypto/hkdf/hkdf.go @@ -0,0 +1,75 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package hkdf implements the HMAC-based Extract-and-Expand Key Derivation +// Function (HKDF) as defined in RFC 5869. +// +// HKDF is a cryptographic key derivation function (KDF) with the goal of +// expanding limited input keying material into one or more cryptographically +// strong secret keys. +// +// RFC 5869: https://tools.ietf.org/html/rfc5869 +package hkdf // import "golang.org/x/crypto/hkdf" + +import ( + "crypto/hmac" + "errors" + "hash" + "io" +) + +type hkdf struct { + expander hash.Hash + size int + + info []byte + counter byte + + prev []byte + cache []byte +} + +func (f *hkdf) Read(p []byte) (int, error) { + // Check whether enough data can be generated + need := len(p) + remains := len(f.cache) + int(255-f.counter+1)*f.size + if remains < need { + return 0, errors.New("hkdf: entropy limit reached") + } + // Read from the cache, if enough data is present + n := copy(p, f.cache) + p = p[n:] + + // Fill the buffer + for len(p) > 0 { + f.expander.Reset() + f.expander.Write(f.prev) + f.expander.Write(f.info) + f.expander.Write([]byte{f.counter}) + f.prev = f.expander.Sum(f.prev[:0]) + f.counter++ + + // Copy the new batch into p + f.cache = f.prev + n = copy(p, f.cache) + p = p[n:] + } + // Save leftovers for next run + f.cache = f.cache[n:] + + return need, nil +} + +// New returns a new HKDF using the given hash, the secret keying material to expand +// and optional salt and info fields. +func New(hash func() hash.Hash, secret, salt, info []byte) io.Reader { + if salt == nil { + salt = make([]byte, hash().Size()) + } + extractor := hmac.New(hash, salt) + extractor.Write(secret) + prk := extractor.Sum(nil) + + return &hkdf{hmac.New(hash, prk), extractor.Size(), info, 1, nil, nil} +} diff --git a/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go b/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go new file mode 100644 index 000000000..ed006aa2b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go @@ -0,0 +1,57 @@ +package quic + +import ( + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type windowUpdateQueue struct { + mutex sync.Mutex + + queue map[protocol.StreamID]bool // used as a set + callback func(wire.Frame) + cryptoStream cryptoStreamI + streamGetter streamGetter +} + +func newWindowUpdateQueue(streamGetter streamGetter, cryptoStream cryptoStreamI, cb func(wire.Frame)) *windowUpdateQueue { + return &windowUpdateQueue{ + queue: make(map[protocol.StreamID]bool), + streamGetter: streamGetter, + cryptoStream: cryptoStream, + callback: cb, + } +} + +func (q *windowUpdateQueue) Add(id protocol.StreamID) { + q.mutex.Lock() + q.queue[id] = true + q.mutex.Unlock() +} + +func (q *windowUpdateQueue) QueueAll() { + q.mutex.Lock() + var offset protocol.ByteCount + for id := range q.queue { + if id == q.cryptoStream.StreamID() { + offset = q.cryptoStream.getWindowUpdate() + } else { + str, err := q.streamGetter.GetOrOpenReceiveStream(id) + if err != nil || str == nil { // the stream can be nil if it was completed before dequeing the window update + continue + } + offset = str.getWindowUpdate() + } + if offset == 0 { // can happen if we received a final offset, right after queueing the window update + continue + } + q.callback(&wire.MaxStreamDataFrame{ + StreamID: id, + ByteOffset: offset, + }) + delete(q.queue, id) + } + q.mutex.Unlock() +} diff --git a/vendor/golang.org/x/crypto/curve25519/curve25519.go b/vendor/golang.org/x/crypto/curve25519/curve25519.go index 2d14c2a78..cb8fbc57b 100644 --- a/vendor/golang.org/x/crypto/curve25519/curve25519.go +++ b/vendor/golang.org/x/crypto/curve25519/curve25519.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// We have a implementation in amd64 assembly so this code is only run on +// We have an implementation in amd64 assembly so this code is only run on // non-amd64 platforms. The amd64 assembly does not support gccgo. // +build !amd64 gccgo appengine diff --git a/vendor/manifest b/vendor/manifest index 5f0d211c0..5872f5f66 100644 --- a/vendor/manifest +++ b/vendor/manifest @@ -34,6 +34,14 @@ "branch": "master", "notests": true }, + { + "importpath": "github.com/bifurcation/mint", + "repository": "https://github.com/bifurcation/mint", + "vcs": "git", + "revision": "64af8ab8ccb81bd5d4eab356f79ba0939117d9f6", + "branch": "master", + "notests": true + }, { "importpath": "github.com/codahale/aesnicheck", "repository": "https://github.com/codahale/aesnicheck", @@ -137,7 +145,7 @@ "importpath": "github.com/lucas-clemente/quic-go", "repository": "https://github.com/lucas-clemente/quic-go", "vcs": "git", - "revision": "a9e2a28315406f825cdfe41f8652110addeb84a5", + "revision": "d71850eb2ff581620f2f5742b558a97de22c13f6", "branch": "master", "notests": true }, @@ -212,7 +220,7 @@ "importpath": "golang.org/x/crypto/curve25519", "repository": "https://go.googlesource.com/crypto", "vcs": "git", - "revision": "2faea1465de239e4babd8f5905cc25b781712442", + "revision": "94eea52f7b742c7cbe0b03b22f0c4c8631ece122", "branch": "master", "path": "curve25519", "notests": true