0
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2024-12-30 22:34:15 -05:00
caddy/vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go

101 lines
2.2 KiB
Go
Raw Normal View History

package handshake
import (
"encoding/asn1"
"fmt"
"net"
"time"
"github.com/lucas-clemente/quic-go/crypto"
)
const (
stkPrefixIP byte = iota
stkPrefixString
)
// An STK is a source address token
type STK struct {
RemoteAddr string
SentTime time.Time
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
Data []byte
Timestamp int64
}
// An STKGenerator generates STKs
type STKGenerator struct {
stkSource crypto.StkSource
}
// NewSTKGenerator initializes a new STKGenerator
func NewSTKGenerator() (*STKGenerator, error) {
stkSource, err := crypto.NewStkSource()
if err != nil {
return nil, err
}
return &STKGenerator{
stkSource: stkSource,
}, nil
}
// NewToken generates a new STK token for a given source address
func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) {
data, err := asn1.Marshal(token{
Data: encodeRemoteAddr(raddr),
Timestamp: time.Now().Unix(),
})
if err != nil {
return nil, err
}
return g.stkSource.NewToken(data)
}
// DecodeToken decodes an STK token
func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) {
// if the client didn't send any STK, DecodeToken will be called with a nil-slice
if len(encrypted) == 0 {
return nil, nil
}
data, err := g.stkSource.DecodeToken(encrypted)
if err != nil {
return nil, err
}
t := &token{}
rest, err := asn1.Unmarshal(data, t)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
}
return &STK{
RemoteAddr: decodeRemoteAddr(t.Data),
SentTime: time.Unix(t.Timestamp, 0),
}, nil
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
return append([]byte{stkPrefixIP}, udpAddr.IP...)
}
return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...)
}
// decodeRemoteAddr decodes the remote address saved in the STK
func decodeRemoteAddr(data []byte) string {
// data will never be empty for an STK that we generated. Check it to be on the safe side
if len(data) == 0 {
return ""
}
if data[0] == stkPrefixIP {
return net.IP(data[1:]).String()
}
return string(data[1:])
}