0
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2025-04-01 02:42:35 -05:00

Merge branch 'master' into diagnostics

# Conflicts:
#	plugins.go
#	vendor/manifest
This commit is contained in:
Matthew Holt 2018-02-16 22:42:14 -07:00
commit 269a8b5fce
No known key found for this signature in database
GPG key ID: 2A349DD577D586A5
270 changed files with 28231 additions and 4887 deletions

4
.gitignore vendored
View file

@ -16,4 +16,6 @@ Caddyfile
og_static/
.vscode/
.vscode/
*.bat

View file

@ -1,5 +1,5 @@
<p align="center">
<a href="https://caddyserver.com"><img src="https://cloud.githubusercontent.com/assets/1128849/25305033/12916fce-2731-11e7-86ec-580d4d31cb16.png" alt="Caddy" width="400"></a>
<a href="https://caddyserver.com"><img src="https://user-images.githubusercontent.com/1128849/36137292-bebc223a-1051-11e8-9a81-4ea9054c96ac.png" alt="Caddy" width="400"></a>
</p>
<h3 align="center">Every Site on HTTPS <!-- Serve Confidently --></h3>
<p align="center">Caddy is a general-purpose HTTP/2 web server that serves HTTPS by default.</p>
@ -59,7 +59,7 @@ customize your build in the browser
pre-built, vanilla binaries
## Build
To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.8 or newer). Follow these instruction for fast building:
To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.9 or newer). Follow these instruction for fast building:
- Get source `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds`
- Now `cd` to `$GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go`

View file

@ -78,8 +78,18 @@ var (
mu sync.Mutex
)
func init() {
OnProcessExit = append(OnProcessExit, func() {
if PidFile != "" {
os.Remove(PidFile)
}
})
}
// Instance contains the state of servers created as a result of
// calling Start and can be used to access or control those servers.
// It is literally an instance of a server type. Instance values
// should NOT be copied. Use *Instance for safety.
type Instance struct {
// serverType is the name of the instance's server type
serverType string
@ -90,10 +100,11 @@ type Instance struct {
// wg is used to wait for all servers to shut down
wg *sync.WaitGroup
// context is the context created for this instance.
// context is the context created for this instance,
// used to coordinate the setting up of the server type
context Context
// servers is the list of servers with their listeners.
// servers is the list of servers with their listeners
servers []ServerListener
// these callbacks execute when certain events occur
@ -102,6 +113,18 @@ type Instance struct {
onRestart []func() error // before restart commences
onShutdown []func() error // stopping, even as part of a restart
onFinalShutdown []func() error // stopping, not as part of a restart
// storing values on an instance is preferable to
// global state because these will get garbage-
// collected after in-process reloads when the
// old instances are destroyed; use StorageMu
// to access this value safely
Storage map[interface{}]interface{}
StorageMu sync.RWMutex
}
func Instances() []*Instance {
return instances
}
// Servers returns the ServerListeners in i.
@ -197,7 +220,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) {
}
// create new instance; if the restart fails, it is simply discarded
newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg}
newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg, Storage: make(map[interface{}]interface{})}
// attempt to start new instance
err := startWithListenerFds(newCaddyfile, newInst, restartFds)
@ -456,7 +479,7 @@ func (i *Instance) Caddyfile() Input {
//
// This function blocks until all the servers are listening.
func Start(cdyfile Input) (*Instance, error) {
inst := &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup)}
inst := &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup), Storage: make(map[interface{}]interface{})}
err := startWithListenerFds(cdyfile, inst, nil)
if err != nil {
return inst, err
@ -469,11 +492,34 @@ func Start(cdyfile Input) (*Instance, error) {
}
func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]restartTriple) error {
// save this instance in the list now so that
// plugins can access it if need be, for example
// the caddytls package, so it can perform cert
// renewals while starting up; we just have to
// remove the instance from the list later if
// it fails
instancesMu.Lock()
instances = append(instances, inst)
instancesMu.Unlock()
var err error
defer func() {
if err != nil {
instancesMu.Lock()
for i, otherInst := range instances {
if otherInst == inst {
instances = append(instances[:i], instances[i+1:]...)
break
}
}
instancesMu.Unlock()
}
}()
if cdyfile == nil {
cdyfile = CaddyfileInput{}
}
err := ValidateAndExecuteDirectives(cdyfile, inst, false)
err = ValidateAndExecuteDirectives(cdyfile, inst, false)
if err != nil {
return err
}
@ -505,10 +551,6 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
return err
}
instancesMu.Lock()
instances = append(instances, inst)
instancesMu.Unlock()
// run any AfterStartup callbacks if this is not
// part of a restart; then show file descriptor notice
if restartFds == nil {
@ -547,7 +589,7 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bool) error {
// If parsing only inst will be nil, create an instance for this function call only.
if justValidate {
inst = &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup)}
inst = &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup), Storage: make(map[interface{}]interface{})}
}
stypeName := cdyfile.ServerType()
@ -564,14 +606,14 @@ func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bo
return err
}
inst.context = stype.NewContext()
inst.context = stype.NewContext(inst)
if inst.context == nil {
return fmt.Errorf("server type %s produced a nil Context", stypeName)
}
sblocks, err = inst.context.InspectServerBlocks(cdyfile.Path(), sblocks)
if err != nil {
return err
return fmt.Errorf("error inspecting server blocks: %v", err)
}
diagnostics.Set("num_server_blocks", len(sblocks))

View file

@ -148,7 +148,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
case "HEAD":
resp, err = fcgiBackend.Head(env)
case "GET":
resp, err = fcgiBackend.Get(env)
resp, err = fcgiBackend.Get(env, r.Body, contentLength)
case "OPTIONS":
resp, err = fcgiBackend.Options(env)
default:

View file

@ -460,12 +460,12 @@ func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Res
}
// Get issues a GET request to the fcgi responder.
func (c *FCGIClient) Get(p map[string]string) (resp *http.Response, err error) {
func (c *FCGIClient) Get(p map[string]string, body io.Reader, l int64) (resp *http.Response, err error) {
p["REQUEST_METHOD"] = "GET"
p["CONTENT_LENGTH"] = "0"
p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10)
return c.Request(p, nil)
return c.Request(p, body)
}
// Head issues a HEAD request to the fcgi responder.

View file

@ -140,7 +140,8 @@ func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[
}
resp, err = fcgi.PostForm(fcgiParams, values)
} else {
resp, err = fcgi.Get(fcgiParams)
rd := bytes.NewReader(data)
resp, err = fcgi.Get(fcgiParams, rd, int64(rd.Len()))
}
default:

View file

@ -27,7 +27,7 @@ func activateHTTPS(cctx caddy.Context) error {
operatorPresent := !caddy.Started()
if !caddy.Quiet && operatorPresent {
fmt.Print("Activating privacy features...")
fmt.Print("Activating privacy features... ")
}
ctx := cctx.(*httpContext)
@ -69,7 +69,7 @@ func activateHTTPS(cctx caddy.Context) error {
}
if !caddy.Quiet && operatorPresent {
fmt.Println(" done.")
fmt.Println("done.")
}
return nil
@ -160,23 +160,37 @@ func hostHasOtherPort(allConfigs []*SiteConfig, thisConfigIdx int, otherPort str
// to listen on HTTPPort. The TLS field of cfg must not be nil.
func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
redirPort := cfg.Addr.Port
if redirPort == DefaultHTTPSPort {
redirPort = "" // default port is redundant
if redirPort == HTTPSPort {
// By default, HTTPSPort should be DefaultHTTPSPort,
// which of course doesn't need to be explicitly stated
// in the Location header. Even if HTTPSPort is changed
// so that it is no longer DefaultHTTPSPort, we shouldn't
// append it to the URL in the Location because changing
// the HTTPS port is assumed to be an internal-only change
// (in other words, we assume port forwarding is going on);
// but redirects go back to a presumably-external client.
// (If redirect clients are also internal, that is more
// advanced, and the user should configure HTTP->HTTPS
// redirects themselves.)
redirPort = ""
}
redirMiddleware := func(next Handler) Handler {
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
// Construct the URL to which to redirect. Note that the Host in a request might
// contain a port, but we just need the hostname; we'll set the port if needed.
// Construct the URL to which to redirect. Note that the Host in a
// request might contain a port, but we just need the hostname from
// it; and we'll set the port if needed.
toURL := "https://"
requestHost, _, err := net.SplitHostPort(r.Host)
if err != nil {
requestHost = r.Host // Host did not contain a port; great
requestHost = r.Host // Host did not contain a port, so use the whole value
}
if redirPort == "" {
toURL += requestHost
} else {
toURL += net.JoinHostPort(requestHost, redirPort)
}
toURL += r.URL.RequestURI()
w.Header().Set("Connection", "close")
@ -184,9 +198,11 @@ func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
return 0, nil
})
}
host := cfg.Addr.Host
port := HTTPPort
addr := net.JoinHostPort(host, port)
return &SiteConfig{
Addr: Address{Original: addr, Host: host, Port: port},
ListenHost: cfg.ListenHost,

View file

@ -53,7 +53,7 @@ func TestRedirPlaintextHost(t *testing.T) {
},
{
Host: "foohost",
Port: "443", // since this is the default HTTPS port, should not be included in Location value
Port: HTTPSPort, // since this is the 'default' HTTPS port, should not be included in Location value
},
{
Host: "*.example.com",

View file

@ -91,11 +91,13 @@ func hideCaddyfile(cctx caddy.Context) error {
return nil
}
func newContext() caddy.Context {
return &httpContext{keysToSiteConfigs: make(map[string]*SiteConfig)}
func newContext(inst *caddy.Instance) caddy.Context {
return &httpContext{instance: inst, keysToSiteConfigs: make(map[string]*SiteConfig)}
}
type httpContext struct {
instance *caddy.Instance
// keysToSiteConfigs maps an address at the top of a
// server block (a "key") to its SiteConfig. Not all
// SiteConfigs will be represented here, only ones
@ -115,12 +117,14 @@ func (h *httpContext) saveConfig(key string, cfg *SiteConfig) {
// executing directives and otherwise prepares the directives to
// be parsed and executed.
func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
siteAddrs := make(map[string]string)
// For each address in each server block, make a new config
for _, sb := range serverBlocks {
for _, key := range sb.Keys {
key = strings.ToLower(key)
if _, dup := h.keysToSiteConfigs[key]; dup {
return serverBlocks, fmt.Errorf("duplicate site address: %s", key)
return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
}
addr, err := standardizeAddress(key)
if err != nil {
@ -136,6 +140,23 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
addr.Port = Port
}
// Make sure the adjusted site address is distinct
addrCopy := addr // make copy so we don't disturb the original, carefully-parsed address struct
if addrCopy.Port == "" && Port == DefaultPort {
addrCopy.Port = Port
}
addrStr := strings.ToLower(addrCopy.String())
if otherSiteKey, dup := siteAddrs[addrStr]; dup {
err := fmt.Errorf("duplicate site address: %s", addrStr)
if (addrCopy.Host == Host && Host != DefaultHost) ||
(addrCopy.Port == Port && Port != DefaultPort) {
err = fmt.Errorf("site defined as %s is a duplicate of %s because of modified "+
"default host and/or port values (usually via -host or -port flags)", key, otherSiteKey)
}
return serverBlocks, err
}
siteAddrs[addrStr] = key
// If default HTTP or HTTPS ports have been customized,
// make sure the ACME challenge ports match
var altHTTPPort, altTLSSNIPort string
@ -146,15 +167,19 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
altTLSSNIPort = HTTPSPort
}
// Make our caddytls.Config, which has a pointer to the
// instance's certificate cache and enough information
// to use automatic HTTPS when the time comes
caddytlsConfig := caddytls.NewConfig(h.instance)
caddytlsConfig.Hostname = addr.Host
caddytlsConfig.AltHTTPPort = altHTTPPort
caddytlsConfig.AltTLSSNIPort = altTLSSNIPort
// Save the config to our master list, and key it for lookups
cfg := &SiteConfig{
Addr: addr,
Root: Root,
TLS: &caddytls.Config{
Hostname: addr.Host,
AltHTTPPort: altHTTPPort,
AltTLSSNIPort: altTLSSNIPort,
},
Addr: addr,
Root: Root,
TLS: caddytlsConfig,
originCaddyfile: sourceFile,
IndexPages: staticfiles.DefaultIndexPages,
}

View file

@ -137,7 +137,7 @@ func TestAddressString(t *testing.T) {
func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
Port = "9999"
filename := "Testfile"
ctx := newContext().(*httpContext)
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
input := strings.NewReader(`localhost`)
sblocks, err := caddyfile.Parse(filename, input, nil)
if err != nil {
@ -153,9 +153,26 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
}
}
// See discussion on PR #2015
func TestInspectServerBlocksWithAdjustedAddress(t *testing.T) {
Port = DefaultPort
Host = "example.com"
filename := "Testfile"
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
input := strings.NewReader("example.com {\n}\n:2015 {\n}")
sblocks, err := caddyfile.Parse(filename, input, nil)
if err != nil {
t.Fatalf("Expected no error setting up test, got: %v", err)
}
_, err = ctx.InspectServerBlocks(filename, sblocks)
if err == nil {
t.Fatalf("Expected an error because site definitions should overlap, got: %v", err)
}
}
func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) {
filename := "Testfile"
ctx := newContext().(*httpContext)
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
input := strings.NewReader("localhost {\n}\nLOCALHOST {\n}")
sblocks, err := caddyfile.Parse(filename, input, nil)
if err != nil {
@ -207,7 +224,7 @@ func TestDirectivesList(t *testing.T) {
}
func TestContextSaveConfig(t *testing.T) {
ctx := newContext().(*httpContext)
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
ctx.saveConfig("foo", new(SiteConfig))
if _, ok := ctx.keysToSiteConfigs["foo"]; !ok {
t.Error("Expected config to be saved, but it wasn't")
@ -226,7 +243,7 @@ func TestContextSaveConfig(t *testing.T) {
// Test to make sure we are correctly hiding the Caddyfile
func TestHideCaddyfile(t *testing.T) {
ctx := newContext().(*httpContext)
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
ctx.saveConfig("test", &SiteConfig{
Root: Root,
originCaddyfile: "Testfile",

View file

@ -392,7 +392,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
if vhost == nil {
// check for ACME challenge even if vhost is nil;
// could be a new host coming online soon
if caddytls.HTTPChallengeHandler(w, r, "localhost", caddytls.DefaultHTTPAlternatePort) {
if caddytls.HTTPChallengeHandler(w, r, "localhost") {
return 0, nil
}
// otherwise, log the error and write a message to the client
@ -408,7 +408,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
// we still check for ACME challenge if the vhost exists,
// because we must apply its HTTP challenge config settings
if s.proxyHTTPChallenge(vhost, w, r) {
if caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost) {
return 0, nil
}
@ -416,31 +416,25 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
// the URL path, so a request to example.com/foo/blog on the site
// defined as example.com/foo appears as /blog instead of /foo/blog.
if pathPrefix != "/" {
r.URL.Path = strings.TrimPrefix(r.URL.Path, pathPrefix)
if !strings.HasPrefix(r.URL.Path, "/") {
r.URL.Path = "/" + r.URL.Path
}
r.URL = trimPathPrefix(r.URL, pathPrefix)
}
return vhost.middlewareChain.ServeHTTP(w, r)
}
// proxyHTTPChallenge solves the ACME HTTP challenge if r is the HTTP
// request for the challenge. If it is, and if the request has been
// fulfilled (response written), true is returned; false otherwise.
// If you don't have a vhost, just call the challenge handler directly.
func (s *Server) proxyHTTPChallenge(vhost *SiteConfig, w http.ResponseWriter, r *http.Request) bool {
if vhost.Addr.Port != caddytls.HTTPChallengePort {
return false
func trimPathPrefix(u *url.URL, prefix string) *url.URL {
// We need to use URL.EscapedPath() when trimming the pathPrefix as
// URL.Path is ambiguous about / or %2f - see docs. See #1927
trimmed := strings.TrimPrefix(u.EscapedPath(), prefix)
if !strings.HasPrefix(trimmed, "/") {
trimmed = "/" + trimmed
}
if vhost.TLS != nil && vhost.TLS.Manual {
return false
trimmedURL, err := url.Parse(trimmed)
if err != nil {
log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmed, err)
return u
}
altPort := caddytls.DefaultHTTPAlternatePort
if vhost.TLS != nil && vhost.TLS.AltHTTPPort != "" {
altPort = vhost.TLS.AltHTTPPort
}
return caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost, altPort)
return trimmedURL
}
// Address returns the address s was assigned to listen on.

View file

@ -16,6 +16,7 @@ package httpserver
import (
"net/http"
"net/url"
"testing"
"time"
)
@ -126,6 +127,94 @@ func TestMakeHTTPServerWithTimeouts(t *testing.T) {
}
}
func TestTrimPathPrefix(t *testing.T) {
for i, pt := range []struct {
path string
prefix string
expected string
shouldFail bool
}{
{
path: "/my/path",
prefix: "/my",
expected: "/path",
shouldFail: false,
},
{
path: "/my/%2f/path",
prefix: "/my",
expected: "/%2f/path",
shouldFail: false,
},
{
path: "/my/path",
prefix: "/my/",
expected: "/path",
shouldFail: false,
},
{
path: "/my///path",
prefix: "/my",
expected: "/path",
shouldFail: true,
},
{
path: "/my///path",
prefix: "/my",
expected: "///path",
shouldFail: false,
},
{
path: "/my/path///slash",
prefix: "/my",
expected: "/path///slash",
shouldFail: false,
},
{
path: "/my/%2f/path/%2f",
prefix: "/my",
expected: "/%2f/path/%2f",
shouldFail: false,
}, {
path: "/my/%20/path",
prefix: "/my",
expected: "/%20/path",
shouldFail: false,
}, {
path: "/path",
prefix: "",
expected: "/path",
shouldFail: false,
}, {
path: "/path/my/",
prefix: "/my",
expected: "/path/my/",
shouldFail: false,
}, {
path: "",
prefix: "/my",
expected: "/",
shouldFail: false,
}, {
path: "/apath",
prefix: "",
expected: "/apath",
shouldFail: false,
},
} {
u, _ := url.Parse(pt.path)
if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.EscapedPath() != want {
if !pt.shouldFail {
t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.EscapedPath())
}
} else if pt.shouldFail {
t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.EscapedPath())
}
}
}
func TestMakeHTTPServerWithHeaderLimit(t *testing.T) {
for name, c := range map[string]struct {
group []*SiteConfig

View file

@ -16,6 +16,7 @@ package requestid
import (
"context"
"log"
"net/http"
"github.com/google/uuid"
@ -24,12 +25,29 @@ import (
// Handler is a middleware handler
type Handler struct {
Next httpserver.Handler
Next httpserver.Handler
HeaderName string // (optional) header from which to read an existing ID
}
func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
reqid := uuid.New().String()
c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid)
var reqid uuid.UUID
uuidFromHeader := r.Header.Get(h.HeaderName)
if h.HeaderName != "" && uuidFromHeader != "" {
// use the ID in the header field if it exists
var err error
reqid, err = uuid.Parse(uuidFromHeader)
if err != nil {
log.Printf("[NOTICE] Parsing request ID from %s header: %v", h.HeaderName, err)
reqid = uuid.New()
}
} else {
// otherwise, create a new one
reqid = uuid.New()
}
// set the request ID on the context
c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid.String())
r = r.WithContext(c)
return h.Next.ServeHTTP(w, r)

View file

@ -15,34 +15,53 @@
package requestid
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/google/uuid"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestRequestID(t *testing.T) {
request, err := http.NewRequest("GET", "http://localhost/", nil)
func TestRequestIDHandler(t *testing.T) {
handler := Handler{
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string)
if value == "" {
t.Error("Request ID should not be empty")
}
return 0, nil
}),
}
req, err := http.NewRequest("GET", "http://localhost/", nil)
if err != nil {
t.Fatal("Could not create HTTP request:", err)
}
rec := httptest.NewRecorder()
reqid := uuid.New().String()
c := context.WithValue(request.Context(), httpserver.RequestIDCtxKey, reqid)
request = request.WithContext(c)
// See caddyhttp/replacer.go
value, _ := request.Context().Value(httpserver.RequestIDCtxKey).(string)
if value == "" {
t.Fatal("Request ID should not be empty")
}
if value != reqid {
t.Fatal("Request ID does not match")
}
handler.ServeHTTP(rec, req)
}
func TestRequestIDFromHeader(t *testing.T) {
headerName := "X-Request-ID"
headerValue := "71a75329-d9f9-4d25-957e-e689a7b68d78"
handler := Handler{
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string)
if value != headerValue {
t.Errorf("Request ID should be '%s' but got '%s'", headerValue, value)
}
return 0, nil
}),
HeaderName: headerName,
}
req, err := http.NewRequest("GET", "http://localhost/", nil)
if err != nil {
t.Fatal("Could not create HTTP request:", err)
}
req.Header.Set(headerName, headerValue)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
}

View file

@ -27,14 +27,19 @@ func init() {
}
func setup(c *caddy.Controller) error {
var headerName string
for c.Next() {
if c.NextArg() {
return c.ArgErr() //no arg expected.
headerName = c.Val()
}
if c.NextArg() {
return c.ArgErr()
}
}
httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
return Handler{Next: next}
return Handler{Next: next, HeaderName: headerName}
})
return nil

View file

@ -45,7 +45,15 @@ func TestSetup(t *testing.T) {
}
func TestSetupWithArg(t *testing.T) {
c := caddy.NewTestController("http", `requestid abc`)
c := caddy.NewTestController("http", `requestid X-Request-ID`)
err := setup(c)
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
}
func TestSetupWithTooManyArgs(t *testing.T) {
c := caddy.NewTestController("http", `requestid foo bar`)
err := setup(c)
if err == nil {
t.Errorf("Expected an error, got: %v", err)

View file

@ -107,6 +107,10 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request) (int, err
if d.IsDir() {
// ensure there is a trailing slash
if urlCopy.Path[len(urlCopy.Path)-1] != '/' {
for strings.HasPrefix(urlCopy.Path, "//") {
// prevent path-based open redirects
urlCopy.Path = strings.TrimPrefix(urlCopy.Path, "/")
}
urlCopy.Path += "/"
http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently)
return http.StatusMovedPermanently, nil
@ -131,6 +135,10 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request) (int, err
}
if redir {
for strings.HasPrefix(urlCopy.Path, "//") {
// prevent path-based open redirects
urlCopy.Path = strings.TrimPrefix(urlCopy.Path, "/")
}
http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently)
return http.StatusMovedPermanently, nil
}

View file

@ -77,9 +77,9 @@ func TestServeHTTP(t *testing.T) {
{
url: "https://foo/dirwithindex/",
expectedStatus: http.StatusOK,
expectedBodyContent: testFiles[webrootDirwithindexIndeHTML],
expectedBodyContent: testFiles[webrootDirwithindexIndexHTML],
expectedEtag: `"2n9cw"`,
expectedContentLength: strconv.Itoa(len(testFiles[webrootDirwithindexIndeHTML])),
expectedContentLength: strconv.Itoa(len(testFiles[webrootDirwithindexIndexHTML])),
},
// Test 4 - access folder with index file without trailing slash
{
@ -235,16 +235,38 @@ func TestServeHTTP(t *testing.T) {
expectedBodyContent: movedPermanently,
},
{
// Test 27 - Check etag
url: "https://foo/notindex.html",
expectedStatus: http.StatusOK,
expectedBodyContent: testFiles[webrootNotIndexHTML],
expectedEtag: `"2n9cm"`,
expectedContentLength: strconv.Itoa(len(testFiles[webrootNotIndexHTML])),
},
{
// Test 28 - Prevent path-based open redirects (directory)
url: "https://foo//example.com%2f..",
expectedStatus: http.StatusMovedPermanently,
expectedLocation: "https://foo/example.com/../",
expectedBodyContent: movedPermanently,
},
{
// Test 29 - Prevent path-based open redirects (file)
url: "https://foo//example.com%2f../dirwithindex/index.html",
expectedStatus: http.StatusMovedPermanently,
expectedLocation: "https://foo/example.com/../dirwithindex/",
expectedBodyContent: movedPermanently,
},
{
// Test 29 - Prevent path-based open redirects (extra leading slashes)
url: "https://foo///example.com%2f..",
expectedStatus: http.StatusMovedPermanently,
expectedLocation: "https://foo/example.com/../",
expectedBodyContent: movedPermanently,
},
}
for i, test := range tests {
// set up response writer and rewuest
// set up response writer and request
responseRecorder := httptest.NewRecorder()
request, err := http.NewRequest("GET", test.url, nil)
if err != nil {
@ -518,7 +540,7 @@ var (
webrootNotIndexHTML = filepath.Join(webrootName, "notindex.html")
webrootDirFile2HTML = filepath.Join(webrootName, "dir", "file2.html")
webrootDirHiddenHTML = filepath.Join(webrootName, "dir", "hidden.html")
webrootDirwithindexIndeHTML = filepath.Join(webrootName, "dirwithindex", "index.html")
webrootDirwithindexIndexHTML = filepath.Join(webrootName, "dirwithindex", "index.html")
webrootSubGzippedHTML = filepath.Join(webrootName, "sub", "gzipped.html")
webrootSubGzippedHTMLGz = filepath.Join(webrootName, "sub", "gzipped.html.gz")
webrootSubGzippedHTMLBr = filepath.Join(webrootName, "sub", "gzipped.html.br")
@ -544,7 +566,7 @@ var testFiles = map[string]string{
webrootFile1HTML: "<h1>file1.html</h1>",
webrootNotIndexHTML: "<h1>notindex.html</h1>",
webrootDirFile2HTML: "<h1>dir/file2.html</h1>",
webrootDirwithindexIndeHTML: "<h1>dirwithindex/index.html</h1>",
webrootDirwithindexIndexHTML: "<h1>dirwithindex/index.html</h1>",
webrootDirHiddenHTML: "<h1>dir/hidden.html</h1>",
webrootSubGzippedHTML: "<h1>gzipped.html</h1>",
webrootSubGzippedHTMLGz: "1.gzipped.html.gz",

View file

@ -15,9 +15,11 @@
package caddytls
import (
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"log"
"strings"
@ -27,24 +29,104 @@ import (
"golang.org/x/crypto/ocsp"
)
// certCache stores certificates in memory,
// keying certificates by name. Certificates
// should not overlap in the names they serve,
// because a name only maps to one certificate.
var certCache = make(map[string]Certificate)
var certCacheMu sync.RWMutex
// certificateCache is to be an instance-wide cache of certs
// that site-specific TLS configs can refer to. Using a
// central map like this avoids duplication of certs in
// memory when the cert is used by multiple sites, and makes
// maintenance easier. Because these are not to be global,
// the cache will get garbage collected after a config reload
// (a new instance will take its place).
type certificateCache struct {
sync.RWMutex
cache map[string]Certificate // keyed by certificate hash
}
// replaceCertificate replaces oldCert with newCert in the cache, and
// updates all configs that are pointing to the old certificate to
// point to the new one instead. newCert must already be loaded into
// the cache (this method does NOT load it into the cache).
//
// Note that all the names on the old certificate will be deleted
// from the name lookup maps of each config, then all the names on
// the new certificate will be added to the lookup maps as long as
// they do not overwrite any entries.
//
// The newCert may be modified and its cache entry updated.
//
// This method is safe for concurrent use.
func (certCache *certificateCache) replaceCertificate(oldCert, newCert Certificate) error {
certCache.Lock()
defer certCache.Unlock()
// have all the configs that are pointing to the old
// certificate point to the new certificate instead
for _, cfg := range oldCert.configs {
// first delete all the name lookup entries that
// pointed to the old certificate
for name, certKey := range cfg.Certificates {
if certKey == oldCert.Hash {
delete(cfg.Certificates, name)
}
}
// then add name lookup entries for the names
// on the new certificate, but don't overwrite
// entries that may already exist, not only as
// a courtesy, but importantly: because if we
// overwrote a value here, and this config no
// longer pointed to a certain certificate in
// the cache, that certificate's list of configs
// referring to it would be incorrect; so just
// insert entries, don't overwrite any
for _, name := range newCert.Names {
if _, ok := cfg.Certificates[name]; !ok {
cfg.Certificates[name] = newCert.Hash
}
}
}
// since caching a new certificate attaches only the config
// that loaded it, the new certificate needs to be given the
// list of all the configs that use it, so copy the list
// over from the old certificate to the new certificate
// in the cache
newCert.configs = oldCert.configs
certCache.cache[newCert.Hash] = newCert
// finally, delete the old certificate from the cache
delete(certCache.cache, oldCert.Hash)
return nil
}
// reloadManagedCertificate reloads the certificate corresponding to the name(s)
// on oldCert into the cache, from storage. This also replaces the old certificate
// with the new one, so that all configurations that used the old cert now point
// to the new cert.
func (certCache *certificateCache) reloadManagedCertificate(oldCert Certificate) error {
// get the certificate from storage and cache it
newCert, err := oldCert.configs[0].CacheManagedCertificate(oldCert.Names[0])
if err != nil {
return fmt.Errorf("unable to reload certificate for %v into cache: %v", oldCert.Names, err)
}
// and replace the old certificate with the new one
err = certCache.replaceCertificate(oldCert, newCert)
if err != nil {
return fmt.Errorf("replacing certificate %v: %v", oldCert.Names, err)
}
return nil
}
// Certificate is a tls.Certificate with associated metadata tacked on.
// Even if the metadata can be obtained by parsing the certificate,
// we can be more efficient by extracting the metadata once so it's
// just there, ready to use.
// we are more efficient by extracting the metadata onto this struct.
type Certificate struct {
tls.Certificate
// Names is the list of names this certificate is written for.
// The first is the CommonName (if any), the rest are SAN.
// This should be the exact list of keys by which this cert
// is accessed in the cache, careful to avoid overlap.
Names []string
// NotAfter is when the certificate expires.
@ -53,59 +135,21 @@ type Certificate struct {
// OCSP contains the certificate's parsed OCSP response.
OCSP *ocsp.Response
// Config is the configuration with which the certificate was
// loaded or obtained and with which it should be maintained.
Config *Config
}
// The hex-encoded hash of this cert's chain's bytes.
Hash string
// getCertificate gets a certificate that matches name (a server name)
// from the in-memory cache. If there is no exact match for name, it
// will be checked against names of the form '*.example.com' (wildcard
// certificates) according to RFC 6125. If a match is found, matched will
// be true. If no matches are found, matched will be false and a default
// certificate will be returned with defaulted set to true. If no default
// certificate is set, defaulted will be set to false.
//
// The logic in this function is adapted from the Go standard library,
// which is by the Go Authors.
//
// This function is safe for concurrent use.
func getCertificate(name string) (cert Certificate, matched, defaulted bool) {
var ok bool
// Not going to trim trailing dots here since RFC 3546 says,
// "The hostname is represented ... without a trailing dot."
// Just normalize to lowercase.
name = strings.ToLower(name)
certCacheMu.RLock()
defer certCacheMu.RUnlock()
// exact match? great, let's use it
if cert, ok = certCache[name]; ok {
matched = true
return
}
// try replacing labels in the name with wildcards until we get a match
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok = certCache[candidate]; ok {
matched = true
return
}
}
// if nothing matches, use the default certificate or bust
cert, defaulted = certCache[""]
return
// configs is the list of configs that use or refer to
// The first one is assumed to be the config that is
// "in charge" of this certificate (i.e. determines
// whether it is managed, how it is managed, etc).
// This field will be populated by cacheCertificate.
// Only meddle with it if you know what you're doing!
configs []*Config
}
// CacheManagedCertificate loads the certificate for domain into the
// cache, flagging it as Managed and, if onDemand is true, as "OnDemand"
// (meaning that it was obtained or loaded during a TLS handshake).
// cache, from the TLS storage for managed certificates. It returns a
// copy of the Certificate that was put into the cache.
//
// This method is safe for concurrent use.
func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) {
@ -117,39 +161,24 @@ func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) {
if err != nil {
return Certificate{}, err
}
cert, err := makeCertificate(siteData.Cert, siteData.Key)
cert, err := makeCertificateWithOCSP(siteData.Cert, siteData.Key)
if err != nil {
return cert, err
}
cert.Config = cfg
cacheCertificate(cert)
return cert, nil
return cfg.cacheCertificate(cert), nil
}
// cacheUnmanagedCertificatePEMFile loads a certificate for host using certFile
// and keyFile, which must be in PEM format. It stores the certificate in
// memory after evicting any other entries in the cache keyed by the names
// on this certificate. In other words, it replaces existing certificates keyed
// by the names on this certificate. The Managed and OnDemand flags of the
// certificate will be set to false.
// the in-memory cache.
//
// This function is safe for concurrent use.
func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
cert, err := makeCertificateFromDisk(certFile, keyFile)
func (cfg *Config) cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
cert, err := makeCertificateFromDiskWithOCSP(certFile, keyFile)
if err != nil {
return err
}
// since this is manually managed, this call might be part of a reload after
// the owner renewed a certificate; so clear cache of any previous cert first,
// otherwise the renewed certificate may never be loaded
certCacheMu.Lock()
for _, name := range cert.Names {
delete(certCache, name)
}
certCacheMu.Unlock()
cacheCertificate(cert)
cfg.cacheCertificate(cert)
return nil
}
@ -157,20 +186,20 @@ func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
// of the certificate and key, then caches it in memory.
//
// This function is safe for concurrent use.
func cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error {
cert, err := makeCertificate(certBytes, keyBytes)
func (cfg *Config) cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error {
cert, err := makeCertificateWithOCSP(certBytes, keyBytes)
if err != nil {
return err
}
cacheCertificate(cert)
cfg.cacheCertificate(cert)
return nil
}
// makeCertificateFromDisk makes a Certificate by loading the
// makeCertificateFromDiskWithOCSP makes a Certificate by loading the
// certificate and key files. It fills out all the fields in
// the certificate except for the Managed and OnDemand flags.
// (It is up to the caller to set those.)
func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) {
// (It is up to the caller to set those.) It staples OCSP.
func makeCertificateFromDiskWithOCSP(certFile, keyFile string) (Certificate, error) {
certPEMBlock, err := ioutil.ReadFile(certFile)
if err != nil {
return Certificate{}, err
@ -179,13 +208,14 @@ func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) {
if err != nil {
return Certificate{}, err
}
return makeCertificate(certPEMBlock, keyPEMBlock)
return makeCertificateWithOCSP(certPEMBlock, keyPEMBlock)
}
// makeCertificate turns a certificate PEM bundle and a key PEM block into
// a Certificate, with OCSP and other relevant metadata tagged with it,
// except for the OnDemand and Managed flags. It is up to the caller to
// set those properties.
// a Certificate with necessary metadata from parsing its bytes filled into
// its struct fields for convenience (except for the OnDemand and Managed
// flags; it is up to the caller to set those properties!). This function
// does NOT staple OCSP.
func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
var cert Certificate
@ -195,16 +225,26 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
return cert, err
}
// Extract relevant metadata and staple OCSP
// Extract necessary metadata
err = fillCertFromLeaf(&cert, tlsCert)
if err != nil {
return cert, err
}
return cert, nil
}
// makeCertificateWithOCSP is the same as makeCertificate except that it also
// staples OCSP to the certificate.
func makeCertificateWithOCSP(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
cert, err := makeCertificate(certPEMBlock, keyPEMBlock)
if err != nil {
return cert, err
}
err = stapleOCSP(&cert, certPEMBlock)
if err != nil {
log.Printf("[WARNING] Stapling OCSP: %v", err)
}
return cert, nil
}
@ -243,65 +283,104 @@ func fillCertFromLeaf(cert *Certificate, tlsCert tls.Certificate) error {
return errors.New("certificate has no names")
}
// save the hash of this certificate (chain) and
// expiration date, for necessity and efficiency
cert.Hash = hashCertificateChain(cert.Certificate.Certificate)
cert.NotAfter = leaf.NotAfter
return nil
}
// cacheCertificate adds cert to the in-memory cache. If the cache is
// empty, cert will be used as the default certificate. If the cache is
// full, random entries are deleted until there is room to map all the
// names on the certificate.
// hashCertificateChain computes the unique hash of certChain,
// which is the chain of DER-encoded bytes. It returns the
// hex encoding of the hash.
func hashCertificateChain(certChain [][]byte) string {
h := sha256.New()
for _, certInChain := range certChain {
h.Write(certInChain)
}
return fmt.Sprintf("%x", h.Sum(nil))
}
// managedCertInStorageExpiresSoon returns true if cert (being a
// managed certificate) is expiring within RenewDurationBefore.
// It returns false if there was an error checking the expiration
// of the certificate as found in storage, or if the certificate
// in storage is NOT expiring soon. A certificate that is expiring
// soon in our cache but is not expiring soon in storage probably
// means that another instance renewed the certificate in the
// meantime, and it would be a good idea to simply load the cert
// into our cache rather than repeating the renewal process again.
func managedCertInStorageExpiresSoon(cert Certificate) (bool, error) {
if len(cert.configs) == 0 {
return false, fmt.Errorf("no configs for certificate")
}
storage, err := cert.configs[0].StorageFor(cert.configs[0].CAUrl)
if err != nil {
return false, err
}
siteData, err := storage.LoadSite(cert.Names[0])
if err != nil {
return false, err
}
tlsCert, err := tls.X509KeyPair(siteData.Cert, siteData.Key)
if err != nil {
return false, err
}
leaf, err := x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
return false, err
}
timeLeft := leaf.NotAfter.Sub(time.Now().UTC())
return timeLeft < RenewDurationBefore, nil
}
// cacheCertificate adds cert to the in-memory cache. If a certificate
// with the same hash is already cached, it is NOT overwritten; instead,
// cfg is added to the existing certificate's list of configs if not
// already in the list. Then all the names on cert are used to add
// entries to cfg.Certificates (the config's name lookup map).
// Then the certificate is stored/updated in the cache. It returns
// a copy of the certificate that ends up being stored in the cache.
//
// This certificate will be keyed to the names in cert.Names. Any names
// already used as a cache key will NOT be replaced by this cert; in
// other words, no overlap is allowed, and this certificate will not
// service those pre-existing names.
// It is VERY important, even for some test cases, that the Hash field
// of the cert be set properly.
//
// This function is safe for concurrent use.
func cacheCertificate(cert Certificate) {
if cert.Config == nil {
cert.Config = new(Config)
func (cfg *Config) cacheCertificate(cert Certificate) Certificate {
cfg.certCache.Lock()
defer cfg.certCache.Unlock()
// if this certificate already exists in the cache,
// use it instead of overwriting it -- very important!
if existingCert, ok := cfg.certCache.cache[cert.Hash]; ok {
cert = existingCert
}
certCacheMu.Lock()
if _, ok := certCache[""]; !ok {
// use as default - must be *appended* to end of list, or bad things happen!
cert.Names = append(cert.Names, "")
}
for len(certCache)+len(cert.Names) > 10000 {
// for simplicity, just remove random elements
for key := range certCache {
if key == "" { // ... but not the default cert
continue
}
delete(certCache, key)
// attach this config to the certificate so we know which
// configs are referencing/using the certificate, but don't
// duplicate entries
var found bool
for _, c := range cert.configs {
if c == cfg {
found = true
break
}
}
for i := 0; i < len(cert.Names); i++ {
name := cert.Names[i]
if _, ok := certCache[name]; ok {
// do not allow certificates to overlap in the names they serve;
// this ambiguity causes problems because it is confusing while
// maintaining certificates; see OCSP maintenance code and
// https://caddy.community/t/random-ocsp-response-errors-for-random-clients/2473?u=matt.
log.Printf("[NOTICE] There is already a certificate loaded for %s, "+
"so certificate for %v will not service that name",
name, cert.Names)
cert.Names = append(cert.Names[:i], cert.Names[i+1:]...)
i--
continue
}
certCache[name] = cert
if !found {
cert.configs = append(cert.configs, cfg)
}
certCacheMu.Unlock()
}
// uncacheCertificate deletes name's certificate from the
// cache. If name is not a key in the certificate cache,
// this function does nothing.
func uncacheCertificate(name string) {
certCacheMu.Lock()
delete(certCache, name)
certCacheMu.Unlock()
// key the certificate by all its names for this config only,
// this is how we find the certificate during handshakes
// (yes, if certs overlap in the names they serve, one will
// overwrite another here, but that's just how it goes)
for _, name := range cert.Names {
cfg.Certificates[name] = cert.Hash
}
// store the certificate
cfg.certCache.cache[cert.Hash] = cert
return cert
}

View file

@ -17,57 +17,71 @@ package caddytls
import "testing"
func TestUnexportedGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
// When cache is empty
if _, matched, defaulted := getCertificate("example.com"); matched || defaulted {
if _, matched, defaulted := cfg.getCertificate("example.com"); matched || defaulted {
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
}
// When cache has one certificate in it (also is default)
defaultCert := Certificate{Names: []string{"example.com", ""}}
certCache[""] = defaultCert
certCache["example.com"] = defaultCert
if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
// When cache has one certificate in it
firstCert := Certificate{Names: []string{"example.com"}}
certCache.cache["0xdeadbeef"] = firstCert
cfg.Certificates["example.com"] = "0xdeadbeef"
if cert, matched, defaulted := cfg.getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
if cert, matched, defaulted := cfg.getCertificate("example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for 'example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
// When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}}
if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
certCache.cache["0xb01dface"] = Certificate{Names: []string{"*.example.com"}}
cfg.Certificates["*.example.com"] = "0xb01dface"
if cert, matched, defaulted := cfg.getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
// When no certificate matches, the default is returned
if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted {
// When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted {
t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
}
// When no certificate matches and SNI is NOT provided, a random is returned
if cert, matched, defaulted := cfg.getCertificate(""); matched || !defaulted {
t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
} else if cert.Names[0] != "example.com" {
t.Errorf("Expected default cert, got: %v", cert)
}
}
func TestCacheCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}})
if _, ok := certCache["example.com"]; !ok {
t.Error("Expected first cert to be cached by key 'example.com', but it wasn't")
cfg.cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}, Hash: "foobar"})
if len(certCache.cache) != 1 {
t.Errorf("Expected length of certificate cache to be 1")
}
if _, ok := certCache["sub.example.com"]; !ok {
t.Error("Expected first cert to be cached by key 'sub.example.com', but it wasn't")
if _, ok := certCache.cache["foobar"]; !ok {
t.Error("Expected first cert to be cached by key 'foobar', but it wasn't")
}
if cert, ok := certCache[""]; !ok || cert.Names[2] != "" {
t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't")
if _, ok := cfg.Certificates["example.com"]; !ok {
t.Error("Expected first cert to be keyed by 'example.com', but it wasn't")
}
if _, ok := cfg.Certificates["sub.example.com"]; !ok {
t.Error("Expected first cert to be keyed by 'sub.example.com', but it wasn't")
}
cacheCertificate(Certificate{Names: []string{"example2.com"}})
if _, ok := certCache["example2.com"]; !ok {
t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't")
// different config, but using same cache; and has cert with overlapping name,
// but different hash
cfg2 := &Config{Certificates: make(map[string]string), certCache: certCache}
cfg2.cacheCertificate(Certificate{Names: []string{"example.com"}, Hash: "barbaz"})
if _, ok := certCache.cache["barbaz"]; !ok {
t.Error("Expected second cert to be cached by key 'barbaz.com', but it wasn't")
}
if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" {
t.Error("Expected second cert to NOT be cached as default, but it was")
if hash, ok := cfg2.Certificates["example.com"]; !ok {
t.Error("Expected second cert to be keyed by 'example.com', but it wasn't")
} else if hash != "barbaz" {
t.Errorf("Expected second cert to map to 'barbaz' but it was %s instead", hash)
}
}

View file

@ -40,7 +40,7 @@ type ACMEClient struct {
AllowPrompts bool
config *Config
acmeClient *acme.Client
locker Locker
storage Storage
}
// newACMEClient creates a new ACMEClient given an email and whether
@ -122,10 +122,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
AllowPrompts: allowPrompts,
config: config,
acmeClient: client,
locker: &syncLock{
nameLocks: make(map[string]*sync.WaitGroup),
nameLocksMu: sync.Mutex{},
},
storage: storage,
}
if config.DNSProvider == "" {
@ -161,7 +158,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
// See if TLS challenge needs to be handled by our own facilities
if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) {
c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSniSolver{})
c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache})
}
// Disable any challenges that should not be used
@ -210,13 +207,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
// Callers who have access to a Config value should use the ObtainCert
// method on that instead of this lower-level method.
func (c *ACMEClient) Obtain(name string) error {
// Get access to ACME storage
storage, err := c.config.StorageFor(c.config.CAUrl)
if err != nil {
return err
}
waiter, err := c.locker.TryLock(name)
waiter, err := c.storage.TryLock(name)
if err != nil {
return err
}
@ -226,7 +217,7 @@ func (c *ACMEClient) Obtain(name string) error {
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
}
defer func() {
if err := c.locker.Unlock(name); err != nil {
if err := c.storage.Unlock(name); err != nil {
log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err)
}
}()
@ -269,7 +260,7 @@ Attempts:
}
// Success - immediately save the certificate resource
err = saveCertResource(storage, certificate)
err = saveCertResource(c.storage, certificate)
if err != nil {
return fmt.Errorf("error saving assets for %v: %v", name, err)
}
@ -282,35 +273,30 @@ Attempts:
return nil
}
// Renew renews the managed certificate for name. This function is
// safe for concurrent use.
// Renew renews the managed certificate for name. It puts the renewed
// certificate into storage (not the cache). This function is safe for
// concurrent use.
//
// Callers who have access to a Config value should use the RenewCert
// method on that instead of this lower-level method.
func (c *ACMEClient) Renew(name string) error {
// Get access to ACME storage
storage, err := c.config.StorageFor(c.config.CAUrl)
if err != nil {
return err
}
waiter, err := c.locker.TryLock(name)
waiter, err := c.storage.TryLock(name)
if err != nil {
return err
}
if waiter != nil {
log.Printf("[INFO] Certificate for %s is already being renewed elsewhere and stored; waiting", name)
waiter.Wait()
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
return nil // assume that the worker that renewed the cert succeeded; avoid hammering this path over and over
}
defer func() {
if err := c.locker.Unlock(name); err != nil {
if err := c.storage.Unlock(name); err != nil {
log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err)
}
}()
// Prepare for renewal (load PEM cert, key, and meta)
siteData, err := storage.LoadSite(name)
siteData, err := c.storage.LoadSite(name)
if err != nil {
return err
}
@ -357,18 +343,13 @@ func (c *ACMEClient) Renew(name string) error {
go diagnostics.Increment("acme_certificates_obtained")
go diagnostics.Increment("acme_certificates_renewed")
return saveCertResource(storage, newCertMeta)
return saveCertResource(c.storage, newCertMeta)
}
// Revoke revokes the certificate for name and deltes
// Revoke revokes the certificate for name and deletes
// it from storage.
func (c *ACMEClient) Revoke(name string) error {
storage, err := c.config.StorageFor(c.config.CAUrl)
if err != nil {
return err
}
siteExists, err := storage.SiteExists(name)
siteExists, err := c.storage.SiteExists(name)
if err != nil {
return err
}
@ -377,7 +358,7 @@ func (c *ACMEClient) Revoke(name string) error {
return errors.New("no certificate and key for " + name)
}
siteData, err := storage.LoadSite(name)
siteData, err := c.storage.LoadSite(name)
if err != nil {
return err
}
@ -387,7 +368,7 @@ func (c *ACMEClient) Revoke(name string) error {
return err
}
err = storage.DeleteSite(name)
err = c.storage.DeleteSite(name)
if err != nil {
return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error())
}

View file

@ -93,16 +93,17 @@ type Config struct {
// an ACME challenge
ListenHost string
// The alternate port (ONLY port, not host)
// to use for the ACME HTTP challenge; this
// port will be used if we proxy challenges
// coming in on port 80 to this alternate port
// The alternate port (ONLY port, not host) to
// use for the ACME HTTP challenge; if non-empty,
// this port will be used instead of
// HTTPChallengePort to spin up a listener for
// the HTTP challenge
AltHTTPPort string
// The alternate port (ONLY port, not host)
// to use for the ACME TLS-SNI challenge.
// The system must forward the standard port
// for the TLS-SNI challenge to this port.
// The system must forward TLSSNIChallengePort
// to this port for challenge to succeed
AltTLSSNIPort string
// The string identifier of the DNS provider
@ -134,7 +135,12 @@ type Config struct {
// Protocol Negotiation (ALPN).
ALPN []string
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig()
// The map of hostname to certificate hash. This is used to complete
// handshakes and serve the right certificate given the SNI.
Certificates map[string]string
certCache *certificateCache // pointer to the Instance's certificate store
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig()
}
// OnDemandState contains some state relevant for providing
@ -155,6 +161,25 @@ type OnDemandState struct {
AskURL *url.URL
}
// NewConfig returns a new Config with a pointer to the instance's
// certificate cache. You will usually need to set Other fields on
// the returned Config for successful practical use.
func NewConfig(inst *caddy.Instance) *Config {
inst.StorageMu.RLock()
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
inst.StorageMu.RUnlock()
if !ok || certCache == nil {
certCache = &certificateCache{cache: make(map[string]Certificate)}
inst.StorageMu.Lock()
inst.Storage[CertCacheInstStorageKey] = certCache
inst.StorageMu.Unlock()
}
cfg := new(Config)
cfg.Certificates = make(map[string]string)
cfg.certCache = certCache
return cfg
}
// ObtainCert obtains a certificate for name using c, as long
// as a certificate does not already exist in storage for that
// name. The name must qualify and c must be flagged as Managed.
@ -330,7 +355,9 @@ func (c *Config) buildStandardTLSConfig() error {
// MakeTLSConfig makes a tls.Config from configs. The returned
// tls.Config is programmed to load the matching caddytls.Config
// based on the hostname in SNI, but that's all.
// based on the hostname in SNI, but that's all. This is used
// to create a single TLS configuration for a listener (a group
// of sites).
func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
if len(configs) == 0 {
return nil, nil
@ -358,15 +385,28 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
}
// convert each caddytls.Config into a tls.Config
// convert this caddytls.Config into a tls.Config
if err := cfg.buildStandardTLSConfig(); err != nil {
return nil, err
}
// Key this config by its hostname (overwriting
// configs with the same hostname pattern); during
// TLS handshakes, configs are loaded based on
// the hostname pattern, according to client's SNI.
// if an existing config with this hostname was already
// configured, then they must be identical (or at least
// compatible), otherwise that is a configuration error
if otherConfig, ok := configMap[cfg.Hostname]; ok {
if err := assertConfigsCompatible(cfg, otherConfig); err != nil {
return nil, fmt.Errorf("incompabile TLS configurations for the same SNI "+
"name (%s) on the same listener: %v",
cfg.Hostname, err)
}
}
// key this config by its hostname (overwrites
// configs with the same hostname pattern; should
// be OK since we already asserted they are roughly
// the same); during TLS handshakes, configs are
// loaded based on the hostname pattern, according
// to client's SNI
configMap[cfg.Hostname] = cfg
}
@ -383,6 +423,63 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
}, nil
}
// assertConfigsCompatible returns an error if the two Configs
// do not have the same (or roughly compatible) configurations.
// If one of the tlsConfig pointers on either Config is nil,
// an error will be returned. If both are nil, no error.
func assertConfigsCompatible(cfg1, cfg2 *Config) error {
c1, c2 := cfg1.tlsConfig, cfg2.tlsConfig
if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) {
return fmt.Errorf("one config is not made")
}
if c1 == nil && c2 == nil {
return nil
}
if len(c1.CipherSuites) != len(c2.CipherSuites) {
return fmt.Errorf("different number of allowed cipher suites")
}
for i, ciph := range c1.CipherSuites {
if c2.CipherSuites[i] != ciph {
return fmt.Errorf("different cipher suites or different order")
}
}
if len(c1.CurvePreferences) != len(c2.CurvePreferences) {
return fmt.Errorf("different number of allowed cipher suites")
}
for i, curve := range c1.CurvePreferences {
if c2.CurvePreferences[i] != curve {
return fmt.Errorf("different curve preferences or different order")
}
}
if len(c1.NextProtos) != len(c2.NextProtos) {
return fmt.Errorf("different number of ALPN (NextProtos) values")
}
for i, proto := range c1.NextProtos {
if c2.NextProtos[i] != proto {
return fmt.Errorf("different ALPN (NextProtos) values or different order")
}
}
if c1.PreferServerCipherSuites != c2.PreferServerCipherSuites {
return fmt.Errorf("one prefers server cipher suites, the other does not")
}
if c1.MinVersion != c2.MinVersion {
return fmt.Errorf("minimum TLS version mismatch")
}
if c1.MaxVersion != c2.MaxVersion {
return fmt.Errorf("maximum TLS version mismatch")
}
if c1.ClientAuth != c2.ClientAuth {
return fmt.Errorf("client authentication policy mismatch")
}
return nil
}
// ConfigGetter gets a Config keyed by key.
type ConfigGetter func(c *caddy.Controller) *Config
@ -522,7 +619,7 @@ var supportedCurvesMap = map[string]tls.CurveID{
"P521": tls.CurveP521,
}
// List of all the curves we want to use by default
// List of all the curves we want to use by default.
//
// This list should only include curves which are fast by design (e.g. X25519)
// and those for which an optimized assembly implementation exists (e.g. P256).
@ -548,4 +645,8 @@ const (
// be capable of proxying or forwarding the request to this
// alternate port.
DefaultHTTPAlternatePort = "5033"
// CertCacheInstStorageKey is the name of the key for
// accessing the certificate storage on the *caddy.Instance.
CertCacheInstStorageKey = "tls_cert_cache"
)

View file

@ -237,15 +237,17 @@ func makeSelfSignedCert(config *Config) error {
return fmt.Errorf("could not create certificate: %v", err)
}
cacheCertificate(Certificate{
chain := [][]byte{derBytes}
config.cacheCertificate(Certificate{
Certificate: tls.Certificate{
Certificate: [][]byte{derBytes},
Certificate: chain,
PrivateKey: privKey,
Leaf: cert,
},
Names: cert.DNSNames,
NotAfter: cert.NotAfter,
Config: config,
Hash: hashCertificateChain(chain),
})
return nil

View file

@ -38,9 +38,9 @@ var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme")
// Storage instance backed by the local disk. The resulting Storage
// instance is guaranteed to be non-nil if there is no error.
func NewFileStorage(caURL *url.URL) (Storage, error) {
return &FileStorage{
Path: filepath.Join(storageBasePath, caURL.Host),
}, nil
storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)}
storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage}
return storage, nil
}
// FileStorage facilitates forming file paths derived from a root
@ -48,6 +48,7 @@ func NewFileStorage(caURL *url.URL) (Storage, error) {
// cross-platform way or persisting ACME assets on the file system.
type FileStorage struct {
Path string
Locker
}
// sites gets the directory that stores site certificate and keys.

127
caddytls/filestoragesync.go Normal file
View file

@ -0,0 +1,127 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"fmt"
"os"
"sync"
"time"
"github.com/mholt/caddy"
)
func init() {
// be sure to remove lock files when exiting the process!
caddy.OnProcessExit = append(caddy.OnProcessExit, func() {
fileStorageNameLocksMu.Lock()
defer fileStorageNameLocksMu.Unlock()
for key, fw := range fileStorageNameLocks {
os.Remove(fw.filename)
delete(fileStorageNameLocks, key)
}
})
}
// fileStorageLock facilitates ACME-related locking by using
// the associated FileStorage, so multiple processes can coordinate
// renewals on the certificates on a shared file system.
type fileStorageLock struct {
caURL string
storage *FileStorage
}
// TryLock attempts to get a lock for name, otherwise it returns
// a Waiter value to wait until the other process is finished.
func (s *fileStorageLock) TryLock(name string) (Waiter, error) {
fileStorageNameLocksMu.Lock()
defer fileStorageNameLocksMu.Unlock()
// see if lock already exists within this process
fw, ok := fileStorageNameLocks[s.caURL+name]
if ok {
// lock already created within process, let caller wait on it
return fw, nil
}
// attempt to persist lock to disk by creating lock file
fw = &fileWaiter{
filename: s.storage.siteCertFile(name) + ".lock",
wg: new(sync.WaitGroup),
}
// parent dir must exist
if err := os.MkdirAll(s.storage.site(name), 0700); err != nil {
return nil, err
}
lf, err := os.OpenFile(fw.filename, os.O_CREATE|os.O_EXCL, 0644)
if err != nil {
if os.IsExist(err) {
// another process has the lock; use it to wait
return fw, nil
}
// otherwise, this was some unexpected error
return nil, err
}
lf.Close()
// looks like we get the lock
fw.wg.Add(1)
fileStorageNameLocks[s.caURL+name] = fw
return nil, nil
}
// Unlock unlocks name.
func (s *fileStorageLock) Unlock(name string) error {
fileStorageNameLocksMu.Lock()
defer fileStorageNameLocksMu.Unlock()
fw, ok := fileStorageNameLocks[s.caURL+name]
if !ok {
return fmt.Errorf("FileStorage: no lock to release for %s", name)
}
os.Remove(fw.filename)
fw.wg.Done()
delete(fileStorageNameLocks, s.caURL+name)
return nil
}
// fileWaiter waits for a file to disappear; it polls
// the file system to check for the existence of a file.
// It also has a WaitGroup which will be faster than
// polling, for when locking need only happen within this
// process.
type fileWaiter struct {
filename string
wg *sync.WaitGroup
}
// Wait waits until the lock is released.
func (fw *fileWaiter) Wait() {
start := time.Now()
fw.wg.Wait()
for time.Since(start) < 1*time.Hour {
_, err := os.Stat(fw.filename)
if os.IsNotExist(err) {
return
}
time.Sleep(1 * time.Second)
}
}
var fileStorageNameLocks = make(map[string]*fileWaiter) // keyed by CA + name
var fileStorageNameLocksMu sync.Mutex
var _ Locker = &fileStorageLock{}
var _ Waiter = &fileWaiter{}

View file

@ -61,15 +61,15 @@ func (cg configGroup) getConfig(name string) *Config {
}
}
// as a fallback, try a config that serves all names
// try a config that serves all names (this
// is basically the same as a config defined
// for "*" -- I think -- but the above loop
// doesn't try an empty string)
if config, ok := cg[""]; ok {
return config
}
// as a last resort, use a random config
// (even if the config isn't for that hostname,
// it should help us serve clients without SNI
// or at least defer TLS alerts to the cert)
// no matches, so just serve up a random config
for _, config := range cg {
return config
}
@ -121,6 +121,86 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
return &cert.Certificate, err
}
// getCertificate gets a certificate that matches name (a server name)
// from the in-memory cache, according to the lookup table associated with
// cfg. The lookup then points to a certificate in the Instance certificate
// cache.
//
// If there is no exact match for name, it will be checked against names of
// the form '*.example.com' (wildcard certificates) according to RFC 6125.
// If a match is found, matched will be true. If no matches are found, matched
// will be false and a "default" certificate will be returned with defaulted
// set to true. If defaulted is false, then no certificates were available.
//
// The logic in this function is adapted from the Go standard library,
// which is by the Go Authors.
//
// This function is safe for concurrent use.
func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defaulted bool) {
var certKey string
var ok bool
// Not going to trim trailing dots here since RFC 3546 says,
// "The hostname is represented ... without a trailing dot."
// Just normalize to lowercase.
name = strings.ToLower(name)
cfg.certCache.RLock()
defer cfg.certCache.RUnlock()
// exact match? great, let's use it
if certKey, ok = cfg.Certificates[name]; ok {
cert = cfg.certCache.cache[certKey]
matched = true
return
}
// try replacing labels in the name with wildcards until we get a match
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if certKey, ok = cfg.Certificates[candidate]; ok {
cert = cfg.certCache.cache[certKey]
matched = true
return
}
}
// check the certCache directly to see if the SNI name is
// already the key of the certificate it wants! this is vital
// for supporting the TLS-SNI challenge, since the tlsSNISolver
// just puts the temporary certificate in the instance cache,
// with no regard for configs; this also means that the SNI
// can contain the hash of a specific cert (chain) it wants
// and we will still be able to serve it up
// (this behavior, by the way, could be controversial as to
// whether it complies with RFC 6066 about SNI, but I think
// it does soooo...)
// NOTE/TODO: TLS-SNI challenge is changing, as of Jan. 2018
// but what will be different, if it ever returns, is unclear
if directCert, ok := cfg.certCache.cache[name]; ok {
cert = directCert
matched = true
return
}
// if nothing matches and SNI was not provided, use a random
// certificate; at least there's a chance this older client
// can connect, and in the future we won't need this provision
// (if SNI is present, it's probably best to just raise a TLS
// alert by not serving a certificate)
if name == "" {
for _, certKey := range cfg.Certificates {
defaulted = true
cert = cfg.certCache.cache[certKey]
return
}
}
return
}
// getCertDuringHandshake will get a certificate for name. It first tries
// the in-memory cache. If no certificate for name is in the cache, the
// config most closely corresponding to name will be loaded. If that config
@ -134,7 +214,7 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
// This function is safe for concurrent use.
func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
// First check our in-memory cache to see if we've already loaded it
cert, matched, defaulted := getCertificate(name)
cert, matched, defaulted := cfg.getCertificate(name)
if matched {
return cert, nil
}
@ -277,7 +357,7 @@ func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) {
obtainCertWaitChans[name] = wait
obtainCertWaitChansMu.Unlock()
// do the obtain
// obtain the certificate
log.Printf("[INFO] Obtaining new certificate for %s", name)
err := cfg.ObtainCert(name, false)
@ -336,9 +416,9 @@ func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certific
// quite common considering not all certs have issuer URLs that support it.
log.Printf("[ERROR] Getting OCSP for %s: %v", name, err)
}
certCacheMu.Lock()
certCache[name] = cert
certCacheMu.Unlock()
cfg.certCache.Lock()
cfg.certCache.cache[cert.Hash] = cert
cfg.certCache.Unlock()
}
}
@ -367,29 +447,22 @@ func (cfg *Config) renewDynamicCertificate(name string, currentCert Certificate)
obtainCertWaitChans[name] = wait
obtainCertWaitChansMu.Unlock()
// do the renew and reload the certificate
// renew and reload the certificate
log.Printf("[INFO] Renewing certificate for %s", name)
err := cfg.RenewCert(name, false)
if err == nil {
// immediately flush this certificate from the cache so
// the name doesn't overlap when we try to replace it,
// which would fail, because overlapping existing cert
// names isn't allowed
certCacheMu.Lock()
for _, certName := range currentCert.Names {
delete(certCache, certName)
}
certCacheMu.Unlock()
// even though the recursive nature of the dynamic cert loading
// would just call this function anyway, we do it here to
// make the replacement as atomic as possible. (TODO: similar
// to the note in maintain.go, it'd be nice if the clearing of
// the cache entries above and this load function were truly
// atomic...)
_, err := currentCert.Config.CacheManagedCertificate(name)
// make the replacement as atomic as possible.
newCert, err := currentCert.configs[0].CacheManagedCertificate(name)
if err != nil {
log.Printf("[ERROR] loading renewed certificate: %v", err)
log.Printf("[ERROR] loading renewed certificate for %s: %v", name, err)
} else {
// replace the old certificate with the new one
err = cfg.certCache.replaceCertificate(currentCert, newCert)
if err != nil {
log.Printf("[ERROR] Replacing certificate for %s: %v", name, err)
}
}
}

View file

@ -21,9 +21,8 @@ import (
)
func TestGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
cfg := new(Config)
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
@ -38,33 +37,40 @@ func TestGetCertificate(t *testing.T) {
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
}
// When cache has one certificate in it (also is default)
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
certCache[""] = defaultCert
certCache["example.com"] = defaultCert
// When cache has one certificate in it
firstCert := Certificate{Names: []string{"example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
cfg.cacheCertificate(firstCert)
if cert, err := cfg.GetCertificate(hello); err != nil {
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
}
if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
if _, err := cfg.GetCertificate(helloNoSNI); err != nil {
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
}
// When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
wildcardCert := Certificate{
Names: []string{"*.example.com"},
Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}},
Hash: "(don't overwrite the first one)",
}
cfg.cacheCertificate(wildcardCert)
if cert, err := cfg.GetCertificate(helloSub); err != nil {
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
}
// When no certificate matches, the default is returned
if cert, err := cfg.GetCertificate(helloNoMatch); err != nil {
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Expected default cert with no matches, got: %v", cert)
// When cache is NOT empty but there's no SNI
if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
t.Errorf("Expected random certificate with no error when no SNI, got err: %v", err)
} else if cert == nil || len(cert.Leaf.DNSNames) == 0 {
t.Errorf("Expected random cert with no matches, got: %v", cert)
}
// When no certificate matches, raise an alert
if _, err := cfg.GetCertificate(helloNoMatch); err == nil {
t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err)
}
}

View file

@ -27,10 +27,11 @@ import (
const challengeBasePath = "/.well-known/acme-challenge"
// HTTPChallengeHandler proxies challenge requests to ACME client if the
// request path starts with challengeBasePath. It returns true if it
// handled the request and no more needs to be done; it returns false
// if this call was a no-op and the request still needs handling.
func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, altPort string) bool {
// request path starts with challengeBasePath, if the HTTP challenge is not
// disabled, and if we are known to be obtaining a certificate for the name.
// It returns true if it handled the request and no more needs to be done;
// it returns false if this call was a no-op and the request still needs handling.
func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost string) bool {
if !strings.HasPrefix(r.URL.Path, challengeBasePath) {
return false
}
@ -50,7 +51,11 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, al
listenHost = "localhost"
}
upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, altPort))
// always proxy to the DefaultHTTPAlternatePort because obviously the
// ACME challenge request already got into one of our HTTP handlers, so
// it means we must have started a HTTP listener on the alternate
// port instead; which is only accessible via listenHost
upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, DefaultHTTPAlternatePort))
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Printf("[ERROR] ACME proxy handler: %v", err)

View file

@ -39,7 +39,7 @@ func TestHTTPChallengeHandlerNoOp(t *testing.T) {
t.Fatalf("Could not craft request, got error: %v", err)
}
rw := httptest.NewRecorder()
if HTTPChallengeHandler(rw, req, "", DefaultHTTPAlternatePort) {
if HTTPChallengeHandler(rw, req, "") {
t.Errorf("Got true with this URL, but shouldn't have: %s", url)
}
}
@ -76,7 +76,7 @@ func TestHTTPChallengeHandlerSuccess(t *testing.T) {
}
rw := httptest.NewRecorder()
HTTPChallengeHandler(rw, req, "", DefaultHTTPAlternatePort)
HTTPChallengeHandler(rw, req, "")
if !proxySuccess {
t.Fatal("Expected request to be proxied, but it wasn't")

View file

@ -87,103 +87,127 @@ func maintainAssets(stopChan chan struct{}) {
// RenewManagedCertificates renews managed certificates,
// including ones loaded on-demand.
func RenewManagedCertificates(allowPrompts bool) (err error) {
var renewQueue, deleteQueue []Certificate
visitedNames := make(map[string]struct{})
certCacheMu.RLock()
for name, cert := range certCache {
if !cert.Config.Managed || cert.Config.SelfSigned {
for _, inst := range caddy.Instances() {
inst.StorageMu.RLock()
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
inst.StorageMu.RUnlock()
if !ok || certCache == nil {
continue
}
// the list of names on this cert should never be empty...
if cert.Names == nil || len(cert.Names) == 0 {
log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v - removing from cache", name, cert.Names)
deleteQueue = append(deleteQueue, cert)
continue
}
// we use the queues for a very important reason: to do any and all
// operations that could require an exclusive write lock outside
// of the read lock! otherwise we get a deadlock, yikes. in other
// words, our first iteration through the certificate cache does NOT
// perform any operations--only queues them--so that more fine-grained
// write locks may be obtained during the actual operations.
var renewQueue, reloadQueue, deleteQueue []Certificate
// skip names whose certificate we've already renewed
if _, ok := visitedNames[name]; ok {
continue
}
for _, name := range cert.Names {
visitedNames[name] = struct{}{}
}
// if its time is up or ending soon, we need to try to renew it
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < RenewDurationBefore {
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
if cert.Config == nil {
log.Printf("[ERROR] %s: No associated TLS config; unable to renew", name)
certCache.RLock()
for certKey, cert := range certCache.cache {
if len(cert.configs) == 0 {
// this is bad if this happens, probably a programmer error (oops)
log.Printf("[ERROR] No associated TLS config for certificate with names %v; unable to manage", cert.Names)
continue
}
if !cert.configs[0].Managed || cert.configs[0].SelfSigned {
continue
}
// queue for renewal when we aren't in a read lock anymore
// (the TLS-SNI challenge will need a write lock in order to
// present the certificate, so we renew outside of read lock)
renewQueue = append(renewQueue, cert)
}
}
certCacheMu.RUnlock()
// Perform renewals that are queued
for _, cert := range renewQueue {
// Get the name which we should use to renew this certificate;
// we only support managing certificates with one name per cert,
// so this should be easy. We can't rely on cert.Config.Hostname
// because it may be a wildcard value from the Caddyfile (e.g.
// *.something.com) which, as of Jan. 2017, is not supported by ACME.
var renewName string
for _, name := range cert.Names {
if name != "" {
renewName = name
break
}
}
// perform renewal
err := cert.Config.RenewCert(renewName, allowPrompts)
if err != nil {
if allowPrompts {
// Certificate renewal failed and the operator is present. See a discussion
// about this in issue 642. For a while, we only stopped if the certificate
// was expired, but in reality, there is no difference between reporting
// it now versus later, except that there's somebody present to deal with
// it right now.
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < RenewDurationBeforeAtStartup {
// See issue 1680. Only fail at startup if the certificate is dangerously
// close to expiration.
return err
}
}
log.Printf("[ERROR] %v", err)
if cert.Config.OnDemand {
// loaded dynamically, removed dynamically
// the list of names on this cert should never be empty... programmer error?
if cert.Names == nil || len(cert.Names) == 0 {
log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v - removing from cache", certKey, cert.Names)
deleteQueue = append(deleteQueue, cert)
continue
}
} else {
// if time is up or expires soon, we need to try to renew it
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < RenewDurationBefore {
// see if the certificate in storage has already been renewed, possibly by another
// instance of Caddy that didn't coordinate with this one; if so, just load it (this
// might happen if another instance already renewed it - kinda sloppy but checking disk
// first is a simple way to possibly drastically reduce rate limit problems)
storedCertExpiring, err := managedCertInStorageExpiresSoon(cert)
if err != nil {
// hmm, weird, but not a big deal, maybe it was deleted or something
log.Printf("[NOTICE] Error while checking if certificate for %v in storage is also expiring soon: %v",
cert.Names, err)
} else if !storedCertExpiring {
// if the certificate is NOT expiring soon and there was no error, then we
// are good to just reload the certificate from storage instead of repeating
// a likely-unnecessary renewal procedure
reloadQueue = append(reloadQueue, cert)
continue
}
// the certificate in storage has not been renewed yet, so we will do it
// NOTE 1: This is not correct 100% of the time, if multiple Caddy instances
// happen to run their maintenance checks at approximately the same times;
// both might start renewal at about the same time and do two renewals and one
// will overwrite the other. Hence TLS storage plugins. This is sort of a TODO.
// NOTE 2: It is super-important to note that the TLS-SNI challenge requires
// a write lock on the cache in order to complete its challenge, so it is extra
// vital that this renew operation does not happen inside our read lock!
renewQueue = append(renewQueue, cert)
}
}
certCache.RUnlock()
// Reload certificates that merely need to be updated in memory
for _, oldCert := range reloadQueue {
timeLeft := oldCert.NotAfter.Sub(time.Now().UTC())
log.Printf("[INFO] Certificate for %v expires in %v, but is already renewed in storage; reloading stored certificate",
oldCert.Names, timeLeft)
err = certCache.reloadManagedCertificate(oldCert)
if err != nil {
if allowPrompts {
return err // operator is present, so report error immediately
}
log.Printf("[ERROR] Loading renewed certificate: %v", err)
}
}
// Renewal queue
for _, oldCert := range renewQueue {
timeLeft := oldCert.NotAfter.Sub(time.Now().UTC())
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", oldCert.Names, timeLeft)
// Get the name which we should use to renew this certificate;
// we only support managing certificates with one name per cert,
// so this should be easy. We can't rely on cert.Config.Hostname
// because it may be a wildcard value from the Caddyfile (e.g.
// *.something.com) which, as of Jan. 2017, is not supported by ACME.
// TODO: ^ ^ ^ (wildcards)
renewName := oldCert.Names[0]
// perform renewal
err := oldCert.configs[0].RenewCert(renewName, allowPrompts)
if err != nil {
if allowPrompts {
// Certificate renewal failed and the operator is present. See a discussion
// about this in issue 642. For a while, we only stopped if the certificate
// was expired, but in reality, there is no difference between reporting
// it now versus later, except that there's somebody present to deal with
// it right now. Follow-up: See issue 1680. Only fail in this case if the
// certificate is dangerously close to expiration.
timeLeft := oldCert.NotAfter.Sub(time.Now().UTC())
if timeLeft < RenewDurationBeforeAtStartup {
return err
}
}
log.Printf("[ERROR] %v", err)
if oldCert.configs[0].OnDemand {
// loaded dynamically, remove dynamically
deleteQueue = append(deleteQueue, oldCert)
}
continue
}
// successful renewal, so update in-memory cache by loading
// renewed certificate so it will be used with handshakes
// we must delete all the names this cert services from the cache
// so that we can replace the certificate, because replacing names
// already in the cache is not allowed, to avoid later conflicts
// with renewals.
// TODO: It would be nice if this whole operation were idempotent;
// i.e. a thread-safe function to replace a certificate in the cache,
// see also handshake.go for on-demand maintenance.
certCacheMu.Lock()
for _, name := range cert.Names {
delete(certCache, name)
}
certCacheMu.Unlock()
// put the certificate in the cache
_, err := cert.Config.CacheManagedCertificate(cert.Names[0])
err = certCache.reloadManagedCertificate(oldCert)
if err != nil {
if allowPrompts {
return err // operator is present, so report error immediately
@ -191,15 +215,22 @@ func RenewManagedCertificates(allowPrompts bool) (err error) {
log.Printf("[ERROR] %v", err)
}
}
}
// Apply queued deletion changes to the cache
for _, cert := range deleteQueue {
certCacheMu.Lock()
for _, name := range cert.Names {
delete(certCache, name)
// Deletion queue
for _, cert := range deleteQueue {
certCache.Lock()
// remove any pointers to this certificate from Configs
for _, cfg := range cert.configs {
for name, certKey := range cfg.Certificates {
if certKey == cert.Hash {
delete(cfg.Certificates, name)
}
}
}
// then delete the certificate from the cache
delete(certCache.cache, cert.Hash)
certCache.Unlock()
}
certCacheMu.Unlock()
}
return nil
@ -212,91 +243,75 @@ func RenewManagedCertificates(allowPrompts bool) (err error) {
// Ryan Sleevi's recommendations for good OCSP support:
// https://gist.github.com/sleevi/5efe9ef98961ecfb4da8
func UpdateOCSPStaples() {
// Create a temporary place to store updates
// until we release the potentially long-lived
// read lock and use a short-lived write lock.
type ocspUpdate struct {
rawBytes []byte
parsed *ocsp.Response
}
updated := make(map[string]ocspUpdate)
// A single SAN certificate maps to multiple names, so we use this
// set to make sure we don't waste cycles checking OCSP for the same
// certificate multiple times.
visited := make(map[string]struct{})
certCacheMu.RLock()
for name, cert := range certCache {
// skip this certificate if we've already visited it,
// and if not, mark all the names as visited
if _, ok := visited[name]; ok {
continue
}
for _, n := range cert.Names {
visited[n] = struct{}{}
}
// no point in updating OCSP for expired certificates
if time.Now().After(cert.NotAfter) {
for _, inst := range caddy.Instances() {
inst.StorageMu.RLock()
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
inst.StorageMu.RUnlock()
if !ok || certCache == nil {
continue
}
var lastNextUpdate time.Time
if cert.OCSP != nil {
lastNextUpdate = cert.OCSP.NextUpdate
if freshOCSP(cert.OCSP) {
// no need to update staple if ours is still fresh
// Create a temporary place to store updates
// until we release the potentially long-lived
// read lock and use a short-lived write lock
// on the certificate cache.
type ocspUpdate struct {
rawBytes []byte
parsed *ocsp.Response
}
updated := make(map[string]ocspUpdate)
certCache.RLock()
for certHash, cert := range certCache.cache {
// no point in updating OCSP for expired certificates
if time.Now().After(cert.NotAfter) {
continue
}
}
err := stapleOCSP(&cert, nil)
if err != nil {
var lastNextUpdate time.Time
if cert.OCSP != nil {
// if there was no staple before, that's fine; otherwise we should log the error
log.Printf("[ERROR] Checking OCSP: %v", err)
lastNextUpdate = cert.OCSP.NextUpdate
if freshOCSP(cert.OCSP) {
continue // no need to update staple if ours is still fresh
}
}
continue
}
// By this point, we've obtained the latest OCSP response.
// If there was no staple before, or if the response is updated, make
// sure we apply the update to all names on the certificate.
if cert.OCSP != nil && (lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate) {
log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
for _, n := range cert.Names {
// BUG: If this certificate has names on it that appear on another
// certificate in the cache, AND the other certificate is keyed by
// that name in the cache, then this method of 'queueing' the staple
// update will cause this certificate's new OCSP to be stapled to
// a different certificate! See:
// https://caddy.community/t/random-ocsp-response-errors-for-random-clients/2473?u=matt
// This problem should be avoided if names on certificates in the
// cache don't overlap with regards to the cache keys.
// (This is isn't a bug anymore, since we're careful when we add
// certificates to the cache by skipping keying when key already exists.)
updated[n] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
err := stapleOCSP(&cert, nil)
if err != nil {
if cert.OCSP != nil {
// if there was no staple before, that's fine; otherwise we should log the error
log.Printf("[ERROR] Checking OCSP: %v", err)
}
continue
}
// By this point, we've obtained the latest OCSP response.
// If there was no staple before, or if the response is updated, make
// sure we apply the update to all names on the certificate.
if cert.OCSP != nil && (lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate) {
log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
updated[certHash] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
}
}
}
certCacheMu.RUnlock()
certCache.RUnlock()
// This write lock should be brief since we have all the info we need now.
certCacheMu.Lock()
for name, update := range updated {
cert := certCache[name]
cert.OCSP = update.parsed
cert.Certificate.OCSPStaple = update.rawBytes
certCache[name] = cert
// These write locks should be brief since we have all the info we need now.
for certKey, update := range updated {
certCache.Lock()
cert := certCache.cache[certKey]
cert.OCSP = update.parsed
cert.Certificate.OCSPStaple = update.rawBytes
certCache.cache[certKey] = cert
certCache.Unlock()
}
}
certCacheMu.Unlock()
}
// DeleteOldStapleFiles deletes cached OCSP staples that have expired.
// TODO: Should we do this for certificates too?
func DeleteOldStapleFiles() {
// TODO: Upgrade caddytls.Storage to support OCSP operations too
files, err := ioutil.ReadDir(ocspFolder)
if err != nil {
// maybe just hasn't been created yet; no big deal

View file

@ -38,6 +38,7 @@ func init() {
// are specified by the user in the config file. All the automatic HTTPS
// stuff comes later outside of this function.
func setupTLS(c *caddy.Controller) error {
// obtain the configGetter, which loads the config we're, uh, configuring
configGetter, ok := configGetters[c.ServerType()]
if !ok {
return fmt.Errorf("no caddytls.ConfigGetter for %s server type; must call RegisterConfigGetter", c.ServerType())
@ -47,6 +48,14 @@ func setupTLS(c *caddy.Controller) error {
return fmt.Errorf("no caddytls.Config to set up for %s", c.Key)
}
// the certificate cache is tied to the current caddy.Instance; get a pointer to it
certCache, ok := c.Get(CertCacheInstStorageKey).(*certificateCache)
if !ok || certCache == nil {
certCache = &certificateCache{cache: make(map[string]Certificate)}
c.Set(CertCacheInstStorageKey, certCache)
}
config.certCache = certCache
config.Enabled = true
for c.Next() {
@ -237,7 +246,7 @@ func setupTLS(c *caddy.Controller) error {
// load a single certificate and key, if specified
if certificateFile != "" && keyFile != "" {
err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
err := config.cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
if err != nil {
return c.Errf("Unable to load certificate and key files for '%s': %v", c.Key, err)
}
@ -246,7 +255,7 @@ func setupTLS(c *caddy.Controller) error {
// load a directory of certificates, if specified
if loadDir != "" {
err := loadCertsInDir(c, loadDir)
err := loadCertsInDir(config, c, loadDir)
if err != nil {
return err
}
@ -273,7 +282,7 @@ func setupTLS(c *caddy.Controller) error {
// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
//
// This function may write to the log as it walks the directory tree.
func loadCertsInDir(c *caddy.Controller, dir string) error {
func loadCertsInDir(cfg *Config, c *caddy.Controller, dir string) error {
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
@ -336,7 +345,7 @@ func loadCertsInDir(c *caddy.Controller, dir string) error {
return c.Errf("%s: no private key block found", path)
}
err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
err = cfg.cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
if err != nil {
return c.Errf("%s: failed to load cert and key for '%s': %v", path, c.Key, err)
}

View file

@ -46,9 +46,12 @@ func TestMain(m *testing.M) {
}
func TestSetupParseBasic(t *testing.T) {
cfg := new(Config)
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", `tls `+certFile+` `+keyFile+``)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c)
if err != nil {
@ -124,9 +127,12 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
must_staple
alpn http/1.1
}`
cfg := new(Config)
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c)
if err != nil {
@ -158,9 +164,11 @@ func TestSetupDefaultWithOptionalParams(t *testing.T) {
params := `tls {
ciphers RSA-3DES-EDE-CBC-SHA
}`
cfg := new(Config)
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c)
if err != nil {
@ -176,9 +184,12 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
params := `tls ` + certFile + ` ` + keyFile + ` {
protocols ssl tls
}`
cfg := new(Config)
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c)
if err == nil {
t.Errorf("Expected errors, but no error returned")
@ -191,6 +202,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
cfg = new(Config)
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c = caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err = setupTLS(c)
if err == nil {
t.Error("Expected errors, but no error returned")
@ -215,6 +227,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
cfg = new(Config)
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c = caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err = setupTLS(c)
if err == nil {
t.Error("Expected errors, but no error returned")
@ -226,7 +239,8 @@ func TestSetupParseWithClientAuth(t *testing.T) {
params := `tls ` + certFile + ` ` + keyFile + ` {
clients
}`
cfg := new(Config)
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params)
err := setupTLS(c)
@ -259,9 +273,11 @@ func TestSetupParseWithClientAuth(t *testing.T) {
clients verify_if_given
}`, tls.VerifyClientCertIfGiven, true, noCAs},
} {
cfg := new(Config)
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", caseData.params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c)
if caseData.expectedErr {
if err == nil {
@ -311,9 +327,11 @@ func TestSetupParseWithCAUrl(t *testing.T) {
ca 1 2
}`, true, ""},
} {
cfg := new(Config)
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", caseData.params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c)
if caseData.expectedErr {
if err == nil {
@ -335,9 +353,11 @@ func TestSetupParseWithKeyType(t *testing.T) {
params := `tls {
key_type p384
}`
cfg := new(Config)
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c)
if err != nil {
@ -353,9 +373,11 @@ func TestSetupParseWithCurves(t *testing.T) {
params := `tls {
curves x25519 p256 p384 p521
}`
cfg := new(Config)
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c)
if err != nil {
@ -380,9 +402,11 @@ func TestSetupParseWithOneTLSProtocol(t *testing.T) {
params := `tls {
protocols tls1.2
}`
cfg := new(Config)
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c)
if err != nil {

View file

@ -107,6 +107,10 @@ type Storage interface {
// in StoreUser. The result is an empty string if there are no
// persisted users in storage.
MostRecentUserEmail() string
// Locker is necessary because synchronizing certificate maintenance
// depends on how storage is implemented.
Locker
}
// ErrNotExist is returned by Storage implementations when

View file

@ -1,57 +0,0 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"fmt"
"sync"
)
var _ Locker = &syncLock{}
type syncLock struct {
nameLocks map[string]*sync.WaitGroup
nameLocksMu sync.Mutex
}
// TryLock attempts to get a lock for name, otherwise it returns
// a Waiter value to wait until the other process is finished.
func (s *syncLock) TryLock(name string) (Waiter, error) {
s.nameLocksMu.Lock()
defer s.nameLocksMu.Unlock()
wg, ok := s.nameLocks[name]
if ok {
// lock already obtained, let caller wait on it
return wg, nil
}
// caller gets lock
wg = new(sync.WaitGroup)
wg.Add(1)
s.nameLocks[name] = wg
return nil, nil
}
// Unlock unlocks name.
func (s *syncLock) Unlock(name string) error {
s.nameLocksMu.Lock()
defer s.nameLocksMu.Unlock()
wg, ok := s.nameLocks[name]
if !ok {
return fmt.Errorf("FileStorage: no lock to release for %s", name)
}
wg.Done()
delete(s.nameLocks, name)
return nil
}

View file

@ -88,30 +88,38 @@ func Revoke(host string) error {
return client.Revoke(host)
}
// tlsSniSolver is a type that can solve tls-sni challenges using
// tlsSNISolver is a type that can solve TLS-SNI challenges using
// an existing listener and our custom, in-memory certificate cache.
type tlsSniSolver struct{}
type tlsSNISolver struct {
certCache *certificateCache
}
// Present adds the challenge certificate to the cache.
func (s tlsSniSolver) Present(domain, token, keyAuth string) error {
func (s tlsSNISolver) Present(domain, token, keyAuth string) error {
cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
if err != nil {
return err
}
cacheCertificate(Certificate{
certHash := hashCertificateChain(cert.Certificate)
s.certCache.Lock()
s.certCache.cache[acmeDomain] = Certificate{
Certificate: cert,
Names: []string{acmeDomain},
})
Hash: certHash, // perhaps not necesssary
}
s.certCache.Unlock()
return nil
}
// CleanUp removes the challenge certificate from the cache.
func (s tlsSniSolver) CleanUp(domain, token, keyAuth string) error {
func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error {
_, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
if err != nil {
return err
}
uncacheCertificate(acmeDomain)
s.certCache.Lock()
delete(s.certCache.cache, acmeDomain)
s.certCache.Unlock()
return nil
}

View file

@ -103,6 +103,20 @@ func (c *Controller) Context() Context {
return c.instance.context
}
// Get safely gets a value from the Instance's storage.
func (c *Controller) Get(key interface{}) interface{} {
c.instance.StorageMu.RLock()
defer c.instance.StorageMu.RUnlock()
return c.instance.Storage[key]
}
// Set safely sets a value on the Instance's storage.
func (c *Controller) Set(key, val interface{}) {
c.instance.StorageMu.Lock()
c.instance.Storage[key] = val
c.instance.StorageMu.Unlock()
}
// NewTestController creates a new Controller for
// the server type and input specified. The filename
// is "Testfile". If the server type is not empty and
@ -113,12 +127,12 @@ func (c *Controller) Context() Context {
// Used only for testing, but exported so plugins can
// use this for convenience.
func NewTestController(serverType, input string) *Controller {
var ctx Context
testInst := &Instance{serverType: serverType, Storage: make(map[interface{}]interface{})}
if stype, err := getServerType(serverType); err == nil {
ctx = stype.NewContext()
testInst.context = stype.NewContext(testInst)
}
return &Controller{
instance: &Instance{serverType: serverType, context: ctx},
instance: testInst,
Dispenser: caddyfile.NewDispenser("Testfile", strings.NewReader(input)),
OncePerServerBlock: func(f func() error) error { return f() },
}

View file

@ -91,6 +91,7 @@ Install the systemd service unit configuration file, reload the systemd daemon,
and start caddy:
```bash
wget https://raw.githubusercontent.com/mholt/caddy/master/dist/init/linux-systemd/caddy.service
sudo cp caddy.service /etc/systemd/system/
sudo chown root:root /etc/systemd/system/caddy.service
sudo chmod 644 /etc/systemd/system/caddy.service

View file

@ -30,8 +30,8 @@ LimitNPROC=512
; Use private /tmp and /var/tmp, which are discarded after caddy stops.
PrivateTmp=true
; Use a minimal /dev
PrivateDevices=true
; Use a minimal /dev (May bring additional security if switched to 'true', but it may not work on Raspberry Pi's or other devices, so it has been disabled in this dist.)
PrivateDevices=false
; Hide /home, /root, and /run/user. Nobody will steal your SSH-keys.
ProtectHome=true
; Make /usr, /boot, /etc and possibly some more folders read-only.

View file

@ -19,6 +19,7 @@ import (
"log"
"net"
"sort"
"sync"
"github.com/mholt/caddy/caddyfile"
)
@ -38,7 +39,7 @@ var (
// eventHooks is a map of hook name to Hook. All hooks plugins
// must have a name.
eventHooks = make(map[string]EventHook)
eventHooks = sync.Map{}
// parsingCallbacks maps server type to map of directive
// to list of callback functions. These aren't really
@ -98,11 +99,15 @@ func ListPlugins() map[string][]string {
p["caddyfile_loaders"] = append(p["caddyfile_loaders"], defaultCaddyfileLoader.name)
}
// event hook plugins
if len(eventHooks) > 0 {
for name := range eventHooks {
p["event_hooks"] = append(p["event_hooks"], name)
}
// List the event hook plugins
hooks := ""
eventHooks.Range(func(k, _ interface{}) bool {
hooks += " hook." + k.(string) + "\n"
return true
})
if hooks != "" {
str += "\nEvent hook plugins:\n"
str += hooks
}
// alphabetize the rest of the plugins
@ -220,7 +225,7 @@ type ServerType struct {
// startup phases before this one. It's a way to keep
// each set of server instances separate and to reduce
// the amount of global state you need.
NewContext func() Context
NewContext func(inst *Instance) Context
}
// Plugin is a type which holds information about a plugin.
@ -277,23 +282,23 @@ func RegisterEventHook(name string, hook EventHook) {
if name == "" {
panic("event hook must have a name")
}
if _, dup := eventHooks[name]; dup {
_, dup := eventHooks.LoadOrStore(name, hook)
if dup {
panic("hook named " + name + " already registered")
}
eventHooks[name] = hook
}
// EmitEvent executes the different hooks passing the EventType as an
// argument. This is a blocking function. Hook developers should
// use 'go' keyword if they don't want to block Caddy.
func EmitEvent(event EventName, info interface{}) {
for name, hook := range eventHooks {
err := hook(event, info)
eventHooks.Range(func(k, v interface{}) bool {
err := v.(EventHook)(event, info)
if err != nil {
log.Printf("error on '%s' hook: %v", name, err)
log.Printf("error on '%s' hook: %v", k.(string), err)
}
}
return true
})
}
// ParsingCallback is a function that is called after
@ -412,6 +417,14 @@ func loadCaddyfileInput(serverType string) (Input, error) {
return caddyfileToUse, nil
}
// OnProcessExit is a list of functions to run when the process
// exits -- they are ONLY for cleanup and should not block,
// return errors, or do anything fancy. They will be run with
// every signal, even if "shutdown callbacks" are not executed.
// This variable must only be modified in the main goroutine
// from init() functions.
var OnProcessExit []func()
// caddyfileLoader pairs the name of a loader to the loader.
type caddyfileLoader struct {
name string

View file

@ -44,16 +44,17 @@ func trapSignalsCrossPlatform() {
if i > 0 {
log.Println("[INFO] SIGINT: Force quit")
if PidFile != "" {
os.Remove(PidFile)
for _, f := range OnProcessExit {
f() // important cleanup actions only
}
os.Exit(2)
}
log.Println("[INFO] SIGINT: Shutting down")
if PidFile != "" {
os.Remove(PidFile)
// important cleanup actions before shutdown callbacks
for _, f := range OnProcessExit {
f()
}
go func() {

View file

@ -33,22 +33,22 @@ func trapSignalsPosix() {
switch sig {
case syscall.SIGQUIT:
log.Println("[INFO] SIGQUIT: Quitting process immediately")
if PidFile != "" {
os.Remove(PidFile)
for _, f := range OnProcessExit {
f() // only perform important cleanup actions
}
os.Exit(0)
case syscall.SIGTERM:
log.Println("[INFO] SIGTERM: Shutting down servers then terminating")
exitCode := executeShutdownCallbacks("SIGTERM")
for _, f := range OnProcessExit {
f() // only perform important cleanup actions
}
err := Stop()
if err != nil {
log.Printf("[ERROR] SIGTERM stop: %v", err)
exitCode = 3
}
if PidFile != "" {
os.Remove(PidFile)
}
os.Exit(exitCode)
case syscall.SIGUSR1:

21
vendor/github.com/bifurcation/mint/LICENSE.md generated vendored Normal file
View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016 Richard Barnes
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

99
vendor/github.com/bifurcation/mint/alert.go generated vendored Normal file
View file

@ -0,0 +1,99 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mint
import "strconv"
type Alert uint8
const (
// alert level
AlertLevelWarning = 1
AlertLevelError = 2
)
const (
AlertCloseNotify Alert = 0
AlertUnexpectedMessage Alert = 10
AlertBadRecordMAC Alert = 20
AlertDecryptionFailed Alert = 21
AlertRecordOverflow Alert = 22
AlertDecompressionFailure Alert = 30
AlertHandshakeFailure Alert = 40
AlertBadCertificate Alert = 42
AlertUnsupportedCertificate Alert = 43
AlertCertificateRevoked Alert = 44
AlertCertificateExpired Alert = 45
AlertCertificateUnknown Alert = 46
AlertIllegalParameter Alert = 47
AlertUnknownCA Alert = 48
AlertAccessDenied Alert = 49
AlertDecodeError Alert = 50
AlertDecryptError Alert = 51
AlertProtocolVersion Alert = 70
AlertInsufficientSecurity Alert = 71
AlertInternalError Alert = 80
AlertInappropriateFallback Alert = 86
AlertUserCanceled Alert = 90
AlertNoRenegotiation Alert = 100
AlertMissingExtension Alert = 109
AlertUnsupportedExtension Alert = 110
AlertCertificateUnobtainable Alert = 111
AlertUnrecognizedName Alert = 112
AlertBadCertificateStatsResponse Alert = 113
AlertBadCertificateHashValue Alert = 114
AlertUnknownPSKIdentity Alert = 115
AlertNoApplicationProtocol Alert = 120
AlertWouldBlock Alert = 254
AlertNoAlert Alert = 255
)
var alertText = map[Alert]string{
AlertCloseNotify: "close notify",
AlertUnexpectedMessage: "unexpected message",
AlertBadRecordMAC: "bad record MAC",
AlertDecryptionFailed: "decryption failed",
AlertRecordOverflow: "record overflow",
AlertDecompressionFailure: "decompression failure",
AlertHandshakeFailure: "handshake failure",
AlertBadCertificate: "bad certificate",
AlertUnsupportedCertificate: "unsupported certificate",
AlertCertificateRevoked: "revoked certificate",
AlertCertificateExpired: "expired certificate",
AlertCertificateUnknown: "unknown certificate",
AlertIllegalParameter: "illegal parameter",
AlertUnknownCA: "unknown certificate authority",
AlertAccessDenied: "access denied",
AlertDecodeError: "error decoding message",
AlertDecryptError: "error decrypting message",
AlertProtocolVersion: "protocol version not supported",
AlertInsufficientSecurity: "insufficient security level",
AlertInternalError: "internal error",
AlertInappropriateFallback: "inappropriate fallback",
AlertUserCanceled: "user canceled",
AlertMissingExtension: "missing extension",
AlertUnsupportedExtension: "unsupported extension",
AlertCertificateUnobtainable: "certificate unobtainable",
AlertUnrecognizedName: "unrecognized name",
AlertBadCertificateStatsResponse: "bad certificate status response",
AlertBadCertificateHashValue: "bad certificate hash value",
AlertUnknownPSKIdentity: "unknown PSK identity",
AlertNoApplicationProtocol: "no application protocol",
AlertNoRenegotiation: "no renegotiation",
AlertWouldBlock: "would have blocked",
AlertNoAlert: "no alert",
}
func (e Alert) String() string {
s, ok := alertText[e]
if ok {
return s
}
return "alert(" + strconv.Itoa(int(e)) + ")"
}
func (e Alert) Error() string {
return e.String()
}

View file

@ -0,0 +1,42 @@
package main
import (
"flag"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"github.com/bifurcation/mint"
)
var url string
func main() {
url := flag.String("url", "https://localhost:4430", "URL to send request")
flag.Parse()
mintdial := func(network, addr string) (net.Conn, error) {
return mint.Dial(network, addr, nil)
}
tr := &http.Transport{
DialTLS: mintdial,
DisableCompression: true,
}
client := &http.Client{Transport: tr}
response, err := client.Get(*url)
if err != nil {
fmt.Println("err:", err)
return
}
defer response.Body.Close()
contents, err := ioutil.ReadAll(response.Body)
if err != nil {
fmt.Printf("%s", err)
os.Exit(1)
}
fmt.Printf("%s\n", string(contents))
}

View file

@ -0,0 +1,37 @@
package main
import (
"flag"
"fmt"
"github.com/bifurcation/mint"
)
var addr string
func main() {
flag.StringVar(&addr, "addr", "localhost:4430", "port")
flag.Parse()
conn, err := mint.Dial("tcp", addr, nil)
if err != nil {
fmt.Println("TLS handshake failed:", err)
return
}
request := "GET / HTTP/1.0\r\n\r\n"
conn.Write([]byte(request))
response := ""
buffer := make([]byte, 1024)
var read int
for err == nil {
read, err = conn.Read(buffer)
fmt.Println(" ~~ read: ", read)
response += string(buffer)
}
fmt.Println("err:", err)
fmt.Println("Received from server:")
fmt.Println(response)
}

View file

@ -0,0 +1,226 @@
package main
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"io/ioutil"
"log"
"net/http"
"github.com/bifurcation/mint"
"golang.org/x/net/http2"
)
var (
port string
serverName string
certFile string
keyFile string
responseFile string
h2 bool
sendTickets bool
)
type responder []byte
func (rsp responder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Write(rsp)
}
// ParsePrivateKeyDER parses a PKCS #1, PKCS #8, or elliptic curve
// PEM-encoded private key.
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) {
keyDER, _ := pem.Decode(keyPEM)
if keyDER == nil {
return nil, err
}
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER.Bytes)
if err != nil {
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER.Bytes)
if err != nil {
generalKey, err = x509.ParseECPrivateKey(keyDER.Bytes)
if err != nil {
// We don't include the actual error into
// the final error. The reason might be
// we don't want to leak any info about
// the private key.
return nil, fmt.Errorf("No successful private key decoder")
}
}
}
switch generalKey.(type) {
case *rsa.PrivateKey:
return generalKey.(*rsa.PrivateKey), nil
case *ecdsa.PrivateKey:
return generalKey.(*ecdsa.PrivateKey), nil
}
// should never reach here
return nil, fmt.Errorf("Should be unreachable")
}
// ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object,
// either a raw x509 certificate or a PKCS #7 structure possibly containing
// multiple certificates, from the top of certsPEM, which itself may
// contain multiple PEM encoded certificate objects.
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) {
block, rest := pem.Decode(certsPEM)
if block == nil {
return nil, rest, nil
}
cert, err := x509.ParseCertificate(block.Bytes)
var certs = []*x509.Certificate{cert}
return certs, rest, err
}
// ParseCertificatesPEM parses a sequence of PEM-encoded certificate and returns them,
// can handle PEM encoded PKCS #7 structures.
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
var certs []*x509.Certificate
var err error
certsPEM = bytes.TrimSpace(certsPEM)
for len(certsPEM) > 0 {
var cert []*x509.Certificate
cert, certsPEM, err = ParseOneCertificateFromPEM(certsPEM)
if err != nil {
return nil, err
} else if cert == nil {
break
}
certs = append(certs, cert...)
}
if len(certsPEM) > 0 {
return nil, fmt.Errorf("Trailing PEM data")
}
return certs, nil
}
func main() {
flag.StringVar(&port, "port", "4430", "port")
flag.StringVar(&serverName, "host", "example.com", "hostname")
flag.StringVar(&certFile, "cert", "", "certificate chain in PEM or DER")
flag.StringVar(&keyFile, "key", "", "private key in PEM format")
flag.StringVar(&responseFile, "response", "", "file to serve")
flag.BoolVar(&h2, "h2", false, "whether to use HTTP/2 (exclusively)")
flag.BoolVar(&sendTickets, "tickets", true, "whether to send session tickets")
flag.Parse()
var certChain []*x509.Certificate
var priv crypto.Signer
var response []byte
var err error
// Load the key and certificate chain
if certFile != "" {
certs, err := ioutil.ReadFile(certFile)
if err != nil {
log.Fatalf("Error: %v", err)
} else {
certChain, err = ParseCertificatesPEM(certs)
if err != nil {
certChain, err = x509.ParseCertificates(certs)
if err != nil {
log.Fatalf("Error parsing certificates: %v", err)
}
}
}
}
if keyFile != "" {
keyPEM, err := ioutil.ReadFile(keyFile)
if err != nil {
log.Fatalf("Error: %v", err)
} else {
priv, err = ParsePrivateKeyPEM(keyPEM)
if priv == nil || err != nil {
log.Fatalf("Error parsing private key: %v", err)
}
}
}
if err != nil {
log.Fatalf("Error: %v", err)
}
// Load response file
if responseFile != "" {
log.Printf("Loading response file: %v", responseFile)
response, err = ioutil.ReadFile(responseFile)
if err != nil {
log.Fatalf("Error: %v", err)
}
} else {
response = []byte("Welcome to the TLS 1.3 zone!")
}
handler := responder(response)
config := mint.Config{
SendSessionTickets: true,
ServerName: serverName,
NextProtos: []string{"http/1.1"},
}
if h2 {
config.NextProtos = []string{"h2"}
}
config.SendSessionTickets = sendTickets
if certChain != nil && priv != nil {
log.Printf("Loading cert: %v key: %v", certFile, keyFile)
config.Certificates = []*mint.Certificate{
{
Chain: certChain,
PrivateKey: priv,
},
}
}
config.Init(false)
service := "0.0.0.0:" + port
srv := &http.Server{Handler: handler}
log.Printf("Listening on port %v", port)
// Need the inner loop here because the h1 server errors on a dropped connection
// Need the outer loop here because the h2 server is per-connection
for {
listener, err := mint.Listen("tcp", service, &config)
if err != nil {
log.Printf("Listen Error: %v", err)
continue
}
if !h2 {
alert := srv.Serve(listener)
if alert != mint.AlertNoAlert {
log.Printf("Serve Error: %v", err)
}
} else {
srv2 := new(http2.Server)
opts := &http2.ServeConnOpts{
Handler: handler,
BaseConfig: srv,
}
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Accept error: %v", err)
continue
}
go srv2.ServeConn(conn, opts)
}
}
}
}

View file

@ -0,0 +1,65 @@
package main
import (
"flag"
"log"
"net"
"github.com/bifurcation/mint"
)
var port string
func main() {
var config mint.Config
config.SendSessionTickets = true
config.ServerName = "localhost"
config.Init(false)
flag.StringVar(&port, "port", "4430", "port")
flag.Parse()
service := "0.0.0.0:" + port
listener, err := mint.Listen("tcp", service, &config)
if err != nil {
log.Fatalf("server: listen: %s", err)
}
log.Print("server: listening")
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("server: accept: %s", err)
break
}
defer conn.Close()
log.Printf("server: accepted from %s", conn.RemoteAddr())
go handleClient(conn)
}
}
func handleClient(conn net.Conn) {
defer conn.Close()
buf := make([]byte, 10)
for {
log.Print("server: conn: waiting")
n, err := conn.Read(buf)
if err != nil {
if err != nil {
log.Printf("server: conn: read: %s", err)
}
break
}
n, err = conn.Write([]byte("hello world"))
log.Printf("server: conn: wrote %d bytes", n)
if err != nil {
log.Printf("server: write: %s", err)
break
}
break
}
log.Println("server: conn: closed")
}

View file

@ -0,0 +1,942 @@
package mint
import (
"bytes"
"crypto"
"hash"
"time"
)
// Client State Machine
//
// START <----+
// Send ClientHello | | Recv HelloRetryRequest
// / v |
// | WAIT_SH ---+
// Can | | Recv ServerHello
// send | V
// early | WAIT_EE
// data | | Recv EncryptedExtensions
// | +--------+--------+
// | Using | | Using certificate
// | PSK | v
// | | WAIT_CERT_CR
// | | Recv | | Recv CertificateRequest
// | | Certificate | v
// | | | WAIT_CERT
// | | | | Recv Certificate
// | | v v
// | | WAIT_CV
// | | | Recv CertificateVerify
// | +> WAIT_FINISHED <+
// | | Recv Finished
// \ |
// | [Send EndOfEarlyData]
// | [Send Certificate [+ CertificateVerify]]
// | Send Finished
// Can send v
// app data --> CONNECTED
// after
// here
//
// State Instructions
// START Send(CH); [RekeyOut; SendEarlyData]
// WAIT_SH Send(CH) || RekeyIn
// WAIT_EE {}
// WAIT_CERT_CR {}
// WAIT_CERT {}
// WAIT_CV {}
// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut;
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
type ClientStateStart struct {
Caps Capabilities
Opts ConnectionOptions
Params ConnectionParameters
cookie []byte
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
}
func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm != nil {
logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message")
return nil, nil, AlertUnexpectedMessage
}
// key_shares
offeredDH := map[NamedGroup][]byte{}
ks := KeyShareExtension{
HandshakeType: HandshakeTypeClientHello,
Shares: make([]KeyShareEntry, len(state.Caps.Groups)),
}
for i, group := range state.Caps.Groups {
pub, priv, err := newKeyShare(group)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err)
return nil, nil, AlertInternalError
}
ks.Shares[i].Group = group
ks.Shares[i].KeyExchange = pub
offeredDH[group] = priv
}
logf(logTypeHandshake, "opts: %+v", state.Opts)
// supported_versions, supported_groups, signature_algorithms, server_name
sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}}
sni := ServerNameExtension(state.Opts.ServerName)
sg := SupportedGroupsExtension{Groups: state.Caps.Groups}
sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
state.Params.ServerName = state.Opts.ServerName
// Application Layer Protocol Negotiation
var alpn *ALPNExtension
if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) {
alpn = &ALPNExtension{Protocols: state.Opts.NextProtos}
}
// Construct base ClientHello
ch := &ClientHelloBody{
CipherSuites: state.Caps.CipherSuites,
}
_, err := prng.Read(ch.Random[:])
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err)
return nil, nil, AlertInternalError
}
for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} {
err := ch.Extensions.Add(ext)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err)
return nil, nil, AlertInternalError
}
}
// XXX: These optional extensions can't be folded into the above because Go
// interface-typed values are never reported as nil
if alpn != nil {
err := ch.Extensions.Add(alpn)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
return nil, nil, AlertInternalError
}
}
if state.cookie != nil {
err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie})
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
return nil, nil, AlertInternalError
}
}
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
}
}
// Handle PSK and EarlyData just before transmitting, so that we can
// calculate the PSK binder value
var psk *PreSharedKeyExtension
var ed *EarlyDataExtension
var offeredPSK PreSharedKey
var earlyHash crypto.Hash
var earlySecret []byte
var clientEarlyTrafficKeys keySet
var clientHello *HandshakeMessage
if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok {
offeredPSK = key
// Narrow ciphersuites to ones that match PSK hash
params, ok := cipherSuiteMap[key.CipherSuite]
if !ok {
logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite")
return nil, nil, AlertInternalError
}
compatibleSuites := []CipherSuite{}
for _, suite := range ch.CipherSuites {
if cipherSuiteMap[suite].Hash == params.Hash {
compatibleSuites = append(compatibleSuites, suite)
}
}
ch.CipherSuites = compatibleSuites
// Signal early data if we're going to do it
if len(state.Opts.EarlyData) > 0 {
state.Params.ClientSendingEarlyData = true
ed = &EarlyDataExtension{}
err = ch.Extensions.Add(ed)
if err != nil {
logf(logTypeHandshake, "Error adding early data extension: %v", err)
return nil, nil, AlertInternalError
}
}
// Signal supported PSK key exchange modes
if len(state.Caps.PSKModes) == 0 {
logf(logTypeHandshake, "PSK selected, but no PSKModes")
return nil, nil, AlertInternalError
}
kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes}
err = ch.Extensions.Add(kem)
if err != nil {
logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err)
return nil, nil, AlertInternalError
}
// Add the shim PSK extension to the ClientHello
logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity)
psk = &PreSharedKeyExtension{
HandshakeType: HandshakeTypeClientHello,
Identities: []PSKIdentity{
{
Identity: key.Identity,
ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd,
},
},
Binders: []PSKBinderEntry{
// Note: Stub to get the length fields right
{Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())},
},
}
ch.Extensions.Add(psk)
// Compute the binder key
h0 := params.Hash.New().Sum(nil)
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
earlyHash = params.Hash
earlySecret = HkdfExtract(params.Hash, zero, key.Key)
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
binderLabel := labelExternalBinder
if key.IsResumption {
binderLabel = labelResumptionBinder
}
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey)
// Compute the binder value
trunc, err := ch.Truncated()
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err)
return nil, nil, AlertInternalError
}
truncHash := params.Hash.New()
truncHash.Write(trunc)
binder := computeFinishedData(params, binderKey, truncHash.Sum(nil))
// Replace the PSK extension
psk.Binders[0].Binder = binder
ch.Extensions.Add(psk)
// If we got here, the earlier marshal succeeded (in ch.Truncated()), so
// this one should too.
clientHello, _ = HandshakeMessageFromBody(ch)
// Compute early traffic keys
h := params.Hash.New()
h.Write(clientHello.Marshal())
chHash := h.Sum(nil)
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
} else if len(state.Opts.EarlyData) > 0 {
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK")
return nil, nil, AlertInternalError
} else {
clientHello, err = HandshakeMessageFromBody(ch)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err)
return nil, nil, AlertInternalError
}
}
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]")
nextState := ClientStateWaitSH{
Caps: state.Caps,
Opts: state.Opts,
Params: state.Params,
OfferedDH: offeredDH,
OfferedPSK: offeredPSK,
earlySecret: earlySecret,
earlyHash: earlyHash,
firstClientHello: state.firstClientHello,
helloRetryRequest: state.helloRetryRequest,
clientHello: clientHello,
}
toSend := []HandshakeAction{
SendHandshakeMessage{clientHello},
}
if state.Params.ClientSendingEarlyData {
toSend = append(toSend, []HandshakeAction{
RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys},
SendEarlyData{},
}...)
}
return nextState, toSend, AlertNoAlert
}
type ClientStateWaitSH struct {
Caps Capabilities
Opts ConnectionOptions
Params ConnectionParameters
OfferedDH map[NamedGroup][]byte
OfferedPSK PreSharedKey
PSK []byte
earlySecret []byte
earlyHash crypto.Hash
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
clientHello *HandshakeMessage
}
func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message")
return nil, nil, AlertUnexpectedMessage
}
bodyGeneric, err := hm.ToBody()
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
switch body := bodyGeneric.(type) {
case *HelloRetryRequestBody:
hrr := body
if state.helloRetryRequest != nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest")
return nil, nil, AlertUnexpectedMessage
}
// Check that the version sent by the server is the one we support
if hrr.Version != supportedVersion {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version)
return nil, nil, AlertProtocolVersion
}
// Check that the server provided a supported ciphersuite
supportedCipherSuite := false
for _, suite := range state.Caps.CipherSuites {
supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite)
}
if !supportedCipherSuite {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite)
return nil, nil, AlertHandshakeFailure
}
// Narrow the supported ciphersuites to the server-provided one
state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite}
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
// The only thing we know how to respond to in an HRR is the Cookie
// extension, so if there is either no Cookie extension or anything other
// than a Cookie extension, we have to fail.
serverCookie := new(CookieExtension)
foundCookie := hrr.Extensions.Find(serverCookie)
if !foundCookie || len(hrr.Extensions) != 1 {
logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions))
return nil, nil, AlertIllegalParameter
}
// Hash the body into a pseudo-message
// XXX: Ignoring some errors here
params := cipherSuiteMap[hrr.CipherSuite]
h := params.Hash.New()
h.Write(state.clientHello.Marshal())
firstClientHello := &HandshakeMessage{
msgType: HandshakeTypeMessageHash,
body: h.Sum(nil),
}
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
return ClientStateStart{
Caps: state.Caps,
Opts: state.Opts,
cookie: serverCookie.Cookie,
firstClientHello: firstClientHello,
helloRetryRequest: hm,
}.Next(nil)
case *ServerHelloBody:
sh := body
// Check that the version sent by the server is the one we support
if sh.Version != supportedVersion {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version)
return nil, nil, AlertProtocolVersion
}
// Check that the server provided a supported ciphersuite
supportedCipherSuite := false
for _, suite := range state.Caps.CipherSuites {
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
}
if !supportedCipherSuite {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
return nil, nil, AlertHandshakeFailure
}
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
// Do PSK or key agreement depending on extensions
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello}
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello}
foundPSK := sh.Extensions.Find(&serverPSK)
foundKeyShare := sh.Extensions.Find(&serverKeyShare)
if foundPSK && (serverPSK.SelectedIdentity == 0) {
state.Params.UsingPSK = true
}
var dhSecret []byte
if foundKeyShare {
sks := serverKeyShare.Shares[0]
priv, ok := state.OfferedDH[sks.Group]
if !ok {
logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group")
return nil, nil, AlertIllegalParameter
}
state.Params.UsingDH = true
dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv)
}
suite := sh.CipherSuite
state.Params.CipherSuite = suite
params, ok := cipherSuiteMap[suite]
if !ok {
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite)
return nil, nil, AlertHandshakeFailure
}
// Start up the handshake hash
handshakeHash := params.Hash.New()
handshakeHash.Write(state.firstClientHello.Marshal())
handshakeHash.Write(state.helloRetryRequest.Marshal())
handshakeHash.Write(state.clientHello.Marshal())
handshakeHash.Write(hm.Marshal())
// Compute handshake secrets
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
var earlySecret []byte
if state.Params.UsingPSK {
if params.Hash != state.earlyHash {
logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]",
state.earlyHash, suite, params.Hash)
}
earlySecret = state.earlySecret
} else {
earlySecret = HkdfExtract(params.Hash, zero, zero)
}
if dhSecret == nil {
dhSecret = zero
}
h0 := params.Hash.New().Sum(nil)
h2 := handshakeHash.Sum(nil)
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret)
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
nextState := ClientStateWaitEE{
Caps: state.Caps,
Params: state.Params,
cryptoParams: params,
handshakeHash: handshakeHash,
certificates: state.Caps.Certificates,
masterSecret: masterSecret,
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
}
toSend := []HandshakeAction{
RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys},
}
return nextState, toSend, AlertNoAlert
}
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType)
return nil, nil, AlertUnexpectedMessage
}
type ClientStateWaitEE struct {
Caps Capabilities
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions {
logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
ee := EncryptedExtensionsBody{}
_, err := ee.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
serverALPN := ALPNExtension{}
serverEarlyData := EarlyDataExtension{}
gotALPN := ee.Extensions.Find(&serverALPN)
state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData)
if gotALPN && len(serverALPN.Protocols) > 0 {
state.Params.NextProto = serverALPN.Protocols[0]
}
state.handshakeHash.Write(hm.Marshal())
if state.Params.UsingPSK {
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
nextState := ClientStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
nextState := ClientStateWaitCertCR{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
type ClientStateWaitCertCR struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil {
logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
bodyGeneric, err := hm.ToBody()
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
state.handshakeHash.Write(hm.Marshal())
switch body := bodyGeneric.(type) {
case *CertificateBody:
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]")
nextState := ClientStateWaitCV{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificate: body,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
case *CertificateRequestBody:
// A certificate request in the handshake should have a zero-length context
if len(body.CertificateRequestContext) > 0 {
logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err)
return nil, nil, AlertIllegalParameter
}
state.Params.UsingClientAuth = true
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]")
nextState := ClientStateWaitCert{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificateRequest: body,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
return nil, nil, AlertUnexpectedMessage
}
type ClientStateWaitCert struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
serverCertificateRequest *CertificateRequestBody
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeCertificate {
logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
cert := &CertificateBody{}
_, err := cert.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
state.handshakeHash.Write(hm.Marshal())
logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]")
nextState := ClientStateWaitCV{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificate: cert,
serverCertificateRequest: state.serverCertificateRequest,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
type ClientStateWaitCV struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
serverCertificate *CertificateBody
serverCertificateRequest *CertificateRequestBody
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
certVerify := CertificateVerifyBody{}
_, err := certVerify.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
hcv := state.handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey
if err := certVerify.Verify(serverPublicKey, hcv); err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify")
return nil, nil, AlertHandshakeFailure
}
if state.AuthCertificate != nil {
err := state.AuthCertificate(state.serverCertificate.CertificateList)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate")
return nil, nil, AlertBadCertificate
}
} else {
logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate")
}
state.handshakeHash.Write(hm.Marshal())
logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]")
nextState := ClientStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificateRequest: state.serverCertificateRequest,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
type ClientStateWaitFinished struct {
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
serverCertificateRequest *CertificateRequestBody
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeFinished {
logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
// Verify server's Finished
h3 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3)
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)}
_, err := fin.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
if !bytes.Equal(fin.VerifyData, serverFinishedData) {
logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]",
fin.VerifyData, serverFinishedData)
return nil, nil, AlertHandshakeFailure
}
// Update the handshake hash with the Finished
state.handshakeHash.Write(hm.Marshal())
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal())
h4 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4)
// Compute traffic secrets and keys
clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4)
serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4)
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret)
serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret)
exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4)
logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
// Assemble client's second flight
toSend := []HandshakeAction{}
if state.Params.UsingEarlyData {
// Note: We only send EOED if the server is actually going to use the early
// data. Otherwise, it will never see it, and the transcripts will
// mismatch.
// EOED marshal is infallible
eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{})
toSend = append(toSend, SendHandshakeMessage{eoedm})
state.handshakeHash.Write(eoedm.Marshal())
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
}
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys})
if state.Params.UsingClientAuth {
// Extract constraints from certicateRequest
schemes := SignatureAlgorithmsExtension{}
gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes)
if !gotSchemes {
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
return nil, nil, AlertIllegalParameter
}
// Select a certificate
cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates)
if err != nil {
// XXX: Signal this to the application layer?
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
certificate := &CertificateBody{}
certm, err := HandshakeMessageFromBody(certificate)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certm})
state.handshakeHash.Write(certm.Marshal())
} else {
// Create and send Certificate, CertificateVerify
certificate := &CertificateBody{
CertificateList: make([]CertificateEntry, len(cert.Chain)),
}
for i, entry := range cert.Chain {
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
}
certm, err := HandshakeMessageFromBody(certificate)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certm})
state.handshakeHash.Write(certm.Marshal())
hcv := state.handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
certificateVerify := &CertificateVerifyBody{Algorithm: certScheme}
logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash)
err = certificateVerify.Sign(cert.PrivateKey, hcv)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err)
return nil, nil, AlertInternalError
}
certvm, err := HandshakeMessageFromBody(certificateVerify)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certvm})
state.handshakeHash.Write(certvm.Marshal())
}
}
// Compute the client's Finished message
h5 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
fin = &FinishedBody{
VerifyDataLen: len(clientFinishedData),
VerifyData: clientFinishedData,
}
finm, err := HandshakeMessageFromBody(fin)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err)
return nil, nil, AlertInternalError
}
// Compute the resumption secret
state.handshakeHash.Write(finm.Marshal())
h6 := state.handshakeHash.Sum(nil)
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
toSend = append(toSend, []HandshakeAction{
SendHandshakeMessage{finm},
RekeyIn{Label: "application", KeySet: serverTrafficKeys},
RekeyOut{Label: "application", KeySet: clientTrafficKeys},
}...)
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
nextState := StateConnected{
Params: state.Params,
isClient: true,
cryptoParams: state.cryptoParams,
resumptionSecret: resumptionSecret,
clientTrafficSecret: clientTrafficSecret,
serverTrafficSecret: serverTrafficSecret,
exporterSecret: exporterSecret,
}
return nextState, toSend, AlertNoAlert
}

152
vendor/github.com/bifurcation/mint/common.go generated vendored Normal file
View file

@ -0,0 +1,152 @@
package mint
import (
"fmt"
"strconv"
)
var (
supportedVersion uint16 = 0x7f15 // draft-21
// Flags for some minor compat issues
allowWrongVersionNumber = true
allowPKCS1 = true
)
// enum {...} ContentType;
type RecordType byte
const (
RecordTypeAlert RecordType = 21
RecordTypeHandshake RecordType = 22
RecordTypeApplicationData RecordType = 23
)
// enum {...} HandshakeType;
type HandshakeType byte
const (
// Omitted: *_RESERVED
HandshakeTypeClientHello HandshakeType = 1
HandshakeTypeServerHello HandshakeType = 2
HandshakeTypeNewSessionTicket HandshakeType = 4
HandshakeTypeEndOfEarlyData HandshakeType = 5
HandshakeTypeHelloRetryRequest HandshakeType = 6
HandshakeTypeEncryptedExtensions HandshakeType = 8
HandshakeTypeCertificate HandshakeType = 11
HandshakeTypeCertificateRequest HandshakeType = 13
HandshakeTypeCertificateVerify HandshakeType = 15
HandshakeTypeServerConfiguration HandshakeType = 17
HandshakeTypeFinished HandshakeType = 20
HandshakeTypeKeyUpdate HandshakeType = 24
HandshakeTypeMessageHash HandshakeType = 254
)
// uint8 CipherSuite[2];
type CipherSuite uint16
const (
// XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero
// value for this type so that we can detect when a field is set.
CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000
TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301
TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302
TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303
TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304
TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305
)
func (c CipherSuite) String() string {
switch c {
case CIPHER_SUITE_UNKNOWN:
return "unknown"
case TLS_AES_128_GCM_SHA256:
return "TLS_AES_128_GCM_SHA256"
case TLS_AES_256_GCM_SHA384:
return "TLS_AES_256_GCM_SHA384"
case TLS_CHACHA20_POLY1305_SHA256:
return "TLS_CHACHA20_POLY1305_SHA256"
case TLS_AES_128_CCM_SHA256:
return "TLS_AES_128_CCM_SHA256"
case TLS_AES_256_CCM_8_SHA256:
return "TLS_AES_256_CCM_8_SHA256"
}
// cannot use %x here, since it calls String(), leading to infinite recursion
return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16))
}
// enum {...} SignatureScheme
type SignatureScheme uint16
const (
// RSASSA-PKCS1-v1_5 algorithms
RSA_PKCS1_SHA1 SignatureScheme = 0x0201
RSA_PKCS1_SHA256 SignatureScheme = 0x0401
RSA_PKCS1_SHA384 SignatureScheme = 0x0501
RSA_PKCS1_SHA512 SignatureScheme = 0x0601
// ECDSA algorithms
ECDSA_P256_SHA256 SignatureScheme = 0x0403
ECDSA_P384_SHA384 SignatureScheme = 0x0503
ECDSA_P521_SHA512 SignatureScheme = 0x0603
// RSASSA-PSS algorithms
RSA_PSS_SHA256 SignatureScheme = 0x0804
RSA_PSS_SHA384 SignatureScheme = 0x0805
RSA_PSS_SHA512 SignatureScheme = 0x0806
// EdDSA algorithms
Ed25519 SignatureScheme = 0x0807
Ed448 SignatureScheme = 0x0808
)
// enum {...} ExtensionType
type ExtensionType uint16
const (
ExtensionTypeServerName ExtensionType = 0
ExtensionTypeSupportedGroups ExtensionType = 10
ExtensionTypeSignatureAlgorithms ExtensionType = 13
ExtensionTypeALPN ExtensionType = 16
ExtensionTypeKeyShare ExtensionType = 40
ExtensionTypePreSharedKey ExtensionType = 41
ExtensionTypeEarlyData ExtensionType = 42
ExtensionTypeSupportedVersions ExtensionType = 43
ExtensionTypeCookie ExtensionType = 44
ExtensionTypePSKKeyExchangeModes ExtensionType = 45
ExtensionTypeTicketEarlyDataInfo ExtensionType = 46
)
// enum {...} NamedGroup
type NamedGroup uint16
const (
// Elliptic Curve Groups.
P256 NamedGroup = 23
P384 NamedGroup = 24
P521 NamedGroup = 25
// ECDH functions.
X25519 NamedGroup = 29
X448 NamedGroup = 30
// Finite field groups.
FFDHE2048 NamedGroup = 256
FFDHE3072 NamedGroup = 257
FFDHE4096 NamedGroup = 258
FFDHE6144 NamedGroup = 259
FFDHE8192 NamedGroup = 260
)
// enum {...} PskKeyExchangeMode;
type PSKKeyExchangeMode uint8
const (
PSKModeKE PSKKeyExchangeMode = 0
PSKModeDHEKE PSKKeyExchangeMode = 1
)
// enum {
// update_not_requested(0), update_requested(1), (255)
// } KeyUpdateRequest;
type KeyUpdateRequest uint8
const (
KeyUpdateNotRequested KeyUpdateRequest = 0
KeyUpdateRequested KeyUpdateRequest = 1
)

819
vendor/github.com/bifurcation/mint/conn.go generated vendored Normal file
View file

@ -0,0 +1,819 @@
package mint
import (
"crypto"
"crypto/x509"
"encoding/hex"
"fmt"
"io"
"net"
"reflect"
"sync"
"time"
)
var WouldBlock = fmt.Errorf("Would have blocked")
type Certificate struct {
Chain []*x509.Certificate
PrivateKey crypto.Signer
}
type PreSharedKey struct {
CipherSuite CipherSuite
IsResumption bool
Identity []byte
Key []byte
NextProto string
ReceivedAt time.Time
ExpiresAt time.Time
TicketAgeAdd uint32
}
type PreSharedKeyCache interface {
Get(string) (PreSharedKey, bool)
Put(string, PreSharedKey)
Size() int
}
type PSKMapCache map[string]PreSharedKey
// A CookieHandler does two things:
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
// - validates this byte string echoed by the client in the ClientHello
type CookieHandler interface {
Generate(*Conn) ([]byte, error)
Validate(*Conn, []byte) bool
}
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
psk, ok = cache[key]
return
}
func (cache *PSKMapCache) Put(key string, psk PreSharedKey) {
(*cache)[key] = psk
}
func (cache PSKMapCache) Size() int {
return len(cache)
}
// Config is the struct used to pass configuration settings to a TLS client or
// server instance. The settings for client and server are pretty different,
// but we just throw them all in here.
type Config struct {
// Client fields
ServerName string
// Server fields
SendSessionTickets bool
TicketLifetime uint32
TicketLen int
EarlyDataLifetime uint32
AllowEarlyData bool
// Require the client to echo a cookie.
RequireCookie bool
// If cookies are required and no CookieHandler is set, a default cookie handler is used.
// The default cookie handler uses 32 random bytes as a cookie.
CookieHandler CookieHandler
RequireClientAuth bool
// Shared fields
Certificates []*Certificate
AuthCertificate func(chain []CertificateEntry) error
CipherSuites []CipherSuite
Groups []NamedGroup
SignatureSchemes []SignatureScheme
NextProtos []string
PSKs PreSharedKeyCache
PSKModes []PSKKeyExchangeMode
NonBlocking bool
// The same config object can be shared among different connections, so it
// needs its own mutex
mutex sync.RWMutex
}
// Clone returns a shallow clone of c. It is safe to clone a Config that is
// being used concurrently by a TLS client or server.
func (c *Config) Clone() *Config {
c.mutex.Lock()
defer c.mutex.Unlock()
return &Config{
ServerName: c.ServerName,
SendSessionTickets: c.SendSessionTickets,
TicketLifetime: c.TicketLifetime,
TicketLen: c.TicketLen,
EarlyDataLifetime: c.EarlyDataLifetime,
AllowEarlyData: c.AllowEarlyData,
RequireCookie: c.RequireCookie,
RequireClientAuth: c.RequireClientAuth,
Certificates: c.Certificates,
AuthCertificate: c.AuthCertificate,
CipherSuites: c.CipherSuites,
Groups: c.Groups,
SignatureSchemes: c.SignatureSchemes,
NextProtos: c.NextProtos,
PSKs: c.PSKs,
PSKModes: c.PSKModes,
NonBlocking: c.NonBlocking,
}
}
func (c *Config) Init(isClient bool) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// Set defaults
if len(c.CipherSuites) == 0 {
c.CipherSuites = defaultSupportedCipherSuites
}
if len(c.Groups) == 0 {
c.Groups = defaultSupportedGroups
}
if len(c.SignatureSchemes) == 0 {
c.SignatureSchemes = defaultSignatureSchemes
}
if c.TicketLen == 0 {
c.TicketLen = defaultTicketLen
}
if !reflect.ValueOf(c.PSKs).IsValid() {
c.PSKs = &PSKMapCache{}
}
if len(c.PSKModes) == 0 {
c.PSKModes = defaultPSKModes
}
// If there is no certificate, generate one
if !isClient && len(c.Certificates) == 0 {
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
priv, err := newSigningKey(RSA_PSS_SHA256)
if err != nil {
return err
}
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
if err != nil {
return err
}
c.Certificates = []*Certificate{
{
Chain: []*x509.Certificate{cert},
PrivateKey: priv,
},
}
}
return nil
}
func (c *Config) ValidForServer() bool {
return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) ||
(len(c.Certificates) > 0 &&
len(c.Certificates[0].Chain) > 0 &&
c.Certificates[0].PrivateKey != nil)
}
func (c *Config) ValidForClient() bool {
return len(c.ServerName) > 0
}
var (
defaultSupportedCipherSuites = []CipherSuite{
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
}
defaultSupportedGroups = []NamedGroup{
P256,
P384,
FFDHE2048,
X25519,
}
defaultSignatureSchemes = []SignatureScheme{
RSA_PSS_SHA256,
RSA_PSS_SHA384,
RSA_PSS_SHA512,
ECDSA_P256_SHA256,
ECDSA_P384_SHA384,
ECDSA_P521_SHA512,
}
defaultTicketLen = 16
defaultPSKModes = []PSKKeyExchangeMode{
PSKModeKE,
PSKModeDHEKE,
}
)
type ConnectionState struct {
HandshakeState string // string representation of the handshake state.
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement
NextProto string // Selected ALPN proto
}
// Conn implements the net.Conn interface, as with "crypto/tls"
// * Read, Write, and Close are provided locally
// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn
type Conn struct {
config *Config
conn net.Conn
isClient bool
EarlyData []byte
state StateConnected
hState HandshakeState
handshakeMutex sync.Mutex
handshakeAlert Alert
handshakeComplete bool
readBuffer []byte
in, out *RecordLayer
hIn, hOut *HandshakeLayer
extHandler AppExtensionHandler
}
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
c := &Conn{conn: conn, config: config, isClient: isClient}
c.in = NewRecordLayer(c.conn)
c.out = NewRecordLayer(c.conn)
c.hIn = NewHandshakeLayer(c.in)
c.hIn.nonblocking = c.config.NonBlocking
c.hOut = NewHandshakeLayer(c.out)
return c
}
// Read up
func (c *Conn) consumeRecord() error {
pt, err := c.in.ReadRecord()
if pt == nil {
logf(logTypeIO, "extendBuffer returns error %v", err)
return err
}
switch pt.contentType {
case RecordTypeHandshake:
logf(logTypeHandshake, "Received post-handshake message")
// We do not support fragmentation of post-handshake handshake messages.
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
start := 0
for start < len(pt.fragment) {
if len(pt.fragment[start:]) < handshakeHeaderLen {
return fmt.Errorf("Post-handshake handshake message too short for header")
}
hm := &HandshakeMessage{}
hm.msgType = HandshakeType(pt.fragment[start])
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen {
return fmt.Errorf("Post-handshake handshake message too short for body")
}
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen]
// Advance state machine
state, actions, alert := c.state.Next(hm)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error in state transition: %v", alert)
c.sendAlert(alert)
return io.EOF
}
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return io.EOF
}
}
// XXX: If we want to support more advanced cases, e.g., post-handshake
// authentication, we'll need to allow transitions other than
// Connected -> Connected
var connected bool
c.state, connected = state.(StateConnected)
if !connected {
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
c.sendAlert(alert)
return io.EOF
}
start += handshakeHeaderLen + hmLen
}
case RecordTypeAlert:
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
if len(pt.fragment) != 2 {
c.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
if Alert(pt.fragment[1]) == AlertCloseNotify {
return io.EOF
}
switch pt.fragment[0] {
case AlertLevelWarning:
// drop on the floor
case AlertLevelError:
return Alert(pt.fragment[1])
default:
c.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
case RecordTypeApplicationData:
c.readBuffer = append(c.readBuffer, pt.fragment...)
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
}
return err
}
// Read application data up to the size of buffer. Handshake and alert records
// are consumed by the Conn object directly.
func (c *Conn) Read(buffer []byte) (int, error) {
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
if alert := c.Handshake(); alert != AlertNoAlert {
return 0, alert
}
if len(buffer) == 0 {
return 0, nil
}
// Lock the input channel
c.in.Lock()
defer c.in.Unlock()
for len(c.readBuffer) == 0 {
err := c.consumeRecord()
// err can be nil if consumeRecord processed a non app-data
// record.
if err != nil {
if c.config.NonBlocking || err != WouldBlock {
logf(logTypeIO, "conn.Read returns err=%v", err)
return 0, err
}
}
}
var read int
n := len(buffer)
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
if len(c.readBuffer) <= n {
buffer = buffer[:len(c.readBuffer)]
copy(buffer, c.readBuffer)
read = len(c.readBuffer)
c.readBuffer = c.readBuffer[:0]
} else {
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
copy(buffer[:n], c.readBuffer[:n])
c.readBuffer = c.readBuffer[n:]
read = n
}
logf(logTypeVerbose, "Returning %v", string(buffer))
return read, nil
}
// Write application data
func (c *Conn) Write(buffer []byte) (int, error) {
// Lock the output channel
c.out.Lock()
defer c.out.Unlock()
// Send full-size fragments
var start int
sent := 0
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
err := c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: buffer[start : start+maxFragmentLen],
})
if err != nil {
return sent, err
}
sent += maxFragmentLen
}
// Send a final partial fragment if necessary
if start < len(buffer) {
err := c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: buffer[start:],
})
if err != nil {
return sent, err
}
sent += len(buffer[start:])
}
return sent, nil
}
// sendAlert sends a TLS alert message.
// c.out.Mutex <= L.
func (c *Conn) sendAlert(err Alert) error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
var level int
switch err {
case AlertNoRenegotiation, AlertCloseNotify:
level = AlertLevelWarning
default:
level = AlertLevelError
}
buf := []byte{byte(err), byte(level)}
c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeAlert,
fragment: buf,
})
// close_notify and end_of_early_data are not actually errors
if level == AlertLevelWarning {
return &net.OpError{Op: "local error", Err: err}
}
return c.Close()
}
// Close closes the connection.
func (c *Conn) Close() error {
// XXX crypto/tls has an interlock with Write here. Do we need that?
return c.conn.Close()
}
// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline sets the read and write deadlines associated with the connection.
// A zero value for t means Read and Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
// SetReadDeadline sets the read deadline on the underlying connection.
// A zero value for t means Read will not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline on the underlying connection.
// A zero value for t means Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
label := "[server]"
if c.isClient {
label = "[client]"
}
switch action := actionGeneric.(type) {
case SendHandshakeMessage:
err := c.hOut.WriteMessage(action.Message)
if err != nil {
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
return AlertInternalError
}
case RekeyIn:
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet)
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
if err != nil {
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
return AlertInternalError
}
case RekeyOut:
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet)
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
if err != nil {
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
return AlertInternalError
}
case SendEarlyData:
logf(logTypeHandshake, "%s Sending early data...", label)
_, err := c.Write(c.EarlyData)
if err != nil {
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
return AlertInternalError
}
case ReadPastEarlyData:
logf(logTypeHandshake, "%s Reading past early data...", label)
// Scan past all records that fail to decrypt
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
if err == nil {
break
}
_, ok := err.(DecryptError)
for ok {
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
if err == nil {
break
}
_, ok = err.(DecryptError)
}
case ReadEarlyData:
logf(logTypeHandshake, "%s Reading early data...", label)
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
if err != nil {
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
for t == RecordTypeApplicationData {
// Read a record into the buffer. Note that this is safe
// in blocking mode because we read the record in in
// PeekRecordType.
pt, err := c.in.ReadRecord()
if err != nil {
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
c.EarlyData = append(c.EarlyData, pt.fragment...)
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
if err != nil {
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
}
logf(logTypeHandshake, "%s Done reading early data", label)
case StorePSK:
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
if c.isClient {
// Clients look up PSKs based on server name
c.config.PSKs.Put(c.config.ServerName, action.PSK)
} else {
// Servers look them up based on the identity in the extension
c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK)
}
default:
logf(logTypeHandshake, "%s Unknown actionuction type", label)
return AlertInternalError
}
return AlertNoAlert
}
func (c *Conn) HandshakeSetup() Alert {
var state HandshakeState
var actions []HandshakeAction
var alert Alert
if err := c.config.Init(c.isClient); err != nil {
logf(logTypeHandshake, "Error initializing config: %v", err)
return AlertInternalError
}
// Set things up
caps := Capabilities{
CipherSuites: c.config.CipherSuites,
Groups: c.config.Groups,
SignatureSchemes: c.config.SignatureSchemes,
PSKs: c.config.PSKs,
PSKModes: c.config.PSKModes,
AllowEarlyData: c.config.AllowEarlyData,
RequireCookie: c.config.RequireCookie,
CookieHandler: c.config.CookieHandler,
RequireClientAuth: c.config.RequireClientAuth,
NextProtos: c.config.NextProtos,
Certificates: c.config.Certificates,
ExtensionHandler: c.extHandler,
}
opts := ConnectionOptions{
ServerName: c.config.ServerName,
NextProtos: c.config.NextProtos,
EarlyData: c.EarlyData,
}
if caps.RequireCookie && caps.CookieHandler == nil {
caps.CookieHandler = &defaultCookieHandler{}
}
if c.isClient {
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error initializing client state: %v", alert)
return alert
}
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
return alert
}
}
} else {
state = ServerStateStart{Caps: caps, conn: c}
}
c.hState = state
return AlertNoAlert
}
// Handshake causes a TLS handshake on the connection. The `isClient` member
// determines whether a client or server handshake is performed. If a
// handshake has already been performed, then its result will be returned.
func (c *Conn) Handshake() Alert {
label := "[server]"
if c.isClient {
label = "[client]"
}
// TODO Lock handshakeMutex
// TODO Remove CloseNotify hack
if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify {
logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert)
return c.handshakeAlert
}
if c.handshakeComplete {
return AlertNoAlert
}
var alert Alert
if c.hState == nil {
logf(logTypeHandshake, "%s First time through handshake, setting up", label)
alert = c.HandshakeSetup()
if alert != AlertNoAlert {
return alert
}
} else {
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState)
}
state := c.hState
_, connected := state.(StateConnected)
var actions []HandshakeAction
for !connected {
// Read a handshake message
hm, err := c.hIn.ReadMessage()
if err == WouldBlock {
logf(logTypeHandshake, "%s Would block reading message: %v", label, err)
return AlertWouldBlock
}
if err != nil {
logf(logTypeHandshake, "%s Error reading message: %v", label, err)
c.sendAlert(AlertCloseNotify)
return AlertCloseNotify
}
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
// Advance the state machine
state, actions, alert = state.Next(hm)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error in state transition: %v", alert)
return alert
}
for index, action := range actions {
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return alert
}
}
c.hState = state
logf(logTypeHandshake, "state is now %s", c.GetHsState())
_, connected = state.(StateConnected)
}
c.state = state.(StateConnected)
// Send NewSessionTicket if acting as server
if !c.isClient && c.config.SendSessionTickets {
actions, alert := c.state.NewSessionTicket(
c.config.TicketLen,
c.config.TicketLifetime,
c.config.EarlyDataLifetime)
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return alert
}
}
}
c.handshakeComplete = true
return AlertNoAlert
}
func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
if !c.handshakeComplete {
return fmt.Errorf("Cannot update keys until after handshake")
}
request := KeyUpdateNotRequested
if requestUpdate {
request = KeyUpdateRequested
}
// Create the key update and update state
actions, alert := c.state.KeyUpdate(request)
if alert != AlertNoAlert {
c.sendAlert(alert)
return fmt.Errorf("Alert while generating key update: %v", alert)
}
// Take actions (send key update and rekey)
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
c.sendAlert(alert)
return fmt.Errorf("Alert during key update actions: %v", alert)
}
}
return nil
}
func (c *Conn) GetHsState() string {
return reflect.TypeOf(c.hState).Name()
}
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
_, connected := c.hState.(StateConnected)
if !connected {
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
}
if c.state.exporterSecret == nil {
return nil, fmt.Errorf("Internal error: no exporter secret")
}
h0 := c.state.cryptoParams.Hash.New().Sum(nil)
tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0)
hc := c.state.cryptoParams.Hash.New().Sum(context)
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
}
func (c *Conn) State() ConnectionState {
state := ConnectionState{
HandshakeState: c.GetHsState(),
}
if c.handshakeComplete {
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
state.NextProto = c.state.Params.NextProto
}
return state
}
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error {
if c.hState != nil {
return fmt.Errorf("Can't set extension handler after setup")
}
c.extHandler = h
return nil
}

654
vendor/github.com/bifurcation/mint/crypto.go generated vendored Normal file
View file

@ -0,0 +1,654 @@
package mint
import (
"bytes"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/hmac"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"fmt"
"math/big"
"time"
"golang.org/x/crypto/curve25519"
// Blank includes to ensure hash support
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha512"
)
var prng = rand.Reader
type aeadFactory func(key []byte) (cipher.AEAD, error)
type CipherSuiteParams struct {
Suite CipherSuite
Cipher aeadFactory // Cipher factory
Hash crypto.Hash // Hash function
KeyLen int // Key length in octets
IvLen int // IV length in octets
}
type signatureAlgorithm uint8
const (
signatureAlgorithmUnknown = iota
signatureAlgorithmRSA_PKCS1
signatureAlgorithmRSA_PSS
signatureAlgorithmECDSA
)
var (
hashMap = map[SignatureScheme]crypto.Hash{
RSA_PKCS1_SHA1: crypto.SHA1,
RSA_PKCS1_SHA256: crypto.SHA256,
RSA_PKCS1_SHA384: crypto.SHA384,
RSA_PKCS1_SHA512: crypto.SHA512,
ECDSA_P256_SHA256: crypto.SHA256,
ECDSA_P384_SHA384: crypto.SHA384,
ECDSA_P521_SHA512: crypto.SHA512,
RSA_PSS_SHA256: crypto.SHA256,
RSA_PSS_SHA384: crypto.SHA384,
RSA_PSS_SHA512: crypto.SHA512,
}
sigMap = map[SignatureScheme]signatureAlgorithm{
RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1,
RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1,
RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1,
RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1,
ECDSA_P256_SHA256: signatureAlgorithmECDSA,
ECDSA_P384_SHA384: signatureAlgorithmECDSA,
ECDSA_P521_SHA512: signatureAlgorithmECDSA,
RSA_PSS_SHA256: signatureAlgorithmRSA_PSS,
RSA_PSS_SHA384: signatureAlgorithmRSA_PSS,
RSA_PSS_SHA512: signatureAlgorithmRSA_PSS,
}
curveMap = map[SignatureScheme]NamedGroup{
ECDSA_P256_SHA256: P256,
ECDSA_P384_SHA384: P384,
ECDSA_P521_SHA512: P521,
}
newAESGCM = func(key []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
// TLS always uses 12-byte nonces
return cipher.NewGCMWithNonceSize(block, 12)
}
cipherSuiteMap = map[CipherSuite]CipherSuiteParams{
TLS_AES_128_GCM_SHA256: {
Suite: TLS_AES_128_GCM_SHA256,
Cipher: newAESGCM,
Hash: crypto.SHA256,
KeyLen: 16,
IvLen: 12,
},
TLS_AES_256_GCM_SHA384: {
Suite: TLS_AES_256_GCM_SHA384,
Cipher: newAESGCM,
Hash: crypto.SHA384,
KeyLen: 32,
IvLen: 12,
},
}
x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{
RSA_PKCS1_SHA1: x509.SHA1WithRSA,
RSA_PKCS1_SHA256: x509.SHA256WithRSA,
RSA_PKCS1_SHA384: x509.SHA384WithRSA,
RSA_PKCS1_SHA512: x509.SHA512WithRSA,
ECDSA_P256_SHA256: x509.ECDSAWithSHA256,
ECDSA_P384_SHA384: x509.ECDSAWithSHA384,
ECDSA_P521_SHA512: x509.ECDSAWithSHA512,
}
defaultRSAKeySize = 2048
)
func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) {
switch group {
case P256:
crv = elliptic.P256()
case P384:
crv = elliptic.P384()
case P521:
crv = elliptic.P521()
}
return
}
func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) {
switch key.Curve.Params().Name {
case elliptic.P256().Params().Name:
g = P256
case elliptic.P384().Params().Name:
g = P384
case elliptic.P521().Params().Name:
g = P521
}
return
}
func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) {
size = 0
switch group {
case X25519:
size = 32
case P256:
size = 65
case P384:
size = 97
case P521:
size = 133
case FFDHE2048:
size = 256
case FFDHE3072:
size = 384
case FFDHE4096:
size = 512
case FFDHE6144:
size = 768
case FFDHE8192:
size = 1024
}
return
}
func primeFromNamedGroup(group NamedGroup) (p *big.Int) {
switch group {
case FFDHE2048:
p = finiteFieldPrime2048
case FFDHE3072:
p = finiteFieldPrime3072
case FFDHE4096:
p = finiteFieldPrime4096
case FFDHE6144:
p = finiteFieldPrime6144
case FFDHE8192:
p = finiteFieldPrime8192
}
return
}
func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool {
sigType := sigMap[alg]
switch key.(type) {
case *rsa.PrivateKey:
return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS
case *ecdsa.PrivateKey:
return sigType == signatureAlgorithmECDSA
default:
return false
}
}
func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) {
primeLen := len(p.Bytes())
for {
// g = 2 for all ffdhe groups
priv, err = rand.Int(prng, p)
if err != nil {
return
}
pub = big.NewInt(0)
pub.Exp(big.NewInt(2), priv, p)
if len(pub.Bytes()) == primeLen {
return
}
}
}
func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) {
switch group {
case P256, P384, P521:
var x, y *big.Int
crv := curveFromNamedGroup(group)
priv, x, y, err = elliptic.GenerateKey(crv, prng)
if err != nil {
return
}
pub = elliptic.Marshal(crv, x, y)
return
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
p := primeFromNamedGroup(group)
x, X, err2 := ffdheKeyShareFromPrime(p)
if err2 != nil {
err = err2
return
}
priv = x.Bytes()
pubBytes := X.Bytes()
numBytes := keyExchangeSizeFromNamedGroup(group)
pub = make([]byte, numBytes)
copy(pub[numBytes-len(pubBytes):], pubBytes)
return
case X25519:
var private, public [32]byte
_, err = prng.Read(private[:])
if err != nil {
return
}
curve25519.ScalarBaseMult(&public, &private)
priv = private[:]
pub = public[:]
return
default:
return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group)
}
}
func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) {
switch group {
case P256, P384, P521:
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
}
crv := curveFromNamedGroup(group)
pubX, pubY := elliptic.Unmarshal(crv, pub)
x, _ := crv.Params().ScalarMult(pubX, pubY, priv)
xBytes := x.Bytes()
numBytes := len(crv.Params().P.Bytes())
ret := make([]byte, numBytes)
copy(ret[numBytes-len(xBytes):], xBytes)
return ret, nil
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
numBytes := keyExchangeSizeFromNamedGroup(group)
if len(pub) != numBytes {
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
}
p := primeFromNamedGroup(group)
x := big.NewInt(0).SetBytes(priv)
Y := big.NewInt(0).SetBytes(pub)
ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes()
ret := make([]byte, numBytes)
copy(ret[numBytes-len(ZBytes):], ZBytes)
return ret, nil
case X25519:
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
}
var private, public, ret [32]byte
copy(private[:], priv)
copy(public[:], pub)
curve25519.ScalarMult(&ret, &private, &public)
return ret[:], nil
default:
return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group)
}
}
func newSigningKey(sig SignatureScheme) (crypto.Signer, error) {
switch sig {
case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256,
RSA_PKCS1_SHA384, RSA_PKCS1_SHA512,
RSA_PSS_SHA256, RSA_PSS_SHA384,
RSA_PSS_SHA512:
return rsa.GenerateKey(prng, defaultRSAKeySize)
case ECDSA_P256_SHA256:
return ecdsa.GenerateKey(elliptic.P256(), prng)
case ECDSA_P384_SHA384:
return ecdsa.GenerateKey(elliptic.P384(), prng)
case ECDSA_P521_SHA512:
return ecdsa.GenerateKey(elliptic.P521(), prng)
default:
return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig)
}
}
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
sigAlg, ok := x509AlgMap[alg]
if !ok {
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
}
if len(name) == 0 {
return nil, fmt.Errorf("tls.selfsigned: No name provided")
}
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
if err != nil {
return nil, err
}
template := &x509.Certificate{
SerialNumber: serial,
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(0, 0, 1),
SignatureAlgorithm: sigAlg,
Subject: pkix.Name{CommonName: name},
DNSNames: []string{name},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
if err != nil {
return nil, err
}
// It is safe to ignore the error here because we're parsing known-good data
cert, _ := x509.ParseCertificate(der)
return cert, nil
}
// XXX(rlb): Copied from crypto/x509
type ecdsaSignature struct {
R, S *big.Int
}
func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) {
var opts crypto.SignerOpts
hash := hashMap[alg]
if hash == crypto.SHA1 {
return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
}
sigType := sigMap[alg]
var realInput []byte
switch key := privateKey.(type) {
case *rsa.PrivateKey:
switch {
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size())
opts = hash
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
fallthrough
case sigType == signatureAlgorithmRSA_PSS:
logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size())
opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
default:
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key")
}
h := hash.New()
h.Write(sigInput)
realInput = h.Sum(nil)
case *ecdsa.PrivateKey:
if sigType != signatureAlgorithmECDSA {
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key")
}
algGroup := curveMap[alg]
keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey))
if algGroup != keyGroup {
return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination")
}
h := hash.New()
h.Write(sigInput)
realInput = h.Sum(nil)
default:
return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type")
}
sig, err := privateKey.Sign(prng, realInput, opts)
logf(logTypeCrypto, "signature: %x", sig)
return sig, err
}
func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error {
hash := hashMap[alg]
if hash == crypto.SHA1 {
return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
}
sigType := sigMap[alg]
switch pub := publicKey.(type) {
case *rsa.PublicKey:
switch {
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size())
h := hash.New()
h.Write(sigInput)
realInput := h.Sum(nil)
return rsa.VerifyPKCS1v15(pub, hash, realInput, sig)
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
fallthrough
case sigType == signatureAlgorithmRSA_PSS:
logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size())
opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
h := hash.New()
h.Write(sigInput)
realInput := h.Sum(nil)
return rsa.VerifyPSS(pub, hash, realInput, sig, opts)
default:
return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key")
}
case *ecdsa.PublicKey:
if sigType != signatureAlgorithmECDSA {
return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key")
}
if curveMap[alg] != namedGroupFromECDSAKey(pub) {
return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key")
}
ecdsaSig := new(ecdsaSignature)
if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
return err
} else if len(rest) != 0 {
return fmt.Errorf("tls.verify: trailing data after ECDSA signature")
}
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values")
}
h := hash.New()
h.Write(sigInput)
realInput := h.Sum(nil)
if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) {
return fmt.Errorf("tls.verify: ECDSA verification failure")
}
return nil
default:
return fmt.Errorf("tls.verify: Unsupported key type")
}
}
// 0
// |
// v
// PSK -> HKDF-Extract = Early Secret
// |
// +-----> Derive-Secret(.,
// | "ext binder" |
// | "res binder",
// | "")
// | = binder_key
// |
// +-----> Derive-Secret(., "c e traffic",
// | ClientHello)
// | = client_early_traffic_secret
// |
// +-----> Derive-Secret(., "e exp master",
// | ClientHello)
// | = early_exporter_master_secret
// v
// Derive-Secret(., "derived", "")
// |
// v
// (EC)DHE -> HKDF-Extract = Handshake Secret
// |
// +-----> Derive-Secret(., "c hs traffic",
// | ClientHello...ServerHello)
// | = client_handshake_traffic_secret
// |
// +-----> Derive-Secret(., "s hs traffic",
// | ClientHello...ServerHello)
// | = server_handshake_traffic_secret
// v
// Derive-Secret(., "derived", "")
// |
// v
// 0 -> HKDF-Extract = Master Secret
// |
// +-----> Derive-Secret(., "c ap traffic",
// | ClientHello...server Finished)
// | = client_application_traffic_secret_0
// |
// +-----> Derive-Secret(., "s ap traffic",
// | ClientHello...server Finished)
// | = server_application_traffic_secret_0
// |
// +-----> Derive-Secret(., "exp master",
// | ClientHello...server Finished)
// | = exporter_master_secret
// |
// +-----> Derive-Secret(., "res master",
// ClientHello...client Finished)
// = resumption_master_secret
// From RFC 5869
// PRK = HMAC-Hash(salt, IKM)
func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte {
salt := saltIn
// if [salt is] not provided, it is set to a string of HashLen zeros
if salt == nil {
salt = bytes.Repeat([]byte{0}, hash.Size())
}
h := hmac.New(hash.New, salt)
h.Write(input)
out := h.Sum(nil)
logf(logTypeCrypto, "HKDF Extract:\n")
logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt)
logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input)
logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out)
return out
}
const (
labelExternalBinder = "ext binder"
labelResumptionBinder = "res binder"
labelEarlyTrafficSecret = "c e traffic"
labelEarlyExporterSecret = "e exp master"
labelClientHandshakeTrafficSecret = "c hs traffic"
labelServerHandshakeTrafficSecret = "s hs traffic"
labelClientApplicationTrafficSecret = "c ap traffic"
labelServerApplicationTrafficSecret = "s ap traffic"
labelExporterSecret = "exp master"
labelResumptionSecret = "res master"
labelDerived = "derived"
labelFinished = "finished"
labelResumption = "resumption"
)
// struct HkdfLabel {
// uint16 length;
// opaque label<9..255>;
// opaque hash_value<0..255>;
// };
func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte {
label := "tls13 " + labelIn
labelLen := len(label)
hashLen := len(hashValue)
hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen)
hkdfLabel[0] = byte(outLen >> 8)
hkdfLabel[1] = byte(outLen)
hkdfLabel[2] = byte(labelLen)
copy(hkdfLabel[3:3+labelLen], []byte(label))
hkdfLabel[3+labelLen] = byte(hashLen)
copy(hkdfLabel[3+labelLen+1:], hashValue)
return hkdfLabel
}
func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte {
out := []byte{}
T := []byte{}
i := byte(1)
for len(out) < outLen {
block := append(T, info...)
block = append(block, i)
h := hmac.New(hash.New, prk)
h.Write(block)
T = h.Sum(nil)
out = append(out, T...)
i++
}
return out[:outLen]
}
func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte {
info := hkdfEncodeLabel(label, hashValue, outLen)
derived := HkdfExpand(hash, secret, info, outLen)
logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen)
logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret)
logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue)
logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info)
logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived)
return derived
}
func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte {
return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size())
}
func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte {
macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size())
mac := hmac.New(params.Hash.New, macKey)
mac.Write(input)
return mac.Sum(nil)
}
type keySet struct {
cipher aeadFactory
key []byte
iv []byte
}
func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
logf(logTypeCrypto, "making traffic keys: secret=%x", secret)
return keySet{
cipher: params.Cipher,
key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen),
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
}
}

586
vendor/github.com/bifurcation/mint/extensions.go generated vendored Normal file
View file

@ -0,0 +1,586 @@
package mint
import (
"bytes"
"fmt"
"github.com/bifurcation/mint/syntax"
)
type ExtensionBody interface {
Type() ExtensionType
Marshal() ([]byte, error)
Unmarshal(data []byte) (int, error)
}
// struct {
// ExtensionType extension_type;
// opaque extension_data<0..2^16-1>;
// } Extension;
type Extension struct {
ExtensionType ExtensionType
ExtensionData []byte `tls:"head=2"`
}
func (ext Extension) Marshal() ([]byte, error) {
return syntax.Marshal(ext)
}
func (ext *Extension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, ext)
}
type ExtensionList []Extension
type extensionListInner struct {
List []Extension `tls:"head=2"`
}
func (el ExtensionList) Marshal() ([]byte, error) {
return syntax.Marshal(extensionListInner{el})
}
func (el *ExtensionList) Unmarshal(data []byte) (int, error) {
var list extensionListInner
read, err := syntax.Unmarshal(data, &list)
if err != nil {
return 0, err
}
*el = list.List
return read, nil
}
func (el *ExtensionList) Add(src ExtensionBody) error {
data, err := src.Marshal()
if err != nil {
return err
}
if el == nil {
el = new(ExtensionList)
}
// If one already exists with this type, replace it
for i := range *el {
if (*el)[i].ExtensionType == src.Type() {
(*el)[i].ExtensionData = data
return nil
}
}
// Otherwise append
*el = append(*el, Extension{
ExtensionType: src.Type(),
ExtensionData: data,
})
return nil
}
func (el ExtensionList) Find(dst ExtensionBody) bool {
for _, ext := range el {
if ext.ExtensionType == dst.Type() {
_, err := dst.Unmarshal(ext.ExtensionData)
return err == nil
}
}
return false
}
// struct {
// NameType name_type;
// select (name_type) {
// case host_name: HostName;
// } name;
// } ServerName;
//
// enum {
// host_name(0), (255)
// } NameType;
//
// opaque HostName<1..2^16-1>;
//
// struct {
// ServerName server_name_list<1..2^16-1>
// } ServerNameList;
//
// But we only care about the case where there's a single DNS hostname. We
// will never create anything else, and throw if we receive something else
//
// 2 1 2
// | listLen | NameType | nameLen | name |
type ServerNameExtension string
type serverNameInner struct {
NameType uint8
HostName []byte `tls:"head=2,min=1"`
}
type serverNameListInner struct {
ServerNameList []serverNameInner `tls:"head=2,min=1"`
}
func (sni ServerNameExtension) Type() ExtensionType {
return ExtensionTypeServerName
}
func (sni ServerNameExtension) Marshal() ([]byte, error) {
list := serverNameListInner{
ServerNameList: []serverNameInner{{
NameType: 0x00, // host_name
HostName: []byte(sni),
}},
}
return syntax.Marshal(list)
}
func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) {
var list serverNameListInner
read, err := syntax.Unmarshal(data, &list)
if err != nil {
return 0, err
}
// Syntax requires at least one entry
// Entries beyond the first are ignored
if nameType := list.ServerNameList[0].NameType; nameType != 0x00 {
return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType)
}
*sni = ServerNameExtension(list.ServerNameList[0].HostName)
return read, nil
}
// struct {
// NamedGroup group;
// opaque key_exchange<1..2^16-1>;
// } KeyShareEntry;
//
// struct {
// select (Handshake.msg_type) {
// case client_hello:
// KeyShareEntry client_shares<0..2^16-1>;
//
// case hello_retry_request:
// NamedGroup selected_group;
//
// case server_hello:
// KeyShareEntry server_share;
// };
// } KeyShare;
type KeyShareEntry struct {
Group NamedGroup
KeyExchange []byte `tls:"head=2,min=1"`
}
func (kse KeyShareEntry) SizeValid() bool {
return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group)
}
type KeyShareExtension struct {
HandshakeType HandshakeType
SelectedGroup NamedGroup
Shares []KeyShareEntry
}
type KeyShareClientHelloInner struct {
ClientShares []KeyShareEntry `tls:"head=2,min=0"`
}
type KeyShareHelloRetryInner struct {
SelectedGroup NamedGroup
}
type KeyShareServerHelloInner struct {
ServerShare KeyShareEntry
}
func (ks KeyShareExtension) Type() ExtensionType {
return ExtensionTypeKeyShare
}
func (ks KeyShareExtension) Marshal() ([]byte, error) {
switch ks.HandshakeType {
case HandshakeTypeClientHello:
for _, share := range ks.Shares {
if !share.SizeValid() {
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
}
return syntax.Marshal(KeyShareClientHelloInner{ks.Shares})
case HandshakeTypeHelloRetryRequest:
if len(ks.Shares) > 0 {
return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest")
}
return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup})
case HandshakeTypeServerHello:
if len(ks.Shares) != 1 {
return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share")
}
if !ks.Shares[0].SizeValid() {
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]})
default:
return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed")
}
}
func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) {
switch ks.HandshakeType {
case HandshakeTypeClientHello:
var inner KeyShareClientHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
for _, share := range inner.ClientShares {
if !share.SizeValid() {
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
}
ks.Shares = inner.ClientShares
return read, nil
case HandshakeTypeHelloRetryRequest:
var inner KeyShareHelloRetryInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
ks.SelectedGroup = inner.SelectedGroup
return read, nil
case HandshakeTypeServerHello:
var inner KeyShareServerHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
if !inner.ServerShare.SizeValid() {
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
ks.Shares = []KeyShareEntry{inner.ServerShare}
return read, nil
default:
return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed")
}
}
// struct {
// NamedGroup named_group_list<2..2^16-1>;
// } NamedGroupList;
type SupportedGroupsExtension struct {
Groups []NamedGroup `tls:"head=2,min=2"`
}
func (sg SupportedGroupsExtension) Type() ExtensionType {
return ExtensionTypeSupportedGroups
}
func (sg SupportedGroupsExtension) Marshal() ([]byte, error) {
return syntax.Marshal(sg)
}
func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sg)
}
// struct {
// SignatureScheme supported_signature_algorithms<2..2^16-2>;
// } SignatureSchemeList
type SignatureAlgorithmsExtension struct {
Algorithms []SignatureScheme `tls:"head=2,min=2"`
}
func (sa SignatureAlgorithmsExtension) Type() ExtensionType {
return ExtensionTypeSignatureAlgorithms
}
func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) {
return syntax.Marshal(sa)
}
func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sa)
}
// struct {
// opaque identity<1..2^16-1>;
// uint32 obfuscated_ticket_age;
// } PskIdentity;
//
// opaque PskBinderEntry<32..255>;
//
// struct {
// select (Handshake.msg_type) {
// case client_hello:
// PskIdentity identities<7..2^16-1>;
// PskBinderEntry binders<33..2^16-1>;
//
// case server_hello:
// uint16 selected_identity;
// };
//
// } PreSharedKeyExtension;
type PSKIdentity struct {
Identity []byte `tls:"head=2,min=1"`
ObfuscatedTicketAge uint32
}
type PSKBinderEntry struct {
Binder []byte `tls:"head=1,min=32"`
}
type PreSharedKeyExtension struct {
HandshakeType HandshakeType
Identities []PSKIdentity
Binders []PSKBinderEntry
SelectedIdentity uint16
}
type preSharedKeyClientInner struct {
Identities []PSKIdentity `tls:"head=2,min=7"`
Binders []PSKBinderEntry `tls:"head=2,min=33"`
}
type preSharedKeyServerInner struct {
SelectedIdentity uint16
}
func (psk PreSharedKeyExtension) Type() ExtensionType {
return ExtensionTypePreSharedKey
}
func (psk PreSharedKeyExtension) Marshal() ([]byte, error) {
switch psk.HandshakeType {
case HandshakeTypeClientHello:
return syntax.Marshal(preSharedKeyClientInner{
Identities: psk.Identities,
Binders: psk.Binders,
})
case HandshakeTypeServerHello:
if len(psk.Identities) > 0 || len(psk.Binders) > 0 {
return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index")
}
return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity})
default:
return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported")
}
}
func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) {
switch psk.HandshakeType {
case HandshakeTypeClientHello:
var inner preSharedKeyClientInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
if len(inner.Identities) != len(inner.Binders) {
return 0, fmt.Errorf("Lengths of identities and binders not equal")
}
psk.Identities = inner.Identities
psk.Binders = inner.Binders
return read, nil
case HandshakeTypeServerHello:
var inner preSharedKeyServerInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
psk.SelectedIdentity = inner.SelectedIdentity
return read, nil
default:
return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported")
}
}
func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) {
for i, localID := range psk.Identities {
if bytes.Equal(localID.Identity, id) {
return psk.Binders[i].Binder, true
}
}
return nil, false
}
// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode;
//
// struct {
// PskKeyExchangeMode ke_modes<1..255>;
// } PskKeyExchangeModes;
type PSKKeyExchangeModesExtension struct {
KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"`
}
func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType {
return ExtensionTypePSKKeyExchangeModes
}
func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) {
return syntax.Marshal(pkem)
}
func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, pkem)
}
// struct {
// } EarlyDataIndication;
type EarlyDataExtension struct{}
func (ed EarlyDataExtension) Type() ExtensionType {
return ExtensionTypeEarlyData
}
func (ed EarlyDataExtension) Marshal() ([]byte, error) {
return []byte{}, nil
}
func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) {
return 0, nil
}
// struct {
// uint32 max_early_data_size;
// } TicketEarlyDataInfo;
type TicketEarlyDataInfoExtension struct {
MaxEarlyDataSize uint32
}
func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType {
return ExtensionTypeTicketEarlyDataInfo
}
func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) {
return syntax.Marshal(tedi)
}
func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, tedi)
}
// opaque ProtocolName<1..2^8-1>;
//
// struct {
// ProtocolName protocol_name_list<2..2^16-1>
// } ProtocolNameList;
type ALPNExtension struct {
Protocols []string
}
type protocolNameInner struct {
Name []byte `tls:"head=1,min=1"`
}
type alpnExtensionInner struct {
Protocols []protocolNameInner `tls:"head=2,min=2"`
}
func (alpn ALPNExtension) Type() ExtensionType {
return ExtensionTypeALPN
}
func (alpn ALPNExtension) Marshal() ([]byte, error) {
protocols := make([]protocolNameInner, len(alpn.Protocols))
for i, protocol := range alpn.Protocols {
protocols[i] = protocolNameInner{[]byte(protocol)}
}
return syntax.Marshal(alpnExtensionInner{protocols})
}
func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
var inner alpnExtensionInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
alpn.Protocols = make([]string, len(inner.Protocols))
for i, protocol := range inner.Protocols {
alpn.Protocols[i] = string(protocol.Name)
}
return read, nil
}
// struct {
// ProtocolVersion versions<2..254>;
// } SupportedVersions;
type SupportedVersionsExtension struct {
Versions []uint16 `tls:"head=1,min=2,max=254"`
}
func (sv SupportedVersionsExtension) Type() ExtensionType {
return ExtensionTypeSupportedVersions
}
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
return syntax.Marshal(sv)
}
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sv)
}
// struct {
// opaque cookie<1..2^16-1>;
// } Cookie;
type CookieExtension struct {
Cookie []byte `tls:"head=2,min=1"`
}
func (c CookieExtension) Type() ExtensionType {
return ExtensionTypeCookie
}
func (c CookieExtension) Marshal() ([]byte, error) {
return syntax.Marshal(c)
}
func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, c)
}
// defaultCookieLength is the default length of a cookie
const defaultCookieLength = 32
type defaultCookieHandler struct {
data []byte
}
var _ CookieHandler = &defaultCookieHandler{}
// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data
func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) {
h.data = make([]byte, defaultCookieLength)
if _, err := prng.Read(h.data); err != nil {
return nil, err
}
return h.data, nil
}
func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool {
return bytes.Equal(h.data, data)
}

147
vendor/github.com/bifurcation/mint/ffdhe.go generated vendored Normal file
View file

@ -0,0 +1,147 @@
package mint
import (
"encoding/hex"
"math/big"
)
var (
finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B423861285C97FFFFFFFFFFFFFFFF"
finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex)
finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes)
finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF"
finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex)
finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes)
finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" +
"FFFFFFFFFFFFFFFF"
finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex)
finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes)
finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
"A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF"
finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex)
finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes)
finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
"A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" +
"1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" +
"0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" +
"CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" +
"2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" +
"BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" +
"51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" +
"D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" +
"1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" +
"FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" +
"97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" +
"D68C8BB7C5C6424CFFFFFFFFFFFFFFFF"
finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex)
finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes)
)

98
vendor/github.com/bifurcation/mint/frame-reader.go generated vendored Normal file
View file

@ -0,0 +1,98 @@
// Read a generic "framed" packet consisting of a header and a
// This is used for both TLS Records and TLS Handshake Messages
package mint
type framing interface {
headerLen() int
defaultReadLen() int
frameLen(hdr []byte) (int, error)
}
const (
kFrameReaderHdr = 0
kFrameReaderBody = 1
)
type frameNextAction func(f *frameReader) error
type frameReader struct {
details framing
state uint8
header []byte
body []byte
working []byte
writeOffset int
remainder []byte
}
func newFrameReader(d framing) *frameReader {
hdr := make([]byte, d.headerLen())
return &frameReader{
d,
kFrameReaderHdr,
hdr,
nil,
hdr,
0,
nil,
}
}
func dup(a []byte) []byte {
r := make([]byte, len(a))
copy(r, a)
return r
}
func (f *frameReader) needed() int {
tmp := (len(f.working) - f.writeOffset) - len(f.remainder)
if tmp < 0 {
return 0
}
return tmp
}
func (f *frameReader) addChunk(in []byte) {
// Append to the buffer.
logf(logTypeFrameReader, "Appending %v", len(in))
f.remainder = append(f.remainder, in...)
}
func (f *frameReader) process() (hdr []byte, body []byte, err error) {
for f.needed() == 0 {
logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset)
// Fill out our working block
copied := copy(f.working[f.writeOffset:], f.remainder)
f.remainder = f.remainder[copied:]
f.writeOffset += copied
if f.writeOffset < len(f.working) {
logf(logTypeFrameReader, "Read would have blocked 1")
return nil, nil, WouldBlock
}
// Reset the write offset, because we are now full.
f.writeOffset = 0
// We have read a full frame
if f.state == kFrameReaderBody {
logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder))
f.state = kFrameReaderHdr
f.working = f.header
return dup(f.header), dup(f.body), nil
}
// We have read the header
bodyLen, err := f.details.frameLen(f.header)
if err != nil {
return nil, nil, err
}
logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen)
f.body = make([]byte, bodyLen)
f.working = f.body
f.writeOffset = 0
f.state = kFrameReaderBody
}
logf(logTypeFrameReader, "Read would have blocked 2")
return nil, nil, WouldBlock
}

253
vendor/github.com/bifurcation/mint/handshake-layer.go generated vendored Normal file
View file

@ -0,0 +1,253 @@
package mint
import (
"fmt"
"io"
"net"
)
const (
handshakeHeaderLen = 4 // handshake message header length
maxHandshakeMessageLen = 1 << 24 // max handshake message length
)
// struct {
// HandshakeType msg_type; /* handshake type */
// uint24 length; /* bytes in message */
// select (HandshakeType) {
// ...
// } body;
// } Handshake;
//
// We do the select{...} part in a different layer, so we treat the
// actual message body as opaque:
//
// struct {
// HandshakeType msg_type;
// opaque msg<0..2^24-1>
// } Handshake;
//
// TODO: File a spec bug
type HandshakeMessage struct {
// Omitted: length
msgType HandshakeType
body []byte
}
// Note: This could be done with the `syntax` module, using the simplified
// syntax as discussed above. However, since this is so simple, there's not
// much benefit to doing so.
func (hm *HandshakeMessage) Marshal() []byte {
if hm == nil {
return []byte{}
}
msgLen := len(hm.body)
data := make([]byte, 4+len(hm.body))
data[0] = byte(hm.msgType)
data[1] = byte(msgLen >> 16)
data[2] = byte(msgLen >> 8)
data[3] = byte(msgLen)
copy(data[4:], hm.body)
return data
}
func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body)
var body HandshakeMessageBody
switch hm.msgType {
case HandshakeTypeClientHello:
body = new(ClientHelloBody)
case HandshakeTypeServerHello:
body = new(ServerHelloBody)
case HandshakeTypeHelloRetryRequest:
body = new(HelloRetryRequestBody)
case HandshakeTypeEncryptedExtensions:
body = new(EncryptedExtensionsBody)
case HandshakeTypeCertificate:
body = new(CertificateBody)
case HandshakeTypeCertificateRequest:
body = new(CertificateRequestBody)
case HandshakeTypeCertificateVerify:
body = new(CertificateVerifyBody)
case HandshakeTypeFinished:
body = &FinishedBody{VerifyDataLen: len(hm.body)}
case HandshakeTypeNewSessionTicket:
body = new(NewSessionTicketBody)
case HandshakeTypeKeyUpdate:
body = new(KeyUpdateBody)
case HandshakeTypeEndOfEarlyData:
body = new(EndOfEarlyDataBody)
default:
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
}
_, err := body.Unmarshal(hm.body)
return body, err
}
func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
data, err := body.Marshal()
if err != nil {
return nil, err
}
return &HandshakeMessage{
msgType: body.Type(),
body: data,
}, nil
}
type HandshakeLayer struct {
nonblocking bool // Should we operate in nonblocking mode
conn *RecordLayer // Used for reading/writing records
frame *frameReader // The buffered frame reader
}
type handshakeLayerFrameDetails struct{}
func (d handshakeLayerFrameDetails) headerLen() int {
return handshakeHeaderLen
}
func (d handshakeLayerFrameDetails) defaultReadLen() int {
return handshakeHeaderLen + maxFragmentLen
}
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
logf(logTypeIO, "Header=%x", hdr)
return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil
}
func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer {
h := HandshakeLayer{}
h.conn = r
h.frame = newFrameReader(&handshakeLayerFrameDetails{})
return &h
}
func (h *HandshakeLayer) readRecord() error {
logf(logTypeIO, "Trying to read record")
pt, err := h.conn.ReadRecord()
if err != nil {
return err
}
if pt.contentType != RecordTypeHandshake &&
pt.contentType != RecordTypeAlert {
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
}
if pt.contentType == RecordTypeAlert {
logf(logTypeIO, "read alert %v", pt.fragment[1])
if len(pt.fragment) < 2 {
h.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
return Alert(pt.fragment[1])
}
logf(logTypeIO, "read handshake record of len %v", len(pt.fragment))
h.frame.addChunk(pt.fragment)
return nil
}
// sendAlert sends a TLS alert message.
func (h *HandshakeLayer) sendAlert(err Alert) error {
tmp := make([]byte, 2)
tmp[0] = AlertLevelError
tmp[1] = byte(err)
h.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeAlert,
fragment: tmp},
)
// closeNotify is a special case in that it isn't an error:
if err != AlertCloseNotify {
return &net.OpError{Op: "local error", Err: err}
}
return nil
}
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
var hdr, body []byte
var err error
for {
logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder))
if h.frame.needed() > 0 {
logf(logTypeHandshake, "Trying to read a new record")
err = h.readRecord()
}
if err != nil && (h.nonblocking || err != WouldBlock) {
return nil, err
}
hdr, body, err = h.frame.process()
if err == nil {
break
}
if err != nil && (h.nonblocking || err != WouldBlock) {
return nil, err
}
}
logf(logTypeHandshake, "read handshake message")
hm := &HandshakeMessage{}
hm.msgType = HandshakeType(hdr[0])
hm.body = make([]byte, len(body))
copy(hm.body, body)
return hm, nil
}
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error {
return h.WriteMessages([]*HandshakeMessage{hm})
}
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error {
for _, hm := range hms {
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
}
// Write out headers and bodies
buffer := []byte{}
for _, msg := range hms {
msgLen := len(msg.body)
if msgLen > maxHandshakeMessageLen {
return fmt.Errorf("tls.handshakelayer: Message too large to send")
}
buffer = append(buffer, msg.Marshal()...)
}
// Send full-size fragments
var start int
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
err := h.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buffer[start : start+maxFragmentLen],
})
if err != nil {
return err
}
}
// Send a final partial fragment if necessary
if start < len(buffer) {
err := h.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buffer[start:],
})
if err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,450 @@
package mint
import (
"bytes"
"crypto"
"crypto/x509"
"encoding/binary"
"fmt"
"github.com/bifurcation/mint/syntax"
)
type HandshakeMessageBody interface {
Type() HandshakeType
Marshal() ([]byte, error)
Unmarshal(data []byte) (int, error)
}
// struct {
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
// Random random;
// opaque legacy_session_id<0..32>;
// CipherSuite cipher_suites<2..2^16-2>;
// opaque legacy_compression_methods<1..2^8-1>;
// Extension extensions<0..2^16-1>;
// } ClientHello;
type ClientHelloBody struct {
// Omitted: clientVersion
// Omitted: legacySessionID
// Omitted: legacyCompressionMethods
Random [32]byte
CipherSuites []CipherSuite
Extensions ExtensionList
}
type clientHelloBodyInner struct {
LegacyVersion uint16
Random [32]byte
LegacySessionID []byte `tls:"head=1,max=32"`
CipherSuites []CipherSuite `tls:"head=2,min=2"`
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
Extensions []Extension `tls:"head=2"`
}
func (ch ClientHelloBody) Type() HandshakeType {
return HandshakeTypeClientHello
}
func (ch ClientHelloBody) Marshal() ([]byte, error) {
return syntax.Marshal(clientHelloBodyInner{
LegacyVersion: 0x0303,
Random: ch.Random,
LegacySessionID: []byte{},
CipherSuites: ch.CipherSuites,
LegacyCompressionMethods: []byte{0},
Extensions: ch.Extensions,
})
}
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
var inner clientHelloBodyInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
// We are strict about these things because we only support 1.3
if inner.LegacyVersion != 0x0303 {
return 0, fmt.Errorf("tls.clienthello: Incorrect version number")
}
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
}
ch.Random = inner.Random
ch.CipherSuites = inner.CipherSuites
ch.Extensions = inner.Extensions
return read, nil
}
// TODO: File a spec bug to clarify this
func (ch ClientHelloBody) Truncated() ([]byte, error) {
if len(ch.Extensions) == 0 {
return nil, fmt.Errorf("tls.clienthello.truncate: No extensions")
}
pskExt := ch.Extensions[len(ch.Extensions)-1]
if pskExt.ExtensionType != ExtensionTypePreSharedKey {
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK")
}
chm, err := HandshakeMessageFromBody(&ch)
if err != nil {
return nil, err
}
chData := chm.Marshal()
psk := PreSharedKeyExtension{
HandshakeType: HandshakeTypeClientHello,
}
_, err = psk.Unmarshal(pskExt.ExtensionData)
if err != nil {
return nil, err
}
// Marshal just the binders so that we know how much to truncate
binders := struct {
Binders []PSKBinderEntry `tls:"head=2,min=33"`
}{Binders: psk.Binders}
binderData, _ := syntax.Marshal(binders)
binderLen := len(binderData)
chLen := len(chData)
return chData[:chLen-binderLen], nil
}
// struct {
// ProtocolVersion server_version;
// CipherSuite cipher_suite;
// Extension extensions<2..2^16-1>;
// } HelloRetryRequest;
type HelloRetryRequestBody struct {
Version uint16
CipherSuite CipherSuite
Extensions ExtensionList `tls:"head=2,min=2"`
}
func (hrr HelloRetryRequestBody) Type() HandshakeType {
return HandshakeTypeHelloRetryRequest
}
func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) {
return syntax.Marshal(hrr)
}
func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, hrr)
}
// struct {
// ProtocolVersion version;
// Random random;
// CipherSuite cipher_suite;
// Extension extensions<0..2^16-1>;
// } ServerHello;
type ServerHelloBody struct {
Version uint16
Random [32]byte
CipherSuite CipherSuite
Extensions ExtensionList `tls:"head=2"`
}
func (sh ServerHelloBody) Type() HandshakeType {
return HandshakeTypeServerHello
}
func (sh ServerHelloBody) Marshal() ([]byte, error) {
return syntax.Marshal(sh)
}
func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sh)
}
// struct {
// opaque verify_data[verify_data_length];
// } Finished;
//
// verifyDataLen is not a field in the TLS struct, but we add it here so
// that calling code can tell us how much data to expect when we marshal /
// unmarshal. (We could add this to the marshal/unmarshal methods, but let's
// try to keep the signature consistent for now.)
//
// For similar reasons, we don't use the `syntax` module here, because this
// struct doesn't map well to standard TLS presentation language concepts.
//
// TODO: File a spec bug
type FinishedBody struct {
VerifyDataLen int
VerifyData []byte
}
func (fin FinishedBody) Type() HandshakeType {
return HandshakeTypeFinished
}
func (fin FinishedBody) Marshal() ([]byte, error) {
if len(fin.VerifyData) != fin.VerifyDataLen {
return nil, fmt.Errorf("tls.finished: data length mismatch")
}
body := make([]byte, len(fin.VerifyData))
copy(body, fin.VerifyData)
return body, nil
}
func (fin *FinishedBody) Unmarshal(data []byte) (int, error) {
if len(data) < fin.VerifyDataLen {
return 0, fmt.Errorf("tls.finished: Malformed finished; too short")
}
fin.VerifyData = make([]byte, fin.VerifyDataLen)
copy(fin.VerifyData, data[:fin.VerifyDataLen])
return fin.VerifyDataLen, nil
}
// struct {
// Extension extensions<0..2^16-1>;
// } EncryptedExtensions;
//
// Marshal() and Unmarshal() are handled by ExtensionList
type EncryptedExtensionsBody struct {
Extensions ExtensionList `tls:"head=2"`
}
func (ee EncryptedExtensionsBody) Type() HandshakeType {
return HandshakeTypeEncryptedExtensions
}
func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) {
return syntax.Marshal(ee)
}
func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, ee)
}
// opaque ASN1Cert<1..2^24-1>;
//
// struct {
// ASN1Cert cert_data;
// Extension extensions<0..2^16-1>
// } CertificateEntry;
//
// struct {
// opaque certificate_request_context<0..2^8-1>;
// CertificateEntry certificate_list<0..2^24-1>;
// } Certificate;
type CertificateEntry struct {
CertData *x509.Certificate
Extensions ExtensionList
}
type CertificateBody struct {
CertificateRequestContext []byte
CertificateList []CertificateEntry
}
type certificateEntryInner struct {
CertData []byte `tls:"head=3,min=1"`
Extensions ExtensionList `tls:"head=2"`
}
type certificateBodyInner struct {
CertificateRequestContext []byte `tls:"head=1"`
CertificateList []certificateEntryInner `tls:"head=3"`
}
func (c CertificateBody) Type() HandshakeType {
return HandshakeTypeCertificate
}
func (c CertificateBody) Marshal() ([]byte, error) {
inner := certificateBodyInner{
CertificateRequestContext: c.CertificateRequestContext,
CertificateList: make([]certificateEntryInner, len(c.CertificateList)),
}
for i, entry := range c.CertificateList {
inner.CertificateList[i] = certificateEntryInner{
CertData: entry.CertData.Raw,
Extensions: entry.Extensions,
}
}
return syntax.Marshal(inner)
}
func (c *CertificateBody) Unmarshal(data []byte) (int, error) {
inner := certificateBodyInner{}
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return read, err
}
c.CertificateRequestContext = inner.CertificateRequestContext
c.CertificateList = make([]CertificateEntry, len(inner.CertificateList))
for i, entry := range inner.CertificateList {
c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData)
if err != nil {
return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err)
}
c.CertificateList[i].Extensions = entry.Extensions
}
return read, nil
}
// struct {
// SignatureScheme algorithm;
// opaque signature<0..2^16-1>;
// } CertificateVerify;
type CertificateVerifyBody struct {
Algorithm SignatureScheme
Signature []byte `tls:"head=2"`
}
func (cv CertificateVerifyBody) Type() HandshakeType {
return HandshakeTypeCertificateVerify
}
func (cv CertificateVerifyBody) Marshal() ([]byte, error) {
return syntax.Marshal(cv)
}
func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, cv)
}
func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte {
// TODO: Change context for client auth
// TODO: Put this in a const
const context = "TLS 1.3, server CertificateVerify"
sigInput := bytes.Repeat([]byte{0x20}, 64)
sigInput = append(sigInput, []byte(context)...)
sigInput = append(sigInput, []byte{0}...)
sigInput = append(sigInput, data...)
return sigInput
}
func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) {
sigInput := cv.EncodeSignatureInput(handshakeHash)
cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput)
logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
return
}
func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error {
sigInput := cv.EncodeSignatureInput(handshakeHash)
logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
return verify(cv.Algorithm, publicKey, sigInput, cv.Signature)
}
// struct {
// opaque certificate_request_context<0..2^8-1>;
// Extension extensions<2..2^16-1>;
// } CertificateRequest;
type CertificateRequestBody struct {
CertificateRequestContext []byte `tls:"head=1"`
Extensions ExtensionList `tls:"head=2"`
}
func (cr CertificateRequestBody) Type() HandshakeType {
return HandshakeTypeCertificateRequest
}
func (cr CertificateRequestBody) Marshal() ([]byte, error) {
return syntax.Marshal(cr)
}
func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, cr)
}
// struct {
// uint32 ticket_lifetime;
// uint32 ticket_age_add;
// opaque ticket_nonce<1..255>;
// opaque ticket<1..2^16-1>;
// Extension extensions<0..2^16-2>;
// } NewSessionTicket;
type NewSessionTicketBody struct {
TicketLifetime uint32
TicketAgeAdd uint32
TicketNonce []byte `tls:"head=1,min=1"`
Ticket []byte `tls:"head=2,min=1"`
Extensions ExtensionList `tls:"head=2"`
}
const ticketNonceLen = 16
func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) {
buf := make([]byte, 4+ticketNonceLen+ticketLen)
_, err := prng.Read(buf)
if err != nil {
return nil, err
}
tkt := &NewSessionTicketBody{
TicketLifetime: ticketLifetime,
TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]),
TicketNonce: buf[4 : 4+ticketNonceLen],
Ticket: buf[4+ticketNonceLen:],
}
return tkt, err
}
func (tkt NewSessionTicketBody) Type() HandshakeType {
return HandshakeTypeNewSessionTicket
}
func (tkt NewSessionTicketBody) Marshal() ([]byte, error) {
return syntax.Marshal(tkt)
}
func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, tkt)
}
// enum {
// update_not_requested(0), update_requested(1), (255)
// } KeyUpdateRequest;
//
// struct {
// KeyUpdateRequest request_update;
// } KeyUpdate;
type KeyUpdateBody struct {
KeyUpdateRequest KeyUpdateRequest
}
func (ku KeyUpdateBody) Type() HandshakeType {
return HandshakeTypeKeyUpdate
}
func (ku KeyUpdateBody) Marshal() ([]byte, error) {
return syntax.Marshal(ku)
}
func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, ku)
}
// struct {} EndOfEarlyData;
type EndOfEarlyDataBody struct{}
func (eoed EndOfEarlyDataBody) Type() HandshakeType {
return HandshakeTypeEndOfEarlyData
}
func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) {
return []byte{}, nil
}
func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) {
return 0, nil
}

55
vendor/github.com/bifurcation/mint/log.go generated vendored Normal file
View file

@ -0,0 +1,55 @@
package mint
import (
"fmt"
"log"
"os"
"strings"
)
// We use this environment variable to control logging. It should be a
// comma-separated list of log tags (see below) or "*" to enable all logging.
const logConfigVar = "MINT_LOG"
// Pre-defined log types
const (
logTypeCrypto = "crypto"
logTypeHandshake = "handshake"
logTypeNegotiation = "negotiation"
logTypeIO = "io"
logTypeFrameReader = "frame"
logTypeVerbose = "verbose"
)
var (
logFunction = log.Printf
logAll = false
logSettings = map[string]bool{}
)
func init() {
parseLogEnv(os.Environ())
}
func parseLogEnv(env []string) {
for _, stmt := range env {
if strings.HasPrefix(stmt, logConfigVar+"=") {
val := stmt[len(logConfigVar)+1:]
if val == "*" {
logAll = true
} else {
for _, t := range strings.Split(val, ",") {
logSettings[t] = true
}
}
}
}
}
func logf(tag string, format string, args ...interface{}) {
if logAll || logSettings[tag] {
fullFormat := fmt.Sprintf("[%s] %s", tag, format)
logFunction(fullFormat, args...)
}
}

217
vendor/github.com/bifurcation/mint/negotiation.go generated vendored Normal file
View file

@ -0,0 +1,217 @@
package mint
import (
"bytes"
"encoding/hex"
"fmt"
"time"
)
func VersionNegotiation(offered, supported []uint16) (bool, uint16) {
for _, offeredVersion := range offered {
for _, supportedVersion := range supported {
logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion)
if offeredVersion == supportedVersion {
// XXX: Should probably be highest supported version, but for now, we
// only support one version, so it doesn't really matter.
return true, offeredVersion
}
}
}
return false, 0
}
func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) {
for _, share := range keyShares {
for _, group := range groups {
if group != share.Group {
continue
}
pub, priv, err := newKeyShare(share.Group)
if err != nil {
// If we encounter an error, just keep looking
continue
}
dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv)
if err != nil {
// If we encounter an error, just keep looking
continue
}
return true, group, pub, dhSecret
}
}
return false, 0, nil, nil
}
const (
ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds
)
func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) {
logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size())
for i, id := range identities {
identityHex := hex.EncodeToString(id.Identity)
psk, ok := psks.Get(identityHex)
if !ok {
logf(logTypeNegotiation, "No PSK for identity %x", identityHex)
continue
}
// For resumption, make sure the ticket age is correct
if psk.IsResumption {
extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd
knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond)
ticketAgeDelta := knownTicketAge - extTicketAge
if knownTicketAge < extTicketAge {
ticketAgeDelta = extTicketAge - knownTicketAge
}
if ticketAgeDelta > ticketAgeTolerance {
logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity)
logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]",
extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance)
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity)
}
}
params, ok := cipherSuiteMap[psk.CipherSuite]
if !ok {
err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite)
return false, 0, nil, CipherSuiteParams{}, err
}
// Compute binder
binderLabel := labelExternalBinder
if psk.IsResumption {
binderLabel = labelResumptionBinder
}
h0 := params.Hash.New().Sum(nil)
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
earlySecret := HkdfExtract(params.Hash, zero, psk.Key)
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
// context = ClientHello[truncated]
// context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated]
ctxHash := params.Hash.New()
ctxHash.Write(context)
binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil))
if !bytes.Equal(binder, binders[i].Binder) {
logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder)
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity)
}
logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity)
return true, i, &psk, params, nil
}
logf(logTypeNegotiation, "Failed to find a usable PSK")
return false, 0, nil, CipherSuiteParams{}, nil
}
func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) {
logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes)
dhAllowed := false
dhRequired := true
for _, mode := range modes {
dhAllowed = dhAllowed || (mode == PSKModeDHEKE)
dhRequired = dhRequired && (mode == PSKModeDHEKE)
}
// Use PSK if we can meet DH requirement and modes were provided
usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0)
// Use DH if allowed
usingDH := canDoDH && (dhAllowed || !usingPSK)
logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK)
return usingDH, usingPSK
}
func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) {
// Select for server name if provided
candidates := certs
if serverName != nil {
candidatesByName := []*Certificate{}
for _, cert := range certs {
for _, name := range cert.Chain[0].DNSNames {
if len(*serverName) > 0 && name == *serverName {
candidatesByName = append(candidatesByName, cert)
}
}
}
if len(candidatesByName) == 0 {
return nil, 0, fmt.Errorf("No certificates available for server name")
}
candidates = candidatesByName
}
// Select for signature scheme
for _, cert := range candidates {
for _, scheme := range signatureSchemes {
if !schemeValidForKey(scheme, cert.PrivateKey) {
continue
}
return cert, scheme, nil
}
}
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
}
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool {
usingEarlyData := gotEarlyData && usingPSK && allowEarlyData
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData)
return usingEarlyData
}
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {
for _, s1 := range offered {
if psk != nil {
if s1 == psk.CipherSuite {
return s1, nil
}
continue
}
for _, s2 := range supported {
if s1 == s2 {
return s1, nil
}
}
}
return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil)
}
func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) {
for _, p1 := range offered {
if psk != nil {
if p1 != psk.NextProto {
continue
}
}
for _, p2 := range supported {
if p1 == p2 {
return p1, nil
}
}
}
// If the client offers ALPN on resumption, it must match the earlier one
var err error
if psk != nil && psk.IsResumption && (len(offered) > 0) {
err = fmt.Errorf("ALPN for PSK not provided")
}
return "", err
}

296
vendor/github.com/bifurcation/mint/record-layer.go generated vendored Normal file
View file

@ -0,0 +1,296 @@
package mint
import (
"bytes"
"crypto/cipher"
"fmt"
"io"
"sync"
)
const (
sequenceNumberLen = 8 // sequence number length
recordHeaderLen = 5 // record header length
maxFragmentLen = 1 << 14 // max number of bytes in a record
)
type DecryptError string
func (err DecryptError) Error() string {
return string(err)
}
// struct {
// ContentType type;
// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */
// uint16 length;
// opaque fragment[TLSPlaintext.length];
// } TLSPlaintext;
type TLSPlaintext struct {
// Omitted: record_version (static)
// Omitted: length (computed from fragment)
contentType RecordType
fragment []byte
}
type RecordLayer struct {
sync.Mutex
conn io.ReadWriter // The underlying connection
frame *frameReader // The buffered frame reader
nextData []byte // The next record to send
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
cachedError error // Error on the last record read
ivLength int // Length of the seq and nonce fields
seq []byte // Zero-padded sequence number
nonce []byte // Buffer for per-record nonces
cipher cipher.AEAD // AEAD cipher
}
type recordLayerFrameDetails struct{}
func (d recordLayerFrameDetails) headerLen() int {
return recordHeaderLen
}
func (d recordLayerFrameDetails) defaultReadLen() int {
return recordHeaderLen + maxFragmentLen
}
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
return (int(hdr[3]) << 8) | int(hdr[4]), nil
}
func NewRecordLayer(conn io.ReadWriter) *RecordLayer {
r := RecordLayer{}
r.conn = conn
r.frame = newFrameReader(recordLayerFrameDetails{})
r.ivLength = 0
return &r
}
func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error {
var err error
r.cipher, err = cipher(key)
if err != nil {
return err
}
r.ivLength = len(iv)
r.seq = bytes.Repeat([]byte{0}, r.ivLength)
r.nonce = make([]byte, r.ivLength)
copy(r.nonce, iv)
return nil
}
func (r *RecordLayer) incrementSequenceNumber() {
if r.ivLength == 0 {
return
}
for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- {
r.seq[i]++
r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i]
if r.seq[i] != 0 {
return
}
}
// Not allowed to let sequence number wrap.
// Instead, must renegotiate before it does.
// Not likely enough to bother.
panic("TLS: sequence number wraparound")
}
func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext {
// Expand the fragment to hold contentType, padding, and overhead
originalLen := len(pt.fragment)
plaintextLen := originalLen + 1 + padLen
ciphertextLen := plaintextLen + r.cipher.Overhead()
// Assemble the revised plaintext
out := &TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: make([]byte, ciphertextLen),
}
copy(out.fragment, pt.fragment)
out.fragment[originalLen] = byte(pt.contentType)
for i := 1; i <= padLen; i++ {
out.fragment[originalLen+i] = 0
}
// Encrypt the fragment
payload := out.fragment[:plaintextLen]
r.cipher.Seal(payload[:0], r.nonce, payload, nil)
return out
}
func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) {
if len(pt.fragment) < r.cipher.Overhead() {
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead())
return nil, 0, DecryptError(msg)
}
decryptLen := len(pt.fragment) - r.cipher.Overhead()
out := &TLSPlaintext{
contentType: pt.contentType,
fragment: make([]byte, decryptLen),
}
// Decrypt
_, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil)
if err != nil {
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
}
// Find the padding boundary
padLen := 0
for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ {
}
// Transfer the content type
newLen := decryptLen - padLen - 1
out.contentType = RecordType(out.fragment[newLen])
// Truncate the message to remove contentType, padding, overhead
out.fragment = out.fragment[:newLen]
return out, padLen, nil
}
func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
var pt *TLSPlaintext
var err error
for {
pt, err = r.nextRecord()
if err == nil {
break
}
if !block || err != WouldBlock {
return 0, err
}
}
return pt.contentType, nil
}
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
pt, err := r.nextRecord()
// Consume the cached record if there was one
r.cachedRecord = nil
r.cachedError = nil
return pt, err
}
func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
if r.cachedRecord != nil {
logf(logTypeIO, "Returning cached record")
return r.cachedRecord, r.cachedError
}
// Loop until one of three things happens:
//
// 1. We get a frame
// 2. We try to read off the socket and get nothing, in which case
// return WouldBlock
// 3. We get an error.
err := WouldBlock
var header, body []byte
for err != nil {
if r.frame.needed() > 0 {
buf := make([]byte, recordHeaderLen+maxFragmentLen)
n, err := r.conn.Read(buf)
if err != nil {
logf(logTypeIO, "Error reading, %v", err)
return nil, err
}
if n == 0 {
return nil, WouldBlock
}
logf(logTypeIO, "Read %v bytes", n)
buf = buf[:n]
r.frame.addChunk(buf)
}
header, body, err = r.frame.process()
// Loop around on WouldBlock to see if some
// data is now available.
if err != nil && err != WouldBlock {
return nil, err
}
}
pt := &TLSPlaintext{}
// Validate content type
switch RecordType(header[0]) {
default:
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData:
pt.contentType = RecordType(header[0])
}
// Validate version
if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) {
return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2])
}
// Validate size < max
size := (int(header[3]) << 8) + int(header[4])
if size > maxFragmentLen+256 {
return nil, fmt.Errorf("tls.record: Ciphertext size too big")
}
pt.fragment = make([]byte, size)
copy(pt.fragment, body)
// Attempt to decrypt fragment
if r.cipher != nil {
pt, _, err = r.decrypt(pt)
if err != nil {
return nil, err
}
}
// Check that plaintext length is not too long
if len(pt.fragment) > maxFragmentLen {
return nil, fmt.Errorf("tls.record: Plaintext size too big")
}
logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment)
r.cachedRecord = pt
r.incrementSequenceNumber()
return pt, nil
}
func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
return r.WriteRecordWithPadding(pt, 0)
}
func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
if r.cipher != nil {
pt = r.encrypt(pt, padLen)
} else if padLen > 0 {
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
}
if len(pt.fragment) > maxFragmentLen {
return fmt.Errorf("tls.record: Record size too big")
}
length := len(pt.fragment)
header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)}
record := append(header, pt.fragment...)
logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment)
r.incrementSequenceNumber()
_, err := r.conn.Write(record)
return err
}

View file

@ -0,0 +1,898 @@
package mint
import (
"bytes"
"hash"
"reflect"
)
// Server State Machine
//
// START <-----+
// Recv ClientHello | | Send HelloRetryRequest
// v |
// RECVD_CH ----+
// | Select parameters
// | Send ServerHello
// v
// NEGOTIATED
// | Send EncryptedExtensions
// | [Send CertificateRequest]
// Can send | [Send Certificate + CertificateVerify]
// app data --> | Send Finished
// after +--------+--------+
// here No 0-RTT | | 0-RTT
// | v
// | WAIT_EOED <---+
// | Recv | | | Recv
// | EndOfEarlyData | | | early data
// | | +-----+
// +> WAIT_FLIGHT2 <-+
// |
// +--------+--------+
// No auth | | Client auth
// | |
// | v
// | WAIT_CERT
// | Recv | | Recv Certificate
// | empty | v
// | Certificate | WAIT_CV
// | | | Recv
// | v | CertificateVerify
// +-> WAIT_FINISHED <---+
// | Recv Finished
// v
// CONNECTED
//
// NB: Not using state RECVD_CH
//
// State Instructions
// START {}
// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)]
// WAIT_EOED RekeyIn;
// WAIT_FLIGHT2 {}
// WAIT_CERT_CR {}
// WAIT_CERT {}
// WAIT_CV {}
// WAIT_FINISHED RekeyIn; RekeyOut;
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
type ServerStateStart struct {
Caps Capabilities
conn *Conn
cookieSent bool
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
}
func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeClientHello {
logf(logTypeHandshake, "[ServerStateStart] unexpected message")
return nil, nil, AlertUnexpectedMessage
}
ch := &ClientHelloBody{}
_, err := ch.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
clientHello := hm
connParams := ConnectionParameters{}
supportedVersions := new(SupportedVersionsExtension)
serverName := new(ServerNameExtension)
supportedGroups := new(SupportedGroupsExtension)
signatureAlgorithms := new(SignatureAlgorithmsExtension)
clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello}
clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello}
clientEarlyData := &EarlyDataExtension{}
clientALPN := new(ALPNExtension)
clientPSKModes := new(PSKKeyExchangeModesExtension)
clientCookie := new(CookieExtension)
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
gotSupportedVersions := ch.Extensions.Find(supportedVersions)
gotServerName := ch.Extensions.Find(serverName)
gotSupportedGroups := ch.Extensions.Find(supportedGroups)
gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms)
gotEarlyData := ch.Extensions.Find(clientEarlyData)
ch.Extensions.Find(clientKeyShares)
ch.Extensions.Find(clientPSK)
ch.Extensions.Find(clientALPN)
ch.Extensions.Find(clientPSKModes)
ch.Extensions.Find(clientCookie)
if gotServerName {
connParams.ServerName = string(*serverName)
}
// If the client didn't send supportedVersions or doesn't support 1.3,
// then we're done here.
if !gotSupportedVersions {
logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions")
return nil, nil, AlertProtocolVersion
}
versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion})
if !versionOK {
logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version")
return nil, nil, AlertProtocolVersion
}
if state.Caps.RequireCookie && state.cookieSent && !state.Caps.CookieHandler.Validate(state.conn, clientCookie.Cookie) {
logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch")
return nil, nil, AlertAccessDenied
}
// Figure out if we can do DH
canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups)
// Figure out if we can do PSK
canDoPSK := false
var selectedPSK int
var psk *PreSharedKey
var params CipherSuiteParams
if len(clientPSK.Identities) > 0 {
contextBase := []byte{}
if state.helloRetryRequest != nil {
chBytes := state.firstClientHello.Marshal()
hrrBytes := state.helloRetryRequest.Marshal()
contextBase = append(chBytes, hrrBytes...)
}
chTrunc, err := ch.Truncated()
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err)
return nil, nil, AlertDecodeError
}
context := append(contextBase, chTrunc...)
canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err)
return nil, nil, AlertInternalError
}
}
// Figure out if we actually should do DH / PSK
connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes)
// Select a ciphersuite
connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err)
return nil, nil, AlertHandshakeFailure
}
// Send a cookie if required
// NB: Need to do this here because it's after ciphersuite selection, which
// has to be after PSK selection.
// XXX: Doing this statefully for now, could be stateless
var cookieData []byte
if state.Caps.RequireCookie && !state.cookieSent {
var err error
cookieData, err = state.Caps.CookieHandler.Generate(state.conn)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err)
return nil, nil, AlertInternalError
}
}
if cookieData != nil {
// Ignoring errors because everything here is newly constructed, so there
// shouldn't be marshal errors
hrr := &HelloRetryRequestBody{
Version: supportedVersion,
CipherSuite: connParams.CipherSuite,
}
hrr.Extensions.Add(&CookieExtension{Cookie: cookieData})
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
}
}
helloRetryRequest, err := HandshakeMessageFromBody(hrr)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err)
return nil, nil, AlertInternalError
}
params := cipherSuiteMap[connParams.CipherSuite]
h := params.Hash.New()
h.Write(clientHello.Marshal())
firstClientHello := &HandshakeMessage{
msgType: HandshakeTypeMessageHash,
body: h.Sum(nil),
}
nextState := ServerStateStart{
Caps: state.Caps,
conn: state.conn,
cookieSent: true,
firstClientHello: firstClientHello,
helloRetryRequest: helloRetryRequest,
}
toSend := []HandshakeAction{SendHandshakeMessage{helloRetryRequest}}
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]")
return nextState, toSend, AlertNoAlert
}
// If we've got no entropy to make keys from, fail
if !connParams.UsingDH && !connParams.UsingPSK {
logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated")
return nil, nil, AlertHandshakeFailure
}
var pskSecret []byte
var cert *Certificate
var certScheme SignatureScheme
if connParams.UsingPSK {
pskSecret = psk.Key
} else {
psk = nil
// If we're not using a PSK mode, then we need to have certain extensions
if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms {
logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)",
gotServerName, gotSupportedGroups, gotSignatureAlgorithms)
return nil, nil, AlertMissingExtension
}
// Select a certificate
name := string(*serverName)
var err error
cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err)
return nil, nil, AlertAccessDenied
}
}
if !connParams.UsingDH {
dhSecret = nil
}
// Figure out if we're going to do early data
var clientEarlyTrafficSecret []byte
connParams.ClientSendingEarlyData = gotEarlyData
connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData)
if connParams.UsingEarlyData {
h := params.Hash.New()
h.Write(clientHello.Marshal())
chHash := h.Sum(nil)
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
earlySecret := HkdfExtract(params.Hash, zero, pskSecret)
clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
}
// Select a next protocol
connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err)
return nil, nil, AlertNoApplicationProtocol
}
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]")
return ServerStateNegotiated{
Caps: state.Caps,
Params: connParams,
dhGroup: dhGroup,
dhPublic: dhPublic,
dhSecret: dhSecret,
pskSecret: pskSecret,
selectedPSK: selectedPSK,
cert: cert,
certScheme: certScheme,
clientEarlyTrafficSecret: clientEarlyTrafficSecret,
firstClientHello: state.firstClientHello,
helloRetryRequest: state.helloRetryRequest,
clientHello: clientHello,
}.Next(nil)
}
type ServerStateNegotiated struct {
Caps Capabilities
Params ConnectionParameters
dhGroup NamedGroup
dhPublic []byte
dhSecret []byte
pskSecret []byte
clientEarlyTrafficSecret []byte
selectedPSK int
cert *Certificate
certScheme SignatureScheme
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
clientHello *HandshakeMessage
}
func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
// Create the ServerHello
sh := &ServerHelloBody{
Version: supportedVersion,
CipherSuite: state.Params.CipherSuite,
}
_, err := prng.Read(sh.Random[:])
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err)
return nil, nil, AlertInternalError
}
if state.Params.UsingDH {
logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension")
err = sh.Extensions.Add(&KeyShareExtension{
HandshakeType: HandshakeTypeServerHello,
Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}},
})
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err)
return nil, nil, AlertInternalError
}
}
if state.Params.UsingPSK {
logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension")
err = sh.Extensions.Add(&PreSharedKeyExtension{
HandshakeType: HandshakeTypeServerHello,
SelectedIdentity: uint16(state.selectedPSK),
})
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err)
return nil, nil, AlertInternalError
}
}
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
}
}
serverHello, err := HandshakeMessageFromBody(sh)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err)
return nil, nil, AlertInternalError
}
// Look up crypto params
params, ok := cipherSuiteMap[sh.CipherSuite]
if !ok {
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite)
return nil, nil, AlertHandshakeFailure
}
// Start up the handshake hash
handshakeHash := params.Hash.New()
handshakeHash.Write(state.firstClientHello.Marshal())
handshakeHash.Write(state.helloRetryRequest.Marshal())
handshakeHash.Write(state.clientHello.Marshal())
handshakeHash.Write(serverHello.Marshal())
// Compute handshake secrets
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
var earlySecret []byte
if state.Params.UsingPSK {
earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret)
} else {
earlySecret = HkdfExtract(params.Hash, zero, zero)
}
if state.dhSecret == nil {
state.dhSecret = zero
}
h0 := params.Hash.New().Sum(nil)
h2 := handshakeHash.Sum(nil)
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret)
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret)
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret)
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
// Send an EncryptedExtensions message (even if it's empty)
eeList := ExtensionList{}
if state.Params.NextProto != "" {
logf(logTypeHandshake, "[server] sending ALPN extension")
err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}})
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err)
return nil, nil, AlertInternalError
}
}
if state.Params.UsingEarlyData {
logf(logTypeHandshake, "[server] sending EDI extension")
err = eeList.Add(&EarlyDataExtension{})
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err)
return nil, nil, AlertInternalError
}
}
ee := &EncryptedExtensionsBody{eeList}
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
}
}
eem, err := HandshakeMessageFromBody(ee)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err)
return nil, nil, AlertInternalError
}
handshakeHash.Write(eem.Marshal())
toSend := []HandshakeAction{
SendHandshakeMessage{serverHello},
RekeyOut{Label: "handshake", KeySet: serverHandshakeKeys},
SendHandshakeMessage{eem},
}
// Authenticate with a certificate if required
if !state.Params.UsingPSK {
// Send a CertificateRequest message if we want client auth
if state.Caps.RequireClientAuth {
state.Params.UsingClientAuth = true
// XXX: We don't support sending any constraints besides a list of
// supported signature algorithms
cr := &CertificateRequestBody{}
schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
err := cr.Extensions.Add(schemes)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err)
return nil, nil, AlertInternalError
}
crm, err := HandshakeMessageFromBody(cr)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err)
return nil, nil, AlertInternalError
}
//TODO state.state.serverCertificateRequest = cr
toSend = append(toSend, SendHandshakeMessage{crm})
handshakeHash.Write(crm.Marshal())
}
// Create and send Certificate, CertificateVerify
certificate := &CertificateBody{
CertificateList: make([]CertificateEntry, len(state.cert.Chain)),
}
for i, entry := range state.cert.Chain {
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
}
certm, err := HandshakeMessageFromBody(certificate)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certm})
handshakeHash.Write(certm.Marshal())
certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme}
logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash)
hcv := handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
err = certificateVerify.Sign(state.cert.PrivateKey, hcv)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err)
return nil, nil, AlertInternalError
}
certvm, err := HandshakeMessageFromBody(certificateVerify)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certvm})
handshakeHash.Write(certvm.Marshal())
}
// Compute secrets resulting from the server's first flight
h3 := handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3)
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
// Assemble the Finished message
fin := &FinishedBody{
VerifyDataLen: len(serverFinishedData),
VerifyData: serverFinishedData,
}
finm, _ := HandshakeMessageFromBody(fin)
toSend = append(toSend, SendHandshakeMessage{finm})
handshakeHash.Write(finm.Marshal())
// Compute traffic secrets
h4 := handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4)
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4)
clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4)
serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4)
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret)
toSend = append(toSend, RekeyOut{Label: "application", KeySet: serverTrafficKeys})
exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4)
logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
if state.Params.UsingEarlyData {
clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret)
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]")
nextState := ServerStateWaitEOED{
AuthCertificate: state.Caps.AuthCertificate,
Params: state.Params,
cryptoParams: params,
handshakeHash: handshakeHash,
masterSecret: masterSecret,
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
clientTrafficSecret: clientTrafficSecret,
serverTrafficSecret: serverTrafficSecret,
exporterSecret: exporterSecret,
}
toSend = append(toSend, []HandshakeAction{
RekeyIn{Label: "early", KeySet: clientEarlyTrafficKeys},
ReadEarlyData{},
}...)
return nextState, toSend, AlertNoAlert
}
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]")
toSend = append(toSend, []HandshakeAction{
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
ReadPastEarlyData{},
}...)
waitFlight2 := ServerStateWaitFlight2{
AuthCertificate: state.Caps.AuthCertificate,
Params: state.Params,
cryptoParams: params,
handshakeHash: handshakeHash,
masterSecret: masterSecret,
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
clientTrafficSecret: clientTrafficSecret,
serverTrafficSecret: serverTrafficSecret,
exporterSecret: exporterSecret,
}
nextState, moreToSend, alert := waitFlight2.Next(nil)
toSend = append(toSend, moreToSend...)
return nextState, toSend, alert
}
type ServerStateWaitEOED struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData {
logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
if len(hm.body) > 0 {
logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]")
return nil, nil, AlertDecodeError
}
state.handshakeHash.Write(hm.Marshal())
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]")
toSend := []HandshakeAction{
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
}
waitFlight2 := ServerStateWaitFlight2{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
nextState, moreToSend, alert := waitFlight2.Next(nil)
toSend = append(toSend, moreToSend...)
return nextState, toSend, alert
}
type ServerStateWaitFlight2 struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm != nil {
logf(logTypeHandshake, "[ServerStateWaitFlight2] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
if state.Params.UsingClientAuth {
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]")
nextState := ServerStateWaitCert{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]")
nextState := ServerStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
handshakeHash: state.handshakeHash,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
type ServerStateWaitCert struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeCertificate {
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
cert := &CertificateBody{}
_, err := cert.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
return nil, nil, AlertDecodeError
}
state.handshakeHash.Write(hm.Marshal())
if len(cert.CertificateList) == 0 {
logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate")
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]")
nextState := ServerStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
handshakeHash: state.handshakeHash,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]")
nextState := ServerStateWaitCV{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
handshakeHash: state.handshakeHash,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
clientCertificate: cert,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
type ServerStateWaitCV struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
clientCertificate *CertificateBody
}
func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm))
return nil, nil, AlertUnexpectedMessage
}
certVerify := &CertificateVerifyBody{}
_, err := certVerify.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err)
return nil, nil, AlertDecodeError
}
// Verify client signature over handshake hash
hcv := state.handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey
if err := certVerify.Verify(clientPublicKey, hcv); err != nil {
logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err)
return nil, nil, AlertHandshakeFailure
}
if state.AuthCertificate != nil {
err := state.AuthCertificate(state.clientCertificate.CertificateList)
if err != nil {
logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate")
return nil, nil, AlertBadCertificate
}
} else {
logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate")
}
// If it passes, record the certificateVerify in the transcript hash
state.handshakeHash.Write(hm.Marshal())
logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]")
nextState := ServerStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
handshakeHash: state.handshakeHash,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
type ServerStateWaitFinished struct {
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeFinished {
logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()}
_, err := fin.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err)
return nil, nil, AlertDecodeError
}
// Verify client Finished data
h5 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
if !bytes.Equal(fin.VerifyData, clientFinishedData) {
logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify")
return nil, nil, AlertHandshakeFailure
}
// Compute the resumption secret
state.handshakeHash.Write(hm.Marshal())
h6 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6)
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
// Compute client traffic keys
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]")
nextState := StateConnected{
Params: state.Params,
isClient: false,
cryptoParams: state.cryptoParams,
resumptionSecret: resumptionSecret,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
toSend := []HandshakeAction{
RekeyIn{Label: "application", KeySet: clientTrafficKeys},
}
return nextState, toSend, AlertNoAlert
}

230
vendor/github.com/bifurcation/mint/state-machine.go generated vendored Normal file
View file

@ -0,0 +1,230 @@
package mint
import (
"time"
)
// Marker interface for actions that an implementation should take based on
// state transitions.
type HandshakeAction interface{}
type SendHandshakeMessage struct {
Message *HandshakeMessage
}
type SendEarlyData struct{}
type ReadEarlyData struct{}
type ReadPastEarlyData struct{}
type RekeyIn struct {
Label string
KeySet keySet
}
type RekeyOut struct {
Label string
KeySet keySet
}
type StorePSK struct {
PSK PreSharedKey
}
type HandshakeState interface {
Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert)
}
type AppExtensionHandler interface {
Send(hs HandshakeType, el *ExtensionList) error
Receive(hs HandshakeType, el *ExtensionList) error
}
// Capabilities objects represent the capabilities of a TLS client or server,
// as an input to TLS negotiation
type Capabilities struct {
// For both client and server
CipherSuites []CipherSuite
Groups []NamedGroup
SignatureSchemes []SignatureScheme
PSKs PreSharedKeyCache
Certificates []*Certificate
AuthCertificate func(chain []CertificateEntry) error
ExtensionHandler AppExtensionHandler
// For client
PSKModes []PSKKeyExchangeMode
// For server
NextProtos []string
AllowEarlyData bool
RequireCookie bool
CookieHandler CookieHandler
RequireClientAuth bool
}
// ConnectionOptions objects represent per-connection settings for a client
// initiating a connection
type ConnectionOptions struct {
ServerName string
NextProtos []string
EarlyData []byte
}
// ConnectionParameters objects represent the parameters negotiated for a
// connection.
type ConnectionParameters struct {
UsingPSK bool
UsingDH bool
ClientSendingEarlyData bool
UsingEarlyData bool
UsingClientAuth bool
CipherSuite CipherSuite
ServerName string
NextProto string
}
// StateConnected is symmetric between client and server
type StateConnected struct {
Params ConnectionParameters
isClient bool
cryptoParams CipherSuiteParams
resumptionSecret []byte
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
var trafficKeys keySet
if state.isClient {
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
} else {
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
}
kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err)
return nil, AlertInternalError
}
toSend := []HandshakeAction{
SendHandshakeMessage{kum},
RekeyOut{Label: "update", KeySet: trafficKeys},
}
return toSend, AlertNoAlert
}
func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
tkt, err := NewSessionTicket(length, lifetime)
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
return nil, AlertInternalError
}
err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime})
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err)
return nil, AlertInternalError
}
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size())
newPSK := PreSharedKey{
CipherSuite: state.cryptoParams.Suite,
IsResumption: true,
Identity: tkt.Ticket,
Key: resumptionKey,
NextProto: state.Params.NextProto,
ReceivedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second),
TicketAgeAdd: tkt.TicketAgeAdd,
}
tktm, err := HandshakeMessageFromBody(tkt)
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err)
return nil, AlertInternalError
}
toSend := []HandshakeAction{
StorePSK{newPSK},
SendHandshakeMessage{tktm},
}
return toSend, AlertNoAlert
}
func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil {
logf(logTypeHandshake, "[StateConnected] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
bodyGeneric, err := hm.ToBody()
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
switch body := bodyGeneric.(type) {
case *KeyUpdateBody:
var trafficKeys keySet
if !state.isClient {
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
} else {
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
}
toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}}
// If requested, roll outbound keys and send a KeyUpdate
if body.KeyUpdateRequest == KeyUpdateRequested {
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested)
if alert != AlertNoAlert {
return nil, nil, alert
}
toSend = append(toSend, moreToSend...)
}
return state, toSend, AlertNoAlert
case *NewSessionTicketBody:
// XXX: Allow NewSessionTicket in both directions?
if !state.isClient {
return nil, nil, AlertUnexpectedMessage
}
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size())
psk := PreSharedKey{
CipherSuite: state.cryptoParams.Suite,
IsResumption: true,
Identity: body.Ticket,
Key: resumptionKey,
NextProto: state.Params.NextProto,
ReceivedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second),
TicketAgeAdd: body.TicketAgeAdd,
}
toSend := []HandshakeAction{StorePSK{psk}}
return state, toSend, AlertNoAlert
}
logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType)
return nil, nil, AlertUnexpectedMessage
}

243
vendor/github.com/bifurcation/mint/syntax/decode.go generated vendored Normal file
View file

@ -0,0 +1,243 @@
package syntax
import (
"bytes"
"fmt"
"reflect"
"runtime"
)
func Unmarshal(data []byte, v interface{}) (int, error) {
// Check for well-formedness.
// Avoids filling out half a data structure
// before discovering a JSON syntax error.
d := decodeState{}
d.Write(data)
return d.unmarshal(v)
}
// These are the options that can be specified in the struct tag. Right now,
// all of them apply to variable-length vectors and nothing else
type decOpts struct {
head uint // length of length in bytes
min uint // minimum size in bytes
max uint // maximum size in bytes
}
type decodeState struct {
bytes.Buffer
}
func (d *decodeState) unmarshal(v interface{}) (read int, err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
if s, ok := r.(string); ok {
panic(s)
}
err = r.(error)
}
}()
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)")
}
read = d.value(rv)
return read, nil
}
func (e *decodeState) value(v reflect.Value) int {
return valueDecoder(v)(e, v, decOpts{})
}
type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int
func valueDecoder(v reflect.Value) decoderFunc {
return typeDecoder(v.Type().Elem())
}
func typeDecoder(t reflect.Type) decoderFunc {
// Note: Omits the caching / wait-group things that encoding/json uses
return newTypeDecoder(t)
}
func newTypeDecoder(t reflect.Type) decoderFunc {
// Note: Does not support Marshaler, so don't need the allowAddr argument
switch t.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return uintDecoder
case reflect.Array:
return newArrayDecoder(t)
case reflect.Slice:
return newSliceDecoder(t)
case reflect.Struct:
return newStructDecoder(t)
default:
panic(fmt.Errorf("Unsupported type (%s)", t))
}
}
///// Specific decoders below
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
var uintLen int
switch v.Elem().Kind() {
case reflect.Uint8:
uintLen = 1
case reflect.Uint16:
uintLen = 2
case reflect.Uint32:
uintLen = 4
case reflect.Uint64:
uintLen = 8
}
buf := make([]byte, uintLen)
n, err := d.Read(buf)
if err != nil {
panic(err)
}
if n != uintLen {
panic(fmt.Errorf("Insufficient data to read uint"))
}
val := uint64(0)
for _, b := range buf {
val = (val << 8) + uint64(b)
}
v.Elem().SetUint(val)
return uintLen
}
//////////
type arrayDecoder struct {
elemDec decoderFunc
}
func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
n := v.Elem().Type().Len()
read := 0
for i := 0; i < n; i += 1 {
read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts)
}
return read
}
func newArrayDecoder(t reflect.Type) decoderFunc {
dec := &arrayDecoder{typeDecoder(t.Elem())}
return dec.decode
}
//////////
type sliceDecoder struct {
elementType reflect.Type
elementDec decoderFunc
}
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
if opts.head == 0 {
panic(fmt.Errorf("Cannot decode a slice without a header length"))
}
lengthBytes := make([]byte, opts.head)
n, err := d.Read(lengthBytes)
if err != nil {
panic(err)
}
if uint(n) != opts.head {
panic(fmt.Errorf("Not enough data to read header"))
}
length := uint(0)
for _, b := range lengthBytes {
length = (length << 8) + uint(b)
}
if opts.max > 0 && length > opts.max {
panic(fmt.Errorf("Length of vector exceeds declared max"))
}
if length < opts.min {
panic(fmt.Errorf("Length of vector below declared min"))
}
data := make([]byte, length)
n, err = d.Read(data)
if err != nil {
panic(err)
}
if uint(n) != length {
panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length))
}
elemBuf := &decodeState{}
elemBuf.Write(data)
elems := []reflect.Value{}
read := int(opts.head)
for elemBuf.Len() > 0 {
elem := reflect.New(sd.elementType)
read += sd.elementDec(elemBuf, elem, opts)
elems = append(elems, elem)
}
v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems)))
for i := 0; i < len(elems); i += 1 {
v.Elem().Index(i).Set(elems[i].Elem())
}
return read
}
func newSliceDecoder(t reflect.Type) decoderFunc {
dec := &sliceDecoder{
elementType: t.Elem(),
elementDec: typeDecoder(t.Elem()),
}
return dec.decode
}
//////////
type structDecoder struct {
fieldOpts []decOpts
fieldDecs []decoderFunc
}
func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
read := 0
for i := range sd.fieldDecs {
read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i])
}
return read
}
func newStructDecoder(t reflect.Type) decoderFunc {
n := t.NumField()
sd := structDecoder{
fieldOpts: make([]decOpts, n),
fieldDecs: make([]decoderFunc, n),
}
for i := 0; i < n; i += 1 {
f := t.Field(i)
tag := f.Tag.Get("tls")
tagOpts := parseTag(tag)
sd.fieldOpts[i] = decOpts{
head: tagOpts["head"],
max: tagOpts["max"],
min: tagOpts["min"],
}
sd.fieldDecs[i] = typeDecoder(f.Type)
}
return sd.decode
}

187
vendor/github.com/bifurcation/mint/syntax/encode.go generated vendored Normal file
View file

@ -0,0 +1,187 @@
package syntax
import (
"bytes"
"fmt"
"reflect"
"runtime"
)
func Marshal(v interface{}) ([]byte, error) {
e := &encodeState{}
err := e.marshal(v, encOpts{})
if err != nil {
return nil, err
}
return e.Bytes(), nil
}
// These are the options that can be specified in the struct tag. Right now,
// all of them apply to variable-length vectors and nothing else
type encOpts struct {
head uint // length of length in bytes
min uint // minimum size in bytes
max uint // maximum size in bytes
}
type encodeState struct {
bytes.Buffer
}
func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
if s, ok := r.(string); ok {
panic(s)
}
err = r.(error)
}
}()
e.reflectValue(reflect.ValueOf(v), opts)
return nil
}
func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) {
valueEncoder(v)(e, v, opts)
}
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
func valueEncoder(v reflect.Value) encoderFunc {
if !v.IsValid() {
panic(fmt.Errorf("Cannot encode an invalid value"))
}
return typeEncoder(v.Type())
}
func typeEncoder(t reflect.Type) encoderFunc {
// Note: Omits the caching / wait-group things that encoding/json uses
return newTypeEncoder(t)
}
func newTypeEncoder(t reflect.Type) encoderFunc {
// Note: Does not support Marshaler, so don't need the allowAddr argument
switch t.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return uintEncoder
case reflect.Array:
return newArrayEncoder(t)
case reflect.Slice:
return newSliceEncoder(t)
case reflect.Struct:
return newStructEncoder(t)
default:
panic(fmt.Errorf("Unsupported type (%s)", t))
}
}
///// Specific encoders below
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
u := v.Uint()
switch v.Type().Kind() {
case reflect.Uint8:
e.WriteByte(byte(u))
case reflect.Uint16:
e.Write([]byte{byte(u >> 8), byte(u)})
case reflect.Uint32:
e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
case reflect.Uint64:
e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32),
byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
}
}
//////////
type arrayEncoder struct {
elemEnc encoderFunc
}
func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
n := v.Len()
for i := 0; i < n; i += 1 {
ae.elemEnc(e, v.Index(i), opts)
}
}
func newArrayEncoder(t reflect.Type) encoderFunc {
enc := &arrayEncoder{typeEncoder(t.Elem())}
return enc.encode
}
//////////
type sliceEncoder struct {
ae *arrayEncoder
}
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if opts.head == 0 {
panic(fmt.Errorf("Cannot encode a slice without a header length"))
}
arrayState := &encodeState{}
se.ae.encode(arrayState, v, opts)
n := uint(arrayState.Len())
if opts.max > 0 && n > opts.max {
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max))
}
if n>>(8*opts.head) > 0 {
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
}
if n < opts.min {
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min))
}
for i := int(opts.head - 1); i >= 0; i -= 1 {
e.WriteByte(byte(n >> (8 * uint(i))))
}
e.Write(arrayState.Bytes())
}
func newSliceEncoder(t reflect.Type) encoderFunc {
enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}}
return enc.encode
}
//////////
type structEncoder struct {
fieldOpts []encOpts
fieldEncs []encoderFunc
}
func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
for i := range se.fieldEncs {
se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i])
}
}
func newStructEncoder(t reflect.Type) encoderFunc {
n := t.NumField()
se := structEncoder{
fieldOpts: make([]encOpts, n),
fieldEncs: make([]encoderFunc, n),
}
for i := 0; i < n; i += 1 {
f := t.Field(i)
tag := f.Tag.Get("tls")
tagOpts := parseTag(tag)
se.fieldOpts[i] = encOpts{
head: tagOpts["head"],
max: tagOpts["max"],
min: tagOpts["min"],
}
se.fieldEncs[i] = typeEncoder(f.Type)
}
return se.encode
}

30
vendor/github.com/bifurcation/mint/syntax/tags.go generated vendored Normal file
View file

@ -0,0 +1,30 @@
package syntax
import (
"strconv"
"strings"
)
// `tls:"head=2,min=2,max=255"`
type tagOptions map[string]uint
// parseTag parses a struct field's "tls" tag as a comma-separated list of
// name=value pairs, where the values MUST be unsigned integers
func parseTag(tag string) tagOptions {
opts := tagOptions{}
for _, token := range strings.Split(tag, ",") {
if strings.Index(token, "=") == -1 {
continue
}
parts := strings.Split(token, "=")
if len(parts[0]) == 0 {
continue
}
if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
opts[parts[0]] = uint(val)
}
}
return opts
}

168
vendor/github.com/bifurcation/mint/tls.go generated vendored Normal file
View file

@ -0,0 +1,168 @@
package mint
// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls
import (
"errors"
"net"
"strings"
"time"
)
// Server returns a new TLS server side connection
// using conn as the underlying transport.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Server(conn net.Conn, config *Config) *Conn {
return NewConn(conn, config, false)
}
// Client returns a new TLS client side connection
// using conn as the underlying transport.
// The config cannot be nil: users must set either ServerName or
// InsecureSkipVerify in the config.
func Client(conn net.Conn, config *Config) *Conn {
return NewConn(conn, config, true)
}
// A listener implements a network listener (net.Listener) for TLS connections.
type Listener struct {
net.Listener
config *Config
}
// Accept waits for and returns the next incoming TLS connection.
// The returned connection c is a *tls.Conn.
func (l *Listener) Accept() (c net.Conn, err error) {
c, err = l.Listener.Accept()
if err != nil {
return
}
server := Server(c, l.config)
err = server.Handshake()
if err == AlertNoAlert {
err = nil
}
c = server
return
}
// NewListener creates a Listener which accepts connections from an inner
// Listener and wraps each connection with Server.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func NewListener(inner net.Listener, config *Config) net.Listener {
l := new(Listener)
l.Listener = inner
l.config = config
return l
}
// Listen creates a TLS listener accepting connections on the
// given network address using net.Listen.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Listen(network, laddr string, config *Config) (net.Listener, error) {
if config == nil || !config.ValidForServer() {
return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config")
}
l, err := net.Listen(network, laddr)
if err != nil {
return nil, err
}
return NewListener(l, config), nil
}
type TimeoutError struct{}
func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" }
func (TimeoutError) Timeout() bool { return true }
func (TimeoutError) Temporary() bool { return true }
// DialWithDialer connects to the given network address using dialer.Dial and
// then initiates a TLS handshake, returning the resulting TLS connection. Any
// timeout or deadline given in the dialer apply to connection and TLS
// handshake as a whole.
//
// DialWithDialer interprets a nil configuration as equivalent to the zero
// configuration; see the documentation of Config for the defaults.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
// We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now.
timeout := dialer.Timeout
if !dialer.Deadline.IsZero() {
deadlineTimeout := dialer.Deadline.Sub(time.Now())
if timeout == 0 || deadlineTimeout < timeout {
timeout = deadlineTimeout
}
}
var errChannel chan error
if timeout != 0 {
errChannel = make(chan error, 2)
time.AfterFunc(timeout, func() {
errChannel <- TimeoutError{}
})
}
rawConn, err := dialer.Dial(network, addr)
if err != nil {
return nil, err
}
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
if config == nil {
config = &Config{}
}
// If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if config.ServerName == "" {
// Make a copy to avoid polluting argument or default.
c := config.Clone()
c.ServerName = hostname
config = c
}
conn := Client(rawConn, config)
if timeout == 0 {
err = conn.Handshake()
if err == AlertNoAlert {
err = nil
}
} else {
go func() {
errChannel <- conn.Handshake()
}()
err = <-errChannel
if err == AlertNoAlert {
err = nil
}
}
if err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
// Dial connects to the given network address using net.Dial
// and then initiates a TLS handshake, returning the resulting
// TLS connection.
// Dial interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults.
func Dial(network, addr string, config *Config) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config)
}

View file

@ -1,32 +0,0 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
)
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet) error
ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error
SendingAllowed() bool
GetStopWaitingFrame(force bool) *frames.StopWaitingFrame
DequeuePacketForRetransmission() (packet *Packet)
GetLeastUnacked() protocol.PacketNumber
GetAlarmTimeout() time.Time
OnAlarm()
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
SetLowerLimit(protocol.PacketNumber)
GetAlarmTimeout() time.Time
GetAckFrame() *frames.AckFrame
}

View file

@ -3,7 +3,7 @@ package quic
import (
"sync"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var bufferPool sync.Pool

View file

@ -10,32 +10,39 @@ import (
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
type client struct {
mutex sync.Mutex
listenErr error
mutex sync.Mutex
conn connection
hostname string
errorChan chan struct{}
handshakeChan <-chan handshakeEvent
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
versionNegotiated bool // has the server accepted our version
receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
tlsConf *tls.Config
config *Config
versionNegotiated bool // has version negotiation completed yet
tlsConf *tls.Config
config *Config
tls handshake.MintTLS // only used when using TLS
connectionID protocol.ConnectionID
version protocol.VersionNumber
initialVersion protocol.VersionNumber
version protocol.VersionNumber
session packetHandler
}
var (
// make it possible to mock connection ID generation in the tests
generateConnectionID = utils.GenerateConnectionID
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
)
@ -53,71 +60,6 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
return Dial(udpConn, udpAddr, addr, tlsConf, config)
}
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddrNonFWSecure(
addr string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
}
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func DialNonFWSecure(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
connID, err := utils.GenerateConnectionID()
if err != nil {
return nil, err
}
var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
}
}
clientConfig := populateClientConfig(config)
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: clientConfig.Versions[0],
errorChan: make(chan struct{}),
}
err = c.createNewSession(nil)
if err != nil {
return nil, err
}
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
return c.session.(NonFWSession), c.establishSecureConnection()
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func Dial(
@ -127,15 +69,39 @@ func Dial(
tlsConf *tls.Config,
config *Config,
) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
connID, err := generateConnectionID()
if err != nil {
return nil, err
}
err = sess.WaitUntilHandshakeComplete()
if err != nil {
var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
}
}
clientConfig := populateClientConfig(config)
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: clientConfig.Versions[0],
versionNegotiationChan: make(chan struct{}),
}
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
if err := c.dial(); err != nil {
return nil, err
}
return sess, nil
return c.session, nil
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
@ -153,6 +119,10 @@ func populateClientConfig(config *Config) *Config {
if config.HandshakeTimeout != 0 {
handshakeTimeout = config.HandshakeTimeout
}
idleTimeout := protocol.DefaultIdleTimeout
if config.IdleTimeout != 0 {
idleTimeout = config.IdleTimeout
}
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
if maxReceiveStreamFlowControlWindow == 0 {
@ -166,32 +136,109 @@ func populateClientConfig(config *Config) *Config {
return &Config{
Versions: versions,
HandshakeTimeout: handshakeTimeout,
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
IdleTimeout: idleTimeout,
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
KeepAlive: config.KeepAlive,
}
}
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
func (c *client) establishSecureConnection() error {
func (c *client) dial() error {
var err error
if c.version.UsesTLS() {
err = c.dialTLS()
} else {
err = c.dialGQUIC()
}
if err == errCloseSessionForNewVersion {
return c.dial()
}
return err
}
func (c *client) dialGQUIC() error {
if err := c.createNewGQUICSession(); err != nil {
return err
}
go c.listen()
return c.establishSecureConnection()
}
func (c *client) dialTLS() error {
params := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
IdleTimeout: c.config.IdleTimeout,
OmitConnectionID: c.config.RequestConnectionIDOmission,
// TODO(#523): make these values configurable
MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
}
csc := handshake.NewCryptoStreamConn(nil)
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil {
return err
}
mintConf.ExtensionHandler = extHandler
mintConf.ServerName = c.hostname
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
go c.listen()
if err := c.establishSecureConnection(); err != nil {
if err != handshake.ErrCloseSessionForRetry {
return err
}
utils.Infof("Received a Retry packet. Recreating session.")
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
if err := c.establishSecureConnection(); err != nil {
return err
}
}
return nil
}
// establishSecureConnection runs the session, and tries to establish a secure connection
// It returns:
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
// - any other error that might occur
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
func (c *client) establishSecureConnection() error {
var runErr error
errorChan := make(chan struct{})
go func() {
runErr = c.session.run() // returns as soon as the session is closed
close(errorChan)
utils.Infof("Connection %x closed.", c.connectionID)
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
c.conn.Close()
}
}()
// wait until the server accepts the QUIC version (or an error occurs)
select {
case <-errorChan:
return runErr
case <-c.versionNegotiationChan:
}
select {
case <-c.errorChan:
return c.listenErr
case ev := <-c.handshakeChan:
if ev.err != nil {
return ev.err
}
if ev.encLevel != protocol.EncryptionSecure {
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
}
return nil
case <-errorChan:
return runErr
case err := <-c.session.handshakeStatus():
return err
}
}
// Listen listens
// Listen listens on the underlying connection and passes packets on for handling.
// It returns when the connection is closed.
func (c *client) listen() {
var err error
@ -205,13 +252,15 @@ func (c *client) listen() {
n, addr, err = c.conn.Read(data)
if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
c.session.Close(err)
c.mutex.Lock()
if c.session != nil {
c.session.Close(err)
}
c.mutex.Unlock()
}
break
}
data = data[:n]
c.handlePacket(addr, data)
c.handlePacket(addr, data[:n])
}
}
@ -219,10 +268,14 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
rcvTime := time.Now()
r := bytes.NewReader(packet)
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer)
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
if err != nil {
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
// drop this packet if we can't parse the Public Header
// drop this packet if we can't parse the header
return
}
// reject packets with truncated connection id if we didn't request truncation
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
return
}
hdr.Raw = packet[:len(packet)-r.Len()]
@ -230,6 +283,11 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
c.mutex.Lock()
defer c.mutex.Unlock()
// reject packets with the wrong connection ID
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
return
}
if hdr.ResetFlag {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
@ -238,44 +296,48 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
utils.Infof("Received a spoofed Public Reset. Ignoring.")
return
}
pr, err := parsePublicReset(r)
pr, err := wire.ParsePublicReset(r)
if err != nil {
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.")
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
return
}
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber)))
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
return
}
// ignore delayed / duplicated version negotiation packets
if c.versionNegotiated && hdr.VersionFlag {
return
}
// handle Version Negotiation Packets
if hdr.IsVersionNegotiation {
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
return
}
// this is the first packet after the client sent a packet with the VersionFlag set
// if the server doesn't send a version negotiation packet, it supports the suggested version
if !hdr.VersionFlag && !c.versionNegotiated {
c.versionNegotiated = true
}
if hdr.VersionFlag {
// version negotiation packets have no payload
if err := c.handlePacketWithVersionFlag(hdr); err != nil {
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
c.session.Close(err)
}
return
}
// this is the first packet we are receiving
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
if !c.versionNegotiated {
c.versionNegotiated = true
close(c.versionNegotiationChan)
}
// TODO: validate packet number and connection ID on Retry packets (for IETF QUIC)
c.session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
publicHeader: hdr,
data: packet[len(packet)-r.Len():],
rcvTime: rcvTime,
remoteAddr: remoteAddr,
header: hdr,
data: packet[len(packet)-r.Len():],
rcvTime: rcvTime,
})
}
func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
for _, v := range hdr.SupportedVersions {
if v == c.version {
// the version negotiation packet contains the version that we offered
@ -285,51 +347,57 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
}
}
newVersion := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if newVersion == protocol.VersionUnsupported {
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok {
return qerr.InvalidVersion
}
c.receivedVersionNegotiationPacket = true
c.negotiatedVersions = hdr.SupportedVersions
// switch to negotiated version
c.initialVersion = c.version
c.version = newVersion
c.versionNegotiated = true
var err error
c.connectionID, err = utils.GenerateConnectionID()
if err != nil {
return err
}
utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID)
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
c.session.Close(errCloseSessionForNewVersion)
return c.createNewSession(hdr.SupportedVersions)
return nil
}
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
var err error
c.session, c.handshakeChan, err = newClientSession(
func (c *client) createNewGQUICSession() (err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.session, err = newClientSession(
c.conn,
c.hostname,
c.version,
c.connectionID,
c.tlsConf,
c.config,
negotiatedVersions,
c.initialVersion,
c.negotiatedVersions,
)
if err != nil {
return err
}
go func() {
// session.run() returns as soon as the session is closed
err := c.session.run()
if err == errCloseSessionForNewVersion {
return
}
c.listenErr = err
close(c.errorChan)
utils.Infof("Connection %x closed.", c.connectionID)
c.conn.Close()
}()
return nil
return err
}
func (c *client) createNewTLSSession(
paramsChan <-chan handshake.TransportParameters,
version protocol.VersionNumber,
) (err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.session, err = newTLSClientSession(
c.conn,
c.hostname,
c.version,
c.connectionID,
c.config,
c.tls,
paramsChan,
1,
)
return err
}

View file

@ -1,58 +0,0 @@
package crypto
import (
"crypto/cipher"
"errors"
"github.com/lucas-clemente/aes12"
"github.com/lucas-clemente/quic-go/protocol"
)
type aeadAESGCM struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
// NewAEADAESGCM creates a AEAD using AES-GCM with 12 bytes tag size
//
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
// tag size, and couples the cipher and aes packages closely.
// See https://github.com/lucas-clemente/aes12.
func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
}
encrypterCipher, err := aes12.NewCipher(myKey)
if err != nil {
return nil, err
}
encrypter, err := aes12.NewGCM(encrypterCipher)
if err != nil {
return nil, err
}
decrypterCipher, err := aes12.NewCipher(otherKey)
if err != nil {
return nil, err
}
decrypter, err := aes12.NewGCM(decrypterCipher)
if err != nil {
return nil, err
}
return &aeadAESGCM{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData)
}

View file

@ -1,14 +0,0 @@
package crypto
import (
"encoding/binary"
"github.com/lucas-clemente/quic-go/protocol"
)
func makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}

View file

@ -1,76 +0,0 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
)
// StkSource is used to create and verify source address tokens
type StkSource interface {
// NewToken creates a new token
NewToken([]byte) ([]byte, error)
// DecodeToken decodes a token
DecodeToken([]byte) ([]byte, error)
}
type stkSource struct {
aead cipher.AEAD
}
const stkKeySize = 16
// Chrome currently sets this to 12, but discusses changing it to 16. We start
// at 16 :)
const stkNonceSize = 16
// NewStkSource creates a source for source address tokens
func NewStkSource() (StkSource, error) {
secret := make([]byte, 32)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
key, err := deriveKey(secret)
if err != nil {
return nil, err
}
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize)
if err != nil {
return nil, err
}
return &stkSource{aead: aead}, nil
}
func (s *stkSource) NewToken(data []byte) ([]byte, error) {
nonce := make([]byte, stkNonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
return s.aead.Seal(nonce, nonce, data, nil), nil
}
func (s *stkSource) DecodeToken(p []byte) ([]byte, error) {
if len(p) < stkNonceSize {
return nil, fmt.Errorf("STK too short: %d", len(p))
}
nonce := p[:stkNonceSize]
return s.aead.Open(nil, nonce, p[stkNonceSize:], nil)
}
func deriveKey(secret []byte) ([]byte, error) {
r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key"))
key := make([]byte, stkKeySize)
if _, err := io.ReadFull(r, key); err != nil {
return nil, err
}
return key, nil
}

View file

@ -0,0 +1,41 @@
package quic
import (
"io"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type cryptoStreamI interface {
StreamID() protocol.StreamID
io.Reader
io.Writer
handleStreamFrame(*wire.StreamFrame) error
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error)
setReadOffset(protocol.ByteCount)
// methods needed for flow control
getWindowUpdate() protocol.ByteCount
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
type cryptoStream struct {
*stream
}
var _ cryptoStreamI = &cryptoStream{}
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
str := newStream(version.CryptoStreamID(), sender, flowController, version)
return &cryptoStream{str}
}
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
s.receiveStream.readOffset = offset
s.receiveStream.frameQueue.readPosition = offset
}

View file

@ -7,12 +7,15 @@ import (
"net/http"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
func main() {
verbose := flag.Bool("v", false, "verbose")
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
flag.Parse()
urls := flag.Args()
@ -23,8 +26,17 @@ func main() {
}
utils.SetLogTimeFormat("")
versions := protocol.SupportedVersions
if *tls {
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
}
roundTripper := &h2quic.RoundTripper{
QuicConfig: &quic.Config{Versions: versions},
}
defer roundTripper.Close()
hclient := &http.Client{
Transport: &h2quic.RoundTripper{},
Transport: roundTripper,
}
var wg sync.WaitGroup

View file

@ -17,7 +17,9 @@ import (
_ "net/http/pprof"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
@ -121,6 +123,7 @@ func main() {
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
www := flag.String("www", "/var/www", "www data")
tcp := flag.Bool("tcp", false, "also listen on TCP")
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
flag.Parse()
if *verbose {
@ -130,6 +133,11 @@ func main() {
}
utils.SetLogTimeFormat("")
versions := protocol.SupportedVersions
if *tls {
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
}
certFile := *certPath + "/fullchain.pem"
keyFile := *certPath + "/privkey.pem"
@ -148,7 +156,11 @@ func main() {
if *tcp {
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
} else {
err = h2quic.ListenAndServeQUIC(bCap, certFile, keyFile, nil)
server := h2quic.Server{
Server: &http.Server{Addr: bCap},
QuicConfig: &quic.Config{Versions: versions},
}
err = server.ListenAndServeTLS(certFile, keyFile)
}
if err != nil {
fmt.Println(err)

View file

@ -1,240 +0,0 @@
package flowcontrol
import (
"errors"
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
type flowControlManager struct {
connectionParameters handshake.ConnectionParametersManager
rttStats *congestion.RTTStats
streamFlowController map[protocol.StreamID]*flowController
connFlowController *flowController
mutex sync.RWMutex
}
var _ FlowControlManager = &flowControlManager{}
var errMapAccess = errors.New("Error accessing the flowController map.")
// NewFlowControlManager creates a new flow control manager
func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager {
return &flowControlManager{
connectionParameters: connectionParameters,
rttStats: rttStats,
streamFlowController: make(map[protocol.StreamID]*flowController),
connFlowController: newFlowController(0, false, connectionParameters, rttStats),
}
}
// NewStream creates new flow controllers for a stream
// it does nothing if the stream already exists
func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesToConnection bool) {
f.mutex.Lock()
defer f.mutex.Unlock()
if _, ok := f.streamFlowController[streamID]; ok {
return
}
f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connectionParameters, f.rttStats)
}
// RemoveStream removes a closed stream from flow control
func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) {
f.mutex.Lock()
delete(f.streamFlowController, streamID)
f.mutex.Unlock()
}
// ResetStream should be called when receiving a RstStreamFrame
// it updates the byte offset to the value in the RstStreamFrame
// streamID must not be 0 here
func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
f.mutex.Lock()
defer f.mutex.Unlock()
streamFlowController, err := f.getFlowController(streamID)
if err != nil {
return err
}
increment, err := streamFlowController.UpdateHighestReceived(byteOffset)
if err != nil {
return qerr.StreamDataAfterTermination
}
if streamFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
}
if streamFlowController.ContributesToConnection() {
f.connFlowController.IncrementHighestReceived(increment)
if f.connFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
}
}
return nil
}
// UpdateHighestReceived updates the highest received byte offset for a stream
// it adds the number of additional bytes to connection level flow control
// streamID must not be 0 here
func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
f.mutex.Lock()
defer f.mutex.Unlock()
streamFlowController, err := f.getFlowController(streamID)
if err != nil {
return err
}
// UpdateHighestReceived returns an ErrReceivedSmallerByteOffset when StreamFrames got reordered
// this error can be ignored here
increment, _ := streamFlowController.UpdateHighestReceived(byteOffset)
if streamFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
}
if streamFlowController.ContributesToConnection() {
f.connFlowController.IncrementHighestReceived(increment)
if f.connFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
}
}
return nil
}
// streamID must not be 0 here
func (f *flowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error {
f.mutex.Lock()
defer f.mutex.Unlock()
fc, err := f.getFlowController(streamID)
if err != nil {
return err
}
fc.AddBytesRead(n)
if fc.ContributesToConnection() {
f.connFlowController.AddBytesRead(n)
}
return nil
}
func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) {
f.mutex.Lock()
defer f.mutex.Unlock()
// get WindowUpdates for streams
for id, fc := range f.streamFlowController {
if necessary, newIncrement, offset := fc.MaybeUpdateWindow(); necessary {
res = append(res, WindowUpdate{StreamID: id, Offset: offset})
if fc.ContributesToConnection() && newIncrement != 0 {
f.connFlowController.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(newIncrement) * protocol.ConnectionFlowControlMultiplier))
}
}
}
// get a WindowUpdate for the connection
if necessary, _, offset := f.connFlowController.MaybeUpdateWindow(); necessary {
res = append(res, WindowUpdate{StreamID: 0, Offset: offset})
}
return
}
func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) {
f.mutex.RLock()
defer f.mutex.RUnlock()
// StreamID can be 0 when retransmitting
if streamID == 0 {
return f.connFlowController.receiveWindow, nil
}
flowController, err := f.getFlowController(streamID)
if err != nil {
return 0, err
}
return flowController.receiveWindow, nil
}
// streamID must not be 0 here
func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error {
f.mutex.Lock()
defer f.mutex.Unlock()
fc, err := f.getFlowController(streamID)
if err != nil {
return err
}
fc.AddBytesSent(n)
if fc.ContributesToConnection() {
f.connFlowController.AddBytesSent(n)
}
return nil
}
// must not be called with StreamID 0
func (f *flowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) {
f.mutex.RLock()
defer f.mutex.RUnlock()
fc, err := f.getFlowController(streamID)
if err != nil {
return 0, err
}
res := fc.SendWindowSize()
if fc.ContributesToConnection() {
res = utils.MinByteCount(res, f.connFlowController.SendWindowSize())
}
return res, nil
}
func (f *flowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
f.mutex.RLock()
defer f.mutex.RUnlock()
return f.connFlowController.SendWindowSize()
}
// streamID may be 0 here
func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
var fc *flowController
if streamID == 0 {
fc = f.connFlowController
} else {
var err error
fc, err = f.getFlowController(streamID)
if err != nil {
return false, err
}
}
return fc.UpdateSendWindow(offset), nil
}
func (f *flowControlManager) getFlowController(streamID protocol.StreamID) (*flowController, error) {
streamFlowController, ok := f.streamFlowController[streamID]
if !ok {
return nil, errMapAccess
}
return streamFlowController, nil
}

View file

@ -1,198 +0,0 @@
package flowcontrol
import (
"errors"
"time"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
type flowController struct {
streamID protocol.StreamID
contributesToConnection bool // does the stream contribute to connection level flow control
connectionParameters handshake.ConnectionParametersManager
rttStats *congestion.RTTStats
bytesSent protocol.ByteCount
sendWindow protocol.ByteCount
lastWindowUpdateTime time.Time
bytesRead protocol.ByteCount
highestReceived protocol.ByteCount
receiveWindow protocol.ByteCount
receiveWindowIncrement protocol.ByteCount
maxReceiveWindowIncrement protocol.ByteCount
}
// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously
var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset")
// newFlowController gets a new flow controller
func newFlowController(streamID protocol.StreamID, contributesToConnection bool, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController {
fc := flowController{
streamID: streamID,
contributesToConnection: contributesToConnection,
connectionParameters: connectionParameters,
rttStats: rttStats,
}
if streamID == 0 {
fc.receiveWindow = connectionParameters.GetReceiveConnectionFlowControlWindow()
fc.receiveWindowIncrement = fc.receiveWindow
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow()
} else {
fc.receiveWindow = connectionParameters.GetReceiveStreamFlowControlWindow()
fc.receiveWindowIncrement = fc.receiveWindow
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow()
}
return &fc
}
func (c *flowController) ContributesToConnection() bool {
return c.contributesToConnection
}
func (c *flowController) getSendWindow() protocol.ByteCount {
if c.sendWindow == 0 {
if c.streamID == 0 {
return c.connectionParameters.GetSendConnectionFlowControlWindow()
}
return c.connectionParameters.GetSendStreamFlowControlWindow()
}
return c.sendWindow
}
func (c *flowController) AddBytesSent(n protocol.ByteCount) {
c.bytesSent += n
}
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
// it returns true if the window was actually updated
func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool {
if newOffset > c.sendWindow {
c.sendWindow = newOffset
return true
}
return false
}
func (c *flowController) SendWindowSize() protocol.ByteCount {
sendWindow := c.getSendWindow()
if c.bytesSent > sendWindow { // should never happen, but make sure we don't do an underflow here
return 0
}
return sendWindow - c.bytesSent
}
func (c *flowController) SendWindowOffset() protocol.ByteCount {
return c.getSendWindow()
}
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
// Should **only** be used for the stream-level FlowController
// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before
// This error occurs every time StreamFrames get reordered and has to be ignored in that case
// It should only be treated as an error when resetting a stream
func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) (protocol.ByteCount, error) {
if byteOffset == c.highestReceived {
return 0, nil
}
if byteOffset > c.highestReceived {
increment := byteOffset - c.highestReceived
c.highestReceived = byteOffset
return increment, nil
}
return 0, ErrReceivedSmallerByteOffset
}
// IncrementHighestReceived adds an increment to the highestReceived value
// Should **only** be used for the connection-level FlowController
func (c *flowController) IncrementHighestReceived(increment protocol.ByteCount) {
c.highestReceived += increment
}
func (c *flowController) AddBytesRead(n protocol.ByteCount) {
// pretend we sent a WindowUpdate when reading the first byte
// this way auto-tuning of the window increment already works for the first WindowUpdate
if c.bytesRead == 0 {
c.lastWindowUpdateTime = time.Now()
}
c.bytesRead += n
}
// MaybeUpdateWindow updates the receive window, if necessary
// if the receive window increment is changed, the new value is returned, otherwise a 0
// the last return value is the new offset of the receive window
func (c *flowController) MaybeUpdateWindow() (bool, protocol.ByteCount /* new increment */, protocol.ByteCount /* new offset */) {
diff := c.receiveWindow - c.bytesRead
// Chromium implements the same threshold
if diff < (c.receiveWindowIncrement / 2) {
var newWindowIncrement protocol.ByteCount
oldWindowIncrement := c.receiveWindowIncrement
c.maybeAdjustWindowIncrement()
if c.receiveWindowIncrement != oldWindowIncrement {
newWindowIncrement = c.receiveWindowIncrement
}
c.lastWindowUpdateTime = time.Now()
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement
return true, newWindowIncrement, c.receiveWindow
}
return false, 0, 0
}
// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
func (c *flowController) maybeAdjustWindowIncrement() {
if c.lastWindowUpdateTime.IsZero() {
return
}
rtt := c.rttStats.SmoothedRTT()
if rtt == 0 {
return
}
timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime)
// interval between the window updates is sufficiently large, no need to increase the increment
if timeSinceLastWindowUpdate >= 2*rtt {
return
}
oldWindowSize := c.receiveWindowIncrement
c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement)
// debug log, if the window size was actually increased
if oldWindowSize < c.receiveWindowIncrement {
newWindowSize := c.receiveWindowIncrement / (1 << 10)
if c.streamID == 0 {
utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize)
} else {
utils.Debugf("Increasing receive flow control window increment for stream %d to %d kB", c.streamID, newWindowSize)
}
}
}
// EnsureMinimumWindowIncrement sets a minimum window increment
// it is intended be used for the connection-level flow controller
// it should make sure that the connection-level window is increased when a stream-level window grows
func (c *flowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) {
if inc > c.receiveWindowIncrement {
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement)
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
}
}
func (c *flowController) CheckFlowControlViolation() bool {
return c.highestReceived > c.receiveWindow
}

View file

@ -1,26 +0,0 @@
package flowcontrol
import "github.com/lucas-clemente/quic-go/protocol"
// WindowUpdate provides the data for WindowUpdateFrames.
type WindowUpdate struct {
StreamID protocol.StreamID
Offset protocol.ByteCount
}
// A FlowControlManager manages the flow control
type FlowControlManager interface {
NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool)
RemoveStream(streamID protocol.StreamID)
// methods needed for receiving data
ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error
GetWindowUpdates() []WindowUpdate
GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error)
// methods needed for sending data
AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error
SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error)
RemainingConnectionWindowSize() protocol.ByteCount
UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error)
}

View file

@ -1,9 +0,0 @@
package frames
import "github.com/lucas-clemente/quic-go/protocol"
// AckRange is an ACK range
type AckRange struct {
FirstPacketNumber protocol.PacketNumber
LastPacketNumber protocol.PacketNumber
}

View file

@ -1,44 +0,0 @@
package frames
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
// A BlockedFrame in QUIC
type BlockedFrame struct {
StreamID protocol.StreamID
}
//Write writes a BlockedFrame frame
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
b.WriteByte(0x05)
utils.WriteUint32(b, uint32(f.StreamID))
return nil
}
// MinLength of a written frame
func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
return 1 + 4, nil
}
// ParseBlockedFrame parses a BLOCKED frame
func ParseBlockedFrame(r *bytes.Reader) (*BlockedFrame, error) {
frame := &BlockedFrame{}
// read the TypeByte
_, err := r.ReadByte()
if err != nil {
return nil, err
}
sid, err := utils.ReadUint32(r)
if err != nil {
return nil, err
}
frame.StreamID = protocol.StreamID(sid)
return frame, nil
}

View file

@ -1,73 +0,0 @@
package frames
import (
"bytes"
"errors"
"io"
"math"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
// A ConnectionCloseFrame in QUIC
type ConnectionCloseFrame struct {
ErrorCode qerr.ErrorCode
ReasonPhrase string
}
// ParseConnectionCloseFrame reads a CONNECTION_CLOSE frame
func ParseConnectionCloseFrame(r *bytes.Reader) (*ConnectionCloseFrame, error) {
frame := &ConnectionCloseFrame{}
// read the TypeByte
_, err := r.ReadByte()
if err != nil {
return nil, err
}
errorCode, err := utils.ReadUint32(r)
if err != nil {
return nil, err
}
frame.ErrorCode = qerr.ErrorCode(errorCode)
reasonPhraseLen, err := utils.ReadUint16(r)
if err != nil {
return nil, err
}
if reasonPhraseLen > uint16(protocol.MaxPacketSize) {
return nil, qerr.Error(qerr.InvalidConnectionCloseData, "reason phrase too long")
}
reasonPhrase := make([]byte, reasonPhraseLen)
if _, err := io.ReadFull(r, reasonPhrase); err != nil {
return nil, err
}
frame.ReasonPhrase = string(reasonPhrase)
return frame, nil
}
// MinLength of a written frame
func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)), nil
}
// Write writes an CONNECTION_CLOSE frame.
func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
b.WriteByte(0x02)
utils.WriteUint32(b, uint32(f.ErrorCode))
if len(f.ReasonPhrase) > math.MaxUint16 {
return errors.New("ConnectionFrame: ReasonPhrase too long")
}
reasonPhraseLen := uint16(len(f.ReasonPhrase))
utils.WriteUint16(b, reasonPhraseLen)
b.WriteString(f.ReasonPhrase)
return nil
}

View file

@ -1,13 +0,0 @@
package frames
import (
"bytes"
"github.com/lucas-clemente/quic-go/protocol"
)
// A Frame in QUIC
type Frame interface {
Write(b *bytes.Buffer, version protocol.VersionNumber) error
MinLength(version protocol.VersionNumber) (protocol.ByteCount, error)
}

View file

@ -1,28 +0,0 @@
package frames
import "github.com/lucas-clemente/quic-go/internal/utils"
// LogFrame logs a frame, either sent or received
func LogFrame(frame Frame, sent bool) {
if !utils.Debug() {
return
}
dir := "<-"
if sent {
dir = "->"
}
switch f := frame.(type) {
case *StreamFrame:
utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen())
case *StopWaitingFrame:
if sent {
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen)
} else {
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked)
}
case *AckFrame:
utils.Debugf("\t%s &frames.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String())
default:
utils.Debugf("\t%s %#v", dir, frame)
}
}

View file

@ -1,59 +0,0 @@
package frames
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
// A RstStreamFrame in QUIC
type RstStreamFrame struct {
StreamID protocol.StreamID
ErrorCode uint32
ByteOffset protocol.ByteCount
}
//Write writes a RST_STREAM frame
func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
b.WriteByte(0x01)
utils.WriteUint32(b, uint32(f.StreamID))
utils.WriteUint64(b, uint64(f.ByteOffset))
utils.WriteUint32(b, f.ErrorCode)
return nil
}
// MinLength of a written frame
func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
return 1 + 4 + 8 + 4, nil
}
// ParseRstStreamFrame parses a RST_STREAM frame
func ParseRstStreamFrame(r *bytes.Reader) (*RstStreamFrame, error) {
frame := &RstStreamFrame{}
// read the TypeByte
_, err := r.ReadByte()
if err != nil {
return nil, err
}
sid, err := utils.ReadUint32(r)
if err != nil {
return nil, err
}
frame.StreamID = protocol.StreamID(sid)
byteOffset, err := utils.ReadUint64(r)
if err != nil {
return nil, err
}
frame.ByteOffset = protocol.ByteCount(byteOffset)
frame.ErrorCode, err = utils.ReadUint32(r)
if err != nil {
return nil, err
}
return frame, nil
}

View file

@ -1,54 +0,0 @@
package frames
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
// A WindowUpdateFrame in QUIC
type WindowUpdateFrame struct {
StreamID protocol.StreamID
ByteOffset protocol.ByteCount
}
//Write writes a RST_STREAM frame
func (f *WindowUpdateFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
typeByte := uint8(0x04)
b.WriteByte(typeByte)
utils.WriteUint32(b, uint32(f.StreamID))
utils.WriteUint64(b, uint64(f.ByteOffset))
return nil
}
// MinLength of a written frame
func (f *WindowUpdateFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
return 1 + 4 + 8, nil
}
// ParseWindowUpdateFrame parses a RST_STREAM frame
func ParseWindowUpdateFrame(r *bytes.Reader) (*WindowUpdateFrame, error) {
frame := &WindowUpdateFrame{}
// read the TypeByte
_, err := r.ReadByte()
if err != nil {
return nil, err
}
sid, err := utils.ReadUint32(r)
if err != nil {
return nil, err
}
frame.StreamID = protocol.StreamID(sid)
byteOffset, err := utils.ReadUint64(r)
if err != nil {
return nil, err
}
frame.ByteOffset = protocol.ByteCount(byteOffset)
return frame, nil
}

View file

@ -15,8 +15,8 @@ import (
"golang.org/x/net/idna"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
@ -34,10 +34,10 @@ type client struct {
config *quic.Config
opts *roundTripperOpts
hostname string
encryptionLevel protocol.EncryptionLevel
handshakeErr error
dialOnce sync.Once
hostname string
handshakeErr error
dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
session quic.Session
headerStream quic.Stream
@ -51,8 +51,8 @@ type client struct {
var _ http.RoundTripper = &client{}
var defaultQuicConfig = &quic.Config{
RequestConnectionIDTruncation: true,
KeepAlive: true,
RequestConnectionIDOmission: true,
KeepAlive: true,
}
// newClient creates a new client
@ -61,26 +61,31 @@ func newClient(
tlsConfig *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
) *client {
config := defaultQuicConfig
if quicConfig != nil {
config = quicConfig
}
return &client{
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
encryptionLevel: protocol.EncryptionUnencrypted,
tlsConf: tlsConfig,
config: config,
opts: opts,
headerErrored: make(chan struct{}),
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
tlsConf: tlsConfig,
config: config,
opts: opts,
headerErrored: make(chan struct{}),
dialer: dialer,
}
}
// dial dials the connection
func (c *client) dial() error {
var err error
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
if c.dialer != nil {
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
} else {
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
}
if err != nil {
return err
}
@ -90,9 +95,6 @@ func (c *client) dial() error {
if err != nil {
return err
}
if c.headerStream.StreamID() != 3 {
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3")
}
c.requestWriter = newRequestWriter(c.headerStream)
go c.handleHeaderStream()
return nil
@ -102,45 +104,44 @@ func (c *client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream)
var lastStream protocol.StreamID
var err error
for err == nil {
err = c.readResponse(h2framer, decoder)
}
utils.Debugf("Error handling header stream: %s", err)
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
// stop all running request
close(c.headerErrored)
}
for {
frame, err := h2framer.ReadFrame()
if err != nil {
c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
break
}
lastStream = protocol.StreamID(frame.Header().StreamID)
hframe, ok := frame.(*http2.HeadersFrame)
if !ok {
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")
break
}
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
if err != nil {
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")
break
}
c.mutex.RLock()
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock()
if !ok {
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
break
}
rsp, err := responseFromHeaders(mhframe)
if err != nil {
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
}
responseChan <- rsp
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
frame, err := h2framer.ReadFrame()
if err != nil {
return err
}
hframe, ok := frame.(*http2.HeadersFrame)
if !ok {
return errors.New("not a headers frame")
}
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
if err != nil {
return fmt.Errorf("cannot read header fields: %s", err.Error())
}
// stop all running request
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
close(c.headerErrored)
c.mutex.RLock()
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock()
if !ok {
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
}
rsp, err := responseFromHeaders(mhframe)
if err != nil {
return err
}
responseChan <- rsp
return nil
}
// Roundtrip executes a request and returns a response

View file

@ -13,8 +13,8 @@ import (
"golang.org/x/net/lex/httplex"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
type requestWriter struct {

View file

@ -8,8 +8,8 @@ import (
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
@ -83,7 +83,7 @@ func (w *responseWriter) Write(p []byte) (int, error) {
func (w *responseWriter) Flush() {}
// TODO: Implement a functional CloseNotify method.
// This is a NOP. Use http.Request.Context
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
// test that we implement http.Flusher

View file

@ -41,6 +41,11 @@ type RoundTripper struct {
// If nil, reasonable default values will be used.
QuicConfig *quic.Config
// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddr will be used.
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
clients map[string]roundTripCloser
}
@ -120,7 +125,13 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
if onlyCached {
return nil, ErrNoCachedConn
}
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
client = newClient(
hostname,
r.TLSClientConfig,
&roundTripperOpts{DisableCompression: r.DisableCompression},
r.QuicConfig,
r.Dial,
)
r.clients[hostname] = client
}
return client, nil

View file

@ -7,14 +7,14 @@ import (
"net"
"net/http"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
@ -50,6 +50,7 @@ type Server struct {
listenerMutex sync.Mutex
listener quic.Listener
closed bool
supportedVersionsAsString string
}
@ -88,6 +89,10 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
return errors.New("use of h2quic.Server without http.Server")
}
s.listenerMutex.Lock()
if s.closed {
s.listenerMutex.Unlock()
return errors.New("Server is already closed")
}
if s.listener != nil {
s.listenerMutex.Unlock()
return errors.New("ListenAndServe may only be called once")
@ -122,29 +127,23 @@ func (s *Server) handleHeaderStream(session streamCreator) {
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
return
}
if stream.StreamID() != 3 {
session.Close(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3"))
return
}
hpackDecoder := hpack.NewDecoder(4096, nil)
h2framer := http2.NewFramer(nil, stream)
go func() {
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
for {
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
// QuicErrors must originate from stream.Read() returning an error.
// In this case, the session has already logged the error, so we don't
// need to log it again.
if _, ok := err.(*qerr.QuicError); !ok {
utils.Errorf("error handling h2 request: %s", err.Error())
}
session.Close(err)
return
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
for {
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
// QuicErrors must originate from stream.Read() returning an error.
// In this case, the session has already logged the error, so we don't
// need to log it again.
if _, ok := err.(*qerr.QuicError); !ok {
utils.Errorf("error handling h2 request: %s", err.Error())
}
session.Close(err)
return
}
}()
}
}
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
@ -170,8 +169,6 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
return err
}
req.RemoteAddr = session.RemoteAddr().String()
if utils.Debug() {
utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
} else {
@ -187,19 +184,25 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
return nil
}
var streamEnded bool
if h2headersFrame.StreamEnded() {
dataStream.(remoteCloser).CloseRemote(0)
streamEnded = true
_, _ = dataStream.Read([]byte{0}) // read the eof
}
reqBody := newRequestBody(dataStream)
req.Body = reqBody
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
// handleRequest should be as non-blocking as possible to minimize
// head-of-line blocking. Potentially blocking code is run in a separate
// goroutine, enabling handleRequest to return before the code is executed.
go func() {
streamEnded := h2headersFrame.StreamEnded()
if streamEnded {
dataStream.(remoteCloser).CloseRemote(0)
streamEnded = true
_, _ = dataStream.Read([]byte{0}) // read the eof
}
req = req.WithContext(dataStream.Context())
reqBody := newRequestBody(dataStream)
req.Body = reqBody
req.RemoteAddr = session.RemoteAddr().String()
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
handler := s.Handler
if handler == nil {
handler = http.DefaultServeMux
@ -225,7 +228,8 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
}
if responseWriter.dataStream != nil {
if !streamEnded && !reqBody.requestRead {
responseWriter.dataStream.Reset(nil)
// in gQUIC, the error code doesn't matter, so just use 0 here
responseWriter.dataStream.CancelRead(0)
}
responseWriter.dataStream.Close()
}
@ -243,6 +247,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
func (s *Server) Close() error {
s.listenerMutex.Lock()
defer s.listenerMutex.Unlock()
s.closed = true
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
@ -279,12 +284,11 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
}
if s.supportedVersionsAsString == "" {
for i, v := range protocol.SupportedVersions {
s.supportedVersionsAsString += strconv.Itoa(int(v))
if i != len(protocol.SupportedVersions)-1 {
s.supportedVersionsAsString += ","
}
var versions []string
for _, v := range protocol.SupportedVersions {
versions = append(versions, v.ToAltSvc())
}
s.supportedVersionsAsString = strings.Join(versions, ",")
}
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
@ -344,6 +348,9 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
}
defer tcpConn.Close()
tlsConn := tls.NewListener(tcpConn, config)
defer tlsConn.Close()
// Start the servers
httpServer := &http.Server{
Addr: addr,
@ -365,7 +372,7 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
hErr := make(chan error)
qErr := make(chan error)
go func() {
hErr <- httpServer.Serve(tcpConn)
hErr <- httpServer.Serve(tlsConn)
}()
go func() {
qErr <- quicServer.Serve(udpConn)

View file

@ -1,265 +0,0 @@
package handshake
import (
"bytes"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
// ConnectionParametersManager negotiates and stores the connection parameters
// A ConnectionParametersManager can be used for a server as well as a client
// For the server:
// 1. call SetFromMap with the values received in the CHLO. This sets the corresponding values here, subject to negotiation
// 2. call GetHelloMap to get the values to send in the SHLO
// For the client:
// 1. call GetHelloMap to get the values to send in a CHLO
// 2. call SetFromMap with the values received in the SHLO
type ConnectionParametersManager interface {
SetFromMap(map[Tag][]byte) error
GetHelloMap() (map[Tag][]byte, error)
GetSendStreamFlowControlWindow() protocol.ByteCount
GetSendConnectionFlowControlWindow() protocol.ByteCount
GetReceiveStreamFlowControlWindow() protocol.ByteCount
GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount
GetReceiveConnectionFlowControlWindow() protocol.ByteCount
GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount
GetMaxOutgoingStreams() uint32
GetMaxIncomingStreams() uint32
GetIdleConnectionStateLifetime() time.Duration
TruncateConnectionID() bool
}
type connectionParametersManager struct {
mutex sync.RWMutex
version protocol.VersionNumber
perspective protocol.Perspective
flowControlNegotiated bool
truncateConnectionID bool
maxStreamsPerConnection uint32
maxIncomingDynamicStreamsPerConnection uint32
idleConnectionStateLifetime time.Duration
sendStreamFlowControlWindow protocol.ByteCount
sendConnectionFlowControlWindow protocol.ByteCount
receiveStreamFlowControlWindow protocol.ByteCount
receiveConnectionFlowControlWindow protocol.ByteCount
maxReceiveStreamFlowControlWindow protocol.ByteCount
maxReceiveConnectionFlowControlWindow protocol.ByteCount
}
var _ ConnectionParametersManager = &connectionParametersManager{}
// ErrMalformedTag is returned when the tag value cannot be read
var (
ErrMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
ErrFlowControlRenegotiationNotSupported = qerr.Error(qerr.InvalidCryptoMessageParameter, "renegotiation of flow control parameters not supported")
)
// NewConnectionParamatersManager creates a new connection parameters manager
func NewConnectionParamatersManager(
pers protocol.Perspective, v protocol.VersionNumber,
maxReceiveStreamFlowControlWindow protocol.ByteCount, maxReceiveConnectionFlowControlWindow protocol.ByteCount,
) ConnectionParametersManager {
h := &connectionParametersManager{
perspective: pers,
version: v,
sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client
sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client
receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
}
if h.perspective == protocol.PerspectiveServer {
h.idleConnectionStateLifetime = protocol.DefaultIdleTimeout
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective
} else {
h.idleConnectionStateLifetime = protocol.MaxIdleTimeoutClient
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the server's perspective
}
return h
}
// SetFromMap reads all params
func (h *connectionParametersManager) SetFromMap(params map[Tag][]byte) error {
h.mutex.Lock()
defer h.mutex.Unlock()
if value, ok := params[TagTCID]; ok && h.perspective == protocol.PerspectiveServer {
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return ErrMalformedTag
}
h.truncateConnectionID = (clientValue == 0)
}
if value, ok := params[TagMSPC]; ok {
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return ErrMalformedTag
}
h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue)
}
if value, ok := params[TagMIDS]; ok {
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return ErrMalformedTag
}
h.maxIncomingDynamicStreamsPerConnection = h.negotiateMaxIncomingDynamicStreamsPerConnection(clientValue)
}
if value, ok := params[TagICSL]; ok {
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return ErrMalformedTag
}
h.idleConnectionStateLifetime = h.negotiateIdleConnectionStateLifetime(time.Duration(clientValue) * time.Second)
}
if value, ok := params[TagSFCW]; ok {
if h.flowControlNegotiated {
return ErrFlowControlRenegotiationNotSupported
}
sendStreamFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return ErrMalformedTag
}
h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow)
}
if value, ok := params[TagCFCW]; ok {
if h.flowControlNegotiated {
return ErrFlowControlRenegotiationNotSupported
}
sendConnectionFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return ErrMalformedTag
}
h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow)
}
_, containsSFCW := params[TagSFCW]
_, containsCFCW := params[TagCFCW]
if containsCFCW || containsSFCW {
h.flowControlNegotiated = true
}
return nil
}
func (h *connectionParametersManager) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 {
return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection)
}
func (h *connectionParametersManager) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 {
return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection)
}
func (h *connectionParametersManager) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration {
if h.perspective == protocol.PerspectiveServer {
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutServer)
}
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutClient)
}
// GetHelloMap gets all parameters needed for the Hello message
func (h *connectionParametersManager) GetHelloMap() (map[Tag][]byte, error) {
sfcw := bytes.NewBuffer([]byte{})
utils.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow()))
cfcw := bytes.NewBuffer([]byte{})
utils.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow()))
mspc := bytes.NewBuffer([]byte{})
utils.WriteUint32(mspc, h.maxStreamsPerConnection)
mids := bytes.NewBuffer([]byte{})
utils.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection)
icsl := bytes.NewBuffer([]byte{})
utils.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second))
return map[Tag][]byte{
TagICSL: icsl.Bytes(),
TagMSPC: mspc.Bytes(),
TagMIDS: mids.Bytes(),
TagCFCW: cfcw.Bytes(),
TagSFCW: sfcw.Bytes(),
}, nil
}
// GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.sendStreamFlowControlWindow
}
// GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.sendConnectionFlowControlWindow
}
// GetReceiveStreamFlowControlWindow gets the size of the stream-level flow control window for receiving data
func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.receiveStreamFlowControlWindow
}
// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
return h.maxReceiveStreamFlowControlWindow
}
// GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data
func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.receiveConnectionFlowControlWindow
}
// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
return h.maxReceiveConnectionFlowControlWindow
}
// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection
func (h *connectionParametersManager) GetMaxOutgoingStreams() uint32 {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.maxIncomingDynamicStreamsPerConnection
}
// GetMaxIncomingStreams get the maximum number of incoming streams per connection
func (h *connectionParametersManager) GetMaxIncomingStreams() uint32 {
h.mutex.RLock()
defer h.mutex.RUnlock()
maxStreams := protocol.MaxIncomingDynamicStreamsPerConnection
return utils.MaxUint32(uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, uint32(float64(maxStreams)*protocol.MaxStreamsMultiplier))
}
// GetIdleConnectionStateLifetime gets the idle timeout
func (h *connectionParametersManager) GetIdleConnectionStateLifetime() time.Duration {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.idleConnectionStateLifetime
}
// TruncateConnectionID determines if the client requests truncated ConnectionIDs
func (h *connectionParametersManager) TruncateConnectionID() bool {
if h.perspective == protocol.PerspectiveClient {
return false
}
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.truncateConnectionID
}

View file

@ -1,24 +0,0 @@
package handshake
import "github.com/lucas-clemente/quic-go/protocol"
// Sealer seals a packet
type Sealer func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
// CryptoSetup is a crypto setup
type CryptoSetup interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
HandleCryptoStream() error
// TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
}
// TransportParameters are parameters sent to the peer during the handshake
type TransportParameters struct {
RequestConnectionIDTruncation bool
}

View file

@ -1,100 +0,0 @@
package handshake
import (
"encoding/asn1"
"fmt"
"net"
"time"
"github.com/lucas-clemente/quic-go/crypto"
)
const (
stkPrefixIP byte = iota
stkPrefixString
)
// An STK is a source address token
type STK struct {
RemoteAddr string
SentTime time.Time
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
Data []byte
Timestamp int64
}
// An STKGenerator generates STKs
type STKGenerator struct {
stkSource crypto.StkSource
}
// NewSTKGenerator initializes a new STKGenerator
func NewSTKGenerator() (*STKGenerator, error) {
stkSource, err := crypto.NewStkSource()
if err != nil {
return nil, err
}
return &STKGenerator{
stkSource: stkSource,
}, nil
}
// NewToken generates a new STK token for a given source address
func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) {
data, err := asn1.Marshal(token{
Data: encodeRemoteAddr(raddr),
Timestamp: time.Now().Unix(),
})
if err != nil {
return nil, err
}
return g.stkSource.NewToken(data)
}
// DecodeToken decodes an STK token
func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) {
// if the client didn't send any STK, DecodeToken will be called with a nil-slice
if len(encrypted) == 0 {
return nil, nil
}
data, err := g.stkSource.DecodeToken(encrypted)
if err != nil {
return nil, err
}
t := &token{}
rest, err := asn1.Unmarshal(data, t)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
}
return &STK{
RemoteAddr: decodeRemoteAddr(t.Data),
SentTime: time.Unix(t.Timestamp, 0),
}, nil
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
return append([]byte{stkPrefixIP}, udpAddr.IP...)
}
return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...)
}
// decodeRemoteAddr decodes the remote address saved in the STK
func decodeRemoteAddr(data []byte) string {
// data will never be empty for an STK that we generated. Check it to be on the safe side
if len(data) == 0 {
return ""
}
if data[0] == stkPrefixIP {
return net.IP(data[1:]).String()
}
return string(data[1:])
}

View file

@ -1 +0,0 @@
package chrome

View file

@ -1 +0,0 @@
package gquic

View file

@ -1 +0,0 @@
package self

View file

@ -1,14 +1,12 @@
package quicproxy
import (
"bytes"
"net"
"sync"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// Connection is a UDP connection
@ -28,21 +26,43 @@ const (
DirectionIncoming Direction = iota
// DirectionOutgoing is the direction from the server to the client.
DirectionOutgoing
// DirectionBoth is both incoming and outgoing
DirectionBoth
)
func (d Direction) String() string {
switch d {
case DirectionIncoming:
return "incoming"
case DirectionOutgoing:
return "outgoing"
case DirectionBoth:
return "both"
default:
panic("unknown direction")
}
}
func (d Direction) Is(dir Direction) bool {
if d == DirectionBoth || dir == DirectionBoth {
return true
}
return d == dir
}
// DropCallback is a callback that determines which packet gets dropped.
type DropCallback func(Direction, protocol.PacketNumber) bool
type DropCallback func(dir Direction, packetCount uint64) bool
// NoDropper doesn't drop packets.
var NoDropper DropCallback = func(Direction, protocol.PacketNumber) bool {
var NoDropper DropCallback = func(Direction, uint64) bool {
return false
}
// DelayCallback is a callback that determines how much delay to apply to a packet.
type DelayCallback func(Direction, protocol.PacketNumber) time.Duration
type DelayCallback func(dir Direction, packetCount uint64) time.Duration
// NoDelay doesn't apply a delay.
var NoDelay DelayCallback = func(Direction, protocol.PacketNumber) time.Duration {
var NoDelay DelayCallback = func(Direction, uint64) time.Duration {
return 0
}
@ -62,6 +82,8 @@ type Opts struct {
type QuicProxy struct {
mutex sync.Mutex
version protocol.VersionNumber
conn *net.UDPConn
serverAddr *net.UDPAddr
@ -73,7 +95,10 @@ type QuicProxy struct {
}
// NewQuicProxy creates a new UDP proxy
func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) {
func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) {
if opts == nil {
opts = &Opts{}
}
laddr, err := net.ResolveUDPAddr("udp", local)
if err != nil {
return nil, err
@ -103,6 +128,7 @@ func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) {
serverAddr: raddr,
dropPacket: packetDropper,
delayPacket: packetDelayer,
version: version,
}
go p.runProxy()
@ -119,6 +145,7 @@ func (p *QuicProxy) LocalAddr() net.Addr {
return p.conn.LocalAddr()
}
// LocalPort is the UDP port number the proxy is listening on.
func (p *QuicProxy) LocalPort() int {
return p.conn.LocalAddr().(*net.UDPAddr).Port
}
@ -137,7 +164,7 @@ func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
// runProxy listens on the proxy address and handles incoming packets.
func (p *QuicProxy) runProxy() error {
for {
buffer := make([]byte, protocol.MaxPacketSize)
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
if err != nil {
return err
@ -159,20 +186,14 @@ func (p *QuicProxy) runProxy() error {
}
p.mutex.Unlock()
atomic.AddUint64(&conn.incomingPacketCounter, 1)
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
r := bytes.NewReader(raw)
hdr, err := quic.ParsePublicHeader(r, protocol.PerspectiveClient)
if err != nil {
return err
}
if p.dropPacket(DirectionIncoming, hdr.PacketNumber) {
if p.dropPacket(DirectionIncoming, packetCount) {
continue
}
// Send the packet to the server
delay := p.delayPacket(DirectionIncoming, hdr.PacketNumber)
delay := p.delayPacket(DirectionIncoming, packetCount)
if delay != 0 {
time.AfterFunc(delay, func() {
// TODO: handle error
@ -190,28 +211,20 @@ func (p *QuicProxy) runProxy() error {
// runConnection handles packets from server to a single client
func (p *QuicProxy) runConnection(conn *connection) error {
for {
buffer := make([]byte, protocol.MaxPacketSize)
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, err := conn.ServerConn.Read(buffer)
if err != nil {
return err
}
raw := buffer[0:n]
// TODO: Switch back to using the public header once Chrome properly sets the type byte.
// r := bytes.NewReader(raw)
// , err := quic.ParsePublicHeader(r, protocol.PerspectiveServer)
// if err != nil {
// return err
// }
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
v := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
packetNumber := protocol.PacketNumber(v)
if p.dropPacket(DirectionOutgoing, packetNumber) {
if p.dropPacket(DirectionOutgoing, packetCount) {
continue
}
delay := p.delayPacket(DirectionOutgoing, packetNumber)
delay := p.delayPacket(DirectionOutgoing, packetCount)
if delay != 0 {
time.AfterFunc(delay, func() {
// TODO: handle error

View file

@ -27,7 +27,7 @@ var _ = BeforeEach(func() {
if len(logFileName) > 0 {
var err error
logFile, err = os.Create("./log.txt")
logFile, err = os.Create(logFileName)
Expect(err).ToNot(HaveOccurred())
log.SetOutput(logFile)
utils.SetLogLevel(utils.LogLevelDebug)

View file

@ -7,7 +7,9 @@ import (
"net/http"
"strconv"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
@ -23,8 +25,9 @@ var (
PRData = GeneratePRData(dataLen)
PRDataLong = GeneratePRData(dataLenLong)
server *h2quic.Server
port string
server *h2quic.Server
stoppedServing chan struct{}
port string
)
func init() {
@ -75,11 +78,16 @@ func GeneratePRData(l int) []byte {
return res
}
func StartQuicServer() {
// StartQuicServer starts a h2quic.Server.
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
func StartQuicServer(versions []protocol.VersionNumber) {
server = &h2quic.Server{
Server: &http.Server{
TLSConfig: testdata.GetTLSConfig(),
},
QuicConfig: &quic.Config{
Versions: versions,
},
}
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
@ -88,14 +96,18 @@ func StartQuicServer() {
Expect(err).NotTo(HaveOccurred())
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
stoppedServing = make(chan struct{})
go func() {
defer GinkgoRecover()
server.Serve(conn)
close(stoppedServing)
}()
}
func StopQuicServer() {
Expect(server.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing).Should(BeClosed())
}
func Port() string {

View file

@ -6,23 +6,55 @@ import (
"net"
"time"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// The StreamID is the ID of a QUIC stream.
type StreamID = protocol.StreamID
// A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber
// A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie
// ConnectionState records basic details about the QUIC connection.
type ConnectionState = handshake.ConnectionState
// An ErrorCode is an application-defined error code.
type ErrorCode = protocol.ApplicationErrorCode
// Stream is the interface implemented by QUIC streams
type Stream interface {
// StreamID returns the stream ID.
StreamID() StreamID
// Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Reader
// Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Writer
// Close closes the write-direction of the stream.
// Future calls to Write are not permitted after calling Close.
// It must not be called concurrently with Write.
// It must not be called after calling CancelWrite.
io.Closer
StreamID() protocol.StreamID
// Reset closes the stream with an error.
Reset(error)
// CancelWrite aborts sending on this stream.
// It must not be called after Close.
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
// Write will unblock immediately, and future calls to Write will fail.
CancelWrite(ErrorCode) error
// CancelRead aborts receiving on this stream.
// It will ask the peer to stop transmitting stream data.
// Read will unblock immediately, and future Read calls will fail.
CancelRead(ErrorCode) error
// The context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
// Warning: This API should not be considered stable and might change soon.
@ -43,6 +75,41 @@ type Stream interface {
SetDeadline(t time.Time) error
}
// A ReceiveStream is a unidirectional Receive Stream.
type ReceiveStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Read
io.Reader
// see Stream.CancelRead
CancelRead(ErrorCode) error
// see Stream.SetReadDealine
SetReadDeadline(t time.Time) error
}
// A SendStream is a unidirectional Send Stream.
type SendStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Write
io.Writer
// see Stream.Close
io.Closer
// see Stream.CancelWrite
CancelWrite(ErrorCode) error
// see Stream.Context
Context() context.Context
// see Stream.SetWriteDeadline
SetWriteDeadline(t time.Time) error
}
// StreamError is returned by Read and Write when the peer cancels the stream.
type StreamError interface {
error
Canceled() bool
ErrorCode() ErrorCode
}
// A Session is a QUIC connection between two peers.
type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
@ -64,53 +131,41 @@ type Session interface {
// The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon.
Context() context.Context
}
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
// The communication is encrypted, but not yet forward secure.
type NonFWSession interface {
Session
WaitUntilHandshakeComplete() error
}
// An STK is a Source Address token.
// It is issued by the server and sent to the client. For the client, it is an opaque blob.
// The client can send the STK in subsequent handshakes to prove ownership of its IP address.
type STK struct {
// The remote address this token was issued for.
// If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String())
// Otherwise, this is the string representation of the net.Addr (net.Addr.String())
remoteAddr string
// The time that the STK was issued (resolution 1 second)
sentTime time.Time
// ConnectionState returns basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
ConnectionState() ConnectionState
}
// Config contains all configuration data needed for a QUIC server or client.
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
type Config struct {
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
// Warning: This API should not be considered stable and will change soon.
Versions []protocol.VersionNumber
// Ask the server to truncate the connection ID sent in the Public Header.
Versions []VersionNumber
// Ask the server to omit the connection ID sent in the Public Header.
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client.
RequestConnectionIDTruncation bool
RequestConnectionIDOmission bool
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 10 seconds.
HandshakeTimeout time.Duration
// AcceptSTK determines if an STK is accepted.
// It is called with stk = nil if the client didn't send an STK.
// If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours.
// IdleTimeout is the maximum duration that may pass without any incoming network activity.
// This value only applies after the handshake has completed.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 30 seconds.
IdleTimeout time.Duration
// AcceptCookie determines if a Cookie is accepted.
// It is called with cookie = nil if the client didn't send an Cookie.
// If not set, it verifies that the address matches, and that the Cookie was issued within the last 24 hours.
// This option is only valid for the server.
AcceptSTK func(clientAddr net.Addr, stk *STK) bool
AcceptCookie func(clientAddr net.Addr, cookie *Cookie) bool
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
MaxReceiveStreamFlowControlWindow protocol.ByteCount
MaxReceiveStreamFlowControlWindow uint64
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
MaxReceiveConnectionFlowControlWindow protocol.ByteCount
MaxReceiveConnectionFlowControlWindow uint64
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool
}

View file

@ -0,0 +1,48 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet) error
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
SetHandshakeComplete()
// SendingAllowed says if a packet can be sent.
// Sending packets might not be possible because:
// * we're congestion limited
// * we're tracking the maximum number of sent packets
SendingAllowed() bool
// TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets.
TimeUntilSend() time.Time
// ShouldSendNumPackets returns the number of packets that should be sent immediately.
// It always returns a number greater or equal than 1.
// A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay.
// Note that the number of packets is only calculated based on the pacing algorithm.
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
ShouldSendNumPackets() int
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() (packet *Packet)
GetLeastUnacked() protocol.PacketNumber
GetAlarmTimeout() time.Time
OnAlarm()
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
IgnoreBelow(protocol.PacketNumber)
GetAlarmTimeout() time.Time
GetAckFrame() *wire.AckFrame
}

Some files were not shown because too many files have changed in this diff Show more