cloudflared-mirror/ssh_server_tests/tests.py

196 lines
6.0 KiB
Python
Raw Permalink Normal View History

2019-09-18 16:33:13 +00:00
"""
Cloudflared Integration tests
"""
import unittest
import subprocess
import os
import tempfile
from contextlib import contextmanager
from pexpect import pxssh
class TestSSHBase(unittest.TestCase):
"""
SSH test base class containing constants and helper funcs
"""
HOSTNAME = os.environ["SSH_HOSTNAME"]
SSH_USER = os.environ["SSH_USER"]
SSH_TARGET = f"{SSH_USER}@{HOSTNAME}"
AUTHORIZED_KEYS_SSH_CONFIG = os.environ["AUTHORIZED_KEYS_SSH_CONFIG"]
SHORT_LIVED_CERT_SSH_CONFIG = os.environ["SHORT_LIVED_CERT_SSH_CONFIG"]
SSH_OPTIONS = {"StrictHostKeyChecking": "no"}
@classmethod
def get_ssh_command(cls, pty=True):
"""
Return ssh command arg list. If pty is true, a PTY is forced for the session.
"""
cmd = [
"ssh",
"-o",
"StrictHostKeyChecking=no",
"-F",
cls.AUTHORIZED_KEYS_SSH_CONFIG,
cls.SSH_TARGET,
]
if not pty:
cmd += ["-T"]
else:
cmd += ["-tt"]
return cmd
@classmethod
@contextmanager
def ssh_session_manager(cls, *args, **kwargs):
"""
Context manager for interacting with a pxssh session.
Disables pty echo on the remote server and ensures session is terminated afterward.
"""
session = pxssh.pxssh(options=cls.SSH_OPTIONS)
session.login(
cls.HOSTNAME,
username=cls.SSH_USER,
original_prompt=r"[#@$]",
ssh_config=kwargs.get("ssh_config", cls.AUTHORIZED_KEYS_SSH_CONFIG),
ssh_tunnels=kwargs.get("ssh_tunnels", {}),
)
try:
session.sendline("stty -echo")
session.prompt()
yield session
finally:
session.logout()
@staticmethod
def get_command_output(session, cmd):
"""
Executes command on remote ssh server and waits for prompt.
Returns command output
"""
session.sendline(cmd)
session.prompt()
return session.before.decode().strip()
def exec_command(self, cmd, shell=False):
"""
Executes command locally. Raises Assertion error for non-zero return code.
Returns stdout and stderr
"""
proc = subprocess.Popen(
cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=shell
)
raw_out, raw_err = proc.communicate()
out = raw_out.decode()
err = raw_err.decode()
self.assertEqual(proc.returncode, 0, msg=f"stdout: {out} stderr: {err}")
return out.strip(), err.strip()
class TestSSHCommandExec(TestSSHBase):
"""
Tests inline ssh command exec
"""
# Name of file to be downloaded over SCP on remote server.
REMOTE_SCP_FILENAME = os.environ["REMOTE_SCP_FILENAME"]
@classmethod
def get_scp_base_command(cls):
return [
"scp",
"-o",
"StrictHostKeyChecking=no",
"-v",
"-F",
cls.AUTHORIZED_KEYS_SSH_CONFIG,
]
@unittest.skip(
"This creates files on the remote. Should be skipped until server is dockerized."
)
def test_verbose_scp_sink_mode(self):
with tempfile.NamedTemporaryFile() as fl:
self.exec_command(
self.get_scp_base_command() + [fl.name, f"{self.SSH_TARGET}:"]
)
def test_verbose_scp_source_mode(self):
with tempfile.TemporaryDirectory() as tmpdirname:
self.exec_command(
self.get_scp_base_command()
+ [f"{self.SSH_TARGET}:{self.REMOTE_SCP_FILENAME}", tmpdirname]
)
local_filename = os.path.join(tmpdirname, self.REMOTE_SCP_FILENAME)
self.assertTrue(os.path.exists(local_filename))
self.assertTrue(os.path.getsize(local_filename) > 0)
def test_pty_command(self):
base_cmd = self.get_ssh_command()
out, _ = self.exec_command(base_cmd + ["whoami"])
self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
out, _ = self.exec_command(base_cmd + ["tty"])
self.assertNotEqual(out, "not a tty")
def test_non_pty_command(self):
base_cmd = self.get_ssh_command(pty=False)
out, _ = self.exec_command(base_cmd + ["whoami"])
self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
out, _ = self.exec_command(base_cmd + ["tty"])
self.assertEqual(out, "not a tty")
class TestSSHShell(TestSSHBase):
"""
Tests interactive SSH shell
"""
# File path to a file on the remote server with root only read privileges.
ROOT_ONLY_TEST_FILE_PATH = os.environ["ROOT_ONLY_TEST_FILE_PATH"]
def test_ssh_pty(self):
with self.ssh_session_manager() as session:
# Test shell launched as correct user
username = self.get_command_output(session, "whoami")
self.assertEqual(username.lower(), self.SSH_USER.lower())
# Test USER env variable set
user_var = self.get_command_output(session, "echo $USER")
self.assertEqual(user_var.lower(), self.SSH_USER.lower())
# Test HOME env variable set to true user home.
home_env = self.get_command_output(session, "echo $HOME")
pwd = self.get_command_output(session, "pwd")
self.assertEqual(pwd, home_env)
# Test shell launched in correct user home dir.
self.assertIn(username, pwd)
# Ensure shell launched with correct user's permissions and privs.
# Cant read root owned 0700 files.
output = self.get_command_output(
session, f"cat {self.ROOT_ONLY_TEST_FILE_PATH}"
)
self.assertIn("Permission denied", output)
def test_short_lived_cert_auth(self):
with self.ssh_session_manager(
ssh_config=self.SHORT_LIVED_CERT_SSH_CONFIG
) as session:
username = self.get_command_output(session, "whoami")
self.assertEqual(username.lower(), self.SSH_USER.lower())
unittest.main()