0
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2024-12-23 22:27:38 -05:00
caddy/vendor/github.com/bifurcation/mint/conn.go

820 lines
21 KiB
Go
Raw Normal View History

package mint
import (
"crypto"
"crypto/x509"
"encoding/hex"
"fmt"
"io"
"net"
"reflect"
"sync"
"time"
)
var WouldBlock = fmt.Errorf("Would have blocked")
type Certificate struct {
Chain []*x509.Certificate
PrivateKey crypto.Signer
}
type PreSharedKey struct {
CipherSuite CipherSuite
IsResumption bool
Identity []byte
Key []byte
NextProto string
ReceivedAt time.Time
ExpiresAt time.Time
TicketAgeAdd uint32
}
type PreSharedKeyCache interface {
Get(string) (PreSharedKey, bool)
Put(string, PreSharedKey)
Size() int
}
type PSKMapCache map[string]PreSharedKey
// A CookieHandler does two things:
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
// - validates this byte string echoed by the client in the ClientHello
type CookieHandler interface {
Generate(*Conn) ([]byte, error)
Validate(*Conn, []byte) bool
}
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
psk, ok = cache[key]
return
}
func (cache *PSKMapCache) Put(key string, psk PreSharedKey) {
(*cache)[key] = psk
}
func (cache PSKMapCache) Size() int {
return len(cache)
}
// Config is the struct used to pass configuration settings to a TLS client or
// server instance. The settings for client and server are pretty different,
// but we just throw them all in here.
type Config struct {
// Client fields
ServerName string
// Server fields
SendSessionTickets bool
TicketLifetime uint32
TicketLen int
EarlyDataLifetime uint32
AllowEarlyData bool
// Require the client to echo a cookie.
RequireCookie bool
// If cookies are required and no CookieHandler is set, a default cookie handler is used.
// The default cookie handler uses 32 random bytes as a cookie.
CookieHandler CookieHandler
RequireClientAuth bool
// Shared fields
Certificates []*Certificate
AuthCertificate func(chain []CertificateEntry) error
CipherSuites []CipherSuite
Groups []NamedGroup
SignatureSchemes []SignatureScheme
NextProtos []string
PSKs PreSharedKeyCache
PSKModes []PSKKeyExchangeMode
NonBlocking bool
// The same config object can be shared among different connections, so it
// needs its own mutex
mutex sync.RWMutex
}
// Clone returns a shallow clone of c. It is safe to clone a Config that is
// being used concurrently by a TLS client or server.
func (c *Config) Clone() *Config {
c.mutex.Lock()
defer c.mutex.Unlock()
return &Config{
ServerName: c.ServerName,
SendSessionTickets: c.SendSessionTickets,
TicketLifetime: c.TicketLifetime,
TicketLen: c.TicketLen,
EarlyDataLifetime: c.EarlyDataLifetime,
AllowEarlyData: c.AllowEarlyData,
RequireCookie: c.RequireCookie,
RequireClientAuth: c.RequireClientAuth,
Certificates: c.Certificates,
AuthCertificate: c.AuthCertificate,
CipherSuites: c.CipherSuites,
Groups: c.Groups,
SignatureSchemes: c.SignatureSchemes,
NextProtos: c.NextProtos,
PSKs: c.PSKs,
PSKModes: c.PSKModes,
NonBlocking: c.NonBlocking,
}
}
func (c *Config) Init(isClient bool) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// Set defaults
if len(c.CipherSuites) == 0 {
c.CipherSuites = defaultSupportedCipherSuites
}
if len(c.Groups) == 0 {
c.Groups = defaultSupportedGroups
}
if len(c.SignatureSchemes) == 0 {
c.SignatureSchemes = defaultSignatureSchemes
}
if c.TicketLen == 0 {
c.TicketLen = defaultTicketLen
}
if !reflect.ValueOf(c.PSKs).IsValid() {
c.PSKs = &PSKMapCache{}
}
if len(c.PSKModes) == 0 {
c.PSKModes = defaultPSKModes
}
// If there is no certificate, generate one
if !isClient && len(c.Certificates) == 0 {
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
priv, err := newSigningKey(RSA_PSS_SHA256)
if err != nil {
return err
}
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
if err != nil {
return err
}
c.Certificates = []*Certificate{
{
Chain: []*x509.Certificate{cert},
PrivateKey: priv,
},
}
}
return nil
}
func (c *Config) ValidForServer() bool {
return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) ||
(len(c.Certificates) > 0 &&
len(c.Certificates[0].Chain) > 0 &&
c.Certificates[0].PrivateKey != nil)
}
func (c *Config) ValidForClient() bool {
return len(c.ServerName) > 0
}
var (
defaultSupportedCipherSuites = []CipherSuite{
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
}
defaultSupportedGroups = []NamedGroup{
P256,
P384,
FFDHE2048,
X25519,
}
defaultSignatureSchemes = []SignatureScheme{
RSA_PSS_SHA256,
RSA_PSS_SHA384,
RSA_PSS_SHA512,
ECDSA_P256_SHA256,
ECDSA_P384_SHA384,
ECDSA_P521_SHA512,
}
defaultTicketLen = 16
defaultPSKModes = []PSKKeyExchangeMode{
PSKModeKE,
PSKModeDHEKE,
}
)
type ConnectionState struct {
HandshakeState string // string representation of the handshake state.
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement
NextProto string // Selected ALPN proto
}
// Conn implements the net.Conn interface, as with "crypto/tls"
// * Read, Write, and Close are provided locally
// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn
type Conn struct {
config *Config
conn net.Conn
isClient bool
EarlyData []byte
state StateConnected
hState HandshakeState
handshakeMutex sync.Mutex
handshakeAlert Alert
handshakeComplete bool
readBuffer []byte
in, out *RecordLayer
hIn, hOut *HandshakeLayer
extHandler AppExtensionHandler
}
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
c := &Conn{conn: conn, config: config, isClient: isClient}
c.in = NewRecordLayer(c.conn)
c.out = NewRecordLayer(c.conn)
c.hIn = NewHandshakeLayer(c.in)
c.hIn.nonblocking = c.config.NonBlocking
c.hOut = NewHandshakeLayer(c.out)
return c
}
// Read up
func (c *Conn) consumeRecord() error {
pt, err := c.in.ReadRecord()
if pt == nil {
logf(logTypeIO, "extendBuffer returns error %v", err)
return err
}
switch pt.contentType {
case RecordTypeHandshake:
logf(logTypeHandshake, "Received post-handshake message")
// We do not support fragmentation of post-handshake handshake messages.
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
start := 0
for start < len(pt.fragment) {
if len(pt.fragment[start:]) < handshakeHeaderLen {
return fmt.Errorf("Post-handshake handshake message too short for header")
}
hm := &HandshakeMessage{}
hm.msgType = HandshakeType(pt.fragment[start])
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen {
return fmt.Errorf("Post-handshake handshake message too short for body")
}
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen]
// Advance state machine
state, actions, alert := c.state.Next(hm)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error in state transition: %v", alert)
c.sendAlert(alert)
return io.EOF
}
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return io.EOF
}
}
// XXX: If we want to support more advanced cases, e.g., post-handshake
// authentication, we'll need to allow transitions other than
// Connected -> Connected
var connected bool
c.state, connected = state.(StateConnected)
if !connected {
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
c.sendAlert(alert)
return io.EOF
}
start += handshakeHeaderLen + hmLen
}
case RecordTypeAlert:
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
if len(pt.fragment) != 2 {
c.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
if Alert(pt.fragment[1]) == AlertCloseNotify {
return io.EOF
}
switch pt.fragment[0] {
case AlertLevelWarning:
// drop on the floor
case AlertLevelError:
return Alert(pt.fragment[1])
default:
c.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
case RecordTypeApplicationData:
c.readBuffer = append(c.readBuffer, pt.fragment...)
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
}
return err
}
// Read application data up to the size of buffer. Handshake and alert records
// are consumed by the Conn object directly.
func (c *Conn) Read(buffer []byte) (int, error) {
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
if alert := c.Handshake(); alert != AlertNoAlert {
return 0, alert
}
if len(buffer) == 0 {
return 0, nil
}
// Lock the input channel
c.in.Lock()
defer c.in.Unlock()
for len(c.readBuffer) == 0 {
err := c.consumeRecord()
// err can be nil if consumeRecord processed a non app-data
// record.
if err != nil {
if c.config.NonBlocking || err != WouldBlock {
logf(logTypeIO, "conn.Read returns err=%v", err)
return 0, err
}
}
}
var read int
n := len(buffer)
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
if len(c.readBuffer) <= n {
buffer = buffer[:len(c.readBuffer)]
copy(buffer, c.readBuffer)
read = len(c.readBuffer)
c.readBuffer = c.readBuffer[:0]
} else {
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
copy(buffer[:n], c.readBuffer[:n])
c.readBuffer = c.readBuffer[n:]
read = n
}
logf(logTypeVerbose, "Returning %v", string(buffer))
return read, nil
}
// Write application data
func (c *Conn) Write(buffer []byte) (int, error) {
// Lock the output channel
c.out.Lock()
defer c.out.Unlock()
// Send full-size fragments
var start int
sent := 0
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
err := c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: buffer[start : start+maxFragmentLen],
})
if err != nil {
return sent, err
}
sent += maxFragmentLen
}
// Send a final partial fragment if necessary
if start < len(buffer) {
err := c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: buffer[start:],
})
if err != nil {
return sent, err
}
sent += len(buffer[start:])
}
return sent, nil
}
// sendAlert sends a TLS alert message.
// c.out.Mutex <= L.
func (c *Conn) sendAlert(err Alert) error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
var level int
switch err {
case AlertNoRenegotiation, AlertCloseNotify:
level = AlertLevelWarning
default:
level = AlertLevelError
}
buf := []byte{byte(err), byte(level)}
c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeAlert,
fragment: buf,
})
// close_notify and end_of_early_data are not actually errors
if level == AlertLevelWarning {
return &net.OpError{Op: "local error", Err: err}
}
return c.Close()
}
// Close closes the connection.
func (c *Conn) Close() error {
// XXX crypto/tls has an interlock with Write here. Do we need that?
return c.conn.Close()
}
// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline sets the read and write deadlines associated with the connection.
// A zero value for t means Read and Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
// SetReadDeadline sets the read deadline on the underlying connection.
// A zero value for t means Read will not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline on the underlying connection.
// A zero value for t means Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
label := "[server]"
if c.isClient {
label = "[client]"
}
switch action := actionGeneric.(type) {
case SendHandshakeMessage:
err := c.hOut.WriteMessage(action.Message)
if err != nil {
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
return AlertInternalError
}
case RekeyIn:
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet)
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
if err != nil {
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
return AlertInternalError
}
case RekeyOut:
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet)
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
if err != nil {
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
return AlertInternalError
}
case SendEarlyData:
logf(logTypeHandshake, "%s Sending early data...", label)
_, err := c.Write(c.EarlyData)
if err != nil {
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
return AlertInternalError
}
case ReadPastEarlyData:
logf(logTypeHandshake, "%s Reading past early data...", label)
// Scan past all records that fail to decrypt
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
if err == nil {
break
}
_, ok := err.(DecryptError)
for ok {
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
if err == nil {
break
}
_, ok = err.(DecryptError)
}
case ReadEarlyData:
logf(logTypeHandshake, "%s Reading early data...", label)
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
if err != nil {
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
for t == RecordTypeApplicationData {
// Read a record into the buffer. Note that this is safe
// in blocking mode because we read the record in in
// PeekRecordType.
pt, err := c.in.ReadRecord()
if err != nil {
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
c.EarlyData = append(c.EarlyData, pt.fragment...)
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
if err != nil {
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
}
logf(logTypeHandshake, "%s Done reading early data", label)
case StorePSK:
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
if c.isClient {
// Clients look up PSKs based on server name
c.config.PSKs.Put(c.config.ServerName, action.PSK)
} else {
// Servers look them up based on the identity in the extension
c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK)
}
default:
logf(logTypeHandshake, "%s Unknown actionuction type", label)
return AlertInternalError
}
return AlertNoAlert
}
func (c *Conn) HandshakeSetup() Alert {
var state HandshakeState
var actions []HandshakeAction
var alert Alert
if err := c.config.Init(c.isClient); err != nil {
logf(logTypeHandshake, "Error initializing config: %v", err)
return AlertInternalError
}
// Set things up
caps := Capabilities{
CipherSuites: c.config.CipherSuites,
Groups: c.config.Groups,
SignatureSchemes: c.config.SignatureSchemes,
PSKs: c.config.PSKs,
PSKModes: c.config.PSKModes,
AllowEarlyData: c.config.AllowEarlyData,
RequireCookie: c.config.RequireCookie,
CookieHandler: c.config.CookieHandler,
RequireClientAuth: c.config.RequireClientAuth,
NextProtos: c.config.NextProtos,
Certificates: c.config.Certificates,
ExtensionHandler: c.extHandler,
}
opts := ConnectionOptions{
ServerName: c.config.ServerName,
NextProtos: c.config.NextProtos,
EarlyData: c.EarlyData,
}
if caps.RequireCookie && caps.CookieHandler == nil {
caps.CookieHandler = &defaultCookieHandler{}
}
if c.isClient {
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error initializing client state: %v", alert)
return alert
}
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
return alert
}
}
} else {
state = ServerStateStart{Caps: caps, conn: c}
}
c.hState = state
return AlertNoAlert
}
// Handshake causes a TLS handshake on the connection. The `isClient` member
// determines whether a client or server handshake is performed. If a
// handshake has already been performed, then its result will be returned.
func (c *Conn) Handshake() Alert {
label := "[server]"
if c.isClient {
label = "[client]"
}
// TODO Lock handshakeMutex
// TODO Remove CloseNotify hack
if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify {
logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert)
return c.handshakeAlert
}
if c.handshakeComplete {
return AlertNoAlert
}
var alert Alert
if c.hState == nil {
logf(logTypeHandshake, "%s First time through handshake, setting up", label)
alert = c.HandshakeSetup()
if alert != AlertNoAlert {
return alert
}
} else {
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState)
}
state := c.hState
_, connected := state.(StateConnected)
var actions []HandshakeAction
for !connected {
// Read a handshake message
hm, err := c.hIn.ReadMessage()
if err == WouldBlock {
logf(logTypeHandshake, "%s Would block reading message: %v", label, err)
return AlertWouldBlock
}
if err != nil {
logf(logTypeHandshake, "%s Error reading message: %v", label, err)
c.sendAlert(AlertCloseNotify)
return AlertCloseNotify
}
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
// Advance the state machine
state, actions, alert = state.Next(hm)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error in state transition: %v", alert)
return alert
}
for index, action := range actions {
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return alert
}
}
c.hState = state
logf(logTypeHandshake, "state is now %s", c.GetHsState())
_, connected = state.(StateConnected)
}
c.state = state.(StateConnected)
// Send NewSessionTicket if acting as server
if !c.isClient && c.config.SendSessionTickets {
actions, alert := c.state.NewSessionTicket(
c.config.TicketLen,
c.config.TicketLifetime,
c.config.EarlyDataLifetime)
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return alert
}
}
}
c.handshakeComplete = true
return AlertNoAlert
}
func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
if !c.handshakeComplete {
return fmt.Errorf("Cannot update keys until after handshake")
}
request := KeyUpdateNotRequested
if requestUpdate {
request = KeyUpdateRequested
}
// Create the key update and update state
actions, alert := c.state.KeyUpdate(request)
if alert != AlertNoAlert {
c.sendAlert(alert)
return fmt.Errorf("Alert while generating key update: %v", alert)
}
// Take actions (send key update and rekey)
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
c.sendAlert(alert)
return fmt.Errorf("Alert during key update actions: %v", alert)
}
}
return nil
}
func (c *Conn) GetHsState() string {
return reflect.TypeOf(c.hState).Name()
}
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
_, connected := c.hState.(StateConnected)
if !connected {
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
}
if c.state.exporterSecret == nil {
return nil, fmt.Errorf("Internal error: no exporter secret")
}
h0 := c.state.cryptoParams.Hash.New().Sum(nil)
tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0)
hc := c.state.cryptoParams.Hash.New().Sum(context)
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
}
func (c *Conn) State() ConnectionState {
state := ConnectionState{
HandshakeState: c.GetHsState(),
}
if c.handshakeComplete {
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
state.NextProto = c.state.Params.NextProto
}
return state
}
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error {
if c.hState != nil {
return fmt.Errorf("Can't set extension handler after setup")
}
c.extHandler = h
return nil
}