196 lines
6.0 KiB
Python
196 lines
6.0 KiB
Python
"""
|
|
Cloudflared Integration tests
|
|
"""
|
|
|
|
import unittest
|
|
import subprocess
|
|
import os
|
|
import tempfile
|
|
from contextlib import contextmanager
|
|
|
|
from pexpect import pxssh
|
|
|
|
|
|
class TestSSHBase(unittest.TestCase):
|
|
"""
|
|
SSH test base class containing constants and helper funcs
|
|
"""
|
|
|
|
HOSTNAME = os.environ["SSH_HOSTNAME"]
|
|
SSH_USER = os.environ["SSH_USER"]
|
|
SSH_TARGET = f"{SSH_USER}@{HOSTNAME}"
|
|
AUTHORIZED_KEYS_SSH_CONFIG = os.environ["AUTHORIZED_KEYS_SSH_CONFIG"]
|
|
SHORT_LIVED_CERT_SSH_CONFIG = os.environ["SHORT_LIVED_CERT_SSH_CONFIG"]
|
|
SSH_OPTIONS = {"StrictHostKeyChecking": "no"}
|
|
|
|
@classmethod
|
|
def get_ssh_command(cls, pty=True):
|
|
"""
|
|
Return ssh command arg list. If pty is true, a PTY is forced for the session.
|
|
"""
|
|
cmd = [
|
|
"ssh",
|
|
"-o",
|
|
"StrictHostKeyChecking=no",
|
|
"-F",
|
|
cls.AUTHORIZED_KEYS_SSH_CONFIG,
|
|
cls.SSH_TARGET,
|
|
]
|
|
if not pty:
|
|
cmd += ["-T"]
|
|
else:
|
|
cmd += ["-tt"]
|
|
|
|
return cmd
|
|
|
|
@classmethod
|
|
@contextmanager
|
|
def ssh_session_manager(cls, *args, **kwargs):
|
|
"""
|
|
Context manager for interacting with a pxssh session.
|
|
Disables pty echo on the remote server and ensures session is terminated afterward.
|
|
"""
|
|
session = pxssh.pxssh(options=cls.SSH_OPTIONS)
|
|
|
|
session.login(
|
|
cls.HOSTNAME,
|
|
username=cls.SSH_USER,
|
|
original_prompt=r"[#@$]",
|
|
ssh_config=kwargs.get("ssh_config", cls.AUTHORIZED_KEYS_SSH_CONFIG),
|
|
ssh_tunnels=kwargs.get("ssh_tunnels", {}),
|
|
)
|
|
try:
|
|
session.sendline("stty -echo")
|
|
session.prompt()
|
|
yield session
|
|
finally:
|
|
session.logout()
|
|
|
|
@staticmethod
|
|
def get_command_output(session, cmd):
|
|
"""
|
|
Executes command on remote ssh server and waits for prompt.
|
|
Returns command output
|
|
"""
|
|
session.sendline(cmd)
|
|
session.prompt()
|
|
return session.before.decode().strip()
|
|
|
|
def exec_command(self, cmd, shell=False):
|
|
"""
|
|
Executes command locally. Raises Assertion error for non-zero return code.
|
|
Returns stdout and stderr
|
|
"""
|
|
proc = subprocess.Popen(
|
|
cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=shell
|
|
)
|
|
raw_out, raw_err = proc.communicate()
|
|
|
|
out = raw_out.decode()
|
|
err = raw_err.decode()
|
|
self.assertEqual(proc.returncode, 0, msg=f"stdout: {out} stderr: {err}")
|
|
return out.strip(), err.strip()
|
|
|
|
|
|
class TestSSHCommandExec(TestSSHBase):
|
|
"""
|
|
Tests inline ssh command exec
|
|
"""
|
|
|
|
# Name of file to be downloaded over SCP on remote server.
|
|
REMOTE_SCP_FILENAME = os.environ["REMOTE_SCP_FILENAME"]
|
|
|
|
@classmethod
|
|
def get_scp_base_command(cls):
|
|
return [
|
|
"scp",
|
|
"-o",
|
|
"StrictHostKeyChecking=no",
|
|
"-v",
|
|
"-F",
|
|
cls.AUTHORIZED_KEYS_SSH_CONFIG,
|
|
]
|
|
|
|
@unittest.skip(
|
|
"This creates files on the remote. Should be skipped until server is dockerized."
|
|
)
|
|
def test_verbose_scp_sink_mode(self):
|
|
with tempfile.NamedTemporaryFile() as fl:
|
|
self.exec_command(
|
|
self.get_scp_base_command() + [fl.name, f"{self.SSH_TARGET}:"]
|
|
)
|
|
|
|
def test_verbose_scp_source_mode(self):
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
self.exec_command(
|
|
self.get_scp_base_command()
|
|
+ [f"{self.SSH_TARGET}:{self.REMOTE_SCP_FILENAME}", tmpdirname]
|
|
)
|
|
local_filename = os.path.join(tmpdirname, self.REMOTE_SCP_FILENAME)
|
|
|
|
self.assertTrue(os.path.exists(local_filename))
|
|
self.assertTrue(os.path.getsize(local_filename) > 0)
|
|
|
|
def test_pty_command(self):
|
|
base_cmd = self.get_ssh_command()
|
|
|
|
out, _ = self.exec_command(base_cmd + ["whoami"])
|
|
self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
|
|
|
|
out, _ = self.exec_command(base_cmd + ["tty"])
|
|
self.assertNotEqual(out, "not a tty")
|
|
|
|
def test_non_pty_command(self):
|
|
base_cmd = self.get_ssh_command(pty=False)
|
|
|
|
out, _ = self.exec_command(base_cmd + ["whoami"])
|
|
self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
|
|
|
|
out, _ = self.exec_command(base_cmd + ["tty"])
|
|
self.assertEqual(out, "not a tty")
|
|
|
|
|
|
class TestSSHShell(TestSSHBase):
|
|
"""
|
|
Tests interactive SSH shell
|
|
"""
|
|
|
|
# File path to a file on the remote server with root only read privileges.
|
|
ROOT_ONLY_TEST_FILE_PATH = os.environ["ROOT_ONLY_TEST_FILE_PATH"]
|
|
|
|
def test_ssh_pty(self):
|
|
with self.ssh_session_manager() as session:
|
|
|
|
# Test shell launched as correct user
|
|
username = self.get_command_output(session, "whoami")
|
|
self.assertEqual(username.lower(), self.SSH_USER.lower())
|
|
|
|
# Test USER env variable set
|
|
user_var = self.get_command_output(session, "echo $USER")
|
|
self.assertEqual(user_var.lower(), self.SSH_USER.lower())
|
|
|
|
# Test HOME env variable set to true user home.
|
|
home_env = self.get_command_output(session, "echo $HOME")
|
|
pwd = self.get_command_output(session, "pwd")
|
|
self.assertEqual(pwd, home_env)
|
|
|
|
# Test shell launched in correct user home dir.
|
|
self.assertIn(username, pwd)
|
|
|
|
# Ensure shell launched with correct user's permissions and privs.
|
|
# Can't read root owned 0700 files.
|
|
output = self.get_command_output(
|
|
session, f"cat {self.ROOT_ONLY_TEST_FILE_PATH}"
|
|
)
|
|
self.assertIn("Permission denied", output)
|
|
|
|
def test_short_lived_cert_auth(self):
|
|
with self.ssh_session_manager(
|
|
ssh_config=self.SHORT_LIVED_CERT_SSH_CONFIG
|
|
) as session:
|
|
username = self.get_command_output(session, "whoami")
|
|
self.assertEqual(username.lower(), self.SSH_USER.lower())
|
|
|
|
|
|
unittest.main()
|