From 25cfbec072fb8cc59c73578a9cd5f933594b7e27 Mon Sep 17 00:00:00 2001 From: cthuang Date: Thu, 11 Mar 2021 13:49:09 +0000 Subject: [PATCH] TUN-4050: Add component tests to assert reconnect behavior --- component-tests/config.py | 14 +++------ component-tests/conftest.py | 14 ++++----- component-tests/test_config.py | 20 ++++++------- component-tests/test_logging.py | 2 +- component-tests/test_reconnect.py | 50 +++++++++++++++++++++++++++++++ component-tests/util.py | 39 ++++++++++++++---------- 6 files changed, 96 insertions(+), 43 deletions(-) create mode 100644 component-tests/test_reconnect.py diff --git a/component-tests/config.py b/component-tests/config.py index d447c46b..4598450e 100644 --- a/component-tests/config.py +++ b/component-tests/config.py @@ -14,7 +14,8 @@ from util import LOGGER @dataclass(frozen=True) -class TunnelBaseConfig: +class BaseConfig: + cloudflared_binary: str no_autoupdate: bool = True metrics: str = f'localhost:{METRICS_PORT}' @@ -26,7 +27,7 @@ class TunnelBaseConfig: @dataclass(frozen=True) -class NamedTunnelBaseConfig(TunnelBaseConfig): +class NamedTunnelBaseConfig(BaseConfig): # The attributes of the parent class are ordered before attributes in this class, # so we have to use default values here and check if they are set in __post_init__ tunnel: str = None @@ -67,7 +68,7 @@ class NamedTunnelConfig(NamedTunnelBaseConfig): @dataclass(frozen=True) -class ClassicTunnelBaseConfig(TunnelBaseConfig): +class ClassicTunnelBaseConfig(BaseConfig): hostname: str = None origincert: str = None @@ -99,13 +100,6 @@ class ClassicTunnelConfig(ClassicTunnelBaseConfig): return "https://" + self.hostname -@dataclass -class ComponentTestConfig: - cloudflared_binary: str - named_tunnel_config: NamedTunnelConfig - classic_tunnel_config: ClassicTunnelConfig - - def build_config_from_env(): config_path = get_env("COMPONENT_TESTS_CONFIG") config_content = base64.b64decode( diff --git a/component-tests/conftest.py b/component-tests/conftest.py index d4b4147d..799914b5 100644 --- a/component-tests/conftest.py +++ b/component-tests/conftest.py @@ -4,7 +4,7 @@ import yaml from time import sleep -from config import ComponentTestConfig, NamedTunnelConfig, ClassicTunnelConfig +from config import NamedTunnelConfig, ClassicTunnelConfig from constants import BACKOFF_SECS from util import LOGGER @@ -19,12 +19,12 @@ def component_tests_config(): config = yaml.safe_load(stream) LOGGER.info(f"component tests base config {config}") - def _component_tests_config(extra_named_tunnel_config={}, extra_classic_tunnel_config={}): - named_tunnel_config = NamedTunnelConfig(additional_config=extra_named_tunnel_config, - tunnel=config['tunnel'], credentials_file=config['credentials_file'], ingress=config['ingress']) - classic_tunnel_config = ClassicTunnelConfig( - additional_config=extra_classic_tunnel_config, hostname=config['classic_hostname'], origincert=config['origincert']) - return ComponentTestConfig(config['cloudflared_binary'], named_tunnel_config, classic_tunnel_config) + def _component_tests_config(additional_config={}, named_tunnel=True): + if named_tunnel: + return NamedTunnelConfig(additional_config=additional_config, + cloudflared_binary=config['cloudflared_binary'], tunnel=config['tunnel'], credentials_file=config['credentials_file'], ingress=config['ingress']) + return ClassicTunnelConfig( + additional_config=additional_config, cloudflared_binary=config['cloudflared_binary'], hostname=config['classic_hostname'], origincert=config['origincert']) return _component_tests_config diff --git a/component-tests/test_config.py b/component-tests/test_config.py index dbe1e9b9..6402be57 100644 --- a/component-tests/test_config.py +++ b/component-tests/test_config.py @@ -31,27 +31,27 @@ class TestConfig: {"service": "http_status:404"} ], } - component_tests_config = component_tests_config(extra_config) + config = component_tests_config(extra_config) validate_args = ["ingress", "validate"] - _ = start_cloudflared(tmp_path, component_tests_config, validate_args) + _ = start_cloudflared(tmp_path, config, validate_args) - self.match_rule(tmp_path, component_tests_config, + self.match_rule(tmp_path, config, "http://example.com/index.html", 1) - self.match_rule(tmp_path, component_tests_config, + self.match_rule(tmp_path, config, "https://example.com/index.html", 1) - self.match_rule(tmp_path, component_tests_config, + self.match_rule(tmp_path, config, "https://api.example.com/login", 2) - self.match_rule(tmp_path, component_tests_config, + self.match_rule(tmp_path, config, "https://wss.example.com", 3) - self.match_rule(tmp_path, component_tests_config, + self.match_rule(tmp_path, config, "https://ssh.example.com", 4) - self.match_rule(tmp_path, component_tests_config, + self.match_rule(tmp_path, config, "https://api.example.com", 5) # This is used to check that the command tunnel ingress url matches rule number . Note that rule number uses 1-based indexing - def match_rule(self, tmp_path, component_tests_config, url, rule_num): + def match_rule(self, tmp_path, config, url, rule_num): args = ["ingress", "rule", url] - match_rule = start_cloudflared(tmp_path, component_tests_config, args) + match_rule = start_cloudflared(tmp_path, config, args) assert f"Matched rule #{rule_num}" .encode() in match_rule.stdout diff --git a/component-tests/test_logging.py b/component-tests/test_logging.py index f0898326..23964220 100644 --- a/component-tests/test_logging.py +++ b/component-tests/test_logging.py @@ -69,7 +69,7 @@ class TestLogging: max_batches = 3 batch_requests = 1000 for _ in range(max_batches): - send_requests(config.named_tunnel_config.get_url(), + send_requests(config.get_url(), batch_requests, require_ok=False) files = os.listdir(log_dir) if len(files) == 2: diff --git a/component-tests/test_reconnect.py b/component-tests/test_reconnect.py new file mode 100644 index 00000000..27caf2a6 --- /dev/null +++ b/component-tests/test_reconnect.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +import copy + +from retrying import retry +from time import sleep + +from util import start_cloudflared, wait_tunnel_ready, check_tunnel_not_ready, send_requests + + +class TestReconnect(): + default_ha_conns = 4 + default_reconnect_secs = 5 + extra_config = { + "stdin-control": True, + } + + def test_named_reconnect(self, tmp_path, component_tests_config): + config = component_tests_config(self.extra_config) + with start_cloudflared(tmp_path, config, new_process=True, allow_input=True) as cloudflared: + # Repeat the test multiple times because some issues only occur after multiple reconnects + self.assert_reconnect(config, cloudflared, 5) + + def test_classic_reconnect(self, tmp_path, component_tests_config): + extra_config = copy.copy(self.extra_config) + extra_config["hello-world"] = True + config = component_tests_config( + additional_config=extra_config, named_tunnel=False) + with start_cloudflared(tmp_path, config, cfd_args=[], new_process=True, allow_input=True) as cloudflared: + self.assert_reconnect(config, cloudflared, 1) + + def send_reconnect(self, cloudflared, secs): + # Although it is recommended to use the Popen.communicate method, we cannot + # use it because it blocks on reading stdout and stderr until EOF is reached + cloudflared.stdin.write(f"reconnect {secs}s\n".encode()) + cloudflared.stdin.flush() + + def assert_reconnect(self, config, cloudflared, repeat): + wait_tunnel_ready() + for _ in range(repeat): + for i in range(self.default_ha_conns): + self.send_reconnect(cloudflared, self.default_reconnect_secs) + expect_connections = self.default_ha_conns-i-1 + if expect_connections > 0: + wait_tunnel_ready(expect_connections=expect_connections) + else: + check_tunnel_not_ready() + + sleep(self.default_reconnect_secs + 10) + wait_tunnel_ready() + send_requests(config.get_url(), 1) diff --git a/component-tests/util.py b/component-tests/util.py index dff80f0b..d3de533d 100644 --- a/component-tests/util.py +++ b/component-tests/util.py @@ -19,37 +19,46 @@ def write_config(path, config): return config_path -def start_cloudflared(path, component_test_config, cfd_args=["run"], cfd_pre_args=["tunnel"], new_process=False, classic=False, capture_output=True): - if classic: - config = component_test_config.classic_tunnel_config.full_config - else: - config = component_test_config.named_tunnel_config.full_config - config_path = write_config(path, config) - cmd = [component_test_config.cloudflared_binary] +def start_cloudflared(path, config, cfd_args=["run"], cfd_pre_args=["tunnel"], new_process=False, allow_input=False, capture_output=True): + config_path = write_config(path, config.full_config) + cmd = [config.cloudflared_binary] cmd += cfd_pre_args cmd += ["--config", config_path] cmd += cfd_args LOGGER.info(f"Run cmd {cmd} with config {config}") if new_process: - return run_cloudflared_background(cmd, capture_output) + return run_cloudflared_background(cmd, allow_input, capture_output) # By setting check=True, it will raise an exception if the process exits with non-zero exit code return subprocess.run(cmd, check=True, capture_output=capture_output) @contextmanager -def run_cloudflared_background(cmd, capture_output): +def run_cloudflared_background(cmd, allow_input, capture_output): output = subprocess.PIPE if capture_output else subprocess.DEVNULL + stdin = subprocess.PIPE if allow_input else None try: - cfd = subprocess.Popen(cmd, stdout=output, stderr=output) + cfd = subprocess.Popen(cmd, stdin=stdin, stdout=output, stderr=output) yield cfd finally: cfd.terminate() @retry(stop_max_attempt_number=MAX_RETRIES, wait_fixed=BACKOFF_SECS * 1000) -def wait_tunnel_ready(): +def wait_tunnel_ready(expect_connections=4): url = f'http://localhost:{METRICS_PORT}/ready' - send_requests(url, 1) + + with requests.Session() as s: + resp = send_request(s, url, True) + assert resp.json()[ + "readyConnections"] == expect_connections, f"Ready endpoint returned {resp.json()} but we expect {expect_connections} ready connections" + + +@retry(stop_max_attempt_number=MAX_RETRIES, wait_fixed=BACKOFF_SECS * 1000) +def check_tunnel_not_ready(): + url = f'http://localhost:{METRICS_PORT}/ready' + + resp = requests.get(url, timeout=1) + assert resp.status_code == 503, f"Expect {url} returns 503, got {resp.status_code}" # In some cases we don't need to check response status, such as when sending batch requests to generate logs @@ -58,8 +67,8 @@ def send_requests(url, count, require_ok=True): errors = 0 with requests.Session() as s: for _ in range(count): - ok = send_request(s, url, require_ok) - if not ok: + resp = send_request(s, url, require_ok) + if resp is None: errors += 1 sleep(0.01) if errors > 0: @@ -72,4 +81,4 @@ def send_request(session, url, require_ok): resp = session.get(url, timeout=BACKOFF_SECS) if require_ok: assert resp.status_code == 200, f"{url} returned {resp}" - return True if resp.status_code == 200 else False + return resp if resp.status_code == 200 else None