mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-30 22:34:15 -05:00
254 lines
6.1 KiB
Go
254 lines
6.1 KiB
Go
|
package mint
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
handshakeHeaderLen = 4 // handshake message header length
|
||
|
maxHandshakeMessageLen = 1 << 24 // max handshake message length
|
||
|
)
|
||
|
|
||
|
// struct {
|
||
|
// HandshakeType msg_type; /* handshake type */
|
||
|
// uint24 length; /* bytes in message */
|
||
|
// select (HandshakeType) {
|
||
|
// ...
|
||
|
// } body;
|
||
|
// } Handshake;
|
||
|
//
|
||
|
// We do the select{...} part in a different layer, so we treat the
|
||
|
// actual message body as opaque:
|
||
|
//
|
||
|
// struct {
|
||
|
// HandshakeType msg_type;
|
||
|
// opaque msg<0..2^24-1>
|
||
|
// } Handshake;
|
||
|
//
|
||
|
// TODO: File a spec bug
|
||
|
type HandshakeMessage struct {
|
||
|
// Omitted: length
|
||
|
msgType HandshakeType
|
||
|
body []byte
|
||
|
}
|
||
|
|
||
|
// Note: This could be done with the `syntax` module, using the simplified
|
||
|
// syntax as discussed above. However, since this is so simple, there's not
|
||
|
// much benefit to doing so.
|
||
|
func (hm *HandshakeMessage) Marshal() []byte {
|
||
|
if hm == nil {
|
||
|
return []byte{}
|
||
|
}
|
||
|
|
||
|
msgLen := len(hm.body)
|
||
|
data := make([]byte, 4+len(hm.body))
|
||
|
data[0] = byte(hm.msgType)
|
||
|
data[1] = byte(msgLen >> 16)
|
||
|
data[2] = byte(msgLen >> 8)
|
||
|
data[3] = byte(msgLen)
|
||
|
copy(data[4:], hm.body)
|
||
|
return data
|
||
|
}
|
||
|
|
||
|
func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
|
||
|
logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body)
|
||
|
|
||
|
var body HandshakeMessageBody
|
||
|
switch hm.msgType {
|
||
|
case HandshakeTypeClientHello:
|
||
|
body = new(ClientHelloBody)
|
||
|
case HandshakeTypeServerHello:
|
||
|
body = new(ServerHelloBody)
|
||
|
case HandshakeTypeHelloRetryRequest:
|
||
|
body = new(HelloRetryRequestBody)
|
||
|
case HandshakeTypeEncryptedExtensions:
|
||
|
body = new(EncryptedExtensionsBody)
|
||
|
case HandshakeTypeCertificate:
|
||
|
body = new(CertificateBody)
|
||
|
case HandshakeTypeCertificateRequest:
|
||
|
body = new(CertificateRequestBody)
|
||
|
case HandshakeTypeCertificateVerify:
|
||
|
body = new(CertificateVerifyBody)
|
||
|
case HandshakeTypeFinished:
|
||
|
body = &FinishedBody{VerifyDataLen: len(hm.body)}
|
||
|
case HandshakeTypeNewSessionTicket:
|
||
|
body = new(NewSessionTicketBody)
|
||
|
case HandshakeTypeKeyUpdate:
|
||
|
body = new(KeyUpdateBody)
|
||
|
case HandshakeTypeEndOfEarlyData:
|
||
|
body = new(EndOfEarlyDataBody)
|
||
|
default:
|
||
|
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
|
||
|
}
|
||
|
|
||
|
_, err := body.Unmarshal(hm.body)
|
||
|
return body, err
|
||
|
}
|
||
|
|
||
|
func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
|
||
|
data, err := body.Marshal()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return &HandshakeMessage{
|
||
|
msgType: body.Type(),
|
||
|
body: data,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
type HandshakeLayer struct {
|
||
|
nonblocking bool // Should we operate in nonblocking mode
|
||
|
conn *RecordLayer // Used for reading/writing records
|
||
|
frame *frameReader // The buffered frame reader
|
||
|
}
|
||
|
|
||
|
type handshakeLayerFrameDetails struct{}
|
||
|
|
||
|
func (d handshakeLayerFrameDetails) headerLen() int {
|
||
|
return handshakeHeaderLen
|
||
|
}
|
||
|
|
||
|
func (d handshakeLayerFrameDetails) defaultReadLen() int {
|
||
|
return handshakeHeaderLen + maxFragmentLen
|
||
|
}
|
||
|
|
||
|
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||
|
logf(logTypeIO, "Header=%x", hdr)
|
||
|
return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil
|
||
|
}
|
||
|
|
||
|
func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer {
|
||
|
h := HandshakeLayer{}
|
||
|
h.conn = r
|
||
|
h.frame = newFrameReader(&handshakeLayerFrameDetails{})
|
||
|
return &h
|
||
|
}
|
||
|
|
||
|
func (h *HandshakeLayer) readRecord() error {
|
||
|
logf(logTypeIO, "Trying to read record")
|
||
|
pt, err := h.conn.ReadRecord()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if pt.contentType != RecordTypeHandshake &&
|
||
|
pt.contentType != RecordTypeAlert {
|
||
|
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
|
||
|
}
|
||
|
|
||
|
if pt.contentType == RecordTypeAlert {
|
||
|
logf(logTypeIO, "read alert %v", pt.fragment[1])
|
||
|
if len(pt.fragment) < 2 {
|
||
|
h.sendAlert(AlertUnexpectedMessage)
|
||
|
return io.EOF
|
||
|
}
|
||
|
return Alert(pt.fragment[1])
|
||
|
}
|
||
|
|
||
|
logf(logTypeIO, "read handshake record of len %v", len(pt.fragment))
|
||
|
h.frame.addChunk(pt.fragment)
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// sendAlert sends a TLS alert message.
|
||
|
func (h *HandshakeLayer) sendAlert(err Alert) error {
|
||
|
tmp := make([]byte, 2)
|
||
|
tmp[0] = AlertLevelError
|
||
|
tmp[1] = byte(err)
|
||
|
h.conn.WriteRecord(&TLSPlaintext{
|
||
|
contentType: RecordTypeAlert,
|
||
|
fragment: tmp},
|
||
|
)
|
||
|
|
||
|
// closeNotify is a special case in that it isn't an error:
|
||
|
if err != AlertCloseNotify {
|
||
|
return &net.OpError{Op: "local error", Err: err}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
|
||
|
var hdr, body []byte
|
||
|
var err error
|
||
|
|
||
|
for {
|
||
|
logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder))
|
||
|
if h.frame.needed() > 0 {
|
||
|
logf(logTypeHandshake, "Trying to read a new record")
|
||
|
err = h.readRecord()
|
||
|
}
|
||
|
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
hdr, body, err = h.frame.process()
|
||
|
if err == nil {
|
||
|
break
|
||
|
}
|
||
|
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
logf(logTypeHandshake, "read handshake message")
|
||
|
|
||
|
hm := &HandshakeMessage{}
|
||
|
hm.msgType = HandshakeType(hdr[0])
|
||
|
|
||
|
hm.body = make([]byte, len(body))
|
||
|
copy(hm.body, body)
|
||
|
|
||
|
return hm, nil
|
||
|
}
|
||
|
|
||
|
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error {
|
||
|
return h.WriteMessages([]*HandshakeMessage{hm})
|
||
|
}
|
||
|
|
||
|
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error {
|
||
|
for _, hm := range hms {
|
||
|
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
|
||
|
}
|
||
|
|
||
|
// Write out headers and bodies
|
||
|
buffer := []byte{}
|
||
|
for _, msg := range hms {
|
||
|
msgLen := len(msg.body)
|
||
|
if msgLen > maxHandshakeMessageLen {
|
||
|
return fmt.Errorf("tls.handshakelayer: Message too large to send")
|
||
|
}
|
||
|
|
||
|
buffer = append(buffer, msg.Marshal()...)
|
||
|
}
|
||
|
|
||
|
// Send full-size fragments
|
||
|
var start int
|
||
|
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
||
|
err := h.conn.WriteRecord(&TLSPlaintext{
|
||
|
contentType: RecordTypeHandshake,
|
||
|
fragment: buffer[start : start+maxFragmentLen],
|
||
|
})
|
||
|
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Send a final partial fragment if necessary
|
||
|
if start < len(buffer) {
|
||
|
err := h.conn.WriteRecord(&TLSPlaintext{
|
||
|
contentType: RecordTypeHandshake,
|
||
|
fragment: buffer[start:],
|
||
|
})
|
||
|
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|