From cfef0e737f3f3ba7b7d719be8ca51d38c30c2737 Mon Sep 17 00:00:00 2001 From: Devin Carr Date: Wed, 31 Aug 2022 12:52:44 -0700 Subject: [PATCH] TUN-6720: Remove forcibly closing connection during reconnect signal Previously allowing the reconnect signal forcibly close the connection caused a race condition on which error was returned by the errgroup in the tunnel connection. Allowing the signal to return and provide a context cancel to the connection provides a safer shutdown of the tunnel for this test-only scenario. --- cmd/cloudflared/tunnel/cmd.go | 2 +- component-tests/test_reconnect.py | 17 ++++++----------- component-tests/util.py | 7 +++++-- supervisor/supervisor.go | 3 +-- supervisor/tunnel.go | 12 +++++++++--- 5 files changed, 22 insertions(+), 19 deletions(-) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 793f8c2b..a45e6e0c 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -375,7 +375,7 @@ func StartServer( errC <- metrics.ServeMetrics(metricsListener, ctx.Done(), readinessServer, quickTunnelURL, orchestrator, log) }() - reconnectCh := make(chan supervisor.ReconnectSignal, 1) + reconnectCh := make(chan supervisor.ReconnectSignal, c.Int("ha-connections")) if c.IsSet("stdin-control") { log.Info().Msg("Enabling control through stdin") go stdinControl(reconnectCh, log) diff --git a/component-tests/test_reconnect.py b/component-tests/test_reconnect.py index 0b601171..e125845a 100644 --- a/component-tests/test_reconnect.py +++ b/component-tests/test_reconnect.py @@ -47,17 +47,12 @@ class TestReconnect: cloudflared.stdin.flush() def assert_reconnect(self, config, cloudflared, repeat): - wait_tunnel_ready(tunnel_url=config.get_url(), require_min_connections=self.default_ha_conns) + wait_tunnel_ready(tunnel_url=config.get_url(), + require_min_connections=self.default_ha_conns) for _ in range(repeat): - for i in range(self.default_ha_conns): + for _ in range(self.default_ha_conns): self.send_reconnect(cloudflared, self.default_reconnect_secs) - expect_connections = self.default_ha_conns-i-1 - if expect_connections > 0: - # Don't check if tunnel returns 200 here because there is a race condition between wait_tunnel_ready - # retrying to get 200 response and reconnecting - wait_tunnel_ready(require_min_connections=expect_connections) - else: - check_tunnel_not_connected() - + check_tunnel_not_connected() sleep(self.default_reconnect_secs * 2) - wait_tunnel_ready(tunnel_url=config.get_url(), require_min_connections=self.default_ha_conns) + wait_tunnel_ready(tunnel_url=config.get_url(), + require_min_connections=self.default_ha_conns) diff --git a/component-tests/util.py b/component-tests/util.py index 6e60d08e..34f5faf7 100644 --- a/component-tests/util.py +++ b/component-tests/util.py @@ -15,6 +15,7 @@ from constants import METRICS_PORT, MAX_RETRIES, BACKOFF_SECS LOGGER = logging.getLogger(__name__) + def select_platform(plat): return pytest.mark.skipif( platform.system() != plat, reason=f"Only runs on {plat}") @@ -108,13 +109,15 @@ def _log_cloudflared_logs(cfd_logs): LOGGER.warning(line) -@retry(stop_max_attempt_number=MAX_RETRIES * BACKOFF_SECS, wait_fixed=1000) +@retry(stop_max_attempt_number=MAX_RETRIES, wait_fixed=BACKOFF_SECS * 1000) def check_tunnel_not_connected(): url = f'http://localhost:{METRICS_PORT}/ready' try: - resp = requests.get(url, timeout=1) + resp = requests.get(url, timeout=BACKOFF_SECS) assert resp.status_code == 503, f"Expect {url} returns 503, got {resp.status_code}" + assert resp.json()[ + "readyConnections"] == 0, "Expected all connections to be terminated (pending reconnect)" # cloudflared might already terminate except requests.exceptions.ConnectionError as e: LOGGER.warning(f"Failed to connect to {url}, error: {e}") diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index c6bca29e..bc076f0c 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -295,8 +295,7 @@ func (s *Supervisor) initialize( s.config.ProtocolSelector.Current(), false, } - ch := signal.New(make(chan struct{})) - go s.startTunnel(ctx, i, ch) + go s.startTunnel(ctx, i, s.newConnectedTunnelSignal(i)) time.Sleep(registrationInterval) } return nil diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 2652c22d..f23eeff7 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -546,7 +546,13 @@ func (e *EdgeTunnelServer) serveH2mux( }) errGroup.Go(func() error { - return listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC) + err := listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC) + if err != nil { + // forcefully break the connection (this is only used for testing) + // errgroup will return context canceled for the handler.ServeClassicTunnel + connLog.Logger().Debug().Msg("Forcefully breaking h2mux connection") + } + return err }) return errGroup.Wait() @@ -580,8 +586,8 @@ func (e *EdgeTunnelServer) serveHTTP2( err := listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC) if err != nil { // forcefully break the connection (this is only used for testing) + // errgroup will return context canceled for the h2conn.Serve connLog.Logger().Debug().Msg("Forcefully breaking http2 connection") - _ = tlsServerConn.Close() } return err }) @@ -636,8 +642,8 @@ func (e *EdgeTunnelServer) serveQUIC( err := listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC) if err != nil { // forcefully break the connection (this is only used for testing) + // errgroup will return context canceled for the quicConn.Serve connLogger.Logger().Debug().Msg("Forcefully breaking quic connection") - quicConn.Close() } return err })