package mint import ( "fmt" "io" "net" ) const ( handshakeHeaderLen = 4 // handshake message header length maxHandshakeMessageLen = 1 << 24 // max handshake message length ) // struct { // HandshakeType msg_type; /* handshake type */ // uint24 length; /* bytes in message */ // select (HandshakeType) { // ... // } body; // } Handshake; // // We do the select{...} part in a different layer, so we treat the // actual message body as opaque: // // struct { // HandshakeType msg_type; // opaque msg<0..2^24-1> // } Handshake; // // TODO: File a spec bug type HandshakeMessage struct { // Omitted: length msgType HandshakeType body []byte } // Note: This could be done with the `syntax` module, using the simplified // syntax as discussed above. However, since this is so simple, there's not // much benefit to doing so. func (hm *HandshakeMessage) Marshal() []byte { if hm == nil { return []byte{} } msgLen := len(hm.body) data := make([]byte, 4+len(hm.body)) data[0] = byte(hm.msgType) data[1] = byte(msgLen >> 16) data[2] = byte(msgLen >> 8) data[3] = byte(msgLen) copy(data[4:], hm.body) return data } func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) { logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body) var body HandshakeMessageBody switch hm.msgType { case HandshakeTypeClientHello: body = new(ClientHelloBody) case HandshakeTypeServerHello: body = new(ServerHelloBody) case HandshakeTypeHelloRetryRequest: body = new(HelloRetryRequestBody) case HandshakeTypeEncryptedExtensions: body = new(EncryptedExtensionsBody) case HandshakeTypeCertificate: body = new(CertificateBody) case HandshakeTypeCertificateRequest: body = new(CertificateRequestBody) case HandshakeTypeCertificateVerify: body = new(CertificateVerifyBody) case HandshakeTypeFinished: body = &FinishedBody{VerifyDataLen: len(hm.body)} case HandshakeTypeNewSessionTicket: body = new(NewSessionTicketBody) case HandshakeTypeKeyUpdate: body = new(KeyUpdateBody) case HandshakeTypeEndOfEarlyData: body = new(EndOfEarlyDataBody) default: return body, fmt.Errorf("tls.handshakemessage: Unsupported body type") } _, err := body.Unmarshal(hm.body) return body, err } func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) { data, err := body.Marshal() if err != nil { return nil, err } return &HandshakeMessage{ msgType: body.Type(), body: data, }, nil } type HandshakeLayer struct { nonblocking bool // Should we operate in nonblocking mode conn *RecordLayer // Used for reading/writing records frame *frameReader // The buffered frame reader } type handshakeLayerFrameDetails struct{} func (d handshakeLayerFrameDetails) headerLen() int { return handshakeHeaderLen } func (d handshakeLayerFrameDetails) defaultReadLen() int { return handshakeHeaderLen + maxFragmentLen } func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { logf(logTypeIO, "Header=%x", hdr) return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil } func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer { h := HandshakeLayer{} h.conn = r h.frame = newFrameReader(&handshakeLayerFrameDetails{}) return &h } func (h *HandshakeLayer) readRecord() error { logf(logTypeIO, "Trying to read record") pt, err := h.conn.ReadRecord() if err != nil { return err } if pt.contentType != RecordTypeHandshake && pt.contentType != RecordTypeAlert { return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType) } if pt.contentType == RecordTypeAlert { logf(logTypeIO, "read alert %v", pt.fragment[1]) if len(pt.fragment) < 2 { h.sendAlert(AlertUnexpectedMessage) return io.EOF } return Alert(pt.fragment[1]) } logf(logTypeIO, "read handshake record of len %v", len(pt.fragment)) h.frame.addChunk(pt.fragment) return nil } // sendAlert sends a TLS alert message. func (h *HandshakeLayer) sendAlert(err Alert) error { tmp := make([]byte, 2) tmp[0] = AlertLevelError tmp[1] = byte(err) h.conn.WriteRecord(&TLSPlaintext{ contentType: RecordTypeAlert, fragment: tmp}, ) // closeNotify is a special case in that it isn't an error: if err != AlertCloseNotify { return &net.OpError{Op: "local error", Err: err} } return nil } func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { var hdr, body []byte var err error for { logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder)) if h.frame.needed() > 0 { logf(logTypeHandshake, "Trying to read a new record") err = h.readRecord() } if err != nil && (h.nonblocking || err != WouldBlock) { return nil, err } hdr, body, err = h.frame.process() if err == nil { break } if err != nil && (h.nonblocking || err != WouldBlock) { return nil, err } } logf(logTypeHandshake, "read handshake message") hm := &HandshakeMessage{} hm.msgType = HandshakeType(hdr[0]) hm.body = make([]byte, len(body)) copy(hm.body, body) return hm, nil } func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error { return h.WriteMessages([]*HandshakeMessage{hm}) } func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error { for _, hm := range hms { logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body) } // Write out headers and bodies buffer := []byte{} for _, msg := range hms { msgLen := len(msg.body) if msgLen > maxHandshakeMessageLen { return fmt.Errorf("tls.handshakelayer: Message too large to send") } buffer = append(buffer, msg.Marshal()...) } // Send full-size fragments var start int for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen { err := h.conn.WriteRecord(&TLSPlaintext{ contentType: RecordTypeHandshake, fragment: buffer[start : start+maxFragmentLen], }) if err != nil { return err } } // Send a final partial fragment if necessary if start < len(buffer) { err := h.conn.WriteRecord(&TLSPlaintext{ contentType: RecordTypeHandshake, fragment: buffer[start:], }) if err != nil { return err } } return nil }