mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-23 22:27:38 -05:00
820 lines
21 KiB
Go
820 lines
21 KiB
Go
|
package mint
|
||
|
|
||
|
import (
|
||
|
"crypto"
|
||
|
"crypto/x509"
|
||
|
"encoding/hex"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net"
|
||
|
"reflect"
|
||
|
"sync"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
var WouldBlock = fmt.Errorf("Would have blocked")
|
||
|
|
||
|
type Certificate struct {
|
||
|
Chain []*x509.Certificate
|
||
|
PrivateKey crypto.Signer
|
||
|
}
|
||
|
|
||
|
type PreSharedKey struct {
|
||
|
CipherSuite CipherSuite
|
||
|
IsResumption bool
|
||
|
Identity []byte
|
||
|
Key []byte
|
||
|
NextProto string
|
||
|
ReceivedAt time.Time
|
||
|
ExpiresAt time.Time
|
||
|
TicketAgeAdd uint32
|
||
|
}
|
||
|
|
||
|
type PreSharedKeyCache interface {
|
||
|
Get(string) (PreSharedKey, bool)
|
||
|
Put(string, PreSharedKey)
|
||
|
Size() int
|
||
|
}
|
||
|
|
||
|
type PSKMapCache map[string]PreSharedKey
|
||
|
|
||
|
// A CookieHandler does two things:
|
||
|
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
|
||
|
// - validates this byte string echoed by the client in the ClientHello
|
||
|
type CookieHandler interface {
|
||
|
Generate(*Conn) ([]byte, error)
|
||
|
Validate(*Conn, []byte) bool
|
||
|
}
|
||
|
|
||
|
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
|
||
|
psk, ok = cache[key]
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (cache *PSKMapCache) Put(key string, psk PreSharedKey) {
|
||
|
(*cache)[key] = psk
|
||
|
}
|
||
|
|
||
|
func (cache PSKMapCache) Size() int {
|
||
|
return len(cache)
|
||
|
}
|
||
|
|
||
|
// Config is the struct used to pass configuration settings to a TLS client or
|
||
|
// server instance. The settings for client and server are pretty different,
|
||
|
// but we just throw them all in here.
|
||
|
type Config struct {
|
||
|
// Client fields
|
||
|
ServerName string
|
||
|
|
||
|
// Server fields
|
||
|
SendSessionTickets bool
|
||
|
TicketLifetime uint32
|
||
|
TicketLen int
|
||
|
EarlyDataLifetime uint32
|
||
|
AllowEarlyData bool
|
||
|
// Require the client to echo a cookie.
|
||
|
RequireCookie bool
|
||
|
// If cookies are required and no CookieHandler is set, a default cookie handler is used.
|
||
|
// The default cookie handler uses 32 random bytes as a cookie.
|
||
|
CookieHandler CookieHandler
|
||
|
RequireClientAuth bool
|
||
|
|
||
|
// Shared fields
|
||
|
Certificates []*Certificate
|
||
|
AuthCertificate func(chain []CertificateEntry) error
|
||
|
CipherSuites []CipherSuite
|
||
|
Groups []NamedGroup
|
||
|
SignatureSchemes []SignatureScheme
|
||
|
NextProtos []string
|
||
|
PSKs PreSharedKeyCache
|
||
|
PSKModes []PSKKeyExchangeMode
|
||
|
NonBlocking bool
|
||
|
|
||
|
// The same config object can be shared among different connections, so it
|
||
|
// needs its own mutex
|
||
|
mutex sync.RWMutex
|
||
|
}
|
||
|
|
||
|
// Clone returns a shallow clone of c. It is safe to clone a Config that is
|
||
|
// being used concurrently by a TLS client or server.
|
||
|
func (c *Config) Clone() *Config {
|
||
|
c.mutex.Lock()
|
||
|
defer c.mutex.Unlock()
|
||
|
|
||
|
return &Config{
|
||
|
ServerName: c.ServerName,
|
||
|
|
||
|
SendSessionTickets: c.SendSessionTickets,
|
||
|
TicketLifetime: c.TicketLifetime,
|
||
|
TicketLen: c.TicketLen,
|
||
|
EarlyDataLifetime: c.EarlyDataLifetime,
|
||
|
AllowEarlyData: c.AllowEarlyData,
|
||
|
RequireCookie: c.RequireCookie,
|
||
|
RequireClientAuth: c.RequireClientAuth,
|
||
|
|
||
|
Certificates: c.Certificates,
|
||
|
AuthCertificate: c.AuthCertificate,
|
||
|
CipherSuites: c.CipherSuites,
|
||
|
Groups: c.Groups,
|
||
|
SignatureSchemes: c.SignatureSchemes,
|
||
|
NextProtos: c.NextProtos,
|
||
|
PSKs: c.PSKs,
|
||
|
PSKModes: c.PSKModes,
|
||
|
NonBlocking: c.NonBlocking,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (c *Config) Init(isClient bool) error {
|
||
|
c.mutex.Lock()
|
||
|
defer c.mutex.Unlock()
|
||
|
|
||
|
// Set defaults
|
||
|
if len(c.CipherSuites) == 0 {
|
||
|
c.CipherSuites = defaultSupportedCipherSuites
|
||
|
}
|
||
|
if len(c.Groups) == 0 {
|
||
|
c.Groups = defaultSupportedGroups
|
||
|
}
|
||
|
if len(c.SignatureSchemes) == 0 {
|
||
|
c.SignatureSchemes = defaultSignatureSchemes
|
||
|
}
|
||
|
if c.TicketLen == 0 {
|
||
|
c.TicketLen = defaultTicketLen
|
||
|
}
|
||
|
if !reflect.ValueOf(c.PSKs).IsValid() {
|
||
|
c.PSKs = &PSKMapCache{}
|
||
|
}
|
||
|
if len(c.PSKModes) == 0 {
|
||
|
c.PSKModes = defaultPSKModes
|
||
|
}
|
||
|
|
||
|
// If there is no certificate, generate one
|
||
|
if !isClient && len(c.Certificates) == 0 {
|
||
|
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
|
||
|
priv, err := newSigningKey(RSA_PSS_SHA256)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
c.Certificates = []*Certificate{
|
||
|
{
|
||
|
Chain: []*x509.Certificate{cert},
|
||
|
PrivateKey: priv,
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *Config) ValidForServer() bool {
|
||
|
return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) ||
|
||
|
(len(c.Certificates) > 0 &&
|
||
|
len(c.Certificates[0].Chain) > 0 &&
|
||
|
c.Certificates[0].PrivateKey != nil)
|
||
|
}
|
||
|
|
||
|
func (c *Config) ValidForClient() bool {
|
||
|
return len(c.ServerName) > 0
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
defaultSupportedCipherSuites = []CipherSuite{
|
||
|
TLS_AES_128_GCM_SHA256,
|
||
|
TLS_AES_256_GCM_SHA384,
|
||
|
}
|
||
|
|
||
|
defaultSupportedGroups = []NamedGroup{
|
||
|
P256,
|
||
|
P384,
|
||
|
FFDHE2048,
|
||
|
X25519,
|
||
|
}
|
||
|
|
||
|
defaultSignatureSchemes = []SignatureScheme{
|
||
|
RSA_PSS_SHA256,
|
||
|
RSA_PSS_SHA384,
|
||
|
RSA_PSS_SHA512,
|
||
|
ECDSA_P256_SHA256,
|
||
|
ECDSA_P384_SHA384,
|
||
|
ECDSA_P521_SHA512,
|
||
|
}
|
||
|
|
||
|
defaultTicketLen = 16
|
||
|
|
||
|
defaultPSKModes = []PSKKeyExchangeMode{
|
||
|
PSKModeKE,
|
||
|
PSKModeDHEKE,
|
||
|
}
|
||
|
)
|
||
|
|
||
|
type ConnectionState struct {
|
||
|
HandshakeState string // string representation of the handshake state.
|
||
|
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
|
||
|
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement
|
||
|
NextProto string // Selected ALPN proto
|
||
|
}
|
||
|
|
||
|
// Conn implements the net.Conn interface, as with "crypto/tls"
|
||
|
// * Read, Write, and Close are provided locally
|
||
|
// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn
|
||
|
type Conn struct {
|
||
|
config *Config
|
||
|
conn net.Conn
|
||
|
isClient bool
|
||
|
|
||
|
EarlyData []byte
|
||
|
|
||
|
state StateConnected
|
||
|
hState HandshakeState
|
||
|
handshakeMutex sync.Mutex
|
||
|
handshakeAlert Alert
|
||
|
handshakeComplete bool
|
||
|
|
||
|
readBuffer []byte
|
||
|
in, out *RecordLayer
|
||
|
hIn, hOut *HandshakeLayer
|
||
|
|
||
|
extHandler AppExtensionHandler
|
||
|
}
|
||
|
|
||
|
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
|
||
|
c := &Conn{conn: conn, config: config, isClient: isClient}
|
||
|
c.in = NewRecordLayer(c.conn)
|
||
|
c.out = NewRecordLayer(c.conn)
|
||
|
c.hIn = NewHandshakeLayer(c.in)
|
||
|
c.hIn.nonblocking = c.config.NonBlocking
|
||
|
c.hOut = NewHandshakeLayer(c.out)
|
||
|
return c
|
||
|
}
|
||
|
|
||
|
// Read up
|
||
|
func (c *Conn) consumeRecord() error {
|
||
|
pt, err := c.in.ReadRecord()
|
||
|
if pt == nil {
|
||
|
logf(logTypeIO, "extendBuffer returns error %v", err)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
switch pt.contentType {
|
||
|
case RecordTypeHandshake:
|
||
|
logf(logTypeHandshake, "Received post-handshake message")
|
||
|
// We do not support fragmentation of post-handshake handshake messages.
|
||
|
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
|
||
|
start := 0
|
||
|
for start < len(pt.fragment) {
|
||
|
if len(pt.fragment[start:]) < handshakeHeaderLen {
|
||
|
return fmt.Errorf("Post-handshake handshake message too short for header")
|
||
|
}
|
||
|
|
||
|
hm := &HandshakeMessage{}
|
||
|
hm.msgType = HandshakeType(pt.fragment[start])
|
||
|
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
|
||
|
|
||
|
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen {
|
||
|
return fmt.Errorf("Post-handshake handshake message too short for body")
|
||
|
}
|
||
|
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen]
|
||
|
|
||
|
// Advance state machine
|
||
|
state, actions, alert := c.state.Next(hm)
|
||
|
|
||
|
if alert != AlertNoAlert {
|
||
|
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||
|
c.sendAlert(alert)
|
||
|
return io.EOF
|
||
|
}
|
||
|
|
||
|
for _, action := range actions {
|
||
|
alert = c.takeAction(action)
|
||
|
if alert != AlertNoAlert {
|
||
|
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||
|
c.sendAlert(alert)
|
||
|
return io.EOF
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// XXX: If we want to support more advanced cases, e.g., post-handshake
|
||
|
// authentication, we'll need to allow transitions other than
|
||
|
// Connected -> Connected
|
||
|
var connected bool
|
||
|
c.state, connected = state.(StateConnected)
|
||
|
if !connected {
|
||
|
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
|
||
|
c.sendAlert(alert)
|
||
|
return io.EOF
|
||
|
}
|
||
|
|
||
|
start += handshakeHeaderLen + hmLen
|
||
|
}
|
||
|
case RecordTypeAlert:
|
||
|
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
|
||
|
if len(pt.fragment) != 2 {
|
||
|
c.sendAlert(AlertUnexpectedMessage)
|
||
|
return io.EOF
|
||
|
}
|
||
|
if Alert(pt.fragment[1]) == AlertCloseNotify {
|
||
|
return io.EOF
|
||
|
}
|
||
|
|
||
|
switch pt.fragment[0] {
|
||
|
case AlertLevelWarning:
|
||
|
// drop on the floor
|
||
|
case AlertLevelError:
|
||
|
return Alert(pt.fragment[1])
|
||
|
default:
|
||
|
c.sendAlert(AlertUnexpectedMessage)
|
||
|
return io.EOF
|
||
|
}
|
||
|
|
||
|
case RecordTypeApplicationData:
|
||
|
c.readBuffer = append(c.readBuffer, pt.fragment...)
|
||
|
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Read application data up to the size of buffer. Handshake and alert records
|
||
|
// are consumed by the Conn object directly.
|
||
|
func (c *Conn) Read(buffer []byte) (int, error) {
|
||
|
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
|
||
|
if alert := c.Handshake(); alert != AlertNoAlert {
|
||
|
return 0, alert
|
||
|
}
|
||
|
|
||
|
if len(buffer) == 0 {
|
||
|
return 0, nil
|
||
|
}
|
||
|
|
||
|
// Lock the input channel
|
||
|
c.in.Lock()
|
||
|
defer c.in.Unlock()
|
||
|
for len(c.readBuffer) == 0 {
|
||
|
err := c.consumeRecord()
|
||
|
|
||
|
// err can be nil if consumeRecord processed a non app-data
|
||
|
// record.
|
||
|
if err != nil {
|
||
|
if c.config.NonBlocking || err != WouldBlock {
|
||
|
logf(logTypeIO, "conn.Read returns err=%v", err)
|
||
|
return 0, err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var read int
|
||
|
n := len(buffer)
|
||
|
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
|
||
|
if len(c.readBuffer) <= n {
|
||
|
buffer = buffer[:len(c.readBuffer)]
|
||
|
copy(buffer, c.readBuffer)
|
||
|
read = len(c.readBuffer)
|
||
|
c.readBuffer = c.readBuffer[:0]
|
||
|
} else {
|
||
|
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
|
||
|
copy(buffer[:n], c.readBuffer[:n])
|
||
|
c.readBuffer = c.readBuffer[n:]
|
||
|
read = n
|
||
|
}
|
||
|
|
||
|
logf(logTypeVerbose, "Returning %v", string(buffer))
|
||
|
return read, nil
|
||
|
}
|
||
|
|
||
|
// Write application data
|
||
|
func (c *Conn) Write(buffer []byte) (int, error) {
|
||
|
// Lock the output channel
|
||
|
c.out.Lock()
|
||
|
defer c.out.Unlock()
|
||
|
|
||
|
// Send full-size fragments
|
||
|
var start int
|
||
|
sent := 0
|
||
|
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
||
|
err := c.out.WriteRecord(&TLSPlaintext{
|
||
|
contentType: RecordTypeApplicationData,
|
||
|
fragment: buffer[start : start+maxFragmentLen],
|
||
|
})
|
||
|
|
||
|
if err != nil {
|
||
|
return sent, err
|
||
|
}
|
||
|
sent += maxFragmentLen
|
||
|
}
|
||
|
|
||
|
// Send a final partial fragment if necessary
|
||
|
if start < len(buffer) {
|
||
|
err := c.out.WriteRecord(&TLSPlaintext{
|
||
|
contentType: RecordTypeApplicationData,
|
||
|
fragment: buffer[start:],
|
||
|
})
|
||
|
|
||
|
if err != nil {
|
||
|
return sent, err
|
||
|
}
|
||
|
sent += len(buffer[start:])
|
||
|
}
|
||
|
return sent, nil
|
||
|
}
|
||
|
|
||
|
// sendAlert sends a TLS alert message.
|
||
|
// c.out.Mutex <= L.
|
||
|
func (c *Conn) sendAlert(err Alert) error {
|
||
|
c.handshakeMutex.Lock()
|
||
|
defer c.handshakeMutex.Unlock()
|
||
|
|
||
|
var level int
|
||
|
switch err {
|
||
|
case AlertNoRenegotiation, AlertCloseNotify:
|
||
|
level = AlertLevelWarning
|
||
|
default:
|
||
|
level = AlertLevelError
|
||
|
}
|
||
|
|
||
|
buf := []byte{byte(err), byte(level)}
|
||
|
c.out.WriteRecord(&TLSPlaintext{
|
||
|
contentType: RecordTypeAlert,
|
||
|
fragment: buf,
|
||
|
})
|
||
|
|
||
|
// close_notify and end_of_early_data are not actually errors
|
||
|
if level == AlertLevelWarning {
|
||
|
return &net.OpError{Op: "local error", Err: err}
|
||
|
}
|
||
|
|
||
|
return c.Close()
|
||
|
}
|
||
|
|
||
|
// Close closes the connection.
|
||
|
func (c *Conn) Close() error {
|
||
|
// XXX crypto/tls has an interlock with Write here. Do we need that?
|
||
|
|
||
|
return c.conn.Close()
|
||
|
}
|
||
|
|
||
|
// LocalAddr returns the local network address.
|
||
|
func (c *Conn) LocalAddr() net.Addr {
|
||
|
return c.conn.LocalAddr()
|
||
|
}
|
||
|
|
||
|
// RemoteAddr returns the remote network address.
|
||
|
func (c *Conn) RemoteAddr() net.Addr {
|
||
|
return c.conn.RemoteAddr()
|
||
|
}
|
||
|
|
||
|
// SetDeadline sets the read and write deadlines associated with the connection.
|
||
|
// A zero value for t means Read and Write will not time out.
|
||
|
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
||
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||
|
return c.conn.SetDeadline(t)
|
||
|
}
|
||
|
|
||
|
// SetReadDeadline sets the read deadline on the underlying connection.
|
||
|
// A zero value for t means Read will not time out.
|
||
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||
|
return c.conn.SetReadDeadline(t)
|
||
|
}
|
||
|
|
||
|
// SetWriteDeadline sets the write deadline on the underlying connection.
|
||
|
// A zero value for t means Write will not time out.
|
||
|
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
||
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||
|
return c.conn.SetWriteDeadline(t)
|
||
|
}
|
||
|
|
||
|
func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
|
||
|
label := "[server]"
|
||
|
if c.isClient {
|
||
|
label = "[client]"
|
||
|
}
|
||
|
|
||
|
switch action := actionGeneric.(type) {
|
||
|
case SendHandshakeMessage:
|
||
|
err := c.hOut.WriteMessage(action.Message)
|
||
|
if err != nil {
|
||
|
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
|
||
|
return AlertInternalError
|
||
|
}
|
||
|
|
||
|
case RekeyIn:
|
||
|
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet)
|
||
|
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||
|
if err != nil {
|
||
|
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
|
||
|
return AlertInternalError
|
||
|
}
|
||
|
|
||
|
case RekeyOut:
|
||
|
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet)
|
||
|
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||
|
if err != nil {
|
||
|
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
|
||
|
return AlertInternalError
|
||
|
}
|
||
|
|
||
|
case SendEarlyData:
|
||
|
logf(logTypeHandshake, "%s Sending early data...", label)
|
||
|
_, err := c.Write(c.EarlyData)
|
||
|
if err != nil {
|
||
|
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
|
||
|
return AlertInternalError
|
||
|
}
|
||
|
|
||
|
case ReadPastEarlyData:
|
||
|
logf(logTypeHandshake, "%s Reading past early data...", label)
|
||
|
// Scan past all records that fail to decrypt
|
||
|
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||
|
if err == nil {
|
||
|
break
|
||
|
}
|
||
|
_, ok := err.(DecryptError)
|
||
|
|
||
|
for ok {
|
||
|
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||
|
if err == nil {
|
||
|
break
|
||
|
}
|
||
|
_, ok = err.(DecryptError)
|
||
|
}
|
||
|
|
||
|
case ReadEarlyData:
|
||
|
logf(logTypeHandshake, "%s Reading early data...", label)
|
||
|
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||
|
if err != nil {
|
||
|
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
|
||
|
return AlertInternalError
|
||
|
}
|
||
|
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
|
||
|
|
||
|
for t == RecordTypeApplicationData {
|
||
|
// Read a record into the buffer. Note that this is safe
|
||
|
// in blocking mode because we read the record in in
|
||
|
// PeekRecordType.
|
||
|
pt, err := c.in.ReadRecord()
|
||
|
if err != nil {
|
||
|
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
|
||
|
return AlertInternalError
|
||
|
}
|
||
|
|
||
|
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
|
||
|
c.EarlyData = append(c.EarlyData, pt.fragment...)
|
||
|
|
||
|
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||
|
if err != nil {
|
||
|
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
|
||
|
return AlertInternalError
|
||
|
}
|
||
|
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
|
||
|
}
|
||
|
logf(logTypeHandshake, "%s Done reading early data", label)
|
||
|
|
||
|
case StorePSK:
|
||
|
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
|
||
|
if c.isClient {
|
||
|
// Clients look up PSKs based on server name
|
||
|
c.config.PSKs.Put(c.config.ServerName, action.PSK)
|
||
|
} else {
|
||
|
// Servers look them up based on the identity in the extension
|
||
|
c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK)
|
||
|
}
|
||
|
|
||
|
default:
|
||
|
logf(logTypeHandshake, "%s Unknown actionuction type", label)
|
||
|
return AlertInternalError
|
||
|
}
|
||
|
|
||
|
return AlertNoAlert
|
||
|
}
|
||
|
|
||
|
func (c *Conn) HandshakeSetup() Alert {
|
||
|
var state HandshakeState
|
||
|
var actions []HandshakeAction
|
||
|
var alert Alert
|
||
|
|
||
|
if err := c.config.Init(c.isClient); err != nil {
|
||
|
logf(logTypeHandshake, "Error initializing config: %v", err)
|
||
|
return AlertInternalError
|
||
|
}
|
||
|
|
||
|
// Set things up
|
||
|
caps := Capabilities{
|
||
|
CipherSuites: c.config.CipherSuites,
|
||
|
Groups: c.config.Groups,
|
||
|
SignatureSchemes: c.config.SignatureSchemes,
|
||
|
PSKs: c.config.PSKs,
|
||
|
PSKModes: c.config.PSKModes,
|
||
|
AllowEarlyData: c.config.AllowEarlyData,
|
||
|
RequireCookie: c.config.RequireCookie,
|
||
|
CookieHandler: c.config.CookieHandler,
|
||
|
RequireClientAuth: c.config.RequireClientAuth,
|
||
|
NextProtos: c.config.NextProtos,
|
||
|
Certificates: c.config.Certificates,
|
||
|
ExtensionHandler: c.extHandler,
|
||
|
}
|
||
|
opts := ConnectionOptions{
|
||
|
ServerName: c.config.ServerName,
|
||
|
NextProtos: c.config.NextProtos,
|
||
|
EarlyData: c.EarlyData,
|
||
|
}
|
||
|
|
||
|
if caps.RequireCookie && caps.CookieHandler == nil {
|
||
|
caps.CookieHandler = &defaultCookieHandler{}
|
||
|
}
|
||
|
|
||
|
if c.isClient {
|
||
|
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil)
|
||
|
if alert != AlertNoAlert {
|
||
|
logf(logTypeHandshake, "Error initializing client state: %v", alert)
|
||
|
return alert
|
||
|
}
|
||
|
|
||
|
for _, action := range actions {
|
||
|
alert = c.takeAction(action)
|
||
|
if alert != AlertNoAlert {
|
||
|
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||
|
return alert
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
state = ServerStateStart{Caps: caps, conn: c}
|
||
|
}
|
||
|
|
||
|
c.hState = state
|
||
|
|
||
|
return AlertNoAlert
|
||
|
}
|
||
|
|
||
|
// Handshake causes a TLS handshake on the connection. The `isClient` member
|
||
|
// determines whether a client or server handshake is performed. If a
|
||
|
// handshake has already been performed, then its result will be returned.
|
||
|
func (c *Conn) Handshake() Alert {
|
||
|
label := "[server]"
|
||
|
if c.isClient {
|
||
|
label = "[client]"
|
||
|
}
|
||
|
|
||
|
// TODO Lock handshakeMutex
|
||
|
// TODO Remove CloseNotify hack
|
||
|
if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify {
|
||
|
logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert)
|
||
|
return c.handshakeAlert
|
||
|
}
|
||
|
if c.handshakeComplete {
|
||
|
return AlertNoAlert
|
||
|
}
|
||
|
|
||
|
var alert Alert
|
||
|
if c.hState == nil {
|
||
|
logf(logTypeHandshake, "%s First time through handshake, setting up", label)
|
||
|
alert = c.HandshakeSetup()
|
||
|
if alert != AlertNoAlert {
|
||
|
return alert
|
||
|
}
|
||
|
} else {
|
||
|
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState)
|
||
|
}
|
||
|
|
||
|
state := c.hState
|
||
|
_, connected := state.(StateConnected)
|
||
|
|
||
|
var actions []HandshakeAction
|
||
|
|
||
|
for !connected {
|
||
|
// Read a handshake message
|
||
|
hm, err := c.hIn.ReadMessage()
|
||
|
if err == WouldBlock {
|
||
|
logf(logTypeHandshake, "%s Would block reading message: %v", label, err)
|
||
|
return AlertWouldBlock
|
||
|
}
|
||
|
if err != nil {
|
||
|
logf(logTypeHandshake, "%s Error reading message: %v", label, err)
|
||
|
c.sendAlert(AlertCloseNotify)
|
||
|
return AlertCloseNotify
|
||
|
}
|
||
|
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
|
||
|
|
||
|
// Advance the state machine
|
||
|
state, actions, alert = state.Next(hm)
|
||
|
|
||
|
if alert != AlertNoAlert {
|
||
|
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||
|
return alert
|
||
|
}
|
||
|
|
||
|
for index, action := range actions {
|
||
|
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
|
||
|
alert = c.takeAction(action)
|
||
|
if alert != AlertNoAlert {
|
||
|
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||
|
c.sendAlert(alert)
|
||
|
return alert
|
||
|
}
|
||
|
}
|
||
|
|
||
|
c.hState = state
|
||
|
logf(logTypeHandshake, "state is now %s", c.GetHsState())
|
||
|
|
||
|
_, connected = state.(StateConnected)
|
||
|
}
|
||
|
|
||
|
c.state = state.(StateConnected)
|
||
|
|
||
|
// Send NewSessionTicket if acting as server
|
||
|
if !c.isClient && c.config.SendSessionTickets {
|
||
|
actions, alert := c.state.NewSessionTicket(
|
||
|
c.config.TicketLen,
|
||
|
c.config.TicketLifetime,
|
||
|
c.config.EarlyDataLifetime)
|
||
|
|
||
|
for _, action := range actions {
|
||
|
alert = c.takeAction(action)
|
||
|
if alert != AlertNoAlert {
|
||
|
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||
|
c.sendAlert(alert)
|
||
|
return alert
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
c.handshakeComplete = true
|
||
|
return AlertNoAlert
|
||
|
}
|
||
|
|
||
|
func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
|
||
|
if !c.handshakeComplete {
|
||
|
return fmt.Errorf("Cannot update keys until after handshake")
|
||
|
}
|
||
|
|
||
|
request := KeyUpdateNotRequested
|
||
|
if requestUpdate {
|
||
|
request = KeyUpdateRequested
|
||
|
}
|
||
|
|
||
|
// Create the key update and update state
|
||
|
actions, alert := c.state.KeyUpdate(request)
|
||
|
if alert != AlertNoAlert {
|
||
|
c.sendAlert(alert)
|
||
|
return fmt.Errorf("Alert while generating key update: %v", alert)
|
||
|
}
|
||
|
|
||
|
// Take actions (send key update and rekey)
|
||
|
for _, action := range actions {
|
||
|
alert = c.takeAction(action)
|
||
|
if alert != AlertNoAlert {
|
||
|
c.sendAlert(alert)
|
||
|
return fmt.Errorf("Alert during key update actions: %v", alert)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *Conn) GetHsState() string {
|
||
|
return reflect.TypeOf(c.hState).Name()
|
||
|
}
|
||
|
|
||
|
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||
|
_, connected := c.hState.(StateConnected)
|
||
|
if !connected {
|
||
|
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
|
||
|
}
|
||
|
|
||
|
if c.state.exporterSecret == nil {
|
||
|
return nil, fmt.Errorf("Internal error: no exporter secret")
|
||
|
}
|
||
|
|
||
|
h0 := c.state.cryptoParams.Hash.New().Sum(nil)
|
||
|
tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0)
|
||
|
|
||
|
hc := c.state.cryptoParams.Hash.New().Sum(context)
|
||
|
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
|
||
|
}
|
||
|
|
||
|
func (c *Conn) State() ConnectionState {
|
||
|
state := ConnectionState{
|
||
|
HandshakeState: c.GetHsState(),
|
||
|
}
|
||
|
|
||
|
if c.handshakeComplete {
|
||
|
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
|
||
|
state.NextProto = c.state.Params.NextProto
|
||
|
}
|
||
|
|
||
|
return state
|
||
|
}
|
||
|
|
||
|
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error {
|
||
|
if c.hState != nil {
|
||
|
return fmt.Errorf("Can't set extension handler after setup")
|
||
|
}
|
||
|
|
||
|
c.extHandler = h
|
||
|
return nil
|
||
|
}
|