119 lines
4.5 KiB
Python
119 lines
4.5 KiB
Python
#!/usr/bin/env python
|
|
from contextlib import contextmanager
|
|
import platform
|
|
import signal
|
|
import threading
|
|
import time
|
|
|
|
import pytest
|
|
import requests
|
|
|
|
from constants import protocols
|
|
from util import start_cloudflared, wait_tunnel_ready, check_tunnel_not_connected
|
|
|
|
|
|
def supported_signals():
|
|
if platform.system() == "Windows":
|
|
return [signal.SIGTERM]
|
|
return [signal.SIGTERM, signal.SIGINT]
|
|
|
|
|
|
class TestTermination:
|
|
grace_period = 5
|
|
timeout = 10
|
|
sse_endpoint = "/sse?freq=1s"
|
|
|
|
def _extra_config(self, protocol):
|
|
return {
|
|
"grace-period": f"{self.grace_period}s",
|
|
"protocol": protocol,
|
|
}
|
|
|
|
@pytest.mark.parametrize("signal", supported_signals())
|
|
@pytest.mark.parametrize("protocol", protocols())
|
|
def test_graceful_shutdown(self, tmp_path, component_tests_config, signal, protocol):
|
|
config = component_tests_config(self._extra_config(protocol))
|
|
with start_cloudflared(
|
|
tmp_path, config, new_process=True, capture_output=False) as cloudflared:
|
|
wait_tunnel_ready(tunnel_url=config.get_url())
|
|
|
|
connected = threading.Condition()
|
|
in_flight_req = threading.Thread(
|
|
target=self.stream_request, args=(config, connected, False, ))
|
|
in_flight_req.start()
|
|
|
|
with connected:
|
|
connected.wait(self.timeout)
|
|
# Send signal after the SSE connection is established
|
|
self.terminate_by_signal(cloudflared, signal)
|
|
self.wait_eyeball_thread(
|
|
in_flight_req, self.grace_period + self.timeout)
|
|
|
|
# test cloudflared terminates before grace period expires when all eyeball
|
|
# connections are drained
|
|
@pytest.mark.parametrize("signal", supported_signals())
|
|
@pytest.mark.parametrize("protocol", protocols())
|
|
def test_shutdown_once_no_connection(self, tmp_path, component_tests_config, signal, protocol):
|
|
config = component_tests_config(self._extra_config(protocol))
|
|
with start_cloudflared(
|
|
tmp_path, config, new_process=True, capture_output=False) as cloudflared:
|
|
wait_tunnel_ready(tunnel_url=config.get_url())
|
|
|
|
connected = threading.Condition()
|
|
in_flight_req = threading.Thread(
|
|
target=self.stream_request, args=(config, connected, True, ))
|
|
in_flight_req.start()
|
|
|
|
with connected:
|
|
connected.wait(self.timeout)
|
|
with self.within_grace_period():
|
|
# Send signal after the SSE connection is established
|
|
self.terminate_by_signal(cloudflared, signal)
|
|
self.wait_eyeball_thread(in_flight_req, self.grace_period)
|
|
|
|
@pytest.mark.parametrize("signal", supported_signals())
|
|
@pytest.mark.parametrize("protocol", protocols())
|
|
def test_no_connection_shutdown(self, tmp_path, component_tests_config, signal, protocol):
|
|
config = component_tests_config(self._extra_config(protocol))
|
|
with start_cloudflared(
|
|
tmp_path, config, new_process=True, capture_output=False) as cloudflared:
|
|
wait_tunnel_ready(tunnel_url=config.get_url())
|
|
with self.within_grace_period():
|
|
self.terminate_by_signal(cloudflared, signal)
|
|
|
|
def terminate_by_signal(self, cloudflared, sig):
|
|
cloudflared.send_signal(sig)
|
|
check_tunnel_not_connected()
|
|
cloudflared.wait()
|
|
|
|
def wait_eyeball_thread(self, thread, timeout):
|
|
thread.join(timeout)
|
|
assert thread.is_alive() == False, "eyeball thread is still alive"
|
|
|
|
# Using this context asserts logic within the context is executed within grace period
|
|
@contextmanager
|
|
def within_grace_period(self):
|
|
try:
|
|
start = time.time()
|
|
yield
|
|
finally:
|
|
duration = time.time() - start
|
|
assert duration < self.grace_period
|
|
|
|
def stream_request(self, config, connected, early_terminate):
|
|
expected_terminate_message = "502 Bad Gateway"
|
|
url = config.get_url() + self.sse_endpoint
|
|
|
|
with requests.get(url, timeout=5, stream=True) as resp:
|
|
with connected:
|
|
connected.notifyAll()
|
|
lines = 0
|
|
for line in resp.iter_lines():
|
|
if expected_terminate_message.encode() == line:
|
|
break
|
|
lines += 1
|
|
if early_terminate and lines == 2:
|
|
return
|
|
# /sse returns count followed by 2 new lines
|
|
assert lines >= (self.grace_period * 2)
|