diff --git a/component-tests/test_termination.py b/component-tests/test_termination.py index 0d58496a..fbca69d0 100644 --- a/component-tests/test_termination.py +++ b/component-tests/test_termination.py @@ -1,11 +1,20 @@ #!/usr/bin/env python from contextlib import contextmanager -import requests +import platform import signal import threading import time -from util import start_cloudflared, wait_tunnel_ready, check_tunnel_not_connected, LOGGER +import pytest +import requests + +from util import start_cloudflared, wait_tunnel_ready, check_tunnel_not_connected + + +def supported_signals(): + if platform.system() == "Windows": + return [signal.SIGTERM] + return [signal.SIGTERM, signal.SIGINT] class TestTermination(): @@ -14,57 +23,56 @@ class TestTermination(): extra_config = { "grace-period": f"{grace_period}s", } - signals = [signal.SIGTERM, signal.SIGINT] sse_endpoint = "/sse?freq=1s" - def test_graceful_shutdown(self, tmp_path, component_tests_config): + @pytest.mark.parametrize("signal", supported_signals()) + def test_graceful_shutdown(self, tmp_path, component_tests_config, signal): config = component_tests_config(self.extra_config) - for sig in self.signals: - with start_cloudflared( - tmp_path, config, new_process=True, capture_output=False) as cloudflared: - wait_tunnel_ready(tunnel_url=config.get_url()) + with start_cloudflared( + tmp_path, config, new_process=True, capture_output=False) as cloudflared: + wait_tunnel_ready(tunnel_url=config.get_url()) - connected = threading.Condition() - in_flight_req = threading.Thread( - target=self.stream_request, args=(config, connected,)) - in_flight_req.start() + connected = threading.Condition() + in_flight_req = threading.Thread( + target=self.stream_request, args=(config, connected, False, )) + in_flight_req.start() - with connected: - connected.wait(self.timeout) - # Send signal after the SSE connection is established - self.terminate_by_signal(cloudflared, sig) - self.wait_eyeball_thread( - in_flight_req, self.grace_period + self.timeout) + with connected: + connected.wait(self.timeout) + # Send signal after the SSE connection is established + self.terminate_by_signal(cloudflared, signal) + self.wait_eyeball_thread( + in_flight_req, self.grace_period + self.timeout) # test cloudflared terminates before grace period expires when all eyeball # connections are drained - def test_shutdown_once_no_connection(self, tmp_path, component_tests_config): + @pytest.mark.parametrize("signal", supported_signals()) + def test_shutdown_once_no_connection(self, tmp_path, component_tests_config, signal): config = component_tests_config(self.extra_config) - for sig in self.signals: - with start_cloudflared( - tmp_path, config, new_process=True, capture_output=False) as cloudflared: - wait_tunnel_ready(tunnel_url=config.get_url()) + with start_cloudflared( + tmp_path, config, new_process=True, capture_output=False) as cloudflared: + wait_tunnel_ready(tunnel_url=config.get_url()) - connected = threading.Condition() - in_flight_req = threading.Thread( - target=self.stream_request, args=(config, connected, True, )) - in_flight_req.start() + connected = threading.Condition() + in_flight_req = threading.Thread( + target=self.stream_request, args=(config, connected, True, )) + in_flight_req.start() - with connected: - connected.wait(self.timeout) - with self.within_grace_period(): - # Send signal after the SSE connection is established - self.terminate_by_signal(cloudflared, sig) - self.wait_eyeball_thread(in_flight_req, self.grace_period) + with connected: + connected.wait(self.timeout) + with self.within_grace_period(): + # Send signal after the SSE connection is established + self.terminate_by_signal(cloudflared, signal) + self.wait_eyeball_thread(in_flight_req, self.grace_period) - def test_no_connection_shutdown(self, tmp_path, component_tests_config): + @pytest.mark.parametrize("signal", supported_signals()) + def test_no_connection_shutdown(self, tmp_path, component_tests_config, signal): config = component_tests_config(self.extra_config) - for sig in self.signals: - with start_cloudflared( - tmp_path, config, new_process=True, capture_output=False) as cloudflared: - wait_tunnel_ready(tunnel_url=config.get_url()) - with self.within_grace_period(): - self.terminate_by_signal(cloudflared, sig) + with start_cloudflared( + tmp_path, config, new_process=True, capture_output=False) as cloudflared: + wait_tunnel_ready(tunnel_url=config.get_url()) + with self.within_grace_period(): + self.terminate_by_signal(cloudflared, signal) def terminate_by_signal(self, cloudflared, sig): cloudflared.send_signal(sig)