#!/usr/bin/env python from contextlib import contextmanager import platform import signal import threading import time import pytest import requests from constants import protocols from util import start_cloudflared, wait_tunnel_ready, check_tunnel_not_connected def supported_signals(): if platform.system() == "Windows": return [signal.SIGTERM] return [signal.SIGTERM, signal.SIGINT] class TestTermination: grace_period = 5 timeout = 10 sse_endpoint = "/sse?freq=1s" def _extra_config(self, protocol): return { "grace-period": f"{self.grace_period}s", "protocol": protocol, } @pytest.mark.parametrize("signal", supported_signals()) @pytest.mark.parametrize("protocol", protocols()) def test_graceful_shutdown(self, tmp_path, component_tests_config, signal, protocol): config = component_tests_config(self._extra_config(protocol)) with start_cloudflared( tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], new_process=True, capture_output=False) as cloudflared: wait_tunnel_ready(tunnel_url=config.get_url()) connected = threading.Condition() in_flight_req = threading.Thread( target=self.stream_request, args=(config, connected, False, )) in_flight_req.start() with connected: connected.wait(self.timeout) # Send signal after the SSE connection is established self.terminate_by_signal(cloudflared, signal) self.wait_eyeball_thread( in_flight_req, self.grace_period + self.timeout) # test cloudflared terminates before grace period expires when all eyeball # connections are drained @pytest.mark.parametrize("signal", supported_signals()) @pytest.mark.parametrize("protocol", protocols()) def test_shutdown_once_no_connection(self, tmp_path, component_tests_config, signal, protocol): config = component_tests_config(self._extra_config(protocol)) with start_cloudflared( tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], new_process=True, capture_output=False) as cloudflared: wait_tunnel_ready(tunnel_url=config.get_url()) connected = threading.Condition() in_flight_req = threading.Thread( target=self.stream_request, args=(config, connected, True, )) in_flight_req.start() with connected: connected.wait(self.timeout) with self.within_grace_period(): # Send signal after the SSE connection is established self.terminate_by_signal(cloudflared, signal) self.wait_eyeball_thread(in_flight_req, self.grace_period) @pytest.mark.parametrize("signal", supported_signals()) @pytest.mark.parametrize("protocol", protocols()) def test_no_connection_shutdown(self, tmp_path, component_tests_config, signal, protocol): config = component_tests_config(self._extra_config(protocol)) with start_cloudflared( tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], new_process=True, capture_output=False) as cloudflared: wait_tunnel_ready(tunnel_url=config.get_url()) with self.within_grace_period(): self.terminate_by_signal(cloudflared, signal) def terminate_by_signal(self, cloudflared, sig): cloudflared.send_signal(sig) check_tunnel_not_connected() cloudflared.wait() def wait_eyeball_thread(self, thread, timeout): thread.join(timeout) assert thread.is_alive() == False, "eyeball thread is still alive" # Using this context asserts logic within the context is executed within grace period @contextmanager def within_grace_period(self): try: start = time.time() yield finally: duration = time.time() - start assert duration < self.grace_period def stream_request(self, config, connected, early_terminate): expected_terminate_message = "502 Bad Gateway" url = config.get_url() + self.sse_endpoint with requests.get(url, timeout=5, stream=True) as resp: with connected: connected.notifyAll() lines = 0 for line in resp.iter_lines(): if expected_terminate_message.encode() == line: break lines += 1 if early_terminate and lines == 2: return # /sse returns count followed by 2 new lines assert lines >= (self.grace_period * 2)