2018-02-17 13:29:53 +08:00
package wire
import (
"bytes"
2018-09-03 04:18:54 +07:00
"crypto/rand"
"errors"
"fmt"
2018-02-17 13:29:53 +08:00
"github.com/lucas-clemente/quic-go/internal/protocol"
2018-04-18 15:48:08 -06:00
"github.com/lucas-clemente/quic-go/internal/utils"
2018-02-17 13:29:53 +08:00
)
// Header is the header of a QUIC packet.
// It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header.
type Header struct {
2018-09-03 04:18:54 +07:00
IsPublicHeader bool
Raw [ ] byte
Version protocol . VersionNumber
DestConnectionID protocol . ConnectionID
SrcConnectionID protocol . ConnectionID
OrigDestConnectionID protocol . ConnectionID // only needed in the Retry packet
PacketNumberLen protocol . PacketNumberLen
PacketNumber protocol . PacketNumber
2018-02-17 13:29:53 +08:00
IsVersionNegotiation bool
SupportedVersions [ ] protocol . VersionNumber // Version Number sent in a Version Negotiation Packet by the server
// only needed for the gQUIC Public Header
VersionFlag bool
ResetFlag bool
DiversificationNonce [ ] byte
// only needed for the IETF Header
Type protocol . PacketType
IsLongHeader bool
KeyPhase int
2018-09-03 04:18:54 +07:00
PayloadLen protocol . ByteCount
Token [ ] byte
}
2018-02-17 13:29:53 +08:00
2018-09-03 04:18:54 +07:00
var errInvalidPacketNumberLen = errors . New ( "invalid packet number length" )
// Write writes the Header.
func ( h * Header ) Write ( b * bytes . Buffer , pers protocol . Perspective , ver protocol . VersionNumber ) error {
if ! ver . UsesIETFHeaderFormat ( ) {
h . IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
return h . writePublicHeader ( b , pers , ver )
}
// write an IETF QUIC header
if h . IsLongHeader {
return h . writeLongHeader ( b , ver )
}
return h . writeShortHeader ( b , ver )
2018-02-17 13:29:53 +08:00
}
2018-09-03 04:18:54 +07:00
// TODO: add support for the key phase
func ( h * Header ) writeLongHeader ( b * bytes . Buffer , v protocol . VersionNumber ) error {
b . WriteByte ( byte ( 0x80 | h . Type ) )
utils . BigEndian . WriteUint32 ( b , uint32 ( h . Version ) )
connIDLen , err := encodeConnIDLen ( h . DestConnectionID , h . SrcConnectionID )
2018-02-17 13:29:53 +08:00
if err != nil {
2018-09-03 04:18:54 +07:00
return err
2018-02-17 13:29:53 +08:00
}
2018-09-03 04:18:54 +07:00
b . WriteByte ( connIDLen )
b . Write ( h . DestConnectionID . Bytes ( ) )
b . Write ( h . SrcConnectionID . Bytes ( ) )
2018-02-17 13:29:53 +08:00
2018-09-03 04:18:54 +07:00
if h . Type == protocol . PacketTypeInitial && v . UsesTokenInHeader ( ) {
utils . WriteVarInt ( b , uint64 ( len ( h . Token ) ) )
b . Write ( h . Token )
2018-02-17 13:29:53 +08:00
}
2018-03-25 22:37:41 -06:00
2018-09-03 04:18:54 +07:00
if h . Type == protocol . PacketTypeRetry {
odcil , err := encodeSingleConnIDLen ( h . OrigDestConnectionID )
if err != nil {
return err
}
// randomize the first 4 bits
odcilByte := make ( [ ] byte , 1 )
_ , _ = rand . Read ( odcilByte ) // it's safe to ignore the error here
odcilByte [ 0 ] = ( odcilByte [ 0 ] & 0xf0 ) | odcil
b . Write ( odcilByte )
b . Write ( h . OrigDestConnectionID . Bytes ( ) )
b . Write ( h . Token )
return nil
}
2018-02-17 13:29:53 +08:00
2018-09-03 04:18:54 +07:00
if v . UsesLengthInHeader ( ) {
utils . WriteVarInt ( b , uint64 ( h . PayloadLen ) )
}
if v . UsesVarintPacketNumbers ( ) {
return utils . WriteVarIntPacketNumber ( b , h . PacketNumber , h . PacketNumberLen )
}
utils . BigEndian . WriteUint32 ( b , uint32 ( h . PacketNumber ) )
if h . Type == protocol . PacketType0RTT && v == protocol . Version44 {
if len ( h . DiversificationNonce ) != 32 {
return errors . New ( "invalid diversification nonce length" )
}
b . Write ( h . DiversificationNonce )
}
return nil
2018-02-17 13:29:53 +08:00
}
2018-09-03 04:18:54 +07:00
func ( h * Header ) writeShortHeader ( b * bytes . Buffer , v protocol . VersionNumber ) error {
typeByte := byte ( 0x30 )
typeByte |= byte ( h . KeyPhase << 6 )
if ! v . UsesVarintPacketNumbers ( ) {
switch h . PacketNumberLen {
case protocol . PacketNumberLen1 :
case protocol . PacketNumberLen2 :
typeByte |= 0x1
case protocol . PacketNumberLen4 :
typeByte |= 0x2
default :
return errInvalidPacketNumberLen
}
}
b . WriteByte ( typeByte )
b . Write ( h . DestConnectionID . Bytes ( ) )
if ! v . UsesVarintPacketNumbers ( ) {
switch h . PacketNumberLen {
case protocol . PacketNumberLen1 :
b . WriteByte ( uint8 ( h . PacketNumber ) )
case protocol . PacketNumberLen2 :
utils . BigEndian . WriteUint16 ( b , uint16 ( h . PacketNumber ) )
case protocol . PacketNumberLen4 :
utils . BigEndian . WriteUint32 ( b , uint32 ( h . PacketNumber ) )
2018-02-17 13:29:53 +08:00
}
2018-09-03 04:18:54 +07:00
return nil
2018-02-17 13:29:53 +08:00
}
2018-09-03 04:18:54 +07:00
return utils . WriteVarIntPacketNumber ( b , h . PacketNumber , h . PacketNumberLen )
2018-02-17 13:29:53 +08:00
}
2018-09-03 04:18:54 +07:00
// writePublicHeader writes a Public Header.
func ( h * Header ) writePublicHeader ( b * bytes . Buffer , pers protocol . Perspective , _ protocol . VersionNumber ) error {
if h . ResetFlag || ( h . VersionFlag && pers == protocol . PerspectiveServer ) {
return errors . New ( "PublicHeader: Can only write regular packets" )
}
if h . SrcConnectionID . Len ( ) != 0 {
return errors . New ( "PublicHeader: SrcConnectionID must not be set" )
}
if len ( h . DestConnectionID ) != 0 && len ( h . DestConnectionID ) != 8 {
return fmt . Errorf ( "PublicHeader: wrong length for Connection ID: %d (expected 8)" , len ( h . DestConnectionID ) )
}
publicFlagByte := uint8 ( 0x00 )
if h . VersionFlag {
publicFlagByte |= 0x01
}
if h . DestConnectionID . Len ( ) > 0 {
publicFlagByte |= 0x08
}
if len ( h . DiversificationNonce ) > 0 {
if len ( h . DiversificationNonce ) != 32 {
return errors . New ( "invalid diversification nonce length" )
}
publicFlagByte |= 0x04
}
switch h . PacketNumberLen {
case protocol . PacketNumberLen1 :
publicFlagByte |= 0x00
case protocol . PacketNumberLen2 :
publicFlagByte |= 0x10
case protocol . PacketNumberLen4 :
publicFlagByte |= 0x20
}
b . WriteByte ( publicFlagByte )
if h . DestConnectionID . Len ( ) > 0 {
b . Write ( h . DestConnectionID )
}
if h . VersionFlag && pers == protocol . PerspectiveClient {
utils . BigEndian . WriteUint32 ( b , uint32 ( h . Version ) )
}
if len ( h . DiversificationNonce ) > 0 {
b . Write ( h . DiversificationNonce )
}
switch h . PacketNumberLen {
case protocol . PacketNumberLen1 :
b . WriteByte ( uint8 ( h . PacketNumber ) )
case protocol . PacketNumberLen2 :
utils . BigEndian . WriteUint16 ( b , uint16 ( h . PacketNumber ) )
case protocol . PacketNumberLen4 :
utils . BigEndian . WriteUint32 ( b , uint32 ( h . PacketNumber ) )
case protocol . PacketNumberLen6 :
return errInvalidPacketNumberLen
default :
return errors . New ( "PublicHeader: PacketNumberLen not set" )
2018-02-17 13:29:53 +08:00
}
2018-09-03 04:18:54 +07:00
return nil
2018-02-17 13:29:53 +08:00
}
// GetLength determines the length of the Header.
2018-09-03 04:18:54 +07:00
func ( h * Header ) GetLength ( v protocol . VersionNumber ) ( protocol . ByteCount , error ) {
if ! v . UsesIETFHeaderFormat ( ) {
return h . getPublicHeaderLength ( )
}
return h . getHeaderLength ( v )
}
func ( h * Header ) getHeaderLength ( v protocol . VersionNumber ) ( protocol . ByteCount , error ) {
if h . IsLongHeader {
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol . ByteCount ( h . DestConnectionID . Len ( ) + h . SrcConnectionID . Len ( ) ) + protocol . ByteCount ( h . PacketNumberLen )
if v . UsesLengthInHeader ( ) {
length += utils . VarIntLen ( uint64 ( h . PayloadLen ) )
}
if h . Type == protocol . PacketTypeInitial && v . UsesTokenInHeader ( ) {
length += utils . VarIntLen ( uint64 ( len ( h . Token ) ) ) + protocol . ByteCount ( len ( h . Token ) )
}
if h . Type == protocol . PacketType0RTT && v == protocol . Version44 {
length += protocol . ByteCount ( len ( h . DiversificationNonce ) )
}
return length , nil
}
length := protocol . ByteCount ( 1 /* type byte */ + h . DestConnectionID . Len ( ) )
if h . PacketNumberLen != protocol . PacketNumberLen1 && h . PacketNumberLen != protocol . PacketNumberLen2 && h . PacketNumberLen != protocol . PacketNumberLen4 {
return 0 , fmt . Errorf ( "invalid packet number length: %d" , h . PacketNumberLen )
2018-02-17 13:29:53 +08:00
}
2018-09-03 04:18:54 +07:00
length += protocol . ByteCount ( h . PacketNumberLen )
return length , nil
}
// getPublicHeaderLength gets the length of the publicHeader in bytes.
// It can only be called for regular packets.
func ( h * Header ) getPublicHeaderLength ( ) ( protocol . ByteCount , error ) {
length := protocol . ByteCount ( 1 ) // 1 byte for public flags
if h . PacketNumberLen == protocol . PacketNumberLen6 {
return 0 , errInvalidPacketNumberLen
}
if h . PacketNumberLen != protocol . PacketNumberLen1 && h . PacketNumberLen != protocol . PacketNumberLen2 && h . PacketNumberLen != protocol . PacketNumberLen4 {
return 0 , errPacketNumberLenNotSet
}
length += protocol . ByteCount ( h . PacketNumberLen )
length += protocol . ByteCount ( h . DestConnectionID . Len ( ) )
// Version Number in packets sent by the client
if h . VersionFlag {
length += 4
}
length += protocol . ByteCount ( len ( h . DiversificationNonce ) )
return length , nil
2018-02-17 13:29:53 +08:00
}
// Log logs the Header
2018-04-18 15:48:08 -06:00
func ( h * Header ) Log ( logger utils . Logger ) {
2018-09-03 04:18:54 +07:00
if h . IsPublicHeader {
2018-04-18 15:48:08 -06:00
h . logPublicHeader ( logger )
2018-02-17 13:29:53 +08:00
} else {
2018-04-18 15:48:08 -06:00
h . logHeader ( logger )
2018-02-17 13:29:53 +08:00
}
}
2018-09-03 04:18:54 +07:00
func ( h * Header ) logHeader ( logger utils . Logger ) {
if h . IsLongHeader {
if h . Version == 0 {
logger . Debugf ( "\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}" , h . DestConnectionID , h . SrcConnectionID , h . SupportedVersions )
} else {
var token string
if h . Type == protocol . PacketTypeInitial || h . Type == protocol . PacketTypeRetry {
if len ( h . Token ) == 0 {
token = "Token: (empty), "
} else {
token = fmt . Sprintf ( "Token: %#x, " , h . Token )
}
}
if h . Type == protocol . PacketTypeRetry {
logger . Debugf ( "\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}" , h . Type , h . DestConnectionID , h . SrcConnectionID , token , h . OrigDestConnectionID , h . Version )
return
}
if h . Version == protocol . Version44 {
var divNonce string
if h . Type == protocol . PacketType0RTT {
divNonce = fmt . Sprintf ( "Diversification Nonce: %#x, " , h . DiversificationNonce )
}
logger . Debugf ( "\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, %sVersion: %s}" , h . Type , h . DestConnectionID , h . SrcConnectionID , h . PacketNumber , h . PacketNumberLen , divNonce , h . Version )
return
}
logger . Debugf ( "\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}" , h . Type , h . DestConnectionID , h . SrcConnectionID , token , h . PacketNumber , h . PacketNumberLen , h . PayloadLen , h . Version )
}
} else {
logger . Debugf ( "\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}" , h . DestConnectionID , h . PacketNumber , h . PacketNumberLen , h . KeyPhase )
}
}
func ( h * Header ) logPublicHeader ( logger utils . Logger ) {
ver := "(unset)"
if h . Version != 0 {
ver = h . Version . String ( )
}
logger . Debugf ( "\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}" , h . DestConnectionID , h . PacketNumber , h . PacketNumberLen , ver , h . DiversificationNonce )
}
func encodeConnIDLen ( dest , src protocol . ConnectionID ) ( byte , error ) {
dcil , err := encodeSingleConnIDLen ( dest )
if err != nil {
return 0 , err
}
scil , err := encodeSingleConnIDLen ( src )
if err != nil {
return 0 , err
}
return scil | dcil << 4 , nil
}
func encodeSingleConnIDLen ( id protocol . ConnectionID ) ( byte , error ) {
len := id . Len ( )
if len == 0 {
return 0 , nil
}
if len < 4 || len > 18 {
return 0 , fmt . Errorf ( "invalid connection ID length: %d bytes" , len )
}
return byte ( len - 3 ) , nil
}
func decodeConnIDLen ( enc byte ) ( int /*dest conn id len*/ , int /*src conn id len*/ ) {
return decodeSingleConnIDLen ( enc >> 4 ) , decodeSingleConnIDLen ( enc & 0xf )
}
func decodeSingleConnIDLen ( enc uint8 ) int {
if enc == 0 {
return 0
}
return int ( enc ) + 3
}