package mint import ( "bytes" "fmt" "github.com/bifurcation/mint/syntax" ) type ExtensionBody interface { Type() ExtensionType Marshal() ([]byte, error) Unmarshal(data []byte) (int, error) } // struct { // ExtensionType extension_type; // opaque extension_data<0..2^16-1>; // } Extension; type Extension struct { ExtensionType ExtensionType ExtensionData []byte `tls:"head=2"` } func (ext Extension) Marshal() ([]byte, error) { return syntax.Marshal(ext) } func (ext *Extension) Unmarshal(data []byte) (int, error) { return syntax.Unmarshal(data, ext) } type ExtensionList []Extension type extensionListInner struct { List []Extension `tls:"head=2"` } func (el ExtensionList) Marshal() ([]byte, error) { return syntax.Marshal(extensionListInner{el}) } func (el *ExtensionList) Unmarshal(data []byte) (int, error) { var list extensionListInner read, err := syntax.Unmarshal(data, &list) if err != nil { return 0, err } *el = list.List return read, nil } func (el *ExtensionList) Add(src ExtensionBody) error { data, err := src.Marshal() if err != nil { return err } if el == nil { el = new(ExtensionList) } // If one already exists with this type, replace it for i := range *el { if (*el)[i].ExtensionType == src.Type() { (*el)[i].ExtensionData = data return nil } } // Otherwise append *el = append(*el, Extension{ ExtensionType: src.Type(), ExtensionData: data, }) return nil } func (el ExtensionList) Find(dst ExtensionBody) bool { for _, ext := range el { if ext.ExtensionType == dst.Type() { _, err := dst.Unmarshal(ext.ExtensionData) return err == nil } } return false } // struct { // NameType name_type; // select (name_type) { // case host_name: HostName; // } name; // } ServerName; // // enum { // host_name(0), (255) // } NameType; // // opaque HostName<1..2^16-1>; // // struct { // ServerName server_name_list<1..2^16-1> // } ServerNameList; // // But we only care about the case where there's a single DNS hostname. We // will never create anything else, and throw if we receive something else // // 2 1 2 // | listLen | NameType | nameLen | name | type ServerNameExtension string type serverNameInner struct { NameType uint8 HostName []byte `tls:"head=2,min=1"` } type serverNameListInner struct { ServerNameList []serverNameInner `tls:"head=2,min=1"` } func (sni ServerNameExtension) Type() ExtensionType { return ExtensionTypeServerName } func (sni ServerNameExtension) Marshal() ([]byte, error) { list := serverNameListInner{ ServerNameList: []serverNameInner{{ NameType: 0x00, // host_name HostName: []byte(sni), }}, } return syntax.Marshal(list) } func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) { var list serverNameListInner read, err := syntax.Unmarshal(data, &list) if err != nil { return 0, err } // Syntax requires at least one entry // Entries beyond the first are ignored if nameType := list.ServerNameList[0].NameType; nameType != 0x00 { return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType) } *sni = ServerNameExtension(list.ServerNameList[0].HostName) return read, nil } // struct { // NamedGroup group; // opaque key_exchange<1..2^16-1>; // } KeyShareEntry; // // struct { // select (Handshake.msg_type) { // case client_hello: // KeyShareEntry client_shares<0..2^16-1>; // // case hello_retry_request: // NamedGroup selected_group; // // case server_hello: // KeyShareEntry server_share; // }; // } KeyShare; type KeyShareEntry struct { Group NamedGroup KeyExchange []byte `tls:"head=2,min=1"` } func (kse KeyShareEntry) SizeValid() bool { return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group) } type KeyShareExtension struct { HandshakeType HandshakeType SelectedGroup NamedGroup Shares []KeyShareEntry } type KeyShareClientHelloInner struct { ClientShares []KeyShareEntry `tls:"head=2,min=0"` } type KeyShareHelloRetryInner struct { SelectedGroup NamedGroup } type KeyShareServerHelloInner struct { ServerShare KeyShareEntry } func (ks KeyShareExtension) Type() ExtensionType { return ExtensionTypeKeyShare } func (ks KeyShareExtension) Marshal() ([]byte, error) { switch ks.HandshakeType { case HandshakeTypeClientHello: for _, share := range ks.Shares { if !share.SizeValid() { return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") } } return syntax.Marshal(KeyShareClientHelloInner{ks.Shares}) case HandshakeTypeHelloRetryRequest: if len(ks.Shares) > 0 { return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest") } return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup}) case HandshakeTypeServerHello: if len(ks.Shares) != 1 { return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share") } if !ks.Shares[0].SizeValid() { return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") } return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]}) default: return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed") } } func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) { switch ks.HandshakeType { case HandshakeTypeClientHello: var inner KeyShareClientHelloInner read, err := syntax.Unmarshal(data, &inner) if err != nil { return 0, err } for _, share := range inner.ClientShares { if !share.SizeValid() { return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") } } ks.Shares = inner.ClientShares return read, nil case HandshakeTypeHelloRetryRequest: var inner KeyShareHelloRetryInner read, err := syntax.Unmarshal(data, &inner) if err != nil { return 0, err } ks.SelectedGroup = inner.SelectedGroup return read, nil case HandshakeTypeServerHello: var inner KeyShareServerHelloInner read, err := syntax.Unmarshal(data, &inner) if err != nil { return 0, err } if !inner.ServerShare.SizeValid() { return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") } ks.Shares = []KeyShareEntry{inner.ServerShare} return read, nil default: return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed") } } // struct { // NamedGroup named_group_list<2..2^16-1>; // } NamedGroupList; type SupportedGroupsExtension struct { Groups []NamedGroup `tls:"head=2,min=2"` } func (sg SupportedGroupsExtension) Type() ExtensionType { return ExtensionTypeSupportedGroups } func (sg SupportedGroupsExtension) Marshal() ([]byte, error) { return syntax.Marshal(sg) } func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) { return syntax.Unmarshal(data, sg) } // struct { // SignatureScheme supported_signature_algorithms<2..2^16-2>; // } SignatureSchemeList type SignatureAlgorithmsExtension struct { Algorithms []SignatureScheme `tls:"head=2,min=2"` } func (sa SignatureAlgorithmsExtension) Type() ExtensionType { return ExtensionTypeSignatureAlgorithms } func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) { return syntax.Marshal(sa) } func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) { return syntax.Unmarshal(data, sa) } // struct { // opaque identity<1..2^16-1>; // uint32 obfuscated_ticket_age; // } PskIdentity; // // opaque PskBinderEntry<32..255>; // // struct { // select (Handshake.msg_type) { // case client_hello: // PskIdentity identities<7..2^16-1>; // PskBinderEntry binders<33..2^16-1>; // // case server_hello: // uint16 selected_identity; // }; // // } PreSharedKeyExtension; type PSKIdentity struct { Identity []byte `tls:"head=2,min=1"` ObfuscatedTicketAge uint32 } type PSKBinderEntry struct { Binder []byte `tls:"head=1,min=32"` } type PreSharedKeyExtension struct { HandshakeType HandshakeType Identities []PSKIdentity Binders []PSKBinderEntry SelectedIdentity uint16 } type preSharedKeyClientInner struct { Identities []PSKIdentity `tls:"head=2,min=7"` Binders []PSKBinderEntry `tls:"head=2,min=33"` } type preSharedKeyServerInner struct { SelectedIdentity uint16 } func (psk PreSharedKeyExtension) Type() ExtensionType { return ExtensionTypePreSharedKey } func (psk PreSharedKeyExtension) Marshal() ([]byte, error) { switch psk.HandshakeType { case HandshakeTypeClientHello: return syntax.Marshal(preSharedKeyClientInner{ Identities: psk.Identities, Binders: psk.Binders, }) case HandshakeTypeServerHello: if len(psk.Identities) > 0 || len(psk.Binders) > 0 { return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index") } return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity}) default: return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported") } } func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) { switch psk.HandshakeType { case HandshakeTypeClientHello: var inner preSharedKeyClientInner read, err := syntax.Unmarshal(data, &inner) if err != nil { return 0, err } if len(inner.Identities) != len(inner.Binders) { return 0, fmt.Errorf("Lengths of identities and binders not equal") } psk.Identities = inner.Identities psk.Binders = inner.Binders return read, nil case HandshakeTypeServerHello: var inner preSharedKeyServerInner read, err := syntax.Unmarshal(data, &inner) if err != nil { return 0, err } psk.SelectedIdentity = inner.SelectedIdentity return read, nil default: return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported") } } func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) { for i, localID := range psk.Identities { if bytes.Equal(localID.Identity, id) { return psk.Binders[i].Binder, true } } return nil, false } // enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode; // // struct { // PskKeyExchangeMode ke_modes<1..255>; // } PskKeyExchangeModes; type PSKKeyExchangeModesExtension struct { KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"` } func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType { return ExtensionTypePSKKeyExchangeModes } func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) { return syntax.Marshal(pkem) } func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) { return syntax.Unmarshal(data, pkem) } // struct { // } EarlyDataIndication; type EarlyDataExtension struct{} func (ed EarlyDataExtension) Type() ExtensionType { return ExtensionTypeEarlyData } func (ed EarlyDataExtension) Marshal() ([]byte, error) { return []byte{}, nil } func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) { return 0, nil } // struct { // uint32 max_early_data_size; // } TicketEarlyDataInfo; type TicketEarlyDataInfoExtension struct { MaxEarlyDataSize uint32 } func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType { return ExtensionTypeTicketEarlyDataInfo } func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) { return syntax.Marshal(tedi) } func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) { return syntax.Unmarshal(data, tedi) } // opaque ProtocolName<1..2^8-1>; // // struct { // ProtocolName protocol_name_list<2..2^16-1> // } ProtocolNameList; type ALPNExtension struct { Protocols []string } type protocolNameInner struct { Name []byte `tls:"head=1,min=1"` } type alpnExtensionInner struct { Protocols []protocolNameInner `tls:"head=2,min=2"` } func (alpn ALPNExtension) Type() ExtensionType { return ExtensionTypeALPN } func (alpn ALPNExtension) Marshal() ([]byte, error) { protocols := make([]protocolNameInner, len(alpn.Protocols)) for i, protocol := range alpn.Protocols { protocols[i] = protocolNameInner{[]byte(protocol)} } return syntax.Marshal(alpnExtensionInner{protocols}) } func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) { var inner alpnExtensionInner read, err := syntax.Unmarshal(data, &inner) if err != nil { return 0, err } alpn.Protocols = make([]string, len(inner.Protocols)) for i, protocol := range inner.Protocols { alpn.Protocols[i] = string(protocol.Name) } return read, nil } // struct { // ProtocolVersion versions<2..254>; // } SupportedVersions; type SupportedVersionsExtension struct { Versions []uint16 `tls:"head=1,min=2,max=254"` } func (sv SupportedVersionsExtension) Type() ExtensionType { return ExtensionTypeSupportedVersions } func (sv SupportedVersionsExtension) Marshal() ([]byte, error) { return syntax.Marshal(sv) } func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) { return syntax.Unmarshal(data, sv) } // struct { // opaque cookie<1..2^16-1>; // } Cookie; type CookieExtension struct { Cookie []byte `tls:"head=2,min=1"` } func (c CookieExtension) Type() ExtensionType { return ExtensionTypeCookie } func (c CookieExtension) Marshal() ([]byte, error) { return syntax.Marshal(c) } func (c *CookieExtension) Unmarshal(data []byte) (int, error) { return syntax.Unmarshal(data, c) } // defaultCookieLength is the default length of a cookie const defaultCookieLength = 32 type defaultCookieHandler struct { data []byte } var _ CookieHandler = &defaultCookieHandler{} // NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) { h.data = make([]byte, defaultCookieLength) if _, err := prng.Read(h.data); err != nil { return nil, err } return h.data, nil } func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool { return bytes.Equal(h.data, data) }