diff --git a/README.md b/README.md index 29b974dc..a2a0ec7e 100644 --- a/README.md +++ b/README.md @@ -176,7 +176,7 @@ the configuration file, which means that they should be called before `append_co ### Remote mode Testgres supports the creation of PostgreSQL nodes on a remote host. This is useful when you want to run distributed tests involving multiple nodes spread across different machines. -To use this feature, you need to use the RemoteOperations class. +To use this feature, you need to use the RemoteOperations class. This feature is only supported with Linux. Here is an example of how you might set this up: ```python diff --git a/setup.py b/setup.py index 8cb0f70a..074de8a1 100755 --- a/setup.py +++ b/setup.py @@ -12,7 +12,6 @@ "six>=1.9.0", "psutil", "packaging", - "paramiko", "fabric", "sshtunnel" ] @@ -30,7 +29,7 @@ readme = f.read() setup( - version='1.9.0', + version='1.9.1', name='testgres', packages=['testgres', 'testgres.operations'], description='Testing utility for PostgreSQL and its extensions', diff --git a/testgres/__init__.py b/testgres/__init__.py index b63c7df1..383daf2d 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -46,6 +46,8 @@ First, \ Any +from .config import testgres_config + from .operations.os_ops import OsOperations, ConnectionParams from .operations.local_ops import LocalOperations from .operations.remote_ops import RemoteOperations @@ -53,7 +55,7 @@ __all__ = [ "get_new_node", "get_remote_node", - "NodeBackup", + "NodeBackup", "testgres_config", "TestgresConfig", "configure_testgres", "scoped_config", "push_config", "pop_config", "NodeConnection", "DatabaseError", "InternalError", "ProgrammingError", "OperationalError", "TestgresException", "ExecUtilException", "QueryException", "TimeoutException", "CatchUpException", "StartNodeException", "InitNodeException", "BackupException", diff --git a/testgres/cache.py b/testgres/cache.py index bf8658c9..21198e83 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -57,7 +57,9 @@ def call_initdb(initdb_dir, log=logfile): # our initdb caching mechanism breaks this contract. pg_control = os.path.join(data_dir, XLOG_CONTROL_FILE) system_id = generate_system_id() - os_ops.write(pg_control, system_id, truncate=True, binary=True, read_and_write=True) + cur_pg_control = os_ops.read(pg_control, binary=True) + new_pg_control = system_id + cur_pg_control[len(system_id):] + os_ops.write(pg_control, new_pg_control, truncate=True, binary=True, read_and_write=True) # XXX: build new WAL segment with our system id _params = [get_bin_path("pg_resetwal"), "-D", data_dir, "-f"] diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index 89071282..318ae675 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -198,9 +198,15 @@ def touch(self, filename): with open(filename, "a"): os.utime(filename, None) - def read(self, filename, encoding=None): - with open(filename, "r", encoding=encoding) as file: - return file.read() + def read(self, filename, encoding=None, binary=False): + mode = "rb" if binary else "r" + with open(filename, mode) as file: + content = file.read() + if binary: + return content + if isinstance(content, bytes): + return content.decode(encoding or 'utf-8') + return content def readlines(self, filename, num_lines=0, binary=False, encoding=None): """ diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 6815c7f1..5d9bfe7e 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -1,13 +1,12 @@ +import locale +import logging import os +import subprocess import tempfile import time -from typing import Optional import sshtunnel -import paramiko -from paramiko import SSHClient - from ..exceptions import ExecUtilException from .os_ops import OsOperations, ConnectionParams @@ -16,6 +15,9 @@ sshtunnel.SSH_TIMEOUT = 5.0 sshtunnel.TUNNEL_TIMEOUT = 5.0 +ConsoleEncoding = locale.getdefaultlocale()[1] +if not ConsoleEncoding: + ConsoleEncoding = 'UTF-8' error_markers = [b'error', b'Permission denied', b'fatal', b'No such file or directory'] @@ -31,33 +33,29 @@ def kill(self): def cmdline(self): command = "ps -p {} -o cmd --no-headers".format(self.pid) - stdin, stdout, stderr = self.ssh.exec_command(command) - cmdline = stdout.read().decode('utf-8').strip() + stdin, stdout, stderr = self.ssh.exec_command(command, verbose=True, encoding=ConsoleEncoding) + cmdline = stdout.strip() return cmdline.split() class RemoteOperations(OsOperations): def __init__(self, conn_params: ConnectionParams): + if os.name != "posix": + raise EnvironmentError("Remote operations are supported only on Linux!") + super().__init__(conn_params.username) self.conn_params = conn_params self.host = conn_params.host self.ssh_key = conn_params.ssh_key - self.ssh = self.ssh_connect() self.remote = True self.username = conn_params.username or self.get_user() - self.tunnel = None + self.add_known_host(self.host) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close_tunnel() - if getattr(self, 'ssh', None): - self.ssh.close() - - def __del__(self): - if getattr(self, 'ssh', None): - self.ssh.close() def close_tunnel(self): if getattr(self, 'tunnel', None): @@ -68,26 +66,17 @@ def close_tunnel(self): break time.sleep(0.5) - def ssh_connect(self) -> Optional[SSHClient]: - key = self._read_ssh_key() - ssh = paramiko.SSHClient() - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(self.host, username=self.username, pkey=key) - return ssh - - def _read_ssh_key(self): + def add_known_host(self, host): + cmd = 'ssh-keyscan -H %s >> /home/%s/.ssh/known_hosts' % (host, os.getlogin()) try: - with open(self.ssh_key, "r") as f: - key_data = f.read() - if "BEGIN OPENSSH PRIVATE KEY" in key_data: - key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) - else: - key = paramiko.RSAKey.from_private_key_file(self.ssh_key) - return key - except FileNotFoundError: - raise ExecUtilException(message="No such file or directory: '{}'".format(self.ssh_key)) - except Exception as e: - ExecUtilException(message="An error occurred while reading the ssh key: {}".format(e)) + subprocess.check_call( + cmd, + shell=True, + ) + logging.info("Successfully added %s to known_hosts." % host) + except subprocess.CalledProcessError as e: + raise ExecUtilException(message="Failed to add %s to known_hosts. Error: %s" % (host, str(e)), command=cmd, + exit_code=e.returncode, out=e.stderr) def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None, @@ -97,49 +86,34 @@ def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=Fa Args: - cmd (str): The command to be executed. """ - if self.ssh is None or not self.ssh.get_transport() or not self.ssh.get_transport().is_active(): - self.ssh = self.ssh_connect() - - if isinstance(cmd, list): - cmd = ' '.join(item.decode('utf-8') if isinstance(item, bytes) else item for item in cmd) - if input: - stdin, stdout, stderr = self.ssh.exec_command(cmd) - stdin.write(input) - stdin.flush() - else: - stdin, stdout, stderr = self.ssh.exec_command(cmd) - exit_status = 0 - if wait_exit: - exit_status = stdout.channel.recv_exit_status() + if isinstance(cmd, str): + ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key, cmd] + elif isinstance(cmd, list): + ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key] + cmd + process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + result, error = process.communicate(input) + exit_status = process.returncode if encoding: - result = stdout.read().decode(encoding) - error = stderr.read().decode(encoding) - else: - result = stdout.read() - error = stderr.read() + result = result.decode(encoding) + error = error.decode(encoding) if expect_error: raise Exception(result, error) - if encoding: - error_found = exit_status != 0 or any( - marker.decode(encoding) in error for marker in error_markers) + if not error: + error_found = 0 else: error_found = exit_status != 0 or any( - marker in error for marker in error_markers) + marker in error for marker in [b'error', b'Permission denied', b'fatal', b'No such file or directory']) if error_found: - if exit_status == 0: - exit_status = 1 - if encoding: - message = "Utility exited with non-zero code. Error: {}".format(error.decode(encoding)) - else: + if isinstance(error, bytes): message = b"Utility exited with non-zero code. Error: " + error - raise ExecUtilException(message=message, - command=cmd, - exit_code=exit_status, - out=result) + else: + message = f"Utility exited with non-zero code. Error: {error}" + raise ExecUtilException(message=message, command=cmd, exit_code=exit_status, out=result) if verbose: return exit_status, result, error @@ -154,7 +128,7 @@ def environ(self, var_name: str) -> str: - var_name (str): The name of the environment variable. """ cmd = "echo ${}".format(var_name) - return self.exec_command(cmd, encoding='utf-8').strip() + return self.exec_command(cmd, encoding=ConsoleEncoding).strip() def find_executable(self, executable): search_paths = self.environ("PATH") @@ -185,11 +159,11 @@ def set_env(self, var_name: str, var_val: str): # Get environment variables def get_user(self): - return self.exec_command("echo $USER", encoding='utf-8').strip() + return self.exec_command("echo $USER", encoding=ConsoleEncoding).strip() def get_name(self): cmd = 'python3 -c "import os; print(os.name)"' - return self.exec_command(cmd, encoding='utf-8').strip() + return self.exec_command(cmd, encoding=ConsoleEncoding).strip() # Work with dirs def makedirs(self, path, remove_existing=False): @@ -236,7 +210,7 @@ def listdir(self, path): return result.splitlines() def path_exists(self, path): - result = self.exec_command("test -e {}; echo $?".format(path), encoding='utf-8') + result = self.exec_command("test -e {}; echo $?".format(path), encoding=ConsoleEncoding) return int(result.strip()) == 0 @property @@ -257,22 +231,25 @@ def mkdtemp(self, prefix=None): - prefix (str): The prefix of the temporary directory name. """ if prefix: - temp_dir = self.exec_command("mktemp -d {}XXXXX".format(prefix), encoding='utf-8') + command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"] else: - temp_dir = self.exec_command("mktemp -d", encoding='utf-8') + command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", "mktemp -d"] - if temp_dir: + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + if result.returncode == 0: + temp_dir = result.stdout.strip() if not os.path.isabs(temp_dir): - temp_dir = os.path.join('/home', self.username, temp_dir.strip()) + temp_dir = os.path.join('/home', self.username, temp_dir) return temp_dir else: - raise ExecUtilException("Could not create temporary directory.") + raise ExecUtilException(f"Could not create temporary directory. Error: {result.stderr}") def mkstemp(self, prefix=None): if prefix: - temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding='utf-8') + temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding=ConsoleEncoding) else: - temp_dir = self.exec_command("mktemp", encoding='utf-8') + temp_dir = self.exec_command("mktemp", encoding=ConsoleEncoding) if temp_dir: if not os.path.isabs(temp_dir): @@ -289,20 +266,7 @@ def copytree(self, src, dst): return self.exec_command("cp -r {} {}".format(src, dst)) # Work with files - def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding='utf-8'): - """ - Write data to a file on a remote host - - Args: - - filename (str): The file path where the data will be written. - - data (bytes or str): The data to be written to the file. - - truncate (bool): If True, the file will be truncated before writing ('w' or 'wb' option); - if False (default), data will be appended ('a' or 'ab' option). - - binary (bool): If True, the data will be written in binary mode ('wb' or 'ab' option); - if False (default), the data will be written in text mode ('w' or 'a' option). - - read_and_write (bool): If True, the file will be opened with read and write permissions ('r+' option); - if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option). - """ + def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding=ConsoleEncoding): mode = "wb" if binary else "w" if not truncate: mode = "ab" if binary else "a" @@ -311,35 +275,29 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file: if not truncate: - with self.ssh_connect() as ssh: - sftp = ssh.open_sftp() - try: - sftp.get(filename, tmp_file.name) - tmp_file.seek(0, os.SEEK_END) - except FileNotFoundError: - pass # File does not exist yet, we'll create it - sftp.close() + scp_cmd = ['scp', '-i', self.ssh_key, f"{self.username}@{self.host}:{filename}", tmp_file.name] + subprocess.run(scp_cmd, check=False) # The file might not exist yet + tmp_file.seek(0, os.SEEK_END) + if isinstance(data, bytes) and not binary: data = data.decode(encoding) elif isinstance(data, str) and binary: data = data.encode(encoding) + if isinstance(data, list): - # ensure each line ends with a newline - data = [(s if isinstance(s, str) else s.decode('utf-8')).rstrip('\n') + '\n' for s in data] + data = [(s if isinstance(s, str) else s.decode(ConsoleEncoding)).rstrip('\n') + '\n' for s in data] tmp_file.writelines(data) else: tmp_file.write(data) + tmp_file.flush() - with self.ssh_connect() as ssh: - sftp = ssh.open_sftp() - remote_directory = os.path.dirname(filename) - try: - sftp.stat(remote_directory) - except IOError: - sftp.mkdir(remote_directory) - sftp.put(tmp_file.name, filename) - sftp.close() + scp_cmd = ['scp', '-i', self.ssh_key, tmp_file.name, f"{self.username}@{self.host}:{filename}"] + subprocess.run(scp_cmd, check=True) + + remote_directory = os.path.dirname(filename) + mkdir_cmd = ['ssh', '-i', self.ssh_key, f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"] + subprocess.run(mkdir_cmd, check=True) os.remove(tmp_file.name) @@ -359,7 +317,7 @@ def read(self, filename, binary=False, encoding=None): result = self.exec_command(cmd, encoding=encoding) if not binary and result: - result = result.decode(encoding or 'utf-8') + result = result.decode(encoding or ConsoleEncoding) return result @@ -372,7 +330,7 @@ def readlines(self, filename, num_lines=0, binary=False, encoding=None): result = self.exec_command(cmd, encoding=encoding) if not binary and result: - lines = result.decode(encoding or 'utf-8').splitlines() + lines = result.decode(encoding or ConsoleEncoding).splitlines() else: lines = result.splitlines() @@ -400,13 +358,18 @@ def kill(self, pid, signal): def get_pid(self): # Get current process id - return int(self.exec_command("echo $$", encoding='utf-8')) + return int(self.exec_command("echo $$", encoding=ConsoleEncoding)) def get_process_children(self, pid): - command = "pgrep -P {}".format(pid) - stdin, stdout, stderr = self.ssh.exec_command(command) - children = stdout.readlines() - return [PsUtilProcessProxy(self.ssh, int(child_pid.strip())) for child_pid in children] + command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"pgrep -P {pid}"] + + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + if result.returncode == 0: + children = result.stdout.strip().splitlines() + return [PsUtilProcessProxy(self, int(child_pid.strip())) for child_pid in children] + else: + raise ExecUtilException(f"Error in getting process children. Error: {result.stderr}") # Database control def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, ssh_key=None): @@ -424,18 +387,19 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s """ self.close_tunnel() self.tunnel = sshtunnel.open_tunnel( - (host, 22), # Remote server IP and SSH port - ssh_username=user or self.username, - ssh_pkey=ssh_key or self.ssh_key, - remote_bind_address=(host, port), # PostgreSQL server IP and PostgreSQL port - local_bind_address=('localhost', port) # Local machine IP and available port + (self.host, 22), # Remote server IP and SSH port + ssh_username=self.username, + ssh_pkey=self.ssh_key, + remote_bind_address=(self.host, port), # PostgreSQL server IP and PostgreSQL port + local_bind_address=('localhost', 0) + # Local machine IP and available port (0 means it will pick any available port) ) - self.tunnel.start() try: + # Use localhost and self.tunnel.local_bind_port to connect conn = pglib.connect( - host=host, # change to 'localhost' because we're connecting through a local ssh tunnel + host='localhost', # Connect to localhost port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel database=dbname, user=user or self.username, diff --git a/testgres/utils.py b/testgres/utils.py index 5e12eba9..b7df70d1 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -118,11 +118,13 @@ def get_bin_path(filename): return filename -def get_pg_config(pg_config_path=None): +def get_pg_config(pg_config_path=None, os_ops=None): """ Return output of pg_config (provided that it is installed). NOTE: this function caches the result by default (see GlobalConfig). """ + if os_ops: + tconf.os_ops = os_ops def cache_pg_config_data(cmd): # execute pg_config and get the output @@ -146,7 +148,7 @@ def cache_pg_config_data(cmd): _pg_config_data = {} # return cached data - if _pg_config_data: + if not pg_config_path and _pg_config_data: return _pg_config_data # try specified pg_config path or PG_CONFIG diff --git a/tests/test_remote.py b/tests/test_remote.py index 3794349c..2e0f0676 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -17,9 +17,6 @@ def setup(self): 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519') self.operations = RemoteOperations(conn_params) - yield - self.operations.__del__() - def test_exec_command_success(self): """ Test exec_command for successful command execution. diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index e8386383..44e77fbd 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -135,7 +135,6 @@ def test_init_after_cleanup(self): @unittest.skipUnless(util_exists('pg_resetwal'), 'might be missing') @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') def test_init_unique_system_id(self): - # FAIL # this function exists in PostgreSQL 9.6+ query = 'select system_identifier from pg_control_system()'
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: