diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 01251e1c..7c774a1c 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -1,5 +1,5 @@ -import logging import os +import socket import subprocess import tempfile import platform @@ -45,14 +45,18 @@ def __init__(self, conn_params: ConnectionParams): self.conn_params = conn_params self.host = conn_params.host self.ssh_key = conn_params.ssh_key + self.port = conn_params.port + self.ssh_args = [] if self.ssh_key: - self.ssh_cmd = ["-i", self.ssh_key] - else: - self.ssh_cmd = [] + self.ssh_args += ["-i", self.ssh_key] + if self.port: + self.ssh_args += ["-p", self.port] self.remote = True - self.username = conn_params.username or self.get_user() + self.username = conn_params.username + self.ssh_dest = f"{self.username}@{self.host}" if self.username else self.host self.add_known_host(self.host) self.tunnel_process = None + self.tunnel_port = None def __enter__(self): return self @@ -60,31 +64,25 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close_ssh_tunnel() - def establish_ssh_tunnel(self, local_port, remote_port): - """ - Establish an SSH tunnel from a local port to a remote PostgreSQL port. - """ - ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"] - self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300) + @staticmethod + def is_port_open(host, port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(1) # Таймаут для попытки соединения + try: + sock.connect((host, port)) + return True + except socket.error: + return False def close_ssh_tunnel(self): - if hasattr(self, 'tunnel_process'): + if self.tunnel_process: self.tunnel_process.terminate() self.tunnel_process.wait() + print("SSH tunnel closed.") del self.tunnel_process else: print("No active tunnel to close.") - def add_known_host(self, host): - known_hosts_path = os.path.expanduser("~/.ssh/known_hosts") - cmd = 'ssh-keyscan -H %s >> %s' % (host, known_hosts_path) - - try: - subprocess.check_call(cmd, shell=True) - logging.info("Successfully added %s to known_hosts." % host) - except subprocess.CalledProcessError as e: - raise Exception("Failed to add %s to known_hosts. Error: %s" % (host, str(e))) - def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None, stderr=None, get_process=None, timeout=None): @@ -95,9 +93,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, """ ssh_cmd = [] if isinstance(cmd, str): - ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + [cmd] + ssh_cmd = ['ssh'] + self.ssh_args + [self.ssh_dest, cmd] elif isinstance(cmd, list): - ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + cmd + ssh_cmd = ['ssh'] + self.ssh_args + [self.ssh_dest] + cmd process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if get_process: return process @@ -172,10 +170,6 @@ def set_env(self, var_name: str, var_val: str): """ return self.exec_command("export {}={}".format(var_name, var_val)) - # Get environment variables - def get_user(self): - return self.exec_command("echo $USER", encoding=get_default_encoding()).strip() - def get_name(self): cmd = 'python3 -c "import os; print(os.name)"' return self.exec_command(cmd, encoding=get_default_encoding()).strip() @@ -246,9 +240,9 @@ def mkdtemp(self, prefix=None): - prefix (str): The prefix of the temporary directory name. """ if prefix: - command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"] + command = ["ssh"] + self.ssh_args + [self.ssh_dest, f"mktemp -d {prefix}XXXXX"] else: - command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", "mktemp -d"] + command = ["ssh"] + self.ssh_args + [self.ssh_dest, "mktemp -d"] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) @@ -291,8 +285,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal mode = "r+b" if binary else "r+" with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file: + # Because in scp we set up port using -P option + scp_args = ['-P' if x == '-p' else x for x in self.ssh_args] + if not truncate: - scp_cmd = ['scp'] + self.ssh_cmd + [f"{self.username}@{self.host}:{filename}", tmp_file.name] + scp_cmd = ['scp'] + scp_args + [f"{self.ssh_dest}:{filename}", tmp_file.name] subprocess.run(scp_cmd, check=False) # The file might not exist yet tmp_file.seek(0, os.SEEK_END) @@ -308,11 +305,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal tmp_file.write(data) tmp_file.flush() - scp_cmd = ['scp'] + self.ssh_cmd + [tmp_file.name, f"{self.username}@{self.host}:{filename}"] + scp_cmd = ['scp'] + scp_args + [tmp_file.name, f"{self.ssh_dest}:{filename}"] subprocess.run(scp_cmd, check=True) remote_directory = os.path.dirname(filename) - mkdir_cmd = ['ssh'] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"] + mkdir_cmd = ['ssh'] + self.ssh_args + [self.ssh_dest, f"mkdir -p {remote_directory}"] subprocess.run(mkdir_cmd, check=True) os.remove(tmp_file.name) @@ -377,7 +374,7 @@ def get_pid(self): return int(self.exec_command("echo $$", encoding=get_default_encoding())) def get_process_children(self, pid): - command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"pgrep -P {pid}"] + command = ["ssh"] + self.ssh_args + [self.ssh_dest, f"pgrep -P {pid}"] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) @@ -389,18 +386,11 @@ def get_process_children(self, pid): # Database control def db_connect(self, dbname, user, password=None, host="localhost", port=5432): - """ - Established SSH tunnel and Connects to a PostgreSQL - """ - self.establish_ssh_tunnel(local_port=port, remote_port=5432) - try: - conn = pglib.connect( - host=host, - port=port, - database=dbname, - user=user, - password=password, - ) - return conn - except Exception as e: - raise Exception(f"Could not connect to the database. Error: {e}") + conn = pglib.connect( + host=host, + port=port, + database=dbname, + user=user, + password=password, + ) + return conn pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy