diff --git a/setup.py b/setup.py index 074de8a1..16d4c300 100755 --- a/setup.py +++ b/setup.py @@ -11,9 +11,7 @@ "port-for>=0.4", "six>=1.9.0", "psutil", - "packaging", - "fabric", - "sshtunnel" + "packaging" ] # Add compatibility enum class @@ -29,7 +27,7 @@ readme = f.read() setup( - version='1.9.1', + version='1.9.2', name='testgres', packages=['testgres', 'testgres.operations'], description='Testing utility for PostgreSQL and its extensions', diff --git a/testgres/exceptions.py b/testgres/exceptions.py index 6832c788..ee329031 100644 --- a/testgres/exceptions.py +++ b/testgres/exceptions.py @@ -32,7 +32,16 @@ def __str__(self): if self.out: msg.append(u'----\n{}'.format(self.out)) - return six.text_type('\n').join(msg) + return self.convert_and_join(msg) + + @staticmethod + def convert_and_join(msg_list): + # Convert each byte element in the list to str + str_list = [six.text_type(item, 'utf-8') if isinstance(item, bytes) else six.text_type(item) for item in + msg_list] + + # Join the list into a single string with the specified delimiter + return six.text_type('\n').join(str_list) @six.python_2_unicode_compatible diff --git a/testgres/node.py b/testgres/node.py index 6483514b..84c25327 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -1371,7 +1371,7 @@ def pgbench(self, # should be the last one _params.append(dbname) - proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, proc=True) + proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, get_process=True) return proc diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index 318ae675..a692750e 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -18,7 +18,6 @@ from distutils.spawn import find_executable from distutils import rmtree - CMD_TIMEOUT_SEC = 60 error_markers = [b'error', b'Permission denied', b'fatal'] @@ -37,7 +36,8 @@ def __init__(self, conn_params=None): # Command execution def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=False, text=False, - input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None): + input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + get_process=None, timeout=None): """ Execute a command in a subprocess. @@ -69,9 +69,14 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, stdout=stdout, stderr=stderr, ) - if proc: + if get_process: return process - result, error = process.communicate(input) + + try: + result, error = process.communicate(input, timeout=timeout) + except subprocess.TimeoutExpired: + process.kill() + raise ExecUtilException("Command timed out after {} seconds.".format(timeout)) exit_status = process.returncode error_found = exit_status != 0 or any(marker in error for marker in error_markers) diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 5d9bfe7e..421c0a6d 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -3,17 +3,19 @@ import os import subprocess import tempfile -import time -import sshtunnel +# we support both pg8000 and psycopg2 +try: + import psycopg2 as pglib +except ImportError: + try: + import pg8000 as pglib + except ImportError: + raise ImportError("You must have psycopg2 or pg8000 modules installed") from ..exceptions import ExecUtilException from .os_ops import OsOperations, ConnectionParams -from .os_ops import pglib - -sshtunnel.SSH_TIMEOUT = 5.0 -sshtunnel.TUNNEL_TIMEOUT = 5.0 ConsoleEncoding = locale.getdefaultlocale()[1] if not ConsoleEncoding: @@ -50,21 +52,28 @@ def __init__(self, conn_params: ConnectionParams): self.remote = True self.username = conn_params.username or self.get_user() self.add_known_host(self.host) + self.tunnel_process = None def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - self.close_tunnel() + self.close_ssh_tunnel() - def close_tunnel(self): - if getattr(self, 'tunnel', None): - self.tunnel.stop(force=True) - start_time = time.time() - while self.tunnel.is_active: - if time.time() - start_time > sshtunnel.TUNNEL_TIMEOUT: - break - time.sleep(0.5) + def establish_ssh_tunnel(self, local_port, remote_port): + """ + Establish an SSH tunnel from a local port to a remote PostgreSQL port. + """ + ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"] + self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300) + + def close_ssh_tunnel(self): + if hasattr(self, 'tunnel_process'): + self.tunnel_process.terminate() + self.tunnel_process.wait() + del self.tunnel_process + else: + print("No active tunnel to close.") def add_known_host(self, host): cmd = 'ssh-keyscan -H %s >> /home/%s/.ssh/known_hosts' % (host, os.getlogin()) @@ -78,21 +87,29 @@ def add_known_host(self, host): raise ExecUtilException(message="Failed to add %s to known_hosts. Error: %s" % (host, str(e)), command=cmd, exit_code=e.returncode, out=e.stderr) - def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, + def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None, - stderr=None, proc=None): + stderr=None, get_process=None, timeout=None): """ Execute a command in the SSH session. Args: - cmd (str): The command to be executed. """ + ssh_cmd = [] if isinstance(cmd, str): ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key, cmd] elif isinstance(cmd, list): ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key] + cmd process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if get_process: + return process + + try: + result, error = process.communicate(input, timeout=timeout) + except subprocess.TimeoutExpired: + process.kill() + raise ExecUtilException("Command timed out after {} seconds.".format(timeout)) - result, error = process.communicate(input) exit_status = process.returncode if encoding: @@ -372,41 +389,19 @@ def get_process_children(self, pid): raise ExecUtilException(f"Error in getting process children. Error: {result.stderr}") # Database control - def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, ssh_key=None): + def db_connect(self, dbname, user, password=None, host="localhost", port=5432): """ - Connects to a PostgreSQL database on the remote system. - Args: - - dbname (str): The name of the database to connect to. - - user (str): The username for the database connection. - - password (str, optional): The password for the database connection. Defaults to None. - - host (str, optional): The IP address of the remote system. Defaults to "localhost". - - port (int, optional): The port number of the PostgreSQL service. Defaults to 5432. - - This function establishes a connection to a PostgreSQL database on the remote system using the specified - parameters. It returns a connection object that can be used to interact with the database. + Established SSH tunnel and Connects to a PostgreSQL """ - self.close_tunnel() - self.tunnel = sshtunnel.open_tunnel( - (self.host, 22), # Remote server IP and SSH port - ssh_username=self.username, - ssh_pkey=self.ssh_key, - remote_bind_address=(self.host, port), # PostgreSQL server IP and PostgreSQL port - local_bind_address=('localhost', 0) - # Local machine IP and available port (0 means it will pick any available port) - ) - self.tunnel.start() - + self.establish_ssh_tunnel(local_port=port, remote_port=5432) try: - # Use localhost and self.tunnel.local_bind_port to connect conn = pglib.connect( - host='localhost', # Connect to localhost - port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel + host=host, + port=port, database=dbname, - user=user or self.username, - password=password + user=user, + password=password, ) - return conn except Exception as e: - self.tunnel.stop() - raise ExecUtilException("Could not create db tunnel. {}".format(e)) + raise Exception(f"Could not connect to the database. Error: {e}") diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 44e77fbd..1042f3c4 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -735,9 +735,10 @@ def test_pgbench(self): options=['-q']).pgbench_run(time=2) # run TPC-B benchmark - out = node.pgbench(stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - options=['-T3']) + proc = node.pgbench(stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + options=['-T3']) + out = proc.communicate()[0] self.assertTrue(b'tps = ' in out) def test_pg_config(self):
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: