mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-30 22:34:15 -05:00
core: quic listener will manage the underlying socket by itself (#5749)
* core: quic listener will manage the underlying socket by itself. * format code * rename sharedQUICTLSConfig to sharedQUICState, and it will now manage the number of active requests * add comment * strict unwrap type * fix unwrap * remove comment
This commit is contained in:
parent
0900844c81
commit
7c82e265da
3 changed files with 127 additions and 71 deletions
164
listeners.go
164
listeners.go
|
@ -470,38 +470,90 @@ func ListenPacket(network, addr string) (net.PacketConn, error) {
|
||||||
// unixgram will be used; otherwise, udp will be used).
|
// unixgram will be used; otherwise, udp will be used).
|
||||||
//
|
//
|
||||||
// NOTE: This API is EXPERIMENTAL and may be changed or removed.
|
// NOTE: This API is EXPERIMENTAL and may be changed or removed.
|
||||||
//
|
func (na NetworkAddress) ListenQUIC(ctx context.Context, portOffset uint, config net.ListenConfig, tlsConf *tls.Config, activeRequests *int64) (http3.QUICEarlyListener, error) {
|
||||||
// TODO: See if we can find a more elegant solution closer to the new NetworkAddress.Listen API.
|
lnKey := listenerKey("quic"+na.Network, na.JoinHostPort(portOffset))
|
||||||
func ListenQUIC(ln net.PacketConn, tlsConf *tls.Config, activeRequests *int64) (http3.QUICEarlyListener, error) {
|
|
||||||
lnKey := listenerKey("quic+"+ln.LocalAddr().Network(), ln.LocalAddr().String())
|
|
||||||
|
|
||||||
sharedEarlyListener, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) {
|
sharedEarlyListener, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) {
|
||||||
sqtc := newSharedQUICTLSConfig(tlsConf)
|
lnAny, err := na.Listen(ctx, portOffset, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ln := lnAny.(net.PacketConn)
|
||||||
|
|
||||||
|
h3ln := ln
|
||||||
|
for {
|
||||||
|
// retrieve the underlying socket, so quic-go can optimize.
|
||||||
|
if unwrapper, ok := h3ln.(interface{ Unwrap() net.PacketConn }); ok {
|
||||||
|
h3ln = unwrapper.Unwrap()
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqs := newSharedQUICState(tlsConf, activeRequests)
|
||||||
// http3.ConfigureTLSConfig only uses this field and tls App sets this field as well
|
// http3.ConfigureTLSConfig only uses this field and tls App sets this field as well
|
||||||
//nolint:gosec
|
//nolint:gosec
|
||||||
quicTlsConfig := &tls.Config{GetConfigForClient: sqtc.getConfigForClient}
|
quicTlsConfig := &tls.Config{GetConfigForClient: sqs.getConfigForClient}
|
||||||
earlyLn, err := quic.ListenEarly(ln, http3.ConfigureTLSConfig(quicTlsConfig), &quic.Config{
|
earlyLn, err := quic.ListenEarly(h3ln, http3.ConfigureTLSConfig(quicTlsConfig), &quic.Config{
|
||||||
Allow0RTT: true,
|
Allow0RTT: true,
|
||||||
RequireAddressValidation: func(clientAddr net.Addr) bool {
|
RequireAddressValidation: func(clientAddr net.Addr) bool {
|
||||||
var highLoad bool
|
// TODO: make tunable?
|
||||||
if activeRequests != nil {
|
return sqs.getActiveRequests() > 1000
|
||||||
highLoad = atomic.LoadInt64(activeRequests) > 1000 // TODO: make tunable?
|
|
||||||
}
|
|
||||||
return highLoad
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &sharedQuicListener{EarlyListener: earlyLn, sqtc: sqtc, key: lnKey}, nil
|
// using the original net.PacketConn to close them properly
|
||||||
|
return &sharedQuicListener{EarlyListener: earlyLn, packetConn: ln, sqs: sqs, key: lnKey}, nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sql := sharedEarlyListener.(*sharedQuicListener)
|
sql := sharedEarlyListener.(*sharedQuicListener)
|
||||||
// add current tls.Config to sqtc, so GetConfigForClient will always return the latest tls.Config in case of context cancellation
|
// add current tls.Config to sqs, so GetConfigForClient will always return the latest tls.Config in case of context cancellation,
|
||||||
ctx, cancel := sql.sqtc.addTLSConfig(tlsConf)
|
// and the request counter will reflect current http server
|
||||||
|
ctx, cancel := sql.sqs.addState(tlsConf, activeRequests)
|
||||||
|
|
||||||
|
return &fakeCloseQuicListener{
|
||||||
|
sharedQuicListener: sql,
|
||||||
|
context: ctx,
|
||||||
|
contextCancel: cancel,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DEPRECATED: Use NetworkAddress.ListenQUIC instead. This function will likely be changed or removed in the future.
|
||||||
|
// TODO: See if we can find a more elegant solution closer to the new NetworkAddress.Listen API.
|
||||||
|
func ListenQUIC(ln net.PacketConn, tlsConf *tls.Config, activeRequests *int64) (http3.QUICEarlyListener, error) {
|
||||||
|
lnKey := listenerKey("quic+"+ln.LocalAddr().Network(), ln.LocalAddr().String())
|
||||||
|
|
||||||
|
sharedEarlyListener, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) {
|
||||||
|
sqs := newSharedQUICState(tlsConf, activeRequests)
|
||||||
|
// http3.ConfigureTLSConfig only uses this field and tls App sets this field as well
|
||||||
|
//nolint:gosec
|
||||||
|
quicTlsConfig := &tls.Config{GetConfigForClient: sqs.getConfigForClient}
|
||||||
|
earlyLn, err := quic.ListenEarly(ln, http3.ConfigureTLSConfig(quicTlsConfig), &quic.Config{
|
||||||
|
Allow0RTT: true,
|
||||||
|
RequireAddressValidation: func(clientAddr net.Addr) bool {
|
||||||
|
// TODO: make tunable?
|
||||||
|
return sqs.getActiveRequests() > 1000
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &sharedQuicListener{EarlyListener: earlyLn, sqs: sqs, key: lnKey}, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sql := sharedEarlyListener.(*sharedQuicListener)
|
||||||
|
// add current tls.Config and request counter to sqs, so GetConfigForClient will always return the latest tls.Config in case of context cancellation,
|
||||||
|
// and the request counter will reflect current http server
|
||||||
|
ctx, cancel := sql.sqs.addState(tlsConf, activeRequests)
|
||||||
|
|
||||||
// TODO: to serve QUIC over a unix socket, currently we need to hold onto
|
// TODO: to serve QUIC over a unix socket, currently we need to hold onto
|
||||||
// the underlying net.PacketConn (which we wrap as unixConn to keep count
|
// the underlying net.PacketConn (which we wrap as unixConn to keep count
|
||||||
|
@ -534,38 +586,50 @@ type contextAndCancelFunc struct {
|
||||||
context.CancelFunc
|
context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// sharedQUICTLSConfig manages GetConfigForClient
|
// sharedQUICState manages GetConfigForClient and current number of active requests
|
||||||
// see issue: https://github.com/caddyserver/caddy/pull/4849
|
// see issue: https://github.com/caddyserver/caddy/pull/4849
|
||||||
type sharedQUICTLSConfig struct {
|
type sharedQUICState struct {
|
||||||
rmu sync.RWMutex
|
rmu sync.RWMutex
|
||||||
tlsConfs map[*tls.Config]contextAndCancelFunc
|
tlsConfs map[*tls.Config]contextAndCancelFunc
|
||||||
|
requestCounters map[*tls.Config]*int64
|
||||||
activeTlsConf *tls.Config
|
activeTlsConf *tls.Config
|
||||||
|
activeRequestsCounter *int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// newSharedQUICTLSConfig creates a new sharedQUICTLSConfig
|
// newSharedQUICState creates a new sharedQUICState
|
||||||
func newSharedQUICTLSConfig(tlsConfig *tls.Config) *sharedQUICTLSConfig {
|
func newSharedQUICState(tlsConfig *tls.Config, activeRequests *int64) *sharedQUICState {
|
||||||
sqtc := &sharedQUICTLSConfig{
|
sqtc := &sharedQUICState{
|
||||||
tlsConfs: make(map[*tls.Config]contextAndCancelFunc),
|
tlsConfs: make(map[*tls.Config]contextAndCancelFunc),
|
||||||
|
requestCounters: make(map[*tls.Config]*int64),
|
||||||
activeTlsConf: tlsConfig,
|
activeTlsConf: tlsConfig,
|
||||||
|
activeRequestsCounter: activeRequests,
|
||||||
}
|
}
|
||||||
sqtc.addTLSConfig(tlsConfig)
|
sqtc.addState(tlsConfig, activeRequests)
|
||||||
return sqtc
|
return sqtc
|
||||||
}
|
}
|
||||||
|
|
||||||
// getConfigForClient is used as tls.Config's GetConfigForClient field
|
// getConfigForClient is used as tls.Config's GetConfigForClient field
|
||||||
func (sqtc *sharedQUICTLSConfig) getConfigForClient(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
func (sqs *sharedQUICState) getConfigForClient(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
sqtc.rmu.RLock()
|
sqs.rmu.RLock()
|
||||||
defer sqtc.rmu.RUnlock()
|
defer sqs.rmu.RUnlock()
|
||||||
return sqtc.activeTlsConf.GetConfigForClient(ch)
|
return sqs.activeTlsConf.GetConfigForClient(ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
// addTLSConfig adds tls.Config to the map if not present and returns the corresponding context and its cancelFunc
|
// getActiveRequests returns the number of active requests
|
||||||
// so that when cancelled, the active tls.Config will change
|
func (sqs *sharedQUICState) getActiveRequests() int64 {
|
||||||
func (sqtc *sharedQUICTLSConfig) addTLSConfig(tlsConfig *tls.Config) (context.Context, context.CancelFunc) {
|
// Prevent a race when a context is cancelled and active request counter is being changed
|
||||||
sqtc.rmu.Lock()
|
sqs.rmu.RLock()
|
||||||
defer sqtc.rmu.Unlock()
|
defer sqs.rmu.RUnlock()
|
||||||
|
return atomic.LoadInt64(sqs.activeRequestsCounter)
|
||||||
|
}
|
||||||
|
|
||||||
if cacc, ok := sqtc.tlsConfs[tlsConfig]; ok {
|
// addState adds tls.Config and activeRequests to the map if not present and returns the corresponding context and its cancelFunc
|
||||||
|
// so that when cancelled, the active tls.Config and request counter will change
|
||||||
|
func (sqs *sharedQUICState) addState(tlsConfig *tls.Config, activeRequests *int64) (context.Context, context.CancelFunc) {
|
||||||
|
sqs.rmu.Lock()
|
||||||
|
defer sqs.rmu.Unlock()
|
||||||
|
|
||||||
|
if cacc, ok := sqs.tlsConfs[tlsConfig]; ok {
|
||||||
return cacc.Context, cacc.CancelFunc
|
return cacc.Context, cacc.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -573,23 +637,26 @@ func (sqtc *sharedQUICTLSConfig) addTLSConfig(tlsConfig *tls.Config) (context.Co
|
||||||
wrappedCancel := func() {
|
wrappedCancel := func() {
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
sqtc.rmu.Lock()
|
sqs.rmu.Lock()
|
||||||
defer sqtc.rmu.Unlock()
|
defer sqs.rmu.Unlock()
|
||||||
|
|
||||||
delete(sqtc.tlsConfs, tlsConfig)
|
delete(sqs.tlsConfs, tlsConfig)
|
||||||
if sqtc.activeTlsConf == tlsConfig {
|
delete(sqs.requestCounters, tlsConfig)
|
||||||
// select another tls.Config, if there is none,
|
if sqs.activeTlsConf == tlsConfig {
|
||||||
|
// select another tls.Config and request counter, if there is none,
|
||||||
// related sharedQuicListener will be destroyed anyway
|
// related sharedQuicListener will be destroyed anyway
|
||||||
for tc := range sqtc.tlsConfs {
|
for tc, counter := range sqs.requestCounters {
|
||||||
sqtc.activeTlsConf = tc
|
sqs.activeTlsConf = tc
|
||||||
|
sqs.activeRequestsCounter = counter
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sqtc.tlsConfs[tlsConfig] = contextAndCancelFunc{ctx, wrappedCancel}
|
sqs.tlsConfs[tlsConfig] = contextAndCancelFunc{ctx, wrappedCancel}
|
||||||
|
sqs.requestCounters[tlsConfig] = activeRequests
|
||||||
// there should be at most 2 tls.Configs
|
// there should be at most 2 tls.Configs
|
||||||
if len(sqtc.tlsConfs) > 2 {
|
if len(sqs.tlsConfs) > 2 {
|
||||||
Log().Warn("quic listener tls configs are more than 2", zap.Int("number of configs", len(sqtc.tlsConfs)))
|
Log().Warn("quic listener tls configs are more than 2", zap.Int("number of configs", len(sqs.tlsConfs)))
|
||||||
}
|
}
|
||||||
return ctx, wrappedCancel
|
return ctx, wrappedCancel
|
||||||
}
|
}
|
||||||
|
@ -597,13 +664,17 @@ func (sqtc *sharedQUICTLSConfig) addTLSConfig(tlsConfig *tls.Config) (context.Co
|
||||||
// sharedQuicListener is like sharedListener, but for quic.EarlyListeners.
|
// sharedQuicListener is like sharedListener, but for quic.EarlyListeners.
|
||||||
type sharedQuicListener struct {
|
type sharedQuicListener struct {
|
||||||
*quic.EarlyListener
|
*quic.EarlyListener
|
||||||
sqtc *sharedQUICTLSConfig
|
packetConn net.PacketConn // we have to hold these because quic-go won't close listeners it didn't create
|
||||||
|
sqs *sharedQUICState
|
||||||
key string
|
key string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destruct closes the underlying QUIC listener.
|
// Destruct closes the underlying QUIC listener and its associated net.PacketConn.
|
||||||
func (sql *sharedQuicListener) Destruct() error {
|
func (sql *sharedQuicListener) Destruct() error {
|
||||||
return sql.EarlyListener.Close()
|
// close EarlyListener first to stop any operations being done to the net.PacketConn
|
||||||
|
_ = sql.EarlyListener.Close()
|
||||||
|
// then close the net.PacketConn
|
||||||
|
return sql.packetConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// sharedPacketConn is like sharedListener, but for net.PacketConns.
|
// sharedPacketConn is like sharedListener, but for net.PacketConns.
|
||||||
|
@ -652,6 +723,11 @@ var _ quic.OOBCapablePacketConn = (*fakeClosePacketConn)(nil)
|
||||||
// but doesn't actually use these methods, the only methods needed are `ReadMsgUDP` and `SyscallConn`.
|
// but doesn't actually use these methods, the only methods needed are `ReadMsgUDP` and `SyscallConn`.
|
||||||
var _ net.Conn = (*fakeClosePacketConn)(nil)
|
var _ net.Conn = (*fakeClosePacketConn)(nil)
|
||||||
|
|
||||||
|
// Unwrap returns the underlying net.UDPConn for quic-go optimization
|
||||||
|
func (fcpc *fakeClosePacketConn) Unwrap() any {
|
||||||
|
return fcpc.UDPConn
|
||||||
|
}
|
||||||
|
|
||||||
// Close won't close the underlying socket unless there is no more reference, then listenerPool will close it.
|
// Close won't close the underlying socket unless there is no more reference, then listenerPool will close it.
|
||||||
func (fcpc *fakeClosePacketConn) Close() error {
|
func (fcpc *fakeClosePacketConn) Close() error {
|
||||||
if atomic.CompareAndSwapInt32(&fcpc.closed, 0, 1) {
|
if atomic.CompareAndSwapInt32(&fcpc.closed, 0, 1) {
|
||||||
|
|
|
@ -617,17 +617,6 @@ func (app *App) Stop() error {
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
zap.Strings("addresses", server.Listen))
|
zap.Strings("addresses", server.Listen))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: we have to manually close our listeners because quic-go won't
|
|
||||||
// close listeners it didn't create along with the server itself...
|
|
||||||
// see https://github.com/quic-go/quic-go/issues/3560
|
|
||||||
for _, el := range server.h3listeners {
|
|
||||||
if err := el.Close(); err != nil {
|
|
||||||
app.logger.Error("HTTP/3 listener close",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.String("address", el.LocalAddr().String()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
stopH2Listener := func(server *Server) {
|
stopH2Listener := func(server *Server) {
|
||||||
defer finishedShutdown.Done()
|
defer finishedShutdown.Done()
|
||||||
|
|
|
@ -228,7 +228,6 @@ type Server struct {
|
||||||
|
|
||||||
server *http.Server
|
server *http.Server
|
||||||
h3server *http3.Server
|
h3server *http3.Server
|
||||||
h3listeners []net.PacketConn // TODO: we have to hold these because quic-go won't close listeners it didn't create
|
|
||||||
h2listeners []*http2Listener
|
h2listeners []*http2Listener
|
||||||
addresses []caddy.NetworkAddress
|
addresses []caddy.NetworkAddress
|
||||||
|
|
||||||
|
@ -555,13 +554,7 @@ func (s *Server) findLastRouteWithHostMatcher() int {
|
||||||
// the listener, with Server s as the handler.
|
// the listener, with Server s as the handler.
|
||||||
func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error {
|
func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error {
|
||||||
addr.Network = getHTTP3Network(addr.Network)
|
addr.Network = getHTTP3Network(addr.Network)
|
||||||
lnAny, err := addr.Listen(s.ctx, 0, net.ListenConfig{})
|
h3ln, err := addr.ListenQUIC(s.ctx, 0, net.ListenConfig{}, tlsCfg, &s.activeRequests)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ln := lnAny.(net.PacketConn)
|
|
||||||
|
|
||||||
h3ln, err := caddy.ListenQUIC(ln, tlsCfg, &s.activeRequests)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("starting HTTP/3 QUIC listener: %v", err)
|
return fmt.Errorf("starting HTTP/3 QUIC listener: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -579,8 +572,6 @@ func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.h3listeners = append(s.h3listeners, ln)
|
|
||||||
|
|
||||||
//nolint:errcheck
|
//nolint:errcheck
|
||||||
go s.h3server.ServeListener(h3ln)
|
go s.h3server.ServeListener(h3ln)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue