TUN-8621: Prevent QUIC connection from closing before grace period after unregistering

Whenever cloudflared receives a SIGTERM or SIGINT it goes into graceful shutdown mode, which unregisters the connection and closes the control stream. Unregistering makes it so we no longer receive any new requests and makes the edge close the connection, allowing in-flight requests to finish (within a 3 minute period).
 This was working fine for http2 connections, but the quic proxy was cancelling the context as soon as the controls stream ended, forcing the process to stop immediately.

 This commit changes the behavior so that we wait the full grace period before cancelling the request
This commit is contained in:
GoncaloGarcia 2024-08-30 12:51:20 +01:00 committed by chungthuang
parent 05249c7b51
commit e251a21810
8 changed files with 53 additions and 16 deletions

View File

@ -1,3 +1,7 @@
## 2024.9.2
### Bug Fixes
- We fixed a bug related to `--grace-period`. Tunnels that use QUIC as transport weren't abiding by this waiting period before forcefully closing the connections to the edge. From now on, both QUIC and HTTP2 tunnels will wait for either the grace period to end (defaults to 30 seconds) or until the last in-flight request is handled. Users that wish to maintain the previous behavior should set `--grace-period` to 0 if `--protocol` is set to `quic`. This will force `cloudflared` to shutdown as soon as either SIGTERM or SIGINT is received.
## 2024.2.1 ## 2024.2.1
### Notices ### Notices
- Starting from this version, tunnel diagnostics will be enabled by default. This will allow the engineering team to remotely get diagnostics from cloudflared during debug activities. Users still have the capability to opt-out of this feature by defining `--management-diagnostics=false` (or env `TUNNEL_MANAGEMENT_DIAGNOSTICS`). - Starting from this version, tunnel diagnostics will be enabled by default. This will allow the engineering team to remotely get diagnostics from cloudflared during debug activities. Users still have the capability to opt-out of this feature by defining `--management-diagnostics=false` (or env `TUNNEL_MANAGEMENT_DIAGNOSTICS`).

View File

@ -45,6 +45,7 @@ class TestTermination:
with connected: with connected:
connected.wait(self.timeout) connected.wait(self.timeout)
# Send signal after the SSE connection is established # Send signal after the SSE connection is established
with self.within_grace_period():
self.terminate_by_signal(cloudflared, signal) self.terminate_by_signal(cloudflared, signal)
self.wait_eyeball_thread( self.wait_eyeball_thread(
in_flight_req, self.grace_period + self.timeout) in_flight_req, self.grace_period + self.timeout)
@ -66,7 +67,7 @@ class TestTermination:
with connected: with connected:
connected.wait(self.timeout) connected.wait(self.timeout)
with self.within_grace_period(): with self.within_grace_period(has_connection=False):
# Send signal after the SSE connection is established # Send signal after the SSE connection is established
self.terminate_by_signal(cloudflared, signal) self.terminate_by_signal(cloudflared, signal)
self.wait_eyeball_thread(in_flight_req, self.grace_period) self.wait_eyeball_thread(in_flight_req, self.grace_period)
@ -78,7 +79,7 @@ class TestTermination:
with start_cloudflared( with start_cloudflared(
tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], new_process=True, capture_output=False) as cloudflared: tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], new_process=True, capture_output=False) as cloudflared:
wait_tunnel_ready(tunnel_url=config.get_url()) wait_tunnel_ready(tunnel_url=config.get_url())
with self.within_grace_period(): with self.within_grace_period(has_connection=False):
self.terminate_by_signal(cloudflared, signal) self.terminate_by_signal(cloudflared, signal)
def terminate_by_signal(self, cloudflared, sig): def terminate_by_signal(self, cloudflared, sig):
@ -92,13 +93,21 @@ class TestTermination:
# Using this context asserts logic within the context is executed within grace period # Using this context asserts logic within the context is executed within grace period
@contextmanager @contextmanager
def within_grace_period(self): def within_grace_period(self, has_connection=True):
try: try:
start = time.time() start = time.time()
yield yield
finally: finally:
# If the request takes longer than the grace period then we need to wait at most the grace period.
# If the request fell within the grace period cloudflared can close earlier, but to ensure that it doesn't
# close immediately we add a minimum boundary. If cloudflared shutdown in less than 1s it's likely that
# it shutdown as soon as it received SIGINT. The only way cloudflared can close immediately is if it has no
# in-flight requests
minimum = 1 if has_connection else 0
duration = time.time() - start duration = time.time() - start
assert duration < self.grace_period # Here we truncate to ensure that we don't fail on minute differences like 10.1 instead of 10
assert minimum <= int(duration) <= self.grace_period
def stream_request(self, config, connected, early_terminate): def stream_request(self, config, connected, early_terminate):
expected_terminate_message = "502 Bad Gateway" expected_terminate_message = "502 Bad Gateway"

View File

@ -6,6 +6,8 @@ import (
"net" "net"
"time" "time"
"github.com/pkg/errors"
"github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/management"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -116,27 +118,32 @@ func (c *controlStream) ServeControlStream(
} }
} }
c.waitForUnregister(ctx, registrationClient) return c.waitForUnregister(ctx, registrationClient)
return nil
} }
func (c *controlStream) waitForUnregister(ctx context.Context, registrationClient tunnelrpc.RegistrationClient) { func (c *controlStream) waitForUnregister(ctx context.Context, registrationClient tunnelrpc.RegistrationClient) error {
// wait for connection termination or start of graceful shutdown // wait for connection termination or start of graceful shutdown
defer registrationClient.Close() defer registrationClient.Close()
var shutdownError error
select { select {
case <-ctx.Done(): case <-ctx.Done():
shutdownError = ctx.Err()
break break
case <-c.gracefulShutdownC: case <-c.gracefulShutdownC:
c.stoppedGracefully = true c.stoppedGracefully = true
} }
c.observer.sendUnregisteringEvent(c.connIndex) c.observer.sendUnregisteringEvent(c.connIndex)
registrationClient.GracefulShutdown(ctx, c.gracePeriod) err := registrationClient.GracefulShutdown(ctx, c.gracePeriod)
if err != nil {
return errors.Wrap(err, "Error shutting down control stream")
}
c.observer.log.Info(). c.observer.log.Info().
Int(management.EventTypeKey, int(management.Cloudflared)). Int(management.EventTypeKey, int(management.Cloudflared)).
Uint8(LogFieldConnIndex, c.connIndex). Uint8(LogFieldConnIndex, c.connIndex).
IPAddr(LogFieldIPAddress, c.edgeAddress). IPAddr(LogFieldIPAddress, c.edgeAddress).
Msg("Unregistered tunnel connection") Msg("Unregistered tunnel connection")
return shutdownError
} }
func (c *controlStream) IsStopped() bool { func (c *controlStream) IsStopped() bool {

View File

@ -192,8 +192,9 @@ func (mc mockNamedTunnelRPCClient) RegisterConnection(
}, nil }, nil
} }
func (mc mockNamedTunnelRPCClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) { func (mc mockNamedTunnelRPCClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) error {
close(mc.unregistered) close(mc.unregistered)
return nil
} }
func (mockNamedTunnelRPCClient) Close() {} func (mockNamedTunnelRPCClient) Close() {}

View File

@ -69,6 +69,7 @@ type QUICConnection struct {
rpcTimeout time.Duration rpcTimeout time.Duration
streamWriteTimeout time.Duration streamWriteTimeout time.Duration
gracePeriod time.Duration
} }
// NewQUICConnection returns a new instance of QUICConnection. // NewQUICConnection returns a new instance of QUICConnection.
@ -86,6 +87,7 @@ func NewQUICConnection(
packetRouterConfig *ingress.GlobalRouterConfig, packetRouterConfig *ingress.GlobalRouterConfig,
rpcTimeout time.Duration, rpcTimeout time.Duration,
streamWriteTimeout time.Duration, streamWriteTimeout time.Duration,
gracePeriod time.Duration,
) (*QUICConnection, error) { ) (*QUICConnection, error) {
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger) udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
if err != nil { if err != nil {
@ -122,6 +124,7 @@ func NewQUICConnection(
connIndex: connIndex, connIndex: connIndex,
rpcTimeout: rpcTimeout, rpcTimeout: rpcTimeout,
streamWriteTimeout: streamWriteTimeout, streamWriteTimeout: streamWriteTimeout,
gracePeriod: gracePeriod,
}, nil }, nil
} }
@ -144,8 +147,17 @@ func (q *QUICConnection) Serve(ctx context.Context) error {
// In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control // In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control
// stream is already fully registered before the other goroutines can proceed. // stream is already fully registered before the other goroutines can proceed.
errGroup.Go(func() error { errGroup.Go(func() error {
defer cancel() // err is equal to nil if we exit due to unregistration. If that happens we want to wait the full
return q.serveControlStream(ctx, controlStream) // amount of the grace period, allowing requests to finish before we cancel the context, which will
// make cloudflared exit.
if err := q.serveControlStream(ctx, controlStream); err == nil {
select {
case <-ctx.Done():
case <-time.Tick(q.gracePeriod):
}
}
cancel()
return err
}) })
errGroup.Go(func() error { errGroup.Go(func() error {
defer cancel() defer cancel()

View File

@ -736,6 +736,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QU
nil, nil,
15*time.Second, 15*time.Second,
0*time.Second, 0*time.Second,
0*time.Second,
) )
require.NoError(t, err) require.NoError(t, err)
return qc return qc

View File

@ -604,6 +604,7 @@ func (e *EdgeTunnelServer) serveQUIC(
e.config.PacketConfig, e.config.PacketConfig,
e.config.RPCTimeout, e.config.RPCTimeout,
e.config.WriteStreamTimeout, e.config.WriteStreamTimeout,
e.config.GracePeriod,
) )
if err != nil { if err != nil {
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection") connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection")

View File

@ -23,7 +23,7 @@ type RegistrationClient interface {
edgeAddress net.IP, edgeAddress net.IP,
) (*pogs.ConnectionDetails, error) ) (*pogs.ConnectionDetails, error)
SendLocalConfiguration(ctx context.Context, config []byte) error SendLocalConfiguration(ctx context.Context, config []byte) error
GracefulShutdown(ctx context.Context, gracePeriod time.Duration) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) error
Close() Close()
} }
@ -79,7 +79,7 @@ func (r *registrationClient) SendLocalConfiguration(ctx context.Context, config
return err return err
} }
func (r *registrationClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) { func (r *registrationClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) error {
ctx, cancel := context.WithTimeout(ctx, gracePeriod) ctx, cancel := context.WithTimeout(ctx, gracePeriod)
defer cancel() defer cancel()
defer metrics.CapnpMetrics.ClientOperations.WithLabelValues(metrics.Registration, metrics.OperationUnregisterConnection).Inc() defer metrics.CapnpMetrics.ClientOperations.WithLabelValues(metrics.Registration, metrics.OperationUnregisterConnection).Inc()
@ -88,7 +88,9 @@ func (r *registrationClient) GracefulShutdown(ctx context.Context, gracePeriod t
err := r.client.UnregisterConnection(ctx) err := r.client.UnregisterConnection(ctx)
if err != nil { if err != nil {
metrics.CapnpMetrics.ClientFailures.WithLabelValues(metrics.Registration, metrics.OperationUnregisterConnection).Inc() metrics.CapnpMetrics.ClientFailures.WithLabelValues(metrics.Registration, metrics.OperationUnregisterConnection).Inc()
return err
} }
return nil
} }
func (r *registrationClient) Close() { func (r *registrationClient) Close() {