""" Cloudflared Integration tests """ import unittest import subprocess import os import tempfile from contextlib import contextmanager from pexpect import pxssh class TestSSHBase(unittest.TestCase): """ SSH test base class containing constants and helper funcs """ HOSTNAME = os.environ["SSH_HOSTNAME"] SSH_USER = os.environ["SSH_USER"] SSH_TARGET = f"{SSH_USER}@{HOSTNAME}" AUTHORIZED_KEYS_SSH_CONFIG = os.environ["AUTHORIZED_KEYS_SSH_CONFIG"] SHORT_LIVED_CERT_SSH_CONFIG = os.environ["SHORT_LIVED_CERT_SSH_CONFIG"] SSH_OPTIONS = {"StrictHostKeyChecking": "no"} @classmethod def get_ssh_command(cls, pty=True): """ Return ssh command arg list. If pty is true, a PTY is forced for the session. """ cmd = [ "ssh", "-o", "StrictHostKeyChecking=no", "-F", cls.AUTHORIZED_KEYS_SSH_CONFIG, cls.SSH_TARGET, ] if not pty: cmd += ["-T"] else: cmd += ["-tt"] return cmd @classmethod @contextmanager def ssh_session_manager(cls, *args, **kwargs): """ Context manager for interacting with a pxssh session. Disables pty echo on the remote server and ensures session is terminated afterward. """ session = pxssh.pxssh(options=cls.SSH_OPTIONS) session.login( cls.HOSTNAME, username=cls.SSH_USER, original_prompt=r"[#@$]", ssh_config=kwargs.get("ssh_config", cls.AUTHORIZED_KEYS_SSH_CONFIG), ssh_tunnels=kwargs.get("ssh_tunnels", {}), ) try: session.sendline("stty -echo") session.prompt() yield session finally: session.logout() @staticmethod def get_command_output(session, cmd): """ Executes command on remote ssh server and waits for prompt. Returns command output """ session.sendline(cmd) session.prompt() return session.before.decode().strip() def exec_command(self, cmd, shell=False): """ Executes command locally. Raises Assertion error for non-zero return code. Returns stdout and stderr """ proc = subprocess.Popen( cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=shell ) raw_out, raw_err = proc.communicate() out = raw_out.decode() err = raw_err.decode() self.assertEqual(proc.returncode, 0, msg=f"stdout: {out} stderr: {err}") return out.strip(), err.strip() class TestSSHCommandExec(TestSSHBase): """ Tests inline ssh command exec """ # Name of file to be downloaded over SCP on remote server. REMOTE_SCP_FILENAME = os.environ["REMOTE_SCP_FILENAME"] @classmethod def get_scp_base_command(cls): return [ "scp", "-o", "StrictHostKeyChecking=no", "-v", "-F", cls.AUTHORIZED_KEYS_SSH_CONFIG, ] @unittest.skip( "This creates files on the remote. Should be skipped until server is dockerized." ) def test_verbose_scp_sink_mode(self): with tempfile.NamedTemporaryFile() as fl: self.exec_command( self.get_scp_base_command() + [fl.name, f"{self.SSH_TARGET}:"] ) def test_verbose_scp_source_mode(self): with tempfile.TemporaryDirectory() as tmpdirname: self.exec_command( self.get_scp_base_command() + [f"{self.SSH_TARGET}:{self.REMOTE_SCP_FILENAME}", tmpdirname] ) local_filename = os.path.join(tmpdirname, self.REMOTE_SCP_FILENAME) self.assertTrue(os.path.exists(local_filename)) self.assertTrue(os.path.getsize(local_filename) > 0) def test_pty_command(self): base_cmd = self.get_ssh_command() out, _ = self.exec_command(base_cmd + ["whoami"]) self.assertEqual(out.strip().lower(), self.SSH_USER.lower()) out, _ = self.exec_command(base_cmd + ["tty"]) self.assertNotEqual(out, "not a tty") def test_non_pty_command(self): base_cmd = self.get_ssh_command(pty=False) out, _ = self.exec_command(base_cmd + ["whoami"]) self.assertEqual(out.strip().lower(), self.SSH_USER.lower()) out, _ = self.exec_command(base_cmd + ["tty"]) self.assertEqual(out, "not a tty") class TestSSHShell(TestSSHBase): """ Tests interactive SSH shell """ # File path to a file on the remote server with root only read privileges. ROOT_ONLY_TEST_FILE_PATH = os.environ["ROOT_ONLY_TEST_FILE_PATH"] def test_ssh_pty(self): with self.ssh_session_manager() as session: # Test shell launched as correct user username = self.get_command_output(session, "whoami") self.assertEqual(username.lower(), self.SSH_USER.lower()) # Test USER env variable set user_var = self.get_command_output(session, "echo $USER") self.assertEqual(user_var.lower(), self.SSH_USER.lower()) # Test HOME env variable set to true user home. home_env = self.get_command_output(session, "echo $HOME") pwd = self.get_command_output(session, "pwd") self.assertEqual(pwd, home_env) # Test shell launched in correct user home dir. self.assertIn(username, pwd) # Ensure shell launched with correct user's permissions and privs. # Can't read root owned 0700 files. output = self.get_command_output( session, f"cat {self.ROOT_ONLY_TEST_FILE_PATH}" ) self.assertIn("Permission denied", output) def test_short_lived_cert_auth(self): with self.ssh_session_manager( ssh_config=self.SHORT_LIVED_CERT_SSH_CONFIG ) as session: username = self.get_command_output(session, "whoami") self.assertEqual(username.lower(), self.SSH_USER.lower()) unittest.main()