0
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2024-12-30 22:34:15 -05:00

core: quic listener will manage the underlying socket by itself (#5749)

* core: quic listener will manage the underlying socket by itself.

* format code

* rename sharedQUICTLSConfig to sharedQUICState, and it will now manage the number of active requests

* add comment

* strict unwrap type

* fix unwrap

* remove comment
This commit is contained in:
WeidiDeng 2023-10-16 23:28:15 +08:00 committed by GitHub
parent 0900844c81
commit 7c82e265da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 127 additions and 71 deletions

View file

@ -470,38 +470,90 @@ func ListenPacket(network, addr string) (net.PacketConn, error) {
// unixgram will be used; otherwise, udp will be used). // unixgram will be used; otherwise, udp will be used).
// //
// NOTE: This API is EXPERIMENTAL and may be changed or removed. // NOTE: This API is EXPERIMENTAL and may be changed or removed.
// func (na NetworkAddress) ListenQUIC(ctx context.Context, portOffset uint, config net.ListenConfig, tlsConf *tls.Config, activeRequests *int64) (http3.QUICEarlyListener, error) {
// TODO: See if we can find a more elegant solution closer to the new NetworkAddress.Listen API. lnKey := listenerKey("quic"+na.Network, na.JoinHostPort(portOffset))
func ListenQUIC(ln net.PacketConn, tlsConf *tls.Config, activeRequests *int64) (http3.QUICEarlyListener, error) {
lnKey := listenerKey("quic+"+ln.LocalAddr().Network(), ln.LocalAddr().String())
sharedEarlyListener, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) { sharedEarlyListener, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) {
sqtc := newSharedQUICTLSConfig(tlsConf) lnAny, err := na.Listen(ctx, portOffset, config)
if err != nil {
return nil, err
}
ln := lnAny.(net.PacketConn)
h3ln := ln
for {
// retrieve the underlying socket, so quic-go can optimize.
if unwrapper, ok := h3ln.(interface{ Unwrap() net.PacketConn }); ok {
h3ln = unwrapper.Unwrap()
} else {
break
}
}
sqs := newSharedQUICState(tlsConf, activeRequests)
// http3.ConfigureTLSConfig only uses this field and tls App sets this field as well // http3.ConfigureTLSConfig only uses this field and tls App sets this field as well
//nolint:gosec //nolint:gosec
quicTlsConfig := &tls.Config{GetConfigForClient: sqtc.getConfigForClient} quicTlsConfig := &tls.Config{GetConfigForClient: sqs.getConfigForClient}
earlyLn, err := quic.ListenEarly(ln, http3.ConfigureTLSConfig(quicTlsConfig), &quic.Config{ earlyLn, err := quic.ListenEarly(h3ln, http3.ConfigureTLSConfig(quicTlsConfig), &quic.Config{
Allow0RTT: true, Allow0RTT: true,
RequireAddressValidation: func(clientAddr net.Addr) bool { RequireAddressValidation: func(clientAddr net.Addr) bool {
var highLoad bool // TODO: make tunable?
if activeRequests != nil { return sqs.getActiveRequests() > 1000
highLoad = atomic.LoadInt64(activeRequests) > 1000 // TODO: make tunable?
}
return highLoad
}, },
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &sharedQuicListener{EarlyListener: earlyLn, sqtc: sqtc, key: lnKey}, nil // using the original net.PacketConn to close them properly
return &sharedQuicListener{EarlyListener: earlyLn, packetConn: ln, sqs: sqs, key: lnKey}, nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
sql := sharedEarlyListener.(*sharedQuicListener) sql := sharedEarlyListener.(*sharedQuicListener)
// add current tls.Config to sqtc, so GetConfigForClient will always return the latest tls.Config in case of context cancellation // add current tls.Config to sqs, so GetConfigForClient will always return the latest tls.Config in case of context cancellation,
ctx, cancel := sql.sqtc.addTLSConfig(tlsConf) // and the request counter will reflect current http server
ctx, cancel := sql.sqs.addState(tlsConf, activeRequests)
return &fakeCloseQuicListener{
sharedQuicListener: sql,
context: ctx,
contextCancel: cancel,
}, nil
}
// DEPRECATED: Use NetworkAddress.ListenQUIC instead. This function will likely be changed or removed in the future.
// TODO: See if we can find a more elegant solution closer to the new NetworkAddress.Listen API.
func ListenQUIC(ln net.PacketConn, tlsConf *tls.Config, activeRequests *int64) (http3.QUICEarlyListener, error) {
lnKey := listenerKey("quic+"+ln.LocalAddr().Network(), ln.LocalAddr().String())
sharedEarlyListener, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) {
sqs := newSharedQUICState(tlsConf, activeRequests)
// http3.ConfigureTLSConfig only uses this field and tls App sets this field as well
//nolint:gosec
quicTlsConfig := &tls.Config{GetConfigForClient: sqs.getConfigForClient}
earlyLn, err := quic.ListenEarly(ln, http3.ConfigureTLSConfig(quicTlsConfig), &quic.Config{
Allow0RTT: true,
RequireAddressValidation: func(clientAddr net.Addr) bool {
// TODO: make tunable?
return sqs.getActiveRequests() > 1000
},
})
if err != nil {
return nil, err
}
return &sharedQuicListener{EarlyListener: earlyLn, sqs: sqs, key: lnKey}, nil
})
if err != nil {
return nil, err
}
sql := sharedEarlyListener.(*sharedQuicListener)
// add current tls.Config and request counter to sqs, so GetConfigForClient will always return the latest tls.Config in case of context cancellation,
// and the request counter will reflect current http server
ctx, cancel := sql.sqs.addState(tlsConf, activeRequests)
// TODO: to serve QUIC over a unix socket, currently we need to hold onto // TODO: to serve QUIC over a unix socket, currently we need to hold onto
// the underlying net.PacketConn (which we wrap as unixConn to keep count // the underlying net.PacketConn (which we wrap as unixConn to keep count
@ -534,38 +586,50 @@ type contextAndCancelFunc struct {
context.CancelFunc context.CancelFunc
} }
// sharedQUICTLSConfig manages GetConfigForClient // sharedQUICState manages GetConfigForClient and current number of active requests
// see issue: https://github.com/caddyserver/caddy/pull/4849 // see issue: https://github.com/caddyserver/caddy/pull/4849
type sharedQUICTLSConfig struct { type sharedQUICState struct {
rmu sync.RWMutex rmu sync.RWMutex
tlsConfs map[*tls.Config]contextAndCancelFunc tlsConfs map[*tls.Config]contextAndCancelFunc
requestCounters map[*tls.Config]*int64
activeTlsConf *tls.Config activeTlsConf *tls.Config
activeRequestsCounter *int64
} }
// newSharedQUICTLSConfig creates a new sharedQUICTLSConfig // newSharedQUICState creates a new sharedQUICState
func newSharedQUICTLSConfig(tlsConfig *tls.Config) *sharedQUICTLSConfig { func newSharedQUICState(tlsConfig *tls.Config, activeRequests *int64) *sharedQUICState {
sqtc := &sharedQUICTLSConfig{ sqtc := &sharedQUICState{
tlsConfs: make(map[*tls.Config]contextAndCancelFunc), tlsConfs: make(map[*tls.Config]contextAndCancelFunc),
requestCounters: make(map[*tls.Config]*int64),
activeTlsConf: tlsConfig, activeTlsConf: tlsConfig,
activeRequestsCounter: activeRequests,
} }
sqtc.addTLSConfig(tlsConfig) sqtc.addState(tlsConfig, activeRequests)
return sqtc return sqtc
} }
// getConfigForClient is used as tls.Config's GetConfigForClient field // getConfigForClient is used as tls.Config's GetConfigForClient field
func (sqtc *sharedQUICTLSConfig) getConfigForClient(ch *tls.ClientHelloInfo) (*tls.Config, error) { func (sqs *sharedQUICState) getConfigForClient(ch *tls.ClientHelloInfo) (*tls.Config, error) {
sqtc.rmu.RLock() sqs.rmu.RLock()
defer sqtc.rmu.RUnlock() defer sqs.rmu.RUnlock()
return sqtc.activeTlsConf.GetConfigForClient(ch) return sqs.activeTlsConf.GetConfigForClient(ch)
} }
// addTLSConfig adds tls.Config to the map if not present and returns the corresponding context and its cancelFunc // getActiveRequests returns the number of active requests
// so that when cancelled, the active tls.Config will change func (sqs *sharedQUICState) getActiveRequests() int64 {
func (sqtc *sharedQUICTLSConfig) addTLSConfig(tlsConfig *tls.Config) (context.Context, context.CancelFunc) { // Prevent a race when a context is cancelled and active request counter is being changed
sqtc.rmu.Lock() sqs.rmu.RLock()
defer sqtc.rmu.Unlock() defer sqs.rmu.RUnlock()
return atomic.LoadInt64(sqs.activeRequestsCounter)
}
if cacc, ok := sqtc.tlsConfs[tlsConfig]; ok { // addState adds tls.Config and activeRequests to the map if not present and returns the corresponding context and its cancelFunc
// so that when cancelled, the active tls.Config and request counter will change
func (sqs *sharedQUICState) addState(tlsConfig *tls.Config, activeRequests *int64) (context.Context, context.CancelFunc) {
sqs.rmu.Lock()
defer sqs.rmu.Unlock()
if cacc, ok := sqs.tlsConfs[tlsConfig]; ok {
return cacc.Context, cacc.CancelFunc return cacc.Context, cacc.CancelFunc
} }
@ -573,23 +637,26 @@ func (sqtc *sharedQUICTLSConfig) addTLSConfig(tlsConfig *tls.Config) (context.Co
wrappedCancel := func() { wrappedCancel := func() {
cancel() cancel()
sqtc.rmu.Lock() sqs.rmu.Lock()
defer sqtc.rmu.Unlock() defer sqs.rmu.Unlock()
delete(sqtc.tlsConfs, tlsConfig) delete(sqs.tlsConfs, tlsConfig)
if sqtc.activeTlsConf == tlsConfig { delete(sqs.requestCounters, tlsConfig)
// select another tls.Config, if there is none, if sqs.activeTlsConf == tlsConfig {
// select another tls.Config and request counter, if there is none,
// related sharedQuicListener will be destroyed anyway // related sharedQuicListener will be destroyed anyway
for tc := range sqtc.tlsConfs { for tc, counter := range sqs.requestCounters {
sqtc.activeTlsConf = tc sqs.activeTlsConf = tc
sqs.activeRequestsCounter = counter
break break
} }
} }
} }
sqtc.tlsConfs[tlsConfig] = contextAndCancelFunc{ctx, wrappedCancel} sqs.tlsConfs[tlsConfig] = contextAndCancelFunc{ctx, wrappedCancel}
sqs.requestCounters[tlsConfig] = activeRequests
// there should be at most 2 tls.Configs // there should be at most 2 tls.Configs
if len(sqtc.tlsConfs) > 2 { if len(sqs.tlsConfs) > 2 {
Log().Warn("quic listener tls configs are more than 2", zap.Int("number of configs", len(sqtc.tlsConfs))) Log().Warn("quic listener tls configs are more than 2", zap.Int("number of configs", len(sqs.tlsConfs)))
} }
return ctx, wrappedCancel return ctx, wrappedCancel
} }
@ -597,13 +664,17 @@ func (sqtc *sharedQUICTLSConfig) addTLSConfig(tlsConfig *tls.Config) (context.Co
// sharedQuicListener is like sharedListener, but for quic.EarlyListeners. // sharedQuicListener is like sharedListener, but for quic.EarlyListeners.
type sharedQuicListener struct { type sharedQuicListener struct {
*quic.EarlyListener *quic.EarlyListener
sqtc *sharedQUICTLSConfig packetConn net.PacketConn // we have to hold these because quic-go won't close listeners it didn't create
sqs *sharedQUICState
key string key string
} }
// Destruct closes the underlying QUIC listener. // Destruct closes the underlying QUIC listener and its associated net.PacketConn.
func (sql *sharedQuicListener) Destruct() error { func (sql *sharedQuicListener) Destruct() error {
return sql.EarlyListener.Close() // close EarlyListener first to stop any operations being done to the net.PacketConn
_ = sql.EarlyListener.Close()
// then close the net.PacketConn
return sql.packetConn.Close()
} }
// sharedPacketConn is like sharedListener, but for net.PacketConns. // sharedPacketConn is like sharedListener, but for net.PacketConns.
@ -652,6 +723,11 @@ var _ quic.OOBCapablePacketConn = (*fakeClosePacketConn)(nil)
// but doesn't actually use these methods, the only methods needed are `ReadMsgUDP` and `SyscallConn`. // but doesn't actually use these methods, the only methods needed are `ReadMsgUDP` and `SyscallConn`.
var _ net.Conn = (*fakeClosePacketConn)(nil) var _ net.Conn = (*fakeClosePacketConn)(nil)
// Unwrap returns the underlying net.UDPConn for quic-go optimization
func (fcpc *fakeClosePacketConn) Unwrap() any {
return fcpc.UDPConn
}
// Close won't close the underlying socket unless there is no more reference, then listenerPool will close it. // Close won't close the underlying socket unless there is no more reference, then listenerPool will close it.
func (fcpc *fakeClosePacketConn) Close() error { func (fcpc *fakeClosePacketConn) Close() error {
if atomic.CompareAndSwapInt32(&fcpc.closed, 0, 1) { if atomic.CompareAndSwapInt32(&fcpc.closed, 0, 1) {

View file

@ -617,17 +617,6 @@ func (app *App) Stop() error {
zap.Error(err), zap.Error(err),
zap.Strings("addresses", server.Listen)) zap.Strings("addresses", server.Listen))
} }
// TODO: we have to manually close our listeners because quic-go won't
// close listeners it didn't create along with the server itself...
// see https://github.com/quic-go/quic-go/issues/3560
for _, el := range server.h3listeners {
if err := el.Close(); err != nil {
app.logger.Error("HTTP/3 listener close",
zap.Error(err),
zap.String("address", el.LocalAddr().String()))
}
}
} }
stopH2Listener := func(server *Server) { stopH2Listener := func(server *Server) {
defer finishedShutdown.Done() defer finishedShutdown.Done()

View file

@ -228,7 +228,6 @@ type Server struct {
server *http.Server server *http.Server
h3server *http3.Server h3server *http3.Server
h3listeners []net.PacketConn // TODO: we have to hold these because quic-go won't close listeners it didn't create
h2listeners []*http2Listener h2listeners []*http2Listener
addresses []caddy.NetworkAddress addresses []caddy.NetworkAddress
@ -555,13 +554,7 @@ func (s *Server) findLastRouteWithHostMatcher() int {
// the listener, with Server s as the handler. // the listener, with Server s as the handler.
func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error { func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error {
addr.Network = getHTTP3Network(addr.Network) addr.Network = getHTTP3Network(addr.Network)
lnAny, err := addr.Listen(s.ctx, 0, net.ListenConfig{}) h3ln, err := addr.ListenQUIC(s.ctx, 0, net.ListenConfig{}, tlsCfg, &s.activeRequests)
if err != nil {
return err
}
ln := lnAny.(net.PacketConn)
h3ln, err := caddy.ListenQUIC(ln, tlsCfg, &s.activeRequests)
if err != nil { if err != nil {
return fmt.Errorf("starting HTTP/3 QUIC listener: %v", err) return fmt.Errorf("starting HTTP/3 QUIC listener: %v", err)
} }
@ -579,8 +572,6 @@ func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error
} }
} }
s.h3listeners = append(s.h3listeners, ln)
//nolint:errcheck //nolint:errcheck
go s.h3server.ServeListener(h3ln) go s.h3server.ServeListener(h3ln)