From 02c337551133888dfbde2ed0b2fba7cc0c65429e Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 10 Apr 2023 23:03:36 +0200 Subject: [PATCH 01/23] PBCKP-137 update node.py --- testgres/__init__.py | 4 +- testgres/node.py | 337 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 327 insertions(+), 14 deletions(-) diff --git a/testgres/__init__.py b/testgres/__init__.py index 9d5e37cf..1b33ba3b 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -32,7 +32,7 @@ ProcessType, \ DumpFormat -from .node import PostgresNode +from .node import PostgresNode, NodeApp from .utils import \ reserve_port, \ @@ -53,7 +53,7 @@ "NodeConnection", "DatabaseError", "InternalError", "ProgrammingError", "OperationalError", "TestgresException", "ExecUtilException", "QueryException", "TimeoutException", "CatchUpException", "StartNodeException", "InitNodeException", "BackupException", "XLogMethod", "IsolationLevel", "NodeStatus", "ProcessType", "DumpFormat", - "PostgresNode", + "PostgresNode", "NodeApp", "reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version", "First", "Any", ] diff --git a/testgres/node.py b/testgres/node.py index 378e6803..0d1232a2 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -2,6 +2,12 @@ import io import os +import random +import shutil +import signal +import threading +from queue import Queue + import psutil import subprocess import time @@ -11,6 +17,15 @@ except ImportError: from collections import Iterable +# we support both pg8000 and psycopg2 +try: + import psycopg2 as pglib +except ImportError: + try: + import pg8000 as pglib + except ImportError: + raise ImportError("You must have psycopg2 or pg8000 modules installed") + from shutil import rmtree from six import raise_from, iteritems, text_type from tempfile import mkstemp, mkdtemp @@ -86,6 +101,10 @@ from .backup import NodeBackup +InternalError = pglib.InternalError +ProgrammingError = pglib.ProgrammingError +OperationalError = pglib.OperationalError + class ProcessProxy(object): """ @@ -140,6 +159,9 @@ def __init__(self, name=None, port=None, base_dir=None): self.utils_log_name = self.utils_log_file self.pg_log_name = self.pg_log_file + # Node state + self.is_started = False + def __enter__(self): return self @@ -629,9 +651,38 @@ def get_control_data(self): return out_dict + def slow_start(self, replica=False): + """ + Starts the PostgreSQL instance and then polls the instance + until it reaches the expected state (primary or replica). The state is checked + using the pg_is_in_recovery() function. + + Args: + replica: If True, waits for the instance to be in recovery (i.e., replica mode). + If False, waits for the instance to be in primary mode. Default is False. + """ + self.start() + + if replica: + query = 'SELECT pg_is_in_recovery()' + else: + query = 'SELECT not pg_is_in_recovery()' + # Call poll_query_until until the expected value is returned + self.poll_query_until( + dbname="template1", + query=query, + suppress={pglib.InternalError, + QueryException, + pglib.ProgrammingError, + pglib.OperationalError}) + + def start(self, params=[], wait=True): """ - Start this node using pg_ctl. + Starts the PostgreSQL node using pg_ctl if node has not been started. + By default, it waits for the operation to complete before returning. + Optionally, it can return immediately without waiting for the start operation + to complete by setting the `wait` parameter to False. Args: params: additional arguments for pg_ctl. @@ -640,14 +691,16 @@ def start(self, params=[], wait=True): Returns: This instance of :class:`.PostgresNode`. """ + if self.is_started: + return self _params = [ - get_bin_path("pg_ctl"), - "-D", self.data_dir, - "-l", self.pg_log_file, - "-w" if wait else '-W', # --wait or --no-wait - "start" - ] + params # yapf: disable + get_bin_path("pg_ctl"), + "-D", self.data_dir, + "-l", self.pg_log_file, + "-w" if wait else '-W', # --wait or --no-wait + "start" + ] + params # yapf: disable try: execute_utility(_params, self.utils_log_file) @@ -657,20 +710,22 @@ def start(self, params=[], wait=True): raise_from(StartNodeException(msg, files), e) self._maybe_start_logger() - + self.is_started = True return self def stop(self, params=[], wait=True): """ - Stop this node using pg_ctl. + Stops the PostgreSQL node using pg_ctl if the node has been started. Args: - params: additional arguments for pg_ctl. - wait: wait until operation completes. + params: A list of additional arguments for pg_ctl. Defaults to None. + wait: If True, waits until the operation is complete. Defaults to True. Returns: This instance of :class:`.PostgresNode`. """ + if not self.is_started: + return self _params = [ get_bin_path("pg_ctl"), @@ -682,9 +737,25 @@ def stop(self, params=[], wait=True): execute_utility(_params, self.utils_log_file) self._maybe_stop_logger() - + self.is_started = False return self + def kill(self, someone=None): + """ + Kills the PostgreSQL node or a specified auxiliary process if the node is running. + + Args: + someone: A key to the auxiliary process in the auxiliary_pids dictionary. + If None, the main PostgreSQL node process will be killed. Defaults to None. + """ + if self.is_started: + sig = signal.SIGKILL if os.name != 'nt' else signal.SIGBREAK + if someone == None: + os.kill(self.pid, sig) + else: + os.kill(self.auxiliary_pids[someone][0], sig) + self.is_started = False + def restart(self, params=[]): """ Restart this node using pg_ctl. @@ -1359,3 +1430,245 @@ def connect(self, username=username, password=password, autocommit=autocommit) # yapf: disable + + def table_checksum(self, table, dbname="postgres"): + """ + Calculate the checksum of a table by hashing its rows. + + The function fetches rows from the table in chunks and calculates the checksum + by summing the hash values of each row. The function uses a separate thread + to fetch rows when there are more than 2000 rows in the table. + + Args: + table (str): The name of the table for which the checksum should be calculated. + dbname (str, optional): The name of the database where the table is located. Defaults to "postgres". + + Returns: + int: The calculated checksum of the table. + """ + + def fetch_rows(con, cursor_name): + while True: + rows = con.execute(f"FETCH FORWARD 2000 FROM {cursor_name}") + if not rows: + break + yield rows + + def process_rows(queue, con, cursor_name): + try: + for rows in fetch_rows(con, cursor_name): + queue.put(rows) + except Exception as e: + queue.put(e) + else: + queue.put(None) + + cursor_name = f"cur_{random.randint(0, 2 ** 48)}" + checksum = 0 + query_thread = None + + with self.connect(dbname=dbname) as con: + con.execute(f""" + DECLARE {cursor_name} NO SCROLL CURSOR FOR + SELECT t::text FROM {table} as t + """) + + queue = Queue(maxsize=50) + initial_rows = con.execute(f"FETCH FORWARD 2000 FROM {cursor_name}") + + if not initial_rows: + return 0 + + queue.put(initial_rows) + + if len(initial_rows) == 2000: + query_thread = threading.Thread(target=process_rows, args=(queue, con, cursor_name)) + query_thread.start() + else: + queue.put(None) + + while True: + rows = queue.get() + if rows is None: + break + if isinstance(rows, Exception): + raise rows + + for row in rows: + checksum += hash(row[0]) + + if query_thread is not None: + query_thread.join() + + con.execute(f"CLOSE {cursor_name}; ROLLBACK;") + + return checksum + + def pgbench_table_checksums(self, dbname="postgres", + pgbench_tables=('pgbench_branches', + 'pgbench_tellers', + 'pgbench_accounts', + 'pgbench_history') + ): + """ + Calculate the checksums of the specified pgbench tables using table_checksum method. + + Args: + dbname (str, optional): The name of the database where the pgbench tables are located. Defaults to "postgres". + pgbench_tables (tuple of str, optional): A tuple containing the names of the pgbench tables for which the + checksums should be calculated. Defaults to a tuple containing the + names of the default pgbench tables. + + Returns: + set of tuple: A set of tuples, where each tuple contains the table name and its corresponding checksum. + """ + return {(table, self.table_checksum(table, dbname)) + for table in pgbench_tables} + + def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): + """ + Update or remove configuration options in the specified configuration file, + updates the options specified in the options dictionary, removes any options + specified in the rm_options set, and writes the updated configuration back to + the file. + + Args: + options (dict): A dictionary containing the options to update or add, + with the option names as keys and their values as values. + config (str, optional): The name of the configuration file to update. + Defaults to 'postgresql.auto.conf'. + rm_options (set, optional): A set containing the names of the options to remove. + Defaults to an empty set. + """ + # parse postgresql.auto.conf + path = os.path.join(self.data_dir, config) + + with open(path, 'r') as f: + raw_content = f.read() + + current_options = {} + current_directives = [] + for line in raw_content.splitlines(): + + # ignore comments + if line.startswith('#'): + continue + + if line == '': + continue + + if line.startswith('include'): + current_directives.append(line) + continue + + name, var = line.partition('=')[::2] + name = name.strip() + var = var.strip() + var = var.strip('"') + var = var.strip("'") + + # remove options specified in rm_options list + if name in rm_options: + continue + + current_options[name] = var + + for option in options: + current_options[option] = options[option] + + auto_conf = '' + for option in current_options: + auto_conf += "{0} = '{1}'\n".format( + option, current_options[option]) + + for directive in current_directives: + auto_conf += directive + "\n" + + with open(path, 'wt') as f: + f.write(auto_conf) + + +class NodeApp: + """ + Functions that can be moved to testgres.PostgresNode + We use these functions in ProbackupController and need tp move them in some visible place + """ + + def __init__(self, test_path, nodes_to_cleanup): + self.test_path = test_path + self.nodes_to_cleanup = nodes_to_cleanup + + def make_empty( + self, + base_dir=None): + real_base_dir = os.path.join(self.test_path, base_dir) + shutil.rmtree(real_base_dir, ignore_errors=True) + os.makedirs(real_base_dir) + + node = PostgresNodeExtended(base_dir=real_base_dir) + node.should_rm_dirs = True + self.nodes_to_cleanup.append(node) + + return node + + def make_simple( + self, + base_dir=None, + set_replication=False, + ptrack_enable=False, + initdb_params=[], + pg_options={}): + + node = self.make_empty(base_dir) + node.init( + initdb_params=initdb_params, allow_streaming=set_replication) + + # set major version + with open(os.path.join(node.data_dir, 'PG_VERSION')) as f: + node.major_version_str = str(f.read().rstrip()) + node.major_version = float(node.major_version_str) + + # Sane default parameters + options = {} + options['max_connections'] = 100 + options['shared_buffers'] = '10MB' + options['fsync'] = 'off' + + options['wal_level'] = 'logical' + options['hot_standby'] = 'off' + + options['log_line_prefix'] = '%t [%p]: [%l-1] ' + options['log_statement'] = 'none' + options['log_duration'] = 'on' + options['log_min_duration_statement'] = 0 + options['log_connections'] = 'on' + options['log_disconnections'] = 'on' + options['restart_after_crash'] = 'off' + options['autovacuum'] = 'off' + + # Allow replication in pg_hba.conf + if set_replication: + options['max_wal_senders'] = 10 + + if ptrack_enable: + options['ptrack.map_size'] = '1' + options['shared_preload_libraries'] = 'ptrack' + + if node.major_version >= 13: + options['wal_keep_size'] = '200MB' + else: + options['wal_keep_segments'] = '12' + + # set default values + node.set_auto_conf(options) + + # Apply given parameters + node.set_auto_conf(pg_options) + + # kludge for testgres + # https://github.com/postgrespro/testgres/issues/54 + # for PG >= 13 remove 'wal_keep_segments' parameter + if node.major_version >= 13: + node.set_auto_conf({}, 'postgresql.conf', ['wal_keep_segments']) + + return node From 1512afde8a40bee606046e6a305aa4017ec8419a Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Tue, 11 Apr 2023 15:02:00 +0200 Subject: [PATCH 02/23] PBCKP-137 up version 1.8.6 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a5dc600e..5c6f4a07 100755 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ readme = f.read() setup( - version='1.8.5', + version='1.8.6', name='testgres', packages=['testgres'], description='Testing utility for PostgreSQL and its extensions', From 0d62e0e6881a8cd18e9acd58507fcae74ce71ad9 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Tue, 11 Apr 2023 22:50:33 +0200 Subject: [PATCH 03/23] PBCKP-137 update node.py --- testgres/node.py | 163 +++++++++++++++++++---------------------------- 1 file changed, 67 insertions(+), 96 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 0d1232a2..6d1d4544 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -12,6 +12,7 @@ import subprocess import time + try: from collections.abc import Iterable except ImportError: @@ -104,6 +105,7 @@ InternalError = pglib.InternalError ProgrammingError = pglib.ProgrammingError OperationalError = pglib.OperationalError +DatabaseError = pglib.DatabaseError class ProcessProxy(object): @@ -651,13 +653,15 @@ def get_control_data(self): return out_dict - def slow_start(self, replica=False): + def slow_start(self, replica=False, dbname='template1', username='dev'): """ Starts the PostgreSQL instance and then polls the instance until it reaches the expected state (primary or replica). The state is checked using the pg_is_in_recovery() function. Args: + dbname: + username: replica: If True, waits for the instance to be in recovery (i.e., replica mode). If False, waits for the instance to be in primary mode. Default is False. """ @@ -668,14 +672,15 @@ def slow_start(self, replica=False): else: query = 'SELECT not pg_is_in_recovery()' # Call poll_query_until until the expected value is returned - self.poll_query_until( - dbname="template1", - query=query, - suppress={pglib.InternalError, - QueryException, - pglib.ProgrammingError, - pglib.OperationalError}) - + self.poll_query_until(query=query, + expected=False, + dbname=dbname, + username=username, + suppress={InternalError, + QueryException, + ProgrammingError, + OperationalError, + DatabaseError}) def start(self, params=[], wait=True): """ @@ -1432,96 +1437,66 @@ def connect(self, autocommit=autocommit) # yapf: disable def table_checksum(self, table, dbname="postgres"): - """ - Calculate the checksum of a table by hashing its rows. - - The function fetches rows from the table in chunks and calculates the checksum - by summing the hash values of each row. The function uses a separate thread - to fetch rows when there are more than 2000 rows in the table. - - Args: - table (str): The name of the table for which the checksum should be calculated. - dbname (str, optional): The name of the database where the table is located. Defaults to "postgres". - - Returns: - int: The calculated checksum of the table. - """ - - def fetch_rows(con, cursor_name): - while True: - rows = con.execute(f"FETCH FORWARD 2000 FROM {cursor_name}") - if not rows: - break - yield rows - - def process_rows(queue, con, cursor_name): - try: - for rows in fetch_rows(con, cursor_name): - queue.put(rows) - except Exception as e: - queue.put(e) - else: - queue.put(None) - - cursor_name = f"cur_{random.randint(0, 2 ** 48)}" - checksum = 0 - query_thread = None - - with self.connect(dbname=dbname) as con: - con.execute(f""" - DECLARE {cursor_name} NO SCROLL CURSOR FOR - SELECT t::text FROM {table} as t - """) - - queue = Queue(maxsize=50) - initial_rows = con.execute(f"FETCH FORWARD 2000 FROM {cursor_name}") - - if not initial_rows: - return 0 - - queue.put(initial_rows) - - if len(initial_rows) == 2000: - query_thread = threading.Thread(target=process_rows, args=(queue, con, cursor_name)) - query_thread.start() - else: - queue.put(None) + con = self.connect(dbname=dbname) + + curname = "cur_" + str(random.randint(0, 2 ** 48)) + + con.execute(""" + DECLARE %s NO SCROLL CURSOR FOR + SELECT t::text FROM %s as t + """ % (curname, table)) + + que = Queue(maxsize=50) + sum = 0 + + rows = con.execute("FETCH FORWARD 2000 FROM %s" % curname) + if not rows: + return 0 + que.put(rows) + + th = None + if len(rows) == 2000: + def querier(): + try: + while True: + rows = con.execute("FETCH FORWARD 2000 FROM %s" % curname) + if not rows: + break + que.put(rows) + except Exception as e: + que.put(e) + else: + que.put(None) - while True: - rows = queue.get() - if rows is None: - break - if isinstance(rows, Exception): - raise rows + th = threading.Thread(target=querier) + th.start() + else: + que.put(None) - for row in rows: - checksum += hash(row[0]) + while True: + rows = que.get() + if rows is None: + break + if isinstance(rows, Exception): + raise rows + # hash uses SipHash since Python3.4, therefore it is good enough + for row in rows: + sum += hash(row[0]) - if query_thread is not None: - query_thread.join() + if th is not None: + th.join() - con.execute(f"CLOSE {cursor_name}; ROLLBACK;") + con.execute("CLOSE %s; ROLLBACK;" % curname) - return checksum + con.close() + return sum def pgbench_table_checksums(self, dbname="postgres", - pgbench_tables=('pgbench_branches', - 'pgbench_tellers', - 'pgbench_accounts', - 'pgbench_history') + pgbench_tables = ('pgbench_branches', + 'pgbench_tellers', + 'pgbench_accounts', + 'pgbench_history') ): - """ - Calculate the checksums of the specified pgbench tables using table_checksum method. - - Args: - dbname (str, optional): The name of the database where the pgbench tables are located. Defaults to "postgres". - pgbench_tables (tuple of str, optional): A tuple containing the names of the pgbench tables for which the - checksums should be calculated. Defaults to a tuple containing the - names of the default pgbench tables. - - Returns: - set of tuple: A set of tuples, where each tuple contains the table name and its corresponding checksum. - """ return {(table, self.table_checksum(table, dbname)) for table in pgbench_tables} @@ -1589,10 +1564,6 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): class NodeApp: - """ - Functions that can be moved to testgres.PostgresNode - We use these functions in ProbackupController and need tp move them in some visible place - """ def __init__(self, test_path, nodes_to_cleanup): self.test_path = test_path @@ -1605,7 +1576,7 @@ def make_empty( shutil.rmtree(real_base_dir, ignore_errors=True) os.makedirs(real_base_dir) - node = PostgresNodeExtended(base_dir=real_base_dir) + node = PostgresNode(base_dir=real_base_dir) node.should_rm_dirs = True self.nodes_to_cleanup.append(node) From 8be1b3a72cecd7dd15862c3258b97fb5834e6737 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 17 Apr 2023 10:43:16 +0200 Subject: [PATCH 04/23] PBCKP-137 update node --- testgres/node.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 6d1d4544..17b9a260 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -105,7 +105,6 @@ InternalError = pglib.InternalError ProgrammingError = pglib.ProgrammingError OperationalError = pglib.OperationalError -DatabaseError = pglib.DatabaseError class ProcessProxy(object): @@ -653,7 +652,7 @@ def get_control_data(self): return out_dict - def slow_start(self, replica=False, dbname='template1', username='dev'): + def slow_start(self, replica=False, dbname='template1', username=default_username()): """ Starts the PostgreSQL instance and then polls the instance until it reaches the expected state (primary or replica). The state is checked @@ -673,14 +672,12 @@ def slow_start(self, replica=False, dbname='template1', username='dev'): query = 'SELECT not pg_is_in_recovery()' # Call poll_query_until until the expected value is returned self.poll_query_until(query=query, - expected=False, dbname=dbname, username=username, suppress={InternalError, QueryException, ProgrammingError, - OperationalError, - DatabaseError}) + OperationalError}) def start(self, params=[], wait=True): """ @@ -970,7 +967,7 @@ def psql(self, return process.returncode, out, err @method_decorator(positional_args_hack(['dbname', 'query'])) - def safe_psql(self, query=None, **kwargs): + def safe_psql(self, query=None, expect_error=False, **kwargs): """ Execute a query using psql. @@ -980,6 +977,8 @@ def safe_psql(self, query=None, **kwargs): dbname: database name to connect to. username: database user name. input: raw input to be passed. + expect_error: if True - fail if we didn't get ret + if False - fail if we got ret **kwargs are passed to psql(). @@ -992,7 +991,12 @@ def safe_psql(self, query=None, **kwargs): ret, out, err = self.psql(query=query, **kwargs) if ret: - raise QueryException((err or b'').decode('utf-8'), query) + if expect_error: + out = (err or b'').decode('utf-8') + else: + raise QueryException((err or b'').decode('utf-8'), query) + elif expect_error: + assert False, f"Exception was expected, but query finished successfully: `{query}` " return out From 51f05de66ebc5604dc72ff17af08dd7d4fb1c9a1 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Tue, 2 May 2023 15:39:40 +0200 Subject: [PATCH 05/23] PBCKP-152 change local function on execution by ssh --- setup.py | 4 +- testgres/api.py | 6 + testgres/backup.py | 13 +- testgres/cache.py | 20 +-- testgres/config.py | 21 ++- testgres/connection.py | 39 +++--- testgres/defaults.py | 16 ++- testgres/logger.py | 14 ++ testgres/node.py | 23 +++- testgres/os_ops.py | 293 +++++++++++++++++++++++++++++++++++++++++ testgres/utils.py | 21 ++- 11 files changed, 422 insertions(+), 48 deletions(-) create mode 100644 testgres/os_ops.py diff --git a/setup.py b/setup.py index 5c6f4a07..8ae54e4f 100755 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ from distutils.core import setup # Basic dependencies -install_requires = ["pg8000", "port-for>=0.4", "six>=1.9.0", "psutil"] +install_requires = ["pg8000", "port-for>=0.4", "six>=1.9.0", "psutil", "fabric"] # Add compatibility enum class if sys.version_info < (3, 4): @@ -21,7 +21,7 @@ readme = f.read() setup( - version='1.8.6', + version='1.9.0', name='testgres', packages=['testgres'], description='Testing utility for PostgreSQL and its extensions', diff --git a/testgres/api.py b/testgres/api.py index e90cf7bd..bae46717 100644 --- a/testgres/api.py +++ b/testgres/api.py @@ -30,6 +30,7 @@ PostgresNode(name='...', port=..., base_dir='...') [(3,)] """ +from defaults import default_username from .node import PostgresNode @@ -37,6 +38,11 @@ def get_new_node(name=None, base_dir=None, **kwargs): """ Simply a wrapper around :class:`.PostgresNode` constructor. See :meth:`.PostgresNode.__init__` for details. + For remote connection you can add next parameters: + host='127.0.0.1', + hostname='localhost', + ssh_key=None, + username=default_username() """ # NOTE: leave explicit 'name' and 'base_dir' for compatibility return PostgresNode(name=name, base_dir=base_dir, **kwargs) diff --git a/testgres/backup.py b/testgres/backup.py index a725a1df..b3ffa833 100644 --- a/testgres/backup.py +++ b/testgres/backup.py @@ -6,6 +6,7 @@ from six import raise_from from tempfile import mkdtemp +from os_ops import OsOperations from .enums import XLogMethod from .consts import \ @@ -47,7 +48,7 @@ def __init__(self, username: database user name. xlog_method: none | fetch | stream (see docs) """ - + self.os_ops = node.os_ops if not node.status(): raise BackupException('Node must be running') @@ -60,8 +61,8 @@ def __init__(self, raise BackupException(msg) # Set default arguments - username = username or default_username() - base_dir = base_dir or mkdtemp(prefix=TMP_BACKUP) + username = username or self.os_ops.get_user() + base_dir = base_dir or self.os_ops.mkdtemp(prefix=TMP_BACKUP) # public self.original_node = node @@ -81,7 +82,7 @@ def __init__(self, "-D", data_dir, "-X", xlog_method.value ] # yapf: disable - execute_utility(_params, self.log_file) + execute_utility(_params, self.log_file, hostname=node.hostname, ssh_key=node.ssh_key) def __enter__(self): return self @@ -107,7 +108,7 @@ def _prepare_dir(self, destroy): available = not destroy if available: - dest_base_dir = mkdtemp(prefix=TMP_NODE) + dest_base_dir = self.os_ops.mkdtemp(prefix=TMP_NODE) data1 = os.path.join(self.base_dir, DATA_DIR) data2 = os.path.join(dest_base_dir, DATA_DIR) @@ -185,4 +186,4 @@ def cleanup(self): if self._available: self._available = False - rmtree(self.base_dir, ignore_errors=True) + self.os_ops.rmdirs(self.base_dir, ignore_errors=True) diff --git a/testgres/cache.py b/testgres/cache.py index c3cd9971..cd40e72f 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -6,6 +6,7 @@ from shutil import copytree from six import raise_from +from os_ops import OsOperations from .config import testgres_config from .consts import XLOG_CONTROL_FILE @@ -21,14 +22,16 @@ execute_utility -def cached_initdb(data_dir, logfile=None, params=None): +def cached_initdb(data_dir, logfile=None, hostname='localhost', ssh_key=None, params=None): """ Perform initdb or use cached node files. """ + os_ops = OsOperations(hostname=hostname, ssh_key=ssh_key) + def call_initdb(initdb_dir, log=None): try: _params = [get_bin_path("initdb"), "-D", initdb_dir, "-N"] - execute_utility(_params + (params or []), log) + execute_utility(_params + (params or []), log, hostname=hostname, ssh_key=ssh_key) except ExecUtilException as e: raise_from(InitNodeException("Failed to run initdb"), e) @@ -39,13 +42,14 @@ def call_initdb(initdb_dir, log=None): cached_data_dir = testgres_config.cached_initdb_dir # Initialize cached initdb - if not os.path.exists(cached_data_dir) or \ - not os.listdir(cached_data_dir): + + if not os_ops.path_exists(cached_data_dir) or \ + not os_ops.listdir(cached_data_dir): call_initdb(cached_data_dir) try: # Copy cached initdb to current data dir - copytree(cached_data_dir, data_dir) + os_ops.copytree(cached_data_dir, data_dir) # Assign this node a unique system id if asked to if testgres_config.cached_initdb_unique: @@ -53,12 +57,12 @@ def call_initdb(initdb_dir, log=None): # Some users might rely upon unique system ids, but # our initdb caching mechanism breaks this contract. pg_control = os.path.join(data_dir, XLOG_CONTROL_FILE) - with io.open(pg_control, "r+b") as f: - f.write(generate_system_id()) # overwrite id + system_id = generate_system_id() + os_ops.write(pg_control, system_id, truncate=True, binary=True, read_and_write=True) # XXX: build new WAL segment with our system id _params = [get_bin_path("pg_resetwal"), "-D", data_dir, "-f"] - execute_utility(_params, logfile) + execute_utility(_params, logfile, hostname=hostname, ssh_key=ssh_key) except ExecUtilException as e: msg = "Failed to reset WAL for system id" diff --git a/testgres/config.py b/testgres/config.py index cfcdadc2..b32536ba 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -43,6 +43,9 @@ class GlobalConfig(object): _cached_initdb_dir = None """ underlying class attribute for cached_initdb_dir property """ + + os_ops = None + """ OsOperation object that allows work on remote host """ @property def cached_initdb_dir(self): """ path to a temp directory for cached initdb. """ @@ -52,8 +55,15 @@ def cached_initdb_dir(self): def cached_initdb_dir(self, value): self._cached_initdb_dir = value + # NOTE: assign initial cached dir for initdb + if self.os_ops: + testgres_config.cached_initdb_dir = self.os_ops.mkdtemp(prefix=TMP_CACHE) + else: + testgres_config.cached_initdb_dir = mkdtemp(prefix=TMP_CACHE) + if value: cached_initdb_dirs.add(value) + return testgres_config.cached_initdb_dir @property def temp_dir(self): @@ -133,9 +143,12 @@ def copy(self): @atexit.register -def _rm_cached_initdb_dirs(): +def _rm_cached_initdb_dirs(os_ops=None): for d in cached_initdb_dirs: - rmtree(d, ignore_errors=True) + if os_ops: + os_ops.rmtree(d, ignore_errors=True) + else: + rmtree(d, ignore_errors=True) def push_config(**options): @@ -195,7 +208,3 @@ def configure_testgres(**options): """ testgres_config.update(options) - - -# NOTE: assign initial cached dir for initdb -testgres_config.cached_initdb_dir = mkdtemp(prefix=TMP_CACHE) diff --git a/testgres/connection.py b/testgres/connection.py index ee2a2128..f85f56be 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -1,4 +1,5 @@ # coding: utf-8 +from os_ops import OsOperations # we support both pg8000 and psycopg2 try: @@ -41,11 +42,12 @@ def __init__(self, self._node = node - self._connection = pglib.connect(database=dbname, - user=username, - password=password, - host=node.host, - port=node.port) + self.os_ops = OsOperations(node.host, node.hostname, node.ssh_key, node.username) + self._connection = self.os_ops.db_connect(dbname=dbname, + user=username, + password=password, + host=node.host, + port=node.port) self._connection.autocommit = autocommit self._cursor = self.connection.cursor() @@ -102,17 +104,24 @@ def rollback(self): return self def execute(self, query, *args): - self.cursor.execute(query, args) - try: - res = self.cursor.fetchall() - - # pg8000 might return tuples - if isinstance(res, tuple): - res = [tuple(t) for t in res] - - return res - except Exception: + with self.connection.cursor() as cursor: + cursor.execute(query, args) + try: + res = cursor.fetchall() + + # pg8000 might return tuples + if isinstance(res, tuple): + res = [tuple(t) for t in res] + + return res + except (pglib.ProgrammingError, pglib.InternalError) as e: + # An error occurred while trying to fetch results (e.g., no results to fetch) + print(f"Error fetching results: {e}") + return None + except (pglib.Error, Exception) as e: + # Handle other database errors + print(f"Error executing query: {e}") return None def close(self): diff --git a/testgres/defaults.py b/testgres/defaults.py index 8d5b892e..539183ae 100644 --- a/testgres/defaults.py +++ b/testgres/defaults.py @@ -13,12 +13,15 @@ def default_dbname(): return 'postgres' -def default_username(): +def default_username(os_ops=None): """ Return default username (current user). """ - - return getpass.getuser() + if os_ops: + user = os_ops.get_user() + else: + user = getpass.getuser() + return user def generate_app_name(): @@ -29,7 +32,7 @@ def generate_app_name(): return 'testgres-{}'.format(str(uuid.uuid4())) -def generate_system_id(): +def generate_system_id(os_ops=None): """ Generate a new 64-bit unique system identifier for node. """ @@ -44,7 +47,10 @@ def generate_system_id(): system_id = 0 system_id |= (secs << 32) system_id |= (usecs << 12) - system_id |= (os.getpid() & 0xFFF) + if os_ops: + system_id |= (os_ops.get_pid() & 0xFFF) + else: + system_id |= (os.getpid() & 0xFFF) # pack ULL in native byte order return struct.pack('=Q', system_id) diff --git a/testgres/logger.py b/testgres/logger.py index b4648f44..abd4d255 100644 --- a/testgres/logger.py +++ b/testgres/logger.py @@ -6,6 +6,20 @@ import time +# create logger +log = logging.getLogger('Testgres') +log.setLevel(logging.DEBUG) +# create console handler and set level to debug +ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) +# create formatter +formatter = logging.Formatter('\n%(asctime)s - %(name)s[%(levelname)s]: %(message)s') +# add formatter to ch +ch.setFormatter(formatter) +# add ch to logger +log.addHandler(ch) + + class TestgresLogger(threading.Thread): """ Helper class to implement reading from log files. diff --git a/testgres/node.py b/testgres/node.py index 17b9a260..d895bde8 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -12,6 +12,7 @@ import subprocess import time +from os_ops import OsOperations try: from collections.abc import Iterable @@ -129,7 +130,8 @@ def __repr__(self): class PostgresNode(object): - def __init__(self, name=None, port=None, base_dir=None): + def __init__(self, name=None, port=None, base_dir=None, + host='127.0.0.1', hostname='localhost', ssh_key=None, username=default_username()): """ PostgresNode constructor. @@ -147,10 +149,14 @@ def __init__(self, name=None, port=None, base_dir=None): self._master = None # basic - self.host = '127.0.0.1' self.name = name or generate_app_name() self.port = port or reserve_port() + self.host = host + self.hostname = hostname + self.ssh_key = ssh_key + self.os_ops = OsOperations(host, hostname, ssh_key, username=username) + # defaults for __exit__() self.cleanup_on_good_exit = testgres_config.node_cleanup_on_good_exit self.cleanup_on_bad_exit = testgres_config.node_cleanup_on_bad_exit @@ -455,9 +461,12 @@ def init(self, initdb_params=None, **kwargs): """ # initialize this PostgreSQL node - cached_initdb(data_dir=self.data_dir, - logfile=self.utils_log_file, - params=initdb_params) + cached_initdb( + data_dir=self.data_dir, + logfile=self.utils_log_file, + hostname=self.hostname, + ssh_key=self.ssh_key, + params=initdb_params) # initialize default config files self.default_conf(**kwargs) @@ -514,6 +523,10 @@ def get_auth_method(t): new_lines = [ u"local\treplication\tall\t\t\t{}\n".format(auth_local), u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), + + u"host\treplication\tall\t0.0.0.0/0\t{}\n".format(auth_host), + u"host\tall\tall\t0.0.0.0/0\t{}\n".format(auth_host), + u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host) ] # yapf: disable diff --git a/testgres/os_ops.py b/testgres/os_ops.py new file mode 100644 index 00000000..290fee77 --- /dev/null +++ b/testgres/os_ops.py @@ -0,0 +1,293 @@ +import base64 +import getpass +import os +import shutil +import subprocess +import tempfile +from contextlib import contextmanager +from shutil import rmtree + +try: + import psycopg2 as pglib +except ImportError: + try: + import pg8000 as pglib + except ImportError: + raise ImportError("You must have psycopg2 or pg8000 modules installed") + +from defaults import default_username +from testgres.logger import log + +import paramiko + + +class OsOperations: + + def __init__(self, host='127.0.0.1', hostname='localhost', ssh_key=None, username=default_username()): + self.host = host + self.ssh_key = ssh_key + self.username = username + self.remote = not (self.host == '127.0.0.1' and hostname == 'localhost') + self.ssh = None + + if self.remote: + self.ssh = self.connect() + + def __del__(self): + if self.ssh: + self.ssh.close() + + @contextmanager + def ssh_connect(self): + if not self.remote: + yield None + else: + with open(self.ssh_key, 'r') as f: + key_data = f.read() + if 'BEGIN OPENSSH PRIVATE KEY' in key_data: + key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) + else: + key = paramiko.RSAKey.from_private_key_file(self.ssh_key) + + with paramiko.SSHClient() as ssh: + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.host, username=self.username, pkey=key) + yield ssh + + def connect(self): + with self.ssh_connect() as ssh: + return ssh + + def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): + if isinstance(cmd, list): + cmd = ' '.join(cmd) + log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") + # Source global profile file + execute command + try: + if self.remote: + cmd = f"source /etc/profile.d/custom.sh; {cmd}" + with self.ssh_connect() as ssh: + stdin, stdout, stderr = ssh.exec_command(cmd) + exit_status = 0 + if wait_exit: + exit_status = stdout.channel.recv_exit_status() + result = stdout.read().decode('utf-8') + error = stderr.read().decode('utf-8') + else: + process = subprocess.run(cmd, shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + timeout=60) + exit_status = process.returncode + result = process.stdout + error = process.stderr + + if expect_error: + raise Exception(result, error) + if exit_status != 0 or 'error' in error.lower(): + log.error(f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}") + exit(1) + + if verbose: + return exit_status, result, error + else: + return result + + except Exception as e: + log.error(f"Unexpected error while executing command `{cmd}`: {e}") + return None + + def makedirs(self, path, remove_existing=False): + if remove_existing: + cmd = f'rm -rf {path} && mkdir -p {path}' + else: + cmd = f'mkdir -p {path}' + self.exec_command(cmd) + + def rmdirs(self, path, ignore_errors=True): + if self.remote: + cmd = f'rm -rf {path}' + self.exec_command(cmd) + else: + rmtree(path, ignore_errors=ignore_errors) + + def mkdtemp(self, prefix=None): + if self.remote: + temp_dir = self.exec_command(f'mkdtemp -d {prefix}') + return temp_dir.strip() + else: + return tempfile.mkdtemp(prefix=prefix) + + def path_exists(self, path): + if self.remote: + result = self.exec_command(f'test -e {path}; echo $?') + return int(result.strip()) == 0 + else: + return os.path.exists(path) + + def copytree(self, src, dst): + if self.remote: + self.exec_command(f'cp -r {src} {dst}') + else: + shutil.copytree(src, dst) + + def listdir(self, path): + if self.remote: + result = self.exec_command(f'ls {path}') + return result.splitlines() + else: + return os.listdir(path) + + def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + """ + Write data to a file, both locally and on a remote host. + + :param filename: The file path where the data will be written. + :param data: The data to be written to the file. + :param truncate: If True, the file will be truncated before writing ('w' or 'wb' option); + if False (default), data will be appended ('a' or 'ab' option). + :param binary: If True, the data will be written in binary mode ('wb' or 'ab' option); + if False (default), the data will be written in text mode ('w' or 'a' option). + :param read_and_write: If True, the file will be opened with read and write permissions ('r+' option); + if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option). + """ + mode = 'wb' if binary else 'w' + if not truncate: + mode = 'a' + mode + if read_and_write: + mode = 'r+' + mode + + if self.remote: + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file.write(data) + tmp_file.flush() + + sftp = self.ssh.open_sftp() + sftp.put(tmp_file.name, filename) + sftp.close() + else: + with open(filename, mode) as file: + file.write(data) + + def read(self, filename): + cmd = f'cat {filename}' + return self.exec_command(cmd) + + def readlines(self, filename): + return self.read(filename).splitlines() + + def get_name(self): + cmd = 'python3 -c "import os; print(os.name)"' + return self.exec_command(cmd).strip() + + def kill(self, pid, signal): + cmd = f'kill -{signal} {pid}' + self.exec_command(cmd) + + def environ(self, var_name): + cmd = f"echo ${var_name}" + return self.exec_command(cmd).strip() + + @property + def pathsep(self): + return ':' if self.get_name() == 'posix' else ';' + + def isfile(self, remote_file): + if self.remote: + stdout = self.exec_command(f'test -f {remote_file}; echo $?') + result = int(stdout.strip()) + return result == 0 + else: + return os.path.isfile(remote_file) + + def find_executable(self, executable): + search_paths = self.environ('PATH') + if not search_paths: + return None + + search_paths = search_paths.split(self.pathsep) + for path in search_paths: + remote_file = os.path.join(path, executable) + if self.isfile(remote_file): + return remote_file + + return None + + def is_executable(self, file): + # Check if the file is executable + if self.remote: + if not self.exec_command(f"test -x {file} && echo OK") == 'OK\n': + return False + else: + if not os.access(file, os.X_OK): + return False + return True + + def add_to_path(self, new_path): + os_name = self.get_name() + if os_name == 'posix': + dir_del = ':' + elif os_name == 'nt': + dir_del = ';' + else: + raise Exception(f"Unsupported operating system: {os_name}") + + # Check if the directory is already in PATH + path = self.environ('PATH') + if new_path not in path.split(dir_del): + if self.remote: + self.exec_command(f"export PATH={new_path}{dir_del}{path}") + else: + os.environ['PATH'] = f"{new_path}{dir_del}{path}" + return dir_del + + def set_env(self, var_name, var_val): + # Check if the directory is already in PATH + if self.remote: + self.exec_command(f"export {var_name}={var_val}") + else: + os.environ[var_name] = var_val + + def get_pid(self): + # Get current process id + if self.remote: + process_id = self.exec_command(f"echo $$") + else: + process_id = os.getpid() + return process_id + + def get_user(self): + # Get current user + if self.remote: + user = self.exec_command(f"echo $USER") + else: + user = getpass.getuser() + return user + + @contextmanager + def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + if self.remote: + with self.ssh_connect() as ssh: + # Set up a local port forwarding on a random port + local_port = ssh.forward_remote_port(host, port) + conn = pglib.connect( + host=host, + port=local_port, + dbname=dbname, + user=user, + password=password, + ) + try: + yield conn + finally: + conn.close() + ssh.close_forwarded_tcp(local_port) + else: + with pglib.connect( + host=host, + port=port, + dbname=dbname, + user=user, + password=password, + ) as conn: + yield conn + + diff --git a/testgres/utils.py b/testgres/utils.py index 4d99c69d..877549e7 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -15,6 +15,8 @@ from distutils.spawn import find_executable from six import iteritems +from fabric import Connection + from .config import testgres_config from .exceptions import ExecUtilException @@ -47,7 +49,7 @@ def release_port(port): bound_ports.discard(port) -def execute_utility(args, logfile=None): +def execute_utility(args, logfile=None, hostname='localhost', ssh_key=None): """ Execute utility (pg_ctl, pg_dump etc). @@ -59,6 +61,23 @@ def execute_utility(args, logfile=None): stdout of executed utility. """ + if hostname != 'localhost': + conn = Connection( + hostname, + connect_kwargs={ + "key_filename": f"{ssh_key}", + }, + ) + + # TODO skip remote ssh run if we are on the localhost. + # result = conn.run('hostname', hide=True) + # add logger + + cmd = ' '.join(args) + result = conn.run(cmd, hide=True) + + return result + # run utility if os.name == 'nt': # using output to a temporary file in Windows From 4f38bd505ee99d09f40478a5cd30a1a685481c5b Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Wed, 3 May 2023 11:50:35 +0200 Subject: [PATCH 06/23] PBCKP-152 multihost --- testgres/cache.py | 2 +- testgres/config.py | 10 ++++------ testgres/connection.py | 3 +-- testgres/os_ops.py | 31 ++++++++++++------------------- 4 files changed, 18 insertions(+), 28 deletions(-) diff --git a/testgres/cache.py b/testgres/cache.py index cd40e72f..0ca4d707 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -39,7 +39,7 @@ def call_initdb(initdb_dir, log=None): call_initdb(data_dir, logfile) else: # Fetch cached initdb dir - cached_data_dir = testgres_config.cached_initdb_dir + cached_data_dir = testgres_config.cached_initdb_dir() # Initialize cached initdb diff --git a/testgres/config.py b/testgres/config.py index b32536ba..2e4a3aaa 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -55,12 +55,6 @@ def cached_initdb_dir(self): def cached_initdb_dir(self, value): self._cached_initdb_dir = value - # NOTE: assign initial cached dir for initdb - if self.os_ops: - testgres_config.cached_initdb_dir = self.os_ops.mkdtemp(prefix=TMP_CACHE) - else: - testgres_config.cached_initdb_dir = mkdtemp(prefix=TMP_CACHE) - if value: cached_initdb_dirs.add(value) return testgres_config.cached_initdb_dir @@ -208,3 +202,7 @@ def configure_testgres(**options): """ testgres_config.update(options) + + +# NOTE: assign initial cached dir for initdb +testgres_config.cached_initdb_dir = mkdtemp(prefix=TMP_CACHE) diff --git a/testgres/connection.py b/testgres/connection.py index f85f56be..cc3dbdfe 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -42,8 +42,7 @@ def __init__(self, self._node = node - self.os_ops = OsOperations(node.host, node.hostname, node.ssh_key, node.username) - self._connection = self.os_ops.db_connect(dbname=dbname, + self._connection = node.os_ops.db_connect(dbname=dbname, user=username, password=password, host=node.host, diff --git a/testgres/os_ops.py b/testgres/os_ops.py index 290fee77..e87dcc88 100644 --- a/testgres/os_ops.py +++ b/testgres/os_ops.py @@ -262,32 +262,25 @@ def get_user(self): user = getpass.getuser() return user - @contextmanager def db_connect(self, dbname, user, password=None, host='localhost', port=5432): if self.remote: - with self.ssh_connect() as ssh: - # Set up a local port forwarding on a random port - local_port = ssh.forward_remote_port(host, port) - conn = pglib.connect( - host=host, - port=local_port, - dbname=dbname, - user=user, - password=password, - ) - try: - yield conn - finally: - conn.close() - ssh.close_forwarded_tcp(local_port) + local_port = self.ssh.forward_remote_port(host, port) + conn = pglib.connect( + host=host, + port=local_port, + dbname=dbname, + user=user, + password=password, + ) else: - with pglib.connect( + conn = pglib.connect( host=host, port=port, dbname=dbname, user=user, password=password, - ) as conn: - yield conn + ) + return conn + From f9b6bdbf48747f29f29e99977a130258e3bb3836 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Sun, 11 Jun 2023 00:03:28 +0200 Subject: [PATCH 07/23] PBCKP-152 --- testgres/backup.py | 7 +- testgres/cache.py | 12 +- testgres/config.py | 11 +- testgres/connection.py | 2 +- testgres/defaults.py | 18 +-- testgres/node.py | 194 +++++++++++------------ testgres/op_ops/local_ops.py | 224 ++++++++++++++++++++++++++ testgres/op_ops/os_ops.py | 99 ++++++++++++ testgres/op_ops/remote_ops.py | 259 ++++++++++++++++++++++++++++++ testgres/os_ops.py | 285 ---------------------------------- testgres/utils.py | 10 +- 11 files changed, 697 insertions(+), 424 deletions(-) create mode 100644 testgres/op_ops/local_ops.py create mode 100644 testgres/op_ops/os_ops.py create mode 100644 testgres/op_ops/remote_ops.py delete mode 100644 testgres/os_ops.py diff --git a/testgres/backup.py b/testgres/backup.py index 0a5ad67f..c0fd6e50 100644 --- a/testgres/backup.py +++ b/testgres/backup.py @@ -2,7 +2,6 @@ import os -from shutil import rmtree, copytree from six import raise_from from .enums import XLogMethod @@ -14,8 +13,6 @@ PG_CONF_FILE, \ BACKUP_LOG_FILE -from .defaults import default_username - from .exceptions import BackupException from .utils import \ @@ -80,7 +77,7 @@ def __init__(self, "-D", data_dir, "-X", xlog_method.value ] # yapf: disable - execute_utility(_params, self.log_file, hostname=node.hostname, ssh_key=node.ssh_key) + execute_utility(_params, self.log_file, self.os_ops) def __enter__(self): return self @@ -113,7 +110,7 @@ def _prepare_dir(self, destroy): try: # Copy backup to new data dir - copytree(data1, data2) + self.os_ops.copytree(data1, data2) except Exception as e: raise_from(BackupException('Failed to copy files'), e) else: diff --git a/testgres/cache.py b/testgres/cache.py index 6ef92002..4998e0d2 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -4,7 +4,8 @@ from six import raise_from -from .os_ops import OsOperations +from .op_ops.local_ops import LocalOperations +from .op_ops.os_ops import OsOperations from .config import testgres_config from .consts import XLOG_CONTROL_FILE @@ -20,16 +21,15 @@ execute_utility -def cached_initdb(data_dir, logfile=None, hostname='localhost', ssh_key=None, params=None): +def cached_initdb(data_dir, logfile=None, params=None, os_ops: OsOperations = LocalOperations()): """ Perform initdb or use cached node files. """ - os_ops = OsOperations(hostname=hostname, ssh_key=ssh_key) def call_initdb(initdb_dir, log=None): try: _params = [get_bin_path("initdb"), "-D", initdb_dir, "-N"] - execute_utility(_params + (params or []), log, hostname=hostname, ssh_key=ssh_key) + execute_utility(_params + (params or []), log, os_ops) except ExecUtilException as e: raise_from(InitNodeException("Failed to run initdb"), e) @@ -42,7 +42,7 @@ def call_initdb(initdb_dir, log=None): # Initialize cached initdb if not os_ops.path_exists(cached_data_dir) or \ - not os_ops.listdir(cached_data_dir): + not os_ops.listdir(cached_data_dir): call_initdb(cached_data_dir) try: @@ -60,7 +60,7 @@ def call_initdb(initdb_dir, log=None): # XXX: build new WAL segment with our system id _params = [get_bin_path("pg_resetwal"), "-D", data_dir, "-f"] - execute_utility(_params, logfile, hostname=hostname, ssh_key=ssh_key) + execute_utility(_params, logfile, os_ops) except ExecUtilException as e: msg = "Failed to reset WAL for system id" diff --git a/testgres/config.py b/testgres/config.py index 2e4a3aaa..1be76fbe 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -6,8 +6,8 @@ from contextlib import contextmanager from shutil import rmtree -from tempfile import mkdtemp +from .op_ops.local_ops import LocalOperations from .consts import TMP_CACHE @@ -137,12 +137,9 @@ def copy(self): @atexit.register -def _rm_cached_initdb_dirs(os_ops=None): +def _rm_cached_initdb_dirs(os_ops=LocalOperations()): for d in cached_initdb_dirs: - if os_ops: - os_ops.rmtree(d, ignore_errors=True) - else: - rmtree(d, ignore_errors=True) + os_ops.rmdirs(d, ignore_errors=True) def push_config(**options): @@ -205,4 +202,4 @@ def configure_testgres(**options): # NOTE: assign initial cached dir for initdb -testgres_config.cached_initdb_dir = mkdtemp(prefix=TMP_CACHE) +testgres_config.cached_initdb_dir = testgres_config.os_ops.mkdtemp(prefix=TMP_CACHE) diff --git a/testgres/connection.py b/testgres/connection.py index 763c2490..f3d01ebe 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -37,7 +37,7 @@ def __init__(self, # Set default arguments dbname = dbname or default_dbname() - username = username or default_username() + username = username or default_username(node.os_ops) self._node = node diff --git a/testgres/defaults.py b/testgres/defaults.py index 539183ae..cac788b8 100644 --- a/testgres/defaults.py +++ b/testgres/defaults.py @@ -1,9 +1,10 @@ import datetime -import getpass import os import struct import uuid +from .op_ops.local_ops import LocalOperations + def default_dbname(): """ @@ -13,15 +14,11 @@ def default_dbname(): return 'postgres' -def default_username(os_ops=None): +def default_username(os_ops=LocalOperations()): """ Return default username (current user). """ - if os_ops: - user = os_ops.get_user() - else: - user = getpass.getuser() - return user + return os_ops.get_user() def generate_app_name(): @@ -32,7 +29,7 @@ def generate_app_name(): return 'testgres-{}'.format(str(uuid.uuid4())) -def generate_system_id(os_ops=None): +def generate_system_id(os_ops=LocalOperations()): """ Generate a new 64-bit unique system identifier for node. """ @@ -47,10 +44,7 @@ def generate_system_id(os_ops=None): system_id = 0 system_id |= (secs << 32) system_id |= (usecs << 12) - if os_ops: - system_id |= (os_ops.get_pid() & 0xFFF) - else: - system_id |= (os.getpid() & 0xFFF) + system_id |= (os_ops.get_pid() & 0xFFF) # pack ULL in native byte order return struct.pack('=Q', system_id) diff --git a/testgres/node.py b/testgres/node.py index 29ba2cf3..f06e5cc9 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -12,6 +12,9 @@ import subprocess import time +from op_ops.local_ops import LocalOperations +from op_ops.os_ops import OsOperations +from op_ops.remote_ops import RemoteOperations try: from collections.abc import Iterable @@ -101,7 +104,6 @@ clean_on_error from .backup import NodeBackup -from .os_ops import OsOperations InternalError = pglib.InternalError ProgrammingError = pglib.ProgrammingError @@ -156,7 +158,10 @@ def __init__(self, name=None, port=None, base_dir=None, self.host = host self.hostname = hostname self.ssh_key = ssh_key - self.os_ops = OsOperations(host, hostname, ssh_key, username=username) + if hostname == 'localhost' or host == '127.0.0.1': + self.os_ops = LocalOperations(username=username) + else: + self.os_ops = RemoteOperations(host, hostname, ssh_key) # defaults for __exit__() self.cleanup_on_good_exit = testgres_config.node_cleanup_on_good_exit @@ -201,8 +206,9 @@ def pid(self): if self.status(): pid_file = os.path.join(self.data_dir, PG_PID_FILE) - with io.open(pid_file) as f: - return int(f.readline()) + lines = self.os_ops.readlines(pid_file, num_lines=1) + pid = int(lines[0]) if lines else None + return pid # for clarity return 0 @@ -280,11 +286,11 @@ def master(self): @property def base_dir(self): if not self._base_dir: - self._base_dir = mkdtemp(prefix=TMP_NODE) + self._base_dir = self.os_ops.mkdtemp(prefix=TMP_NODE) # NOTE: it's safe to create a new dir - if not os.path.exists(self._base_dir): - os.makedirs(self._base_dir) + if not self.os_ops.exists(self._base_dir): + self.os_ops.makedirs(self._base_dir) return self._base_dir @@ -293,8 +299,8 @@ def logs_dir(self): path = os.path.join(self.base_dir, LOGS_DIR) # NOTE: it's safe to create a new dir - if not os.path.exists(path): - os.makedirs(path) + if not self.os_ops.exists(path): + self.os_ops.makedirs(path) return path @@ -371,9 +377,7 @@ def _create_recovery_conf(self, username, slot=None): # Since 12 recovery.conf had disappeared if self.version >= PgVer('12'): signal_name = os.path.join(self.data_dir, "standby.signal") - # cross-python touch(). It is vulnerable to races, but who cares? - with open(signal_name, 'a'): - os.utime(signal_name, None) + self.os_ops.touch(signal_name) else: line += "standby_mode=on\n" @@ -431,19 +435,13 @@ def _collect_special_files(self): for f, num_lines in files: # skip missing files - if not os.path.exists(f): + if not self.os_ops.path_exists(f): continue - with io.open(f, "rb") as _f: - if num_lines > 0: - # take last N lines of file - lines = b''.join(file_tail(_f, num_lines)).decode('utf-8') - else: - # read whole file - lines = _f.read().decode('utf-8') + lines = b''.join(self.os_ops.readlines(f, num_lines, encoding='utf-8')) - # fill list - result.append((f, lines)) + # fill list + result.append((f, lines)) return result @@ -465,8 +463,7 @@ def init(self, initdb_params=None, **kwargs): cached_initdb( data_dir=self.data_dir, logfile=self.utils_log_file, - hostname=self.hostname, - ssh_key=self.ssh_key, + os_ops=self.os_ops, params=initdb_params) # initialize default config files @@ -498,47 +495,44 @@ def default_conf(self, hba_conf = os.path.join(self.data_dir, HBA_CONF_FILE) # filter lines in hba file - with io.open(hba_conf, "r+") as conf: - # get rid of comments and blank lines - lines = [ - s for s in conf.readlines() - if len(s.strip()) > 0 and not s.startswith('#') - ] - - # write filtered lines - conf.seek(0) - conf.truncate() - conf.writelines(lines) - - # replication-related settings - if allow_streaming: - # get auth method for host or local users - def get_auth_method(t): - return next((s.split()[-1] - for s in lines if s.startswith(t)), 'trust') - - # get auth methods - auth_local = get_auth_method('local') - auth_host = get_auth_method('host') - - new_lines = [ - u"local\treplication\tall\t\t\t{}\n".format(auth_local), - u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), - - u"host\treplication\tall\t0.0.0.0/0\t{}\n".format(auth_host), - u"host\tall\tall\t0.0.0.0/0\t{}\n".format(auth_host), - - u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host) - ] # yapf: disable - - # write missing lines - for line in new_lines: - if line not in lines: - conf.write(line) + # get rid of comments and blank lines + hba_conf_file = self.os_ops.readlines(hba_conf) + lines = [ + s for s in hba_conf_file + if len(s.strip()) > 0 and not s.startswith('#') + ] + + # write filtered lines + self.os_ops.write(hba_conf_file, lines, truncate=True) + + # replication-related settings + if allow_streaming: + # get auth method for host or local users + def get_auth_method(t): + return next((s.split()[-1] + for s in lines if s.startswith(t)), 'trust') + + # get auth methods + auth_local = get_auth_method('local') + auth_host = get_auth_method('host') + + new_lines = [ + u"local\treplication\tall\t\t\t{}\n".format(auth_local), + u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), + + u"host\treplication\tall\t0.0.0.0/0\t{}\n".format(auth_host), + u"host\tall\tall\t0.0.0.0/0\t{}\n".format(auth_host), + + u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host) + ] # yapf: disable + + # write missing lines + for line in new_lines: + if line not in lines: + self.os_ops.write(hba_conf, line) # overwrite config file - with io.open(postgres_conf, "w") as conf: - conf.truncate() + self.os_ops.write(postgres_conf, '', truncate=True) self.append_conf(fsync=fsync, max_worker_processes=MAX_WORKER_PROCESSES, @@ -613,10 +607,10 @@ def append_conf(self, line='', filename=PG_CONF_FILE, **kwargs): lines.append('{} = {}'.format(option, value)) config_name = os.path.join(self.data_dir, filename) - with io.open(config_name, 'a') as conf: - for line in lines: - conf.write(text_type(line)) - conf.write(text_type('\n')) + conf_text = '' + for line in lines: + conf_text += text_type(line) + '\n' + self.os_ops.write(config_name, conf_text) return self @@ -971,10 +965,7 @@ def psql(self, psql_params.append(dbname) # start psql process - process = subprocess.Popen(psql_params, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + process = self.os_ops.exec_command(psql_params) # wait until it finishes and get stdout and stderr out, err = process.communicate(input=input) @@ -1351,7 +1342,7 @@ def pgbench(self, # Set default arguments dbname = dbname or default_dbname() - username = username or default_username() + username = username or default_username(self.os_ops) _params = [ get_bin_path("pgbench"), @@ -1363,7 +1354,7 @@ def pgbench(self, # should be the last one _params.append(dbname) - proc = subprocess.Popen(_params, stdout=stdout, stderr=stderr) + proc = self.os_ops.exec_command(_params, wait_exit=True) return proc @@ -1403,7 +1394,7 @@ def pgbench_run(self, dbname=None, username=None, options=[], **kwargs): # Set default arguments dbname = dbname or default_dbname() - username = username or default_username() + username = username or default_username(os_ops=self.os_ops) _params = [ get_bin_path("pgbench"), @@ -1534,10 +1525,8 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): Defaults to an empty set. """ # parse postgresql.auto.conf - path = os.path.join(self.data_dir, config) - - with open(path, 'r') as f: - raw_content = f.read() + auto_conf_file = os.path.join(self.data_dir, config) + raw_content = self.os_ops.read(auto_conf_file) current_options = {} current_directives = [] @@ -1577,22 +1566,22 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): for directive in current_directives: auto_conf += directive + "\n" - with open(path, 'wt') as f: - f.write(auto_conf) + self.os_ops.write(auto_conf_file, auto_conf) class NodeApp: - def __init__(self, test_path, nodes_to_cleanup): + def __init__(self, test_path, nodes_to_cleanup, os_ops=LocalOperations()): self.test_path = test_path self.nodes_to_cleanup = nodes_to_cleanup + self.os_ops = os_ops def make_empty( self, base_dir=None): real_base_dir = os.path.join(self.test_path, base_dir) - shutil.rmtree(real_base_dir, ignore_errors=True) - os.makedirs(real_base_dir) + self.os_ops.rmdirs(real_base_dir, ignore_errors=True) + self.os_ops.makedirs(real_base_dir) node = PostgresNode(base_dir=real_base_dir) node.should_rm_dirs = True @@ -1615,27 +1604,24 @@ def make_simple( initdb_params=initdb_params, allow_streaming=set_replication) # set major version - with open(os.path.join(node.data_dir, 'PG_VERSION')) as f: - node.major_version_str = str(f.read().rstrip()) - node.major_version = float(node.major_version_str) - - # Sane default parameters - options = {} - options['max_connections'] = 100 - options['shared_buffers'] = '10MB' - options['fsync'] = 'off' - - options['wal_level'] = 'logical' - options['hot_standby'] = 'off' - - options['log_line_prefix'] = '%t [%p]: [%l-1] ' - options['log_statement'] = 'none' - options['log_duration'] = 'on' - options['log_min_duration_statement'] = 0 - options['log_connections'] = 'on' - options['log_disconnections'] = 'on' - options['restart_after_crash'] = 'off' - options['autovacuum'] = 'off' + pg_version_file = self.os_ops.read(os.path.join(node.data_dir, 'PG_VERSION')) + node.major_version_str = str(pg_version_file.rstrip()) + node.major_version = float(node.major_version_str) + + # Set default parameters + options = {'max_connections': 100, + 'shared_buffers': '10MB', + 'fsync': 'off', + 'wal_level': 'logical', + 'hot_standby': 'off', + 'log_line_prefix': '%t [%p]: [%l-1] ', + 'log_statement': 'none', + 'log_duration': 'on', + 'log_min_duration_statement': 0, + 'log_connections': 'on', + 'log_disconnections': 'on', + 'restart_after_crash': 'off', + 'autovacuum': 'off'} # Allow replication in pg_hba.conf if set_replication: diff --git a/testgres/op_ops/local_ops.py b/testgres/op_ops/local_ops.py new file mode 100644 index 00000000..42a3b4b7 --- /dev/null +++ b/testgres/op_ops/local_ops.py @@ -0,0 +1,224 @@ +import getpass +import os +import shutil +import subprocess +import tempfile +from shutil import rmtree + +from testgres.logger import log + +from .os_ops import OsOperations +from .os_ops import pglib + +CMD_TIMEOUT_SEC = 60 + + +class LocalOperations(OsOperations): + + def __init__(self, username=None): + super().__init__() + self.username = username or self.get_user() + + # Command execution + def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): + if isinstance(cmd, list): + cmd = ' '.join(cmd) + log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") + # Source global profile file + execute command + try: + process = subprocess.run(cmd, shell=True, text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=CMD_TIMEOUT_SEC) + exit_status = process.returncode + result = process.stdout + error = process.stderr + + if expect_error: + raise Exception(result, error) + if exit_status != 0 or 'error' in error.lower(): + log.error(f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}") + exit(1) + + if verbose: + return exit_status, result, error + else: + return result + + except Exception as e: + log.error(f"Unexpected error while executing command `{cmd}`: {e}") + return None + + # Environment setup + def environ(self, var_name): + cmd = f"echo ${var_name}" + return self.exec_command(cmd).strip() + + def find_executable(self, executable): + search_paths = self.environ('PATH') + if not search_paths: + return None + + search_paths = search_paths.split(self.pathsep) + for path in search_paths: + remote_file = os.path.join(path, executable) + if self.isfile(remote_file): + return remote_file + + return None + + def is_executable(self, file): + # Check if the file is executable + return os.access(file, os.X_OK) + + def add_to_path(self, new_path): + pathsep = self.pathsep + # Check if the directory is already in PATH + path = self.environ('PATH') + if new_path not in path.split(pathsep): + if self.remote: + self.exec_command(f"export PATH={new_path}{pathsep}{path}") + else: + os.environ['PATH'] = f"{new_path}{pathsep}{path}" + return pathsep + + def set_env(self, var_name, var_val): + # Check if the directory is already in PATH + os.environ[var_name] = var_val + + # Get environment variables + def get_user(self): + return getpass.getuser() + + def get_name(self): + cmd = 'python3 -c "import os; print(os.name)"' + return self.exec_command(cmd).strip() + + # Work with dirs + def makedirs(self, path, remove_existing=False): + if remove_existing and os.path.exists(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) + + def rmdirs(self, path, ignore_errors=True): + return rmtree(path, ignore_errors=ignore_errors) + + def listdir(self, path): + return os.listdir(path) + + def path_exists(self, path): + return os.path.exists(path) + + @property + def pathsep(self): + os_name = self.get_name() + if os_name == 'posix': + pathsep = ':' + elif os_name == 'nt': + pathsep = ';' + else: + raise Exception(f"Unsupported operating system: {os_name}") + return pathsep + + def mkdtemp(self, prefix=None): + return tempfile.mkdtemp(prefix=prefix) + + def copytree(self, src, dst): + return shutil.copytree(src, dst) + + # Work with files + def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + """ + Write data to a file locally + Args: + filename: The file path where the data will be written. + data: The data to be written to the file. + truncate: If True, the file will be truncated before writing ('w' or 'wb' option); + if False (default), data will be appended ('a' or 'ab' option). + binary: If True, the data will be written in binary mode ('wb' or 'ab' option); + if False (default), the data will be written in text mode ('w' or 'a' option). + read_and_write: If True, the file will be opened with read and write permissions ('r+' option); + if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) + """ + mode = 'wb' if binary else 'w' + if not truncate: + mode = 'a' + mode + if read_and_write: + mode = 'r+' + mode + + with open(filename, mode) as file: + if isinstance(data, list): + file.writelines(data) + else: + file.write(data) + + def touch(self, filename): + """ + Create a new file or update the access and modification times of an existing file. + Args: + filename (str): The name of the file to touch. + + This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. + """ + # cross-python touch(). It is vulnerable to races, but who cares? + with open(filename, 'a'): + os.utime(filename, None) + + def read(self, filename): + with open(filename, 'r') as file: + return file.read() + + def readlines(self, filename, num_lines=0, encoding=None): + """ + Read lines from a local file. + If num_lines is greater than 0, only the last num_lines lines will be read. + """ + assert num_lines >= 0 + + if num_lines == 0: + with open(filename, 'r', encoding=encoding) as file: + return file.readlines() + + else: + bufsize = 8192 + buffers = 1 + + with open(filename, 'r', encoding=encoding) as file: + file.seek(0, os.SEEK_END) + end_pos = file.tell() + + while True: + offset = max(0, end_pos - bufsize * buffers) + file.seek(offset, os.SEEK_SET) + pos = file.tell() + lines = file.readlines() + cur_lines = len(lines) + + if cur_lines >= num_lines or pos == 0: + return lines[-num_lines:] + + buffers = int(buffers * max(2, int(num_lines / max(cur_lines, 1)))) # Adjust buffer size + + def isfile(self, remote_file): + return os.path.isfile(remote_file) + + # Processes control + def kill(self, pid, signal): + # Kill the process + cmd = f'kill -{signal} {pid}' + return self.exec_command(cmd) + + def get_pid(self): + # Get current process id + return os.getpid() + + # Database control + def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + conn = pglib.connect( + host=host, + port=port, + database=dbname, + user=user, + password=password, + ) + return conn diff --git a/testgres/op_ops/os_ops.py b/testgres/op_ops/os_ops.py new file mode 100644 index 00000000..89de2640 --- /dev/null +++ b/testgres/op_ops/os_ops.py @@ -0,0 +1,99 @@ +try: + import psycopg2 as pglib +except ImportError: + try: + import pg8000 as pglib + except ImportError: + raise ImportError("You must have psycopg2 or pg8000 modules installed") + +from testgres.defaults import default_username + + +class OsOperations: + + def __init__(self, username=None): + self.hostname = 'localhost' + self.remote = False + self.ssh = None + self.username = username + + # Command execution + def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): + raise NotImplementedError() + + # Environment setup + def environ(self, var_name): + raise NotImplementedError() + + def find_executable(self, executable): + raise NotImplementedError() + + def is_executable(self, file): + # Check if the file is executable + raise NotImplementedError() + + def add_to_path(self, new_path): + raise NotImplementedError() + + def set_env(self, var_name, var_val): + # Check if the directory is already in PATH + raise NotImplementedError() + + # Get environment variables + def get_user(self): + raise NotImplementedError() + + def get_name(self): + raise NotImplementedError() + + # Work with dirs + def makedirs(self, path, remove_existing=False): + raise NotImplementedError() + + def rmdirs(self, path, ignore_errors=True): + raise NotImplementedError() + + def listdir(self, path): + raise NotImplementedError() + + def path_exists(self, path): + raise NotImplementedError() + + @property + def pathsep(self): + raise NotImplementedError() + + def mkdtemp(self, prefix=None): + raise NotImplementedError() + + def copytree(self, src, dst): + raise NotImplementedError() + + # Work with files + def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + raise NotImplementedError() + + def touch(self, filename): + raise NotImplementedError() + + def read(self, filename): + raise NotImplementedError() + + def readlines(self, filename): + raise NotImplementedError() + + def isfile(self, remote_file): + raise NotImplementedError() + + # Processes control + def kill(self, pid, signal): + # Kill the process + raise NotImplementedError() + + def get_pid(self): + # Get current process id + raise NotImplementedError() + + # Database control + def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + raise NotImplementedError() diff --git a/testgres/op_ops/remote_ops.py b/testgres/op_ops/remote_ops.py new file mode 100644 index 00000000..d5faab4e --- /dev/null +++ b/testgres/op_ops/remote_ops.py @@ -0,0 +1,259 @@ +import os +import tempfile +from contextlib import contextmanager + +from testgres.logger import log + +from .os_ops import OsOperations +from .os_ops import pglib + +import paramiko + + +class RemoteOperations(OsOperations): + """ + This class specifically supports work with Linux systems. It utilizes the SSH + for making connections and performing various file and directory operations, command executions, + environment setup and management, process control, and database connections. + It uses the Paramiko library for SSH connections and operations. + + Some methods are designed to work with specific Linux shell commands, and thus may not work as expected + on other non-Linux systems. + + Attributes: + - hostname (str): The remote system's hostname. Default 'localhost'. + - host (str): The remote system's IP address. Default '127.0.0.1'. + - ssh_key (str): Path to the SSH private key for authentication. + - username (str): Username for the remote system. + - ssh (paramiko.SSHClient): SSH connection to the remote system. + """ + + def __init__(self, hostname='localhost', host='127.0.0.1', ssh_key=None, username=None): + super().__init__(username) + self.hostname = hostname + self.host = host + self.ssh_key = ssh_key + self.remote = True + self.ssh = self.connect() + self.username = username or self.get_user() + + def __del__(self): + if self.ssh: + self.ssh.close() + + @contextmanager + def ssh_connect(self): + if not self.remote: + yield None + else: + with open(self.ssh_key, 'r') as f: + key_data = f.read() + if 'BEGIN OPENSSH PRIVATE KEY' in key_data: + key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) + else: + key = paramiko.RSAKey.from_private_key_file(self.ssh_key) + + with paramiko.SSHClient() as ssh: + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.host, username=self.username, pkey=key) + yield ssh + + def connect(self): + with self.ssh_connect() as ssh: + return ssh + + # Command execution + def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding='utf-8'): + if isinstance(cmd, list): + cmd = ' '.join(cmd) + log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") + # Source global profile file + execute command + try: + cmd = f"source /etc/profile.d/custom.sh; {cmd}" + with self.ssh_connect() as ssh: + stdin, stdout, stderr = ssh.exec_command(cmd) + exit_status = 0 + if wait_exit: + exit_status = stdout.channel.recv_exit_status() + result = stdout.read().decode(encoding) + error = stderr.read().decode(encoding) + + if expect_error: + raise Exception(result, error) + if exit_status != 0 or 'error' in error.lower(): + log.error(f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}") + exit(1) + + if verbose: + return exit_status, result, error + else: + return result + + except Exception as e: + log.error(f"Unexpected error while executing command `{cmd}`: {e}") + return None + + # Environment setup + def environ(self, var_name): + cmd = f"echo ${var_name}" + return self.exec_command(cmd).strip() + + def find_executable(self, executable): + search_paths = self.environ('PATH') + if not search_paths: + return None + + search_paths = search_paths.split(self.pathsep) + for path in search_paths: + remote_file = os.path.join(path, executable) + if self.isfile(remote_file): + return remote_file + + return None + + def is_executable(self, file): + # Check if the file is executable + return self.exec_command(f"test -x {file} && echo OK") == 'OK\n' + + def add_to_path(self, new_path): + pathsep = self.pathsep + # Check if the directory is already in PATH + path = self.environ('PATH') + if new_path not in path.split(pathsep): + if self.remote: + self.exec_command(f"export PATH={new_path}{pathsep}{path}") + else: + os.environ['PATH'] = f"{new_path}{pathsep}{path}" + return pathsep + + def set_env(self, var_name, var_val): + # Check if the directory is already in PATH + return self.exec_command(f"export {var_name}={var_val}") + + # Get environment variables + def get_user(self): + return self.exec_command(f"echo $USER") + + def get_name(self): + cmd = 'python3 -c "import os; print(os.name)"' + return self.exec_command(cmd).strip() + + # Work with dirs + def makedirs(self, path, remove_existing=False): + if remove_existing: + cmd = f'rm -rf {path} && mkdir -p {path}' + else: + cmd = f'mkdir -p {path}' + return self.exec_command(cmd) + + def rmdirs(self, path, ignore_errors=True): + cmd = f'rm -rf {path}' + return self.exec_command(cmd) + + def listdir(self, path): + result = self.exec_command(f'ls {path}') + return result.splitlines() + + def path_exists(self, path): + result = self.exec_command(f'test -e {path}; echo $?') + return int(result.strip()) == 0 + + @property + def pathsep(self): + os_name = self.get_name() + if os_name == 'posix': + pathsep = ':' + elif os_name == 'nt': + pathsep = ';' + else: + raise Exception(f"Unsupported operating system: {os_name}") + return pathsep + + def mkdtemp(self, prefix=None): + temp_dir = self.exec_command(f'mkdtemp -d {prefix}') + return temp_dir.strip() + + def copytree(self, src, dst): + return self.exec_command(f'cp -r {src} {dst}') + + # Work with files + def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + """ + Write data to a file on a remote host + Args: + filename: The file path where the data will be written. + data: The data to be written to the file. + truncate: If True, the file will be truncated before writing ('w' or 'wb' option); + if False (default), data will be appended ('a' or 'ab' option). + binary: If True, the data will be written in binary mode ('wb' or 'ab' option); + if False (default), the data will be written in text mode ('w' or 'a' option). + read_and_write: If True, the file will be opened with read and write permissions ('r+' option); + if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) + """ + mode = 'wb' if binary else 'w' + if not truncate: + mode = 'a' + mode + if read_and_write: + mode = 'r+' + mode + + with tempfile.NamedTemporaryFile(mode=mode) as tmp_file: + if isinstance(data, list): + tmp_file.writelines(data) + else: + tmp_file.write(data) + tmp_file.flush() + + sftp = self.ssh.open_sftp() + sftp.put(tmp_file.name, filename) + sftp.close() + + def touch(self, filename): + """ + Create a new file or update the access and modification times of an existing file on the remote server. + + Args: + filename (str): The name of the file to touch. + + This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. + """ + self.exec_command(f'touch {filename}') + + def read(self, filename, encoding='utf-8'): + cmd = f'cat {filename}' + return self.exec_command(cmd, encoding=encoding) + + def readlines(self, filename, num_lines=0, encoding=None): + encoding = encoding or 'utf-8' + if num_lines > 0: + cmd = f'tail -n {num_lines} {filename}' + lines = self.exec_command(cmd, encoding) + else: + lines = self.read(filename, encoding=encoding).splitlines() + return lines + + def isfile(self, remote_file): + stdout = self.exec_command(f'test -f {remote_file}; echo $?') + result = int(stdout.strip()) + return result == 0 + + # Processes control + def kill(self, pid, signal): + # Kill the process + cmd = f'kill -{signal} {pid}' + return self.exec_command(cmd) + + def get_pid(self): + # Get current process id + return self.exec_command(f"echo $$") + + # Database control + def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + local_port = self.ssh.forward_remote_port(host, port) + conn = pglib.connect( + host=host, + port=local_port, + database=dbname, + user=user, + password=password, + ) + return conn diff --git a/testgres/os_ops.py b/testgres/os_ops.py deleted file mode 100644 index 0be8c2a7..00000000 --- a/testgres/os_ops.py +++ /dev/null @@ -1,285 +0,0 @@ -import getpass -import os -import shutil -import subprocess -import tempfile -from contextlib import contextmanager -from shutil import rmtree - -try: - import psycopg2 as pglib -except ImportError: - try: - import pg8000 as pglib - except ImportError: - raise ImportError("You must have psycopg2 or pg8000 modules installed") - -from testgres.defaults import default_username -from testgres.logger import log - -import paramiko - - -class OsOperations: - - def __init__(self, host='127.0.0.1', hostname='localhost', ssh_key=None, username=default_username()): - self.host = host - self.ssh_key = ssh_key - self.username = username - self.remote = not (self.host == '127.0.0.1' and hostname == 'localhost') - self.ssh = None - - if self.remote: - self.ssh = self.connect() - - def __del__(self): - if self.ssh: - self.ssh.close() - - @contextmanager - def ssh_connect(self): - if not self.remote: - yield None - else: - with open(self.ssh_key, 'r') as f: - key_data = f.read() - if 'BEGIN OPENSSH PRIVATE KEY' in key_data: - key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) - else: - key = paramiko.RSAKey.from_private_key_file(self.ssh_key) - - with paramiko.SSHClient() as ssh: - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(self.host, username=self.username, pkey=key) - yield ssh - - def connect(self): - with self.ssh_connect() as ssh: - return ssh - - def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): - if isinstance(cmd, list): - cmd = ' '.join(cmd) - log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") - # Source global profile file + execute command - try: - if self.remote: - cmd = f"source /etc/profile.d/custom.sh; {cmd}" - with self.ssh_connect() as ssh: - stdin, stdout, stderr = ssh.exec_command(cmd) - exit_status = 0 - if wait_exit: - exit_status = stdout.channel.recv_exit_status() - result = stdout.read().decode('utf-8') - error = stderr.read().decode('utf-8') - else: - process = subprocess.run(cmd, shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - timeout=60) - exit_status = process.returncode - result = process.stdout - error = process.stderr - - if expect_error: - raise Exception(result, error) - if exit_status != 0 or 'error' in error.lower(): - log.error(f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}") - exit(1) - - if verbose: - return exit_status, result, error - else: - return result - - except Exception as e: - log.error(f"Unexpected error while executing command `{cmd}`: {e}") - return None - - def makedirs(self, path, remove_existing=False): - if remove_existing: - cmd = f'rm -rf {path} && mkdir -p {path}' - else: - cmd = f'mkdir -p {path}' - self.exec_command(cmd) - - def rmdirs(self, path, ignore_errors=True): - if self.remote: - cmd = f'rm -rf {path}' - self.exec_command(cmd) - else: - rmtree(path, ignore_errors=ignore_errors) - - def mkdtemp(self, prefix=None): - if self.remote: - temp_dir = self.exec_command(f'mkdtemp -d {prefix}') - return temp_dir.strip() - else: - return tempfile.mkdtemp(prefix=prefix) - - def path_exists(self, path): - if self.remote: - result = self.exec_command(f'test -e {path}; echo $?') - return int(result.strip()) == 0 - else: - return os.path.exists(path) - - def copytree(self, src, dst): - if self.remote: - self.exec_command(f'cp -r {src} {dst}') - else: - shutil.copytree(src, dst) - - def listdir(self, path): - if self.remote: - result = self.exec_command(f'ls {path}') - return result.splitlines() - else: - return os.listdir(path) - - def write(self, filename, data, truncate=False, binary=False, read_and_write=False): - """ - Write data to a file, both locally and on a remote host. - - :param filename: The file path where the data will be written. - :param data: The data to be written to the file. - :param truncate: If True, the file will be truncated before writing ('w' or 'wb' option); - if False (default), data will be appended ('a' or 'ab' option). - :param binary: If True, the data will be written in binary mode ('wb' or 'ab' option); - if False (default), the data will be written in text mode ('w' or 'a' option). - :param read_and_write: If True, the file will be opened with read and write permissions ('r+' option); - if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option). - """ - mode = 'wb' if binary else 'w' - if not truncate: - mode = 'a' + mode - if read_and_write: - mode = 'r+' + mode - - if self.remote: - with tempfile.NamedTemporaryFile() as tmp_file: - tmp_file.write(data) - tmp_file.flush() - - sftp = self.ssh.open_sftp() - sftp.put(tmp_file.name, filename) - sftp.close() - else: - with open(filename, mode) as file: - file.write(data) - - def read(self, filename): - cmd = f'cat {filename}' - return self.exec_command(cmd) - - def readlines(self, filename): - return self.read(filename).splitlines() - - def get_name(self): - cmd = 'python3 -c "import os; print(os.name)"' - return self.exec_command(cmd).strip() - - def kill(self, pid, signal): - cmd = f'kill -{signal} {pid}' - self.exec_command(cmd) - - def environ(self, var_name): - cmd = f"echo ${var_name}" - return self.exec_command(cmd).strip() - - @property - def pathsep(self): - return ':' if self.get_name() == 'posix' else ';' - - def isfile(self, remote_file): - if self.remote: - stdout = self.exec_command(f'test -f {remote_file}; echo $?') - result = int(stdout.strip()) - return result == 0 - else: - return os.path.isfile(remote_file) - - def find_executable(self, executable): - search_paths = self.environ('PATH') - if not search_paths: - return None - - search_paths = search_paths.split(self.pathsep) - for path in search_paths: - remote_file = os.path.join(path, executable) - if self.isfile(remote_file): - return remote_file - - return None - - def is_executable(self, file): - # Check if the file is executable - if self.remote: - if not self.exec_command(f"test -x {file} && echo OK") == 'OK\n': - return False - else: - if not os.access(file, os.X_OK): - return False - return True - - def add_to_path(self, new_path): - os_name = self.get_name() - if os_name == 'posix': - dir_del = ':' - elif os_name == 'nt': - dir_del = ';' - else: - raise Exception(f"Unsupported operating system: {os_name}") - - # Check if the directory is already in PATH - path = self.environ('PATH') - if new_path not in path.split(dir_del): - if self.remote: - self.exec_command(f"export PATH={new_path}{dir_del}{path}") - else: - os.environ['PATH'] = f"{new_path}{dir_del}{path}" - return dir_del - - def set_env(self, var_name, var_val): - # Check if the directory is already in PATH - if self.remote: - self.exec_command(f"export {var_name}={var_val}") - else: - os.environ[var_name] = var_val - - def get_pid(self): - # Get current process id - if self.remote: - process_id = self.exec_command(f"echo $$") - else: - process_id = os.getpid() - return process_id - - def get_user(self): - # Get current user - if self.remote: - user = self.exec_command(f"echo $USER") - else: - user = getpass.getuser() - return user - - def db_connect(self, dbname, user, password=None, host='localhost', port=5432): - if self.remote: - local_port = self.ssh.forward_remote_port(host, port) - conn = pglib.connect( - host=host, - port=local_port, - dbname=dbname, - user=user, - password=password, - ) - else: - conn = pglib.connect( - host=host, - port=port, - dbname=dbname, - user=user, - password=password, - ) - return conn - - - diff --git a/testgres/utils.py b/testgres/utils.py index b27fb6b8..73ca6f1a 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -19,6 +19,8 @@ from six import iteritems from fabric import Connection +from .op_ops.local_ops import LocalOperations +from .op_ops.os_ops import OsOperations from .config import testgres_config from .exceptions import ExecUtilException @@ -52,7 +54,7 @@ def release_port(port): bound_ports.discard(port) -def execute_utility(args, logfile=None, hostname='localhost', ssh_key=None): +def execute_utility(args, logfile=None, os_ops: OsOperations = LocalOperations()): """ Execute utility (pg_ctl, pg_dump etc). @@ -64,11 +66,11 @@ def execute_utility(args, logfile=None, hostname='localhost', ssh_key=None): stdout of executed utility. """ - if hostname != 'localhost': + if os_ops.hostname != 'localhost': conn = Connection( - hostname, + os_ops.hostname, connect_kwargs={ - "key_filename": f"{ssh_key}", + "key_filename": f"{os_ops.ssh_key}", }, ) From ac77ef78f640b9b518817403c6841b25e8f46e9b Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 12 Jun 2023 00:31:28 +0200 Subject: [PATCH 08/23] PBCKP-152 use black for formatting --- testgres/cache.py | 7 +- testgres/config.py | 9 ++- testgres/defaults.py | 3 +- testgres/node.py | 42 +++++------- testgres/{op_ops => os_ops}/local_ops.py | 62 ++++++++++------- testgres/{op_ops => os_ops}/os_ops.py | 13 ++-- testgres/{op_ops => os_ops}/remote_ops.py | 81 +++++++++++++---------- testgres/utils.py | 20 +++--- 8 files changed, 127 insertions(+), 110 deletions(-) rename testgres/{op_ops => os_ops}/local_ops.py (80%) rename testgres/{op_ops => os_ops}/os_ops.py (90%) rename testgres/{op_ops => os_ops}/remote_ops.py (79%) diff --git a/testgres/cache.py b/testgres/cache.py index 4998e0d2..1df5a8ea 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -4,8 +4,8 @@ from six import raise_from -from .op_ops.local_ops import LocalOperations -from .op_ops.os_ops import OsOperations +from .os_ops.local_ops import LocalOperations +from .os_ops.os_ops import OsOperations from .config import testgres_config from .consts import XLOG_CONTROL_FILE @@ -25,6 +25,7 @@ def cached_initdb(data_dir, logfile=None, params=None, os_ops: OsOperations = Lo """ Perform initdb or use cached node files. """ + testgres_config.os_ops = os_ops def call_initdb(initdb_dir, log=None): try: @@ -60,7 +61,7 @@ def call_initdb(initdb_dir, log=None): # XXX: build new WAL segment with our system id _params = [get_bin_path("pg_resetwal"), "-D", data_dir, "-f"] - execute_utility(_params, logfile, os_ops) + execute_utility(_params, logfile, os_ops=os_ops) except ExecUtilException as e: msg = "Failed to reset WAL for system id" diff --git a/testgres/config.py b/testgres/config.py index 1be76fbe..fd942664 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -5,9 +5,8 @@ import tempfile from contextlib import contextmanager -from shutil import rmtree -from .op_ops.local_ops import LocalOperations +from .os_ops.local_ops import LocalOperations from .consts import TMP_CACHE @@ -44,7 +43,7 @@ class GlobalConfig(object): _cached_initdb_dir = None """ underlying class attribute for cached_initdb_dir property """ - os_ops = None + os_ops = LocalOperations() """ OsOperation object that allows work on remote host """ @property def cached_initdb_dir(self): @@ -137,9 +136,9 @@ def copy(self): @atexit.register -def _rm_cached_initdb_dirs(os_ops=LocalOperations()): +def _rm_cached_initdb_dirs(): for d in cached_initdb_dirs: - os_ops.rmdirs(d, ignore_errors=True) + testgres_config.os_ops.rmdirs(d, ignore_errors=True) def push_config(**options): diff --git a/testgres/defaults.py b/testgres/defaults.py index cac788b8..5ffc08de 100644 --- a/testgres/defaults.py +++ b/testgres/defaults.py @@ -1,9 +1,8 @@ import datetime -import os import struct import uuid -from .op_ops.local_ops import LocalOperations +from .os_ops.local_ops import LocalOperations def default_dbname(): diff --git a/testgres/node.py b/testgres/node.py index f06e5cc9..6456e5a9 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -1,20 +1,16 @@ # coding: utf-8 -import io import os import random -import shutil import signal import threading from queue import Queue import psutil -import subprocess import time -from op_ops.local_ops import LocalOperations -from op_ops.os_ops import OsOperations -from op_ops.remote_ops import RemoteOperations +from .os_ops.local_ops import LocalOperations +from .os_ops.remote_ops import RemoteOperations try: from collections.abc import Iterable @@ -32,7 +28,6 @@ from shutil import rmtree from six import raise_from, iteritems, text_type -from tempfile import mkstemp, mkdtemp from .enums import \ NodeStatus, \ @@ -96,7 +91,6 @@ eprint, \ get_bin_path, \ get_pg_version, \ - file_tail, \ reserve_port, \ release_port, \ execute_utility, \ @@ -163,6 +157,7 @@ def __init__(self, name=None, port=None, base_dir=None, else: self.os_ops = RemoteOperations(host, hostname, ssh_key) + testgres_config.os_ops = self.os_ops # defaults for __exit__() self.cleanup_on_good_exit = testgres_config.node_cleanup_on_good_exit self.cleanup_on_bad_exit = testgres_config.node_cleanup_on_bad_exit @@ -289,7 +284,7 @@ def base_dir(self): self._base_dir = self.os_ops.mkdtemp(prefix=TMP_NODE) # NOTE: it's safe to create a new dir - if not self.os_ops.exists(self._base_dir): + if not self.os_ops.path_exists(self._base_dir): self.os_ops.makedirs(self._base_dir) return self._base_dir @@ -299,7 +294,7 @@ def logs_dir(self): path = os.path.join(self.base_dir, LOGS_DIR) # NOTE: it's safe to create a new dir - if not self.os_ops.exists(path): + if not self.os_ops.path_exists(path): self.os_ops.makedirs(path) return path @@ -628,7 +623,7 @@ def status(self): "-D", self.data_dir, "status" ] # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) return NodeStatus.Running except ExecUtilException as e: @@ -650,7 +645,7 @@ def get_control_data(self): _params += ["-D"] if self._pg_version >= PgVer('9.5') else [] _params += [self.data_dir] - data = execute_utility(_params, self.utils_log_file) + data = execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) out_dict = {} @@ -713,7 +708,7 @@ def start(self, params=[], wait=True): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) except ExecUtilException as e: msg = 'Cannot start node' files = self._collect_special_files() @@ -744,7 +739,7 @@ def stop(self, params=[], wait=True): "stop" ] + params # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) self._maybe_stop_logger() self.is_started = False @@ -786,7 +781,7 @@ def restart(self, params=[]): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) except ExecUtilException as e: msg = 'Cannot restart node' files = self._collect_special_files() @@ -813,7 +808,7 @@ def reload(self, params=[]): "reload" ] + params # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) return self @@ -835,7 +830,7 @@ def promote(self, dbname=None, username=None): "promote" ] # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) # for versions below 10 `promote` is asynchronous so we need to wait # until it actually becomes writable @@ -870,7 +865,7 @@ def pg_ctl(self, params): "-w" # wait ] + params # yapf: disable - return execute_utility(_params, self.utils_log_file) + return execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) def free_port(self): """ @@ -1035,10 +1030,9 @@ def dump(self, # Generate tmpfile or tmpdir def tmpfile(): if format == DumpFormat.Directory: - fname = mkdtemp(prefix=TMP_DUMP) + fname = self.os_ops.mkdtemp(prefix=TMP_DUMP) else: - fd, fname = mkstemp(prefix=TMP_DUMP) - os.close(fd) + fname = self.os_ops.mkstemp(prefix=TMP_DUMP) return fname # Set default arguments @@ -1056,7 +1050,7 @@ def tmpfile(): "-F", format.value ] # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) return filename @@ -1085,7 +1079,7 @@ def restore(self, filename, dbname=None, username=None): # try pg_restore if dump is binary formate, and psql if not try: - execute_utility(_params, self.utils_log_name) + execute_utility(_params, self.utils_log_name, os_ops=self.os_ops) except ExecUtilException: self.psql(filename=filename, dbname=dbname, username=username) @@ -1417,7 +1411,7 @@ def pgbench_run(self, dbname=None, username=None, options=[], **kwargs): # should be the last one _params.append(dbname) - return execute_utility(_params, self.utils_log_file) + return execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) def connect(self, dbname=None, diff --git a/testgres/op_ops/local_ops.py b/testgres/os_ops/local_ops.py similarity index 80% rename from testgres/op_ops/local_ops.py rename to testgres/os_ops/local_ops.py index 42a3b4b7..d6977c9f 100644 --- a/testgres/op_ops/local_ops.py +++ b/testgres/os_ops/local_ops.py @@ -14,7 +14,6 @@ class LocalOperations(OsOperations): - def __init__(self, username=None): super().__init__() self.username = username or self.get_user() @@ -22,22 +21,28 @@ def __init__(self, username=None): # Command execution def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): if isinstance(cmd, list): - cmd = ' '.join(cmd) + cmd = " ".join(cmd) log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") # Source global profile file + execute command try: - process = subprocess.run(cmd, shell=True, text=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - timeout=CMD_TIMEOUT_SEC) + process = subprocess.run( + cmd, + shell=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=CMD_TIMEOUT_SEC, + ) exit_status = process.returncode result = process.stdout error = process.stderr if expect_error: raise Exception(result, error) - if exit_status != 0 or 'error' in error.lower(): - log.error(f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}") + if exit_status != 0 or "error" in error.lower(): + log.error( + f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}" + ) exit(1) if verbose: @@ -55,7 +60,7 @@ def environ(self, var_name): return self.exec_command(cmd).strip() def find_executable(self, executable): - search_paths = self.environ('PATH') + search_paths = self.environ("PATH") if not search_paths: return None @@ -74,12 +79,12 @@ def is_executable(self, file): def add_to_path(self, new_path): pathsep = self.pathsep # Check if the directory is already in PATH - path = self.environ('PATH') + path = self.environ("PATH") if new_path not in path.split(pathsep): if self.remote: self.exec_command(f"export PATH={new_path}{pathsep}{path}") else: - os.environ['PATH'] = f"{new_path}{pathsep}{path}" + os.environ["PATH"] = f"{new_path}{pathsep}{path}" return pathsep def set_env(self, var_name, var_val): @@ -112,10 +117,10 @@ def path_exists(self, path): @property def pathsep(self): os_name = self.get_name() - if os_name == 'posix': - pathsep = ':' - elif os_name == 'nt': - pathsep = ';' + if os_name == "posix": + pathsep = ":" + elif os_name == "nt": + pathsep = ";" else: raise Exception(f"Unsupported operating system: {os_name}") return pathsep @@ -123,6 +128,11 @@ def pathsep(self): def mkdtemp(self, prefix=None): return tempfile.mkdtemp(prefix=prefix) + def mkstemp(self, prefix=None): + fd, filename = tempfile.mkstemp(prefix=prefix) + os.close(fd) # Close the file descriptor immediately after creating the file + return filename + def copytree(self, src, dst): return shutil.copytree(src, dst) @@ -140,11 +150,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal read_and_write: If True, the file will be opened with read and write permissions ('r+' option); if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) """ - mode = 'wb' if binary else 'w' + mode = "wb" if binary else "w" if not truncate: - mode = 'a' + mode + mode = "a" + mode if read_and_write: - mode = 'r+' + mode + mode = "r+" + mode with open(filename, mode) as file: if isinstance(data, list): @@ -161,11 +171,11 @@ def touch(self, filename): This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. """ # cross-python touch(). It is vulnerable to races, but who cares? - with open(filename, 'a'): + with open(filename, "a"): os.utime(filename, None) def read(self, filename): - with open(filename, 'r') as file: + with open(filename, "r") as file: return file.read() def readlines(self, filename, num_lines=0, encoding=None): @@ -176,14 +186,14 @@ def readlines(self, filename, num_lines=0, encoding=None): assert num_lines >= 0 if num_lines == 0: - with open(filename, 'r', encoding=encoding) as file: + with open(filename, "r", encoding=encoding) as file: return file.readlines() else: bufsize = 8192 buffers = 1 - with open(filename, 'r', encoding=encoding) as file: + with open(filename, "r", encoding=encoding) as file: file.seek(0, os.SEEK_END) end_pos = file.tell() @@ -197,7 +207,9 @@ def readlines(self, filename, num_lines=0, encoding=None): if cur_lines >= num_lines or pos == 0: return lines[-num_lines:] - buffers = int(buffers * max(2, int(num_lines / max(cur_lines, 1)))) # Adjust buffer size + buffers = int( + buffers * max(2, int(num_lines / max(cur_lines, 1))) + ) # Adjust buffer size def isfile(self, remote_file): return os.path.isfile(remote_file) @@ -205,7 +217,7 @@ def isfile(self, remote_file): # Processes control def kill(self, pid, signal): # Kill the process - cmd = f'kill -{signal} {pid}' + cmd = f"kill -{signal} {pid}" return self.exec_command(cmd) def get_pid(self): @@ -213,7 +225,7 @@ def get_pid(self): return os.getpid() # Database control - def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + def db_connect(self, dbname, user, password=None, host="localhost", port=5432): conn = pglib.connect( host=host, port=port, diff --git a/testgres/op_ops/os_ops.py b/testgres/os_ops/os_ops.py similarity index 90% rename from testgres/op_ops/os_ops.py rename to testgres/os_ops/os_ops.py index 89de2640..1ee1f869 100644 --- a/testgres/op_ops/os_ops.py +++ b/testgres/os_ops/os_ops.py @@ -1,18 +1,15 @@ try: - import psycopg2 as pglib + import psycopg2 as pglib # noqa: F401 except ImportError: try: - import pg8000 as pglib + import pg8000 as pglib # noqa: F401 except ImportError: raise ImportError("You must have psycopg2 or pg8000 modules installed") -from testgres.defaults import default_username - class OsOperations: - def __init__(self, username=None): - self.hostname = 'localhost' + self.hostname = "localhost" self.remote = False self.ssh = None self.username = username @@ -49,7 +46,7 @@ def get_name(self): # Work with dirs def makedirs(self, path, remove_existing=False): raise NotImplementedError() - + def rmdirs(self, path, ignore_errors=True): raise NotImplementedError() @@ -95,5 +92,5 @@ def get_pid(self): raise NotImplementedError() # Database control - def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + def db_connect(self, dbname, user, password=None, host="localhost", port=5432): raise NotImplementedError() diff --git a/testgres/op_ops/remote_ops.py b/testgres/os_ops/remote_ops.py similarity index 79% rename from testgres/op_ops/remote_ops.py rename to testgres/os_ops/remote_ops.py index d5faab4e..e1460b75 100644 --- a/testgres/op_ops/remote_ops.py +++ b/testgres/os_ops/remote_ops.py @@ -28,7 +28,9 @@ class RemoteOperations(OsOperations): - ssh (paramiko.SSHClient): SSH connection to the remote system. """ - def __init__(self, hostname='localhost', host='127.0.0.1', ssh_key=None, username=None): + def __init__( + self, hostname="localhost", host="127.0.0.1", ssh_key=None, username=None + ): super().__init__(username) self.hostname = hostname self.host = host @@ -46,9 +48,9 @@ def ssh_connect(self): if not self.remote: yield None else: - with open(self.ssh_key, 'r') as f: + with open(self.ssh_key, "r") as f: key_data = f.read() - if 'BEGIN OPENSSH PRIVATE KEY' in key_data: + if "BEGIN OPENSSH PRIVATE KEY" in key_data: key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) else: key = paramiko.RSAKey.from_private_key_file(self.ssh_key) @@ -63,9 +65,11 @@ def connect(self): return ssh # Command execution - def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding='utf-8'): + def exec_command( + self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding="utf-8" + ): if isinstance(cmd, list): - cmd = ' '.join(cmd) + cmd = " ".join(cmd) log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") # Source global profile file + execute command try: @@ -80,8 +84,10 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, if expect_error: raise Exception(result, error) - if exit_status != 0 or 'error' in error.lower(): - log.error(f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}") + if exit_status != 0 or "error" in error.lower(): + log.error( + f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}" + ) exit(1) if verbose: @@ -99,7 +105,7 @@ def environ(self, var_name): return self.exec_command(cmd).strip() def find_executable(self, executable): - search_paths = self.environ('PATH') + search_paths = self.environ("PATH") if not search_paths: return None @@ -113,17 +119,17 @@ def find_executable(self, executable): def is_executable(self, file): # Check if the file is executable - return self.exec_command(f"test -x {file} && echo OK") == 'OK\n' + return self.exec_command(f"test -x {file} && echo OK") == "OK\n" def add_to_path(self, new_path): pathsep = self.pathsep # Check if the directory is already in PATH - path = self.environ('PATH') + path = self.environ("PATH") if new_path not in path.split(pathsep): if self.remote: self.exec_command(f"export PATH={new_path}{pathsep}{path}") else: - os.environ['PATH'] = f"{new_path}{pathsep}{path}" + os.environ["PATH"] = f"{new_path}{pathsep}{path}" return pathsep def set_env(self, var_name, var_val): @@ -132,7 +138,7 @@ def set_env(self, var_name, var_val): # Get environment variables def get_user(self): - return self.exec_command(f"echo $USER") + return self.exec_command("echo $USER") def get_name(self): cmd = 'python3 -c "import os; print(os.name)"' @@ -141,40 +147,45 @@ def get_name(self): # Work with dirs def makedirs(self, path, remove_existing=False): if remove_existing: - cmd = f'rm -rf {path} && mkdir -p {path}' + cmd = f"rm -rf {path} && mkdir -p {path}" else: - cmd = f'mkdir -p {path}' + cmd = f"mkdir -p {path}" return self.exec_command(cmd) def rmdirs(self, path, ignore_errors=True): - cmd = f'rm -rf {path}' + cmd = f"rm -rf {path}" return self.exec_command(cmd) def listdir(self, path): - result = self.exec_command(f'ls {path}') + result = self.exec_command(f"ls {path}") return result.splitlines() def path_exists(self, path): - result = self.exec_command(f'test -e {path}; echo $?') + result = self.exec_command(f"test -e {path}; echo $?") return int(result.strip()) == 0 @property def pathsep(self): os_name = self.get_name() - if os_name == 'posix': - pathsep = ':' - elif os_name == 'nt': - pathsep = ';' + if os_name == "posix": + pathsep = ":" + elif os_name == "nt": + pathsep = ";" else: raise Exception(f"Unsupported operating system: {os_name}") return pathsep def mkdtemp(self, prefix=None): - temp_dir = self.exec_command(f'mkdtemp -d {prefix}') + temp_dir = self.exec_command(f"mkdtemp -d {prefix}") return temp_dir.strip() + def mkstemp(self, prefix=None): + cmd = f"mktemp {prefix}XXXXXX" + filename = self.exec_command(cmd).strip() + return filename + def copytree(self, src, dst): - return self.exec_command(f'cp -r {src} {dst}') + return self.exec_command(f"cp -r {src} {dst}") # Work with files def write(self, filename, data, truncate=False, binary=False, read_and_write=False): @@ -190,11 +201,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal read_and_write: If True, the file will be opened with read and write permissions ('r+' option); if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) """ - mode = 'wb' if binary else 'w' + mode = "wb" if binary else "w" if not truncate: - mode = 'a' + mode + mode = "a" + mode if read_and_write: - mode = 'r+' + mode + mode = "r+" + mode with tempfile.NamedTemporaryFile(mode=mode) as tmp_file: if isinstance(data, list): @@ -216,38 +227,38 @@ def touch(self, filename): This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. """ - self.exec_command(f'touch {filename}') + self.exec_command(f"touch {filename}") - def read(self, filename, encoding='utf-8'): - cmd = f'cat {filename}' + def read(self, filename, encoding="utf-8"): + cmd = f"cat {filename}" return self.exec_command(cmd, encoding=encoding) def readlines(self, filename, num_lines=0, encoding=None): - encoding = encoding or 'utf-8' + encoding = encoding or "utf-8" if num_lines > 0: - cmd = f'tail -n {num_lines} {filename}' + cmd = f"tail -n {num_lines} {filename}" lines = self.exec_command(cmd, encoding) else: lines = self.read(filename, encoding=encoding).splitlines() return lines def isfile(self, remote_file): - stdout = self.exec_command(f'test -f {remote_file}; echo $?') + stdout = self.exec_command(f"test -f {remote_file}; echo $?") result = int(stdout.strip()) return result == 0 # Processes control def kill(self, pid, signal): # Kill the process - cmd = f'kill -{signal} {pid}' + cmd = f"kill -{signal} {pid}" return self.exec_command(cmd) def get_pid(self): # Get current process id - return self.exec_command(f"echo $$") + return self.exec_command("echo $$") # Database control - def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + def db_connect(self, dbname, user, password=None, host="localhost", port=5432): local_port = self.ssh.forward_remote_port(host, port) conn = pglib.connect( host=host, diff --git a/testgres/utils.py b/testgres/utils.py index 73ca6f1a..72fd1b9d 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -12,6 +12,9 @@ from contextlib import contextmanager from packaging.version import Version + +from .os_ops.remote_ops import RemoteOperations + try: from shutil import which as find_executable except ImportError: @@ -19,8 +22,8 @@ from six import iteritems from fabric import Connection -from .op_ops.local_ops import LocalOperations -from .op_ops.os_ops import OsOperations +from .os_ops.local_ops import LocalOperations +from .os_ops.os_ops import OsOperations from .config import testgres_config from .exceptions import ExecUtilException @@ -59,6 +62,7 @@ def execute_utility(args, logfile=None, os_ops: OsOperations = LocalOperations() Execute utility (pg_ctl, pg_dump etc). Args: + os_ops: LocalOperations for local node or RemoteOperations for node that connected by ssh. args: utility + arguments (list). logfile: path to file to store stdout and stderr. @@ -66,21 +70,20 @@ def execute_utility(args, logfile=None, os_ops: OsOperations = LocalOperations() stdout of executed utility. """ - if os_ops.hostname != 'localhost': + if isinstance(os_ops, RemoteOperations): conn = Connection( os_ops.hostname, connect_kwargs={ "key_filename": f"{os_ops.ssh_key}", }, ) - # TODO skip remote ssh run if we are on the localhost. # result = conn.run('hostname', hide=True) - # add logger + # add logger cmd = ' '.join(args) - result = conn.run(cmd, hide=True) - + result = conn.run(cmd, hide=True) + return result # run utility @@ -173,8 +176,9 @@ def get_bin_path(filename): def get_pg_config(pg_config_path=None): """ Return output of pg_config (provided that it is installed). - NOTE: this fuction caches the result by default (see GlobalConfig). + NOTE: this function caches the result by default (see GlobalConfig). """ + def cache_pg_config_data(cmd): # execute pg_config and get the output out = subprocess.check_output([cmd]).decode('utf-8') From b04804102457adf6ef29646e9dd86ceaccb24127 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 12 Jun 2023 23:08:10 +0200 Subject: [PATCH 09/23] PBCKP-152 fix failed tests --- testgres/__init__.py | 5 ++ testgres/cache.py | 7 +- testgres/config.py | 2 +- testgres/connection.py | 24 ++---- testgres/defaults.py | 2 +- testgres/node.py | 21 +++-- testgres/operations/__init__.py | 0 testgres/{os_ops => operations}/local_ops.py | 45 ++++++----- testgres/{os_ops => operations}/os_ops.py | 0 testgres/{os_ops => operations}/remote_ops.py | 55 +++++++++---- testgres/utils.py | 8 +- tests/test_remote.py | 81 +++++++++++++++++++ tests/test_simple.py | 6 +- 13 files changed, 181 insertions(+), 75 deletions(-) create mode 100644 testgres/operations/__init__.py rename testgres/{os_ops => operations}/local_ops.py (82%) rename testgres/{os_ops => operations}/os_ops.py (100%) rename testgres/{os_ops => operations}/remote_ops.py (84%) create mode 100755 tests/test_remote.py diff --git a/testgres/__init__.py b/testgres/__init__.py index 1b33ba3b..405262dd 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -46,6 +46,10 @@ First, \ Any +from .operations.os_ops import OsOperations +from .operations.local_ops import LocalOperations +from .operations.remote_ops import RemoteOperations + __all__ = [ "get_new_node", "NodeBackup", @@ -56,4 +60,5 @@ "PostgresNode", "NodeApp", "reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version", "First", "Any", + "OsOperations", "LocalOperations", "RemoteOperations" ] diff --git a/testgres/cache.py b/testgres/cache.py index 1df5a8ea..ef07e976 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -4,8 +4,6 @@ from six import raise_from -from .os_ops.local_ops import LocalOperations -from .os_ops.os_ops import OsOperations from .config import testgres_config from .consts import XLOG_CONTROL_FILE @@ -20,6 +18,9 @@ get_bin_path, \ execute_utility +from .operations.local_ops import LocalOperations +from .operations.os_ops import OsOperations + def cached_initdb(data_dir, logfile=None, params=None, os_ops: OsOperations = LocalOperations()): """ @@ -38,7 +39,7 @@ def call_initdb(initdb_dir, log=None): call_initdb(data_dir, logfile) else: # Fetch cached initdb dir - cached_data_dir = testgres_config.cached_initdb_dir() + cached_data_dir = testgres_config.cached_initdb_dir # Initialize cached initdb diff --git a/testgres/config.py b/testgres/config.py index fd942664..b21d8356 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -6,8 +6,8 @@ from contextlib import contextmanager -from .os_ops.local_ops import LocalOperations from .consts import TMP_CACHE +from .operations.local_ops import LocalOperations class GlobalConfig(object): diff --git a/testgres/connection.py b/testgres/connection.py index f3d01ebe..6725b14f 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -102,23 +102,15 @@ def rollback(self): return self def execute(self, query, *args): + self.cursor.execute(query, args) try: - with self.connection.cursor() as cursor: - cursor.execute(query, args) - try: - res = cursor.fetchall() - - # pg8000 might return tuples - if isinstance(res, tuple): - res = [tuple(t) for t in res] - - return res - except (pglib.ProgrammingError, pglib.InternalError) as e: - # An error occurred while trying to fetch results (e.g., no results to fetch) - print(f"Error fetching results: {e}") - return None - except (pglib.Error, Exception) as e: - # Handle other database errors + res = self.cursor.fetchall() + # pg8000 might return tuples + if isinstance(res, tuple): + res = [tuple(t) for t in res] + + return res + except Exception as e: print(f"Error executing query: {e}") return None diff --git a/testgres/defaults.py b/testgres/defaults.py index 5ffc08de..34bcc08b 100644 --- a/testgres/defaults.py +++ b/testgres/defaults.py @@ -2,7 +2,7 @@ import struct import uuid -from .os_ops.local_ops import LocalOperations +from .operations.local_ops import LocalOperations def default_dbname(): diff --git a/testgres/node.py b/testgres/node.py index 6456e5a9..9aa47d84 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -9,9 +9,6 @@ import psutil import time -from .os_ops.local_ops import LocalOperations -from .os_ops.remote_ops import RemoteOperations - try: from collections.abc import Iterable except ImportError: @@ -99,6 +96,9 @@ from .backup import NodeBackup +from .operations.local_ops import LocalOperations +from .operations.remote_ops import RemoteOperations + InternalError = pglib.InternalError ProgrammingError = pglib.ProgrammingError OperationalError = pglib.OperationalError @@ -201,7 +201,7 @@ def pid(self): if self.status(): pid_file = os.path.join(self.data_dir, PG_PID_FILE) - lines = self.os_ops.readlines(pid_file, num_lines=1) + lines = self.os_ops.readlines(pid_file) pid = int(lines[0]) if lines else None return pid @@ -433,7 +433,8 @@ def _collect_special_files(self): if not self.os_ops.path_exists(f): continue - lines = b''.join(self.os_ops.readlines(f, num_lines, encoding='utf-8')) + file_lines = self.os_ops.readlines(f, num_lines, binary=True, encoding=None) + lines = b''.join(file_lines) # fill list result.append((f, lines)) @@ -498,7 +499,7 @@ def default_conf(self, ] # write filtered lines - self.os_ops.write(hba_conf_file, lines, truncate=True) + self.os_ops.write(hba_conf, lines, truncate=True) # replication-related settings if allow_streaming: @@ -960,11 +961,9 @@ def psql(self, psql_params.append(dbname) # start psql process - process = self.os_ops.exec_command(psql_params) + status_code, out, err = self.os_ops.exec_command(psql_params, shell=False, verbose=True, input=input) - # wait until it finishes and get stdout and stderr - out, err = process.communicate(input=input) - return process.returncode, out, err + return status_code, out, err @method_decorator(positional_args_hack(['dbname', 'query'])) def safe_psql(self, query=None, expect_error=False, **kwargs): @@ -1348,7 +1347,7 @@ def pgbench(self, # should be the last one _params.append(dbname) - proc = self.os_ops.exec_command(_params, wait_exit=True) + proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, shell=False, proc=True) return proc diff --git a/testgres/operations/__init__.py b/testgres/operations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/testgres/os_ops/local_ops.py b/testgres/operations/local_ops.py similarity index 82% rename from testgres/os_ops/local_ops.py rename to testgres/operations/local_ops.py index d6977c9f..acb10df8 100644 --- a/testgres/os_ops/local_ops.py +++ b/testgres/operations/local_ops.py @@ -19,18 +19,25 @@ def __init__(self, username=None): self.username = username or self.get_user() # Command execution - def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): - if isinstance(cmd, list): - cmd = " ".join(cmd) + def exec_command(self, cmd, wait_exit=False, verbose=False, + expect_error=False, encoding=None, shell=True, text=False, + input=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None): log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") # Source global profile file + execute command try: + if proc: + return subprocess.Popen(cmd, + shell=shell, + stdin=input or subprocess.PIPE, + stdout=stdout, + stderr=stderr) process = subprocess.run( cmd, - shell=True, - text=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + input=input, + shell=shell, + text=text, + stdout=stdout, + stderr=stderr, timeout=CMD_TIMEOUT_SEC, ) exit_status = process.returncode @@ -39,11 +46,11 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): if expect_error: raise Exception(result, error) - if exit_status != 0 or "error" in error.lower(): + if exit_status != 0 or "error" in error.lower().decode(encoding or 'utf-8'): # Decode error for comparison log.error( - f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}" + f"Problem in executing command: `{cmd}`\nerror: {error.decode(encoding or 'utf-8')}\nexit_code: {exit_status}" + # Decode for logging ) - exit(1) if verbose: return exit_status, result, error @@ -152,9 +159,9 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal """ mode = "wb" if binary else "w" if not truncate: - mode = "a" + mode + mode = "ab" if binary else "a" if read_and_write: - mode = "r+" + mode + mode = "r+b" if binary else "r+" with open(filename, mode) as file: if isinstance(data, list): @@ -174,26 +181,26 @@ def touch(self, filename): with open(filename, "a"): os.utime(filename, None) - def read(self, filename): - with open(filename, "r") as file: + def read(self, filename, encoding=None): + with open(filename, "r", encoding=encoding) as file: return file.read() - def readlines(self, filename, num_lines=0, encoding=None): + def readlines(self, filename, num_lines=0, binary=False, encoding=None): """ Read lines from a local file. If num_lines is greater than 0, only the last num_lines lines will be read. """ assert num_lines >= 0 - + mode = 'rb' if binary else 'r' if num_lines == 0: - with open(filename, "r", encoding=encoding) as file: + with open(filename, mode, encoding=encoding) as file: # open in binary mode return file.readlines() else: bufsize = 8192 buffers = 1 - with open(filename, "r", encoding=encoding) as file: + with open(filename, mode, encoding=encoding) as file: # open in binary mode file.seek(0, os.SEEK_END) end_pos = file.tell() @@ -205,7 +212,7 @@ def readlines(self, filename, num_lines=0, encoding=None): cur_lines = len(lines) if cur_lines >= num_lines or pos == 0: - return lines[-num_lines:] + return lines[-num_lines:] # get last num_lines from lines buffers = int( buffers * max(2, int(num_lines / max(cur_lines, 1))) diff --git a/testgres/os_ops/os_ops.py b/testgres/operations/os_ops.py similarity index 100% rename from testgres/os_ops/os_ops.py rename to testgres/operations/os_ops.py diff --git a/testgres/os_ops/remote_ops.py b/testgres/operations/remote_ops.py similarity index 84% rename from testgres/os_ops/remote_ops.py rename to testgres/operations/remote_ops.py index e1460b75..dbe88dbe 100644 --- a/testgres/os_ops/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -1,3 +1,4 @@ +import io import os import tempfile from contextlib import contextmanager @@ -65,9 +66,9 @@ def connect(self): return ssh # Command execution - def exec_command( - self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding="utf-8" - ): + def exec_command(self, cmd, wait_exit=False, verbose=False, + expect_error=False, encoding=None, shell=True, text=False, + input=None, stdout=None, stderr=None, proc=None): if isinstance(cmd, list): cmd = " ".join(cmd) log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") @@ -75,20 +76,31 @@ def exec_command( try: cmd = f"source /etc/profile.d/custom.sh; {cmd}" with self.ssh_connect() as ssh: - stdin, stdout, stderr = ssh.exec_command(cmd) + if input: + # encode input and feed it to stdin + stdin, stdout, stderr = ssh.exec_command(cmd) + stdin.write(input) + stdin.flush() + else: + stdin, stdout, stderr = ssh.exec_command(cmd) exit_status = 0 if wait_exit: exit_status = stdout.channel.recv_exit_status() - result = stdout.read().decode(encoding) - error = stderr.read().decode(encoding) + if encoding: + result = stdout.read().decode(encoding) + error = stderr.read().decode(encoding) + else: + # Save as binary string + result = io.BytesIO(stdout.read()).getvalue() + error = io.BytesIO(stderr.read()).getvalue() + error_str = stderr.read() if expect_error: raise Exception(result, error) - if exit_status != 0 or "error" in error.lower(): + if exit_status != 0 or 'error' in error_str: log.error( f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}" ) - exit(1) if verbose: return exit_status, result, error @@ -203,9 +215,9 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal """ mode = "wb" if binary else "w" if not truncate: - mode = "a" + mode + mode = "ab" if binary else "a" if read_and_write: - mode = "r+" + mode + mode = "r+b" if binary else "r+" with tempfile.NamedTemporaryFile(mode=mode) as tmp_file: if isinstance(data, list): @@ -229,17 +241,28 @@ def touch(self, filename): """ self.exec_command(f"touch {filename}") - def read(self, filename, encoding="utf-8"): + def read(self, filename, binary=False, encoding=None): cmd = f"cat {filename}" - return self.exec_command(cmd, encoding=encoding) + result = self.exec_command(cmd, encoding=encoding) + + if not binary and result: + result = result.decode(encoding or 'utf-8') + + return result - def readlines(self, filename, num_lines=0, encoding=None): - encoding = encoding or "utf-8" + def readlines(self, filename, num_lines=0, binary=False, encoding=None): if num_lines > 0: cmd = f"tail -n {num_lines} {filename}" - lines = self.exec_command(cmd, encoding) else: - lines = self.read(filename, encoding=encoding).splitlines() + cmd = f"cat {filename}" + + result = self.exec_command(cmd, encoding=encoding) + + if not binary and result: + lines = result.decode(encoding or 'utf-8').splitlines() + else: + lines = result.splitlines() + return lines def isfile(self, remote_file): diff --git a/testgres/utils.py b/testgres/utils.py index 72fd1b9d..b72c7da0 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -13,8 +13,6 @@ from contextlib import contextmanager from packaging.version import Version -from .os_ops.remote_ops import RemoteOperations - try: from shutil import which as find_executable except ImportError: @@ -22,8 +20,10 @@ from six import iteritems from fabric import Connection -from .os_ops.local_ops import LocalOperations -from .os_ops.os_ops import OsOperations + +from .operations.remote_ops import RemoteOperations +from .operations.local_ops import LocalOperations +from .operations.os_ops import OsOperations from .config import testgres_config from .exceptions import ExecUtilException diff --git a/tests/test_remote.py b/tests/test_remote.py new file mode 100755 index 00000000..47804dfb --- /dev/null +++ b/tests/test_remote.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# coding: utf-8 + +import os +import time + +import pytest +from docker import DockerClient +from paramiko import RSAKey + +from testgres import RemoteOperations + + +class TestRemoteOperations: + @pytest.fixture(scope="class", autouse=True) + def setup_class(self): + # Create shared volume + self.volume_path = os.path.abspath("./tmp/ssh_key") + os.makedirs(self.volume_path, exist_ok=True) + + # Generate SSH keys + private_key_path = os.path.join(self.volume_path, "id_rsa") + public_key_path = os.path.join(self.volume_path, "id_rsa.pub") + + private_key = RSAKey.generate(4096) + private_key.write_private_key_file(private_key_path) + + with open(public_key_path, "w") as f: + f.write(f"{private_key.get_name()} {private_key.get_base64()}") + + self.docker = DockerClient.from_env() + self.container = self.docker.containers.run( + "rastasheep/ubuntu-sshd:18.04", + detach=True, + tty=True, + ports={22: 8022}, + ) + + # Wait for the container to start sshd + time.sleep(10) + + yield + + # Stop and remove the container after tests + self.container.stop() + self.container.remove() + + @pytest.fixture(scope="function", autouse=True) + def setup(self): + self.operations = RemoteOperations( + host="localhost", + username="root", + ssh_key=os.path.join(self.volume_path, "id_rsa") + ) + + yield + + self.operations.__del__() + + def test_exec_command(self): + cmd = "python3 --version" + response = self.operations.exec_command(cmd) + + assert "Python 3.9" in response + + def test_is_executable(self): + cmd = "python3" + response = self.operations.is_executable(cmd) + + assert response is True + + def test_makedirs_and_rmdirs(self): + path = "/test_dir" + + # Test makedirs + self.operations.makedirs(path) + assert self.operations.path_exists(path) + + # Test rmdirs + self.operations.rmdirs(path) + assert not self.operations.path_exists(path) diff --git a/tests/test_simple.py b/tests/test_simple.py index 94420b04..e8b8abee 100755 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -151,8 +151,6 @@ def test_init_unique_system_id(self): self.assertGreater(id2, id1) def test_node_exit(self): - base_dir = None - with self.assertRaises(QueryException): with get_new_node().init() as node: base_dir = node.base_dir @@ -281,7 +279,7 @@ def test_psql(self): node.safe_psql('copy horns from stdin (format csv)', input=b"1\n2\n3\n\\.\n") _sum = node.safe_psql('select sum(w) from horns') - self.assertEqual(_sum, b'6\n') + self.assertEqual(b'6\n', _sum) # check psql's default args, fails with self.assertRaises(QueryException): @@ -614,7 +612,7 @@ def test_users(self): with get_new_node().init().start() as node: node.psql('create role test_user login') value = node.safe_psql('select 1', username='test_user') - self.assertEqual(value, b'1\n') + self.assertEqual(b'1\n', value) def test_poll_query_until(self): with get_new_node() as node: From e098b9796de56f62a5347cdd0fa6576a91ac7b40 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Tue, 13 Jun 2023 22:04:51 +0200 Subject: [PATCH 10/23] PBCKP-152 fix failed tests --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2b188565..b5162dce 100755 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ setup( version='1.9.0', name='testgres', - packages=['testgres'], + packages=['testgres', 'testgres.operations'], description='Testing utility for PostgreSQL and its extensions', url='https://github.com/postgrespro/testgres', long_description=readme, From 1c405ef7dde782d6dd6507708ceb5c35a2e33440 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Wed, 14 Jun 2023 11:15:33 +0200 Subject: [PATCH 11/23] PBCKP-152 add tests for remote_ops.py --- testgres/node.py | 6 +- testgres/operations/remote_ops.py | 243 +++++++++++++++++++----------- tests/test_remote.py | 210 +++++++++++++++++++------- 3 files changed, 313 insertions(+), 146 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 9aa47d84..9d183f96 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -152,10 +152,10 @@ def __init__(self, name=None, port=None, base_dir=None, self.host = host self.hostname = hostname self.ssh_key = ssh_key - if hostname == 'localhost' or host == '127.0.0.1': - self.os_ops = LocalOperations(username=username) - else: + if hostname != 'localhost' or host != '127.0.0.1': self.os_ops = RemoteOperations(host, hostname, ssh_key) + else: + self.os_ops = LocalOperations(username=username) testgres_config.os_ops = self.os_ops # defaults for __exit__() diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index dbe88dbe..e2248015 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -1,106 +1,98 @@ -import io import os import tempfile -from contextlib import contextmanager +from typing import Optional -from testgres.logger import log +import paramiko +from paramiko import SSHClient +from logger import log from .os_ops import OsOperations from .os_ops import pglib -import paramiko +error_markers = [b'error', b'Permission denied'] class RemoteOperations(OsOperations): - """ - This class specifically supports work with Linux systems. It utilizes the SSH - for making connections and performing various file and directory operations, command executions, - environment setup and management, process control, and database connections. - It uses the Paramiko library for SSH connections and operations. - - Some methods are designed to work with specific Linux shell commands, and thus may not work as expected - on other non-Linux systems. - - Attributes: - - hostname (str): The remote system's hostname. Default 'localhost'. - - host (str): The remote system's IP address. Default '127.0.0.1'. - - ssh_key (str): Path to the SSH private key for authentication. - - username (str): Username for the remote system. - - ssh (paramiko.SSHClient): SSH connection to the remote system. - """ - - def __init__( - self, hostname="localhost", host="127.0.0.1", ssh_key=None, username=None - ): + def __init__(self, hostname="localhost", host="127.0.0.1", ssh_key=None, username=None): super().__init__(username) - self.hostname = hostname self.host = host self.ssh_key = ssh_key self.remote = True - self.ssh = self.connect() + self.ssh = self.ssh_connect() self.username = username or self.get_user() def __del__(self): if self.ssh: self.ssh.close() - @contextmanager - def ssh_connect(self): + def ssh_connect(self) -> Optional[SSHClient]: if not self.remote: - yield None + return None else: + key = self._read_ssh_key() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.host, username=self.username, pkey=key) + return ssh + + def _read_ssh_key(self): + try: with open(self.ssh_key, "r") as f: key_data = f.read() if "BEGIN OPENSSH PRIVATE KEY" in key_data: key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) else: key = paramiko.RSAKey.from_private_key_file(self.ssh_key) + return key + except FileNotFoundError: + log.error(f"No such file or directory: '{self.ssh_key}'") + except Exception as e: + log.error(f"An error occurred while reading the ssh key: {e}") - with paramiko.SSHClient() as ssh: - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(self.host, username=self.username, pkey=key) - yield ssh - - def connect(self): - with self.ssh_connect() as ssh: - return ssh - - # Command execution - def exec_command(self, cmd, wait_exit=False, verbose=False, - expect_error=False, encoding=None, shell=True, text=False, - input=None, stdout=None, stderr=None, proc=None): + def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, + encoding=None, shell=True, text=False, input=None, stdout=None, + stderr=None, proc=None): + """ + Execute a command in the SSH session. + Args: + - cmd (str): The command to be executed. + """ if isinstance(cmd, list): cmd = " ".join(cmd) - log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") - # Source global profile file + execute command try: - cmd = f"source /etc/profile.d/custom.sh; {cmd}" - with self.ssh_connect() as ssh: - if input: - # encode input and feed it to stdin - stdin, stdout, stderr = ssh.exec_command(cmd) - stdin.write(input) - stdin.flush() - else: - stdin, stdout, stderr = ssh.exec_command(cmd) - exit_status = 0 - if wait_exit: - exit_status = stdout.channel.recv_exit_status() - if encoding: - result = stdout.read().decode(encoding) - error = stderr.read().decode(encoding) - else: - # Save as binary string - result = io.BytesIO(stdout.read()).getvalue() - error = io.BytesIO(stderr.read()).getvalue() - error_str = stderr.read() + if input: + stdin, stdout, stderr = self.ssh.exec_command(cmd) + stdin.write(input.encode("utf-8")) + stdin.flush() + else: + stdin, stdout, stderr = self.ssh.exec_command(cmd) + exit_status = 0 + if wait_exit: + exit_status = stdout.channel.recv_exit_status() + + if encoding: + result = stdout.read().decode(encoding) + error = stderr.read().decode(encoding) + else: + result = stdout.read() + error = stderr.read() if expect_error: raise Exception(result, error) - if exit_status != 0 or 'error' in error_str: + + if encoding: + error_found = exit_status != 0 or any( + marker.decode(encoding) in error for marker in error_markers) + else: + error_found = exit_status != 0 or any( + marker in error for marker in error_markers) + + if error_found: log.error( f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}" ) + if exit_status == 0: + exit_status = 1 if verbose: return exit_status, result, error @@ -112,7 +104,12 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, return None # Environment setup - def environ(self, var_name): + def environ(self, var_name: str) -> str: + """ + Get the value of an environment variable. + Args: + - var_name (str): The name of the environment variable. + """ cmd = f"echo ${var_name}" return self.exec_command(cmd).strip() @@ -131,7 +128,8 @@ def find_executable(self, executable): def is_executable(self, file): # Check if the file is executable - return self.exec_command(f"test -x {file} && echo OK") == "OK\n" + is_exec = self.exec_command(f"test -x {file} && echo OK") + return is_exec == b"OK\n" def add_to_path(self, new_path): pathsep = self.pathsep @@ -144,8 +142,13 @@ def add_to_path(self, new_path): os.environ["PATH"] = f"{new_path}{pathsep}{path}" return pathsep - def set_env(self, var_name, var_val): - # Check if the directory is already in PATH + def set_env(self, var_name: str, var_val: str) -> None: + """ + Set the value of an environment variable. + Args: + - var_name (str): The name of the environment variable. + - var_val (str): The value to be set for the environment variable. + """ return self.exec_command(f"export {var_name}={var_val}") # Get environment variables @@ -158,22 +161,47 @@ def get_name(self): # Work with dirs def makedirs(self, path, remove_existing=False): + """ + Create a directory in the remote server. + Args: + - path (str): The path to the directory to be created. + - remove_existing (bool): If True, the existing directory at the path will be removed. + """ if remove_existing: cmd = f"rm -rf {path} && mkdir -p {path}" else: cmd = f"mkdir -p {path}" - return self.exec_command(cmd) + exit_status, result, error = self.exec_command(cmd, verbose=True) + if exit_status != 0: + raise Exception(f"Couldn't create dir {path} because of error {error}") + return result - def rmdirs(self, path, ignore_errors=True): + def rmdirs(self, path, verbose=False, ignore_errors=True): + """ + Remove a directory in the remote server. + Args: + - path (str): The path to the directory to be removed. + - verbose (bool): If True, return exit status, result, and error. + - ignore_errors (bool): If True, do not raise error if directory does not exist. + """ cmd = f"rm -rf {path}" - return self.exec_command(cmd) + exit_status, result, error = self.exec_command(cmd, verbose=True) + if verbose: + return exit_status, result, error + else: + return result def listdir(self, path): + """ + List all files and directories in a directory. + Args: + path (str): The path to the directory. + """ result = self.exec_command(f"ls {path}") return result.splitlines() def path_exists(self, path): - result = self.exec_command(f"test -e {path}; echo $?") + result = self.exec_command(f"test -e {path}; echo $?", encoding='utf-8') return int(result.strip()) == 0 @property @@ -188,7 +216,12 @@ def pathsep(self): return pathsep def mkdtemp(self, prefix=None): - temp_dir = self.exec_command(f"mkdtemp -d {prefix}") + """ + Creates a temporary directory in the remote server. + Args: + prefix (str): The prefix of the temporary directory name. + """ + temp_dir = self.exec_command(f"mkdtemp -d {prefix}", encoding='utf-8') return temp_dir.strip() def mkstemp(self, prefix=None): @@ -200,18 +233,19 @@ def copytree(self, src, dst): return self.exec_command(f"cp -r {src} {dst}") # Work with files - def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding='utf-8'): """ Write data to a file on a remote host + Args: - filename: The file path where the data will be written. - data: The data to be written to the file. - truncate: If True, the file will be truncated before writing ('w' or 'wb' option); - if False (default), data will be appended ('a' or 'ab' option). - binary: If True, the data will be written in binary mode ('wb' or 'ab' option); - if False (default), the data will be written in text mode ('w' or 'a' option). - read_and_write: If True, the file will be opened with read and write permissions ('r+' option); - if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) + - filename (str): The file path where the data will be written. + - data (bytes or str): The data to be written to the file. + - truncate (bool): If True, the file will be truncated before writing ('w' or 'wb' option); + if False (default), data will be appended ('a' or 'ab' option). + - binary (bool): If True, the data will be written in binary mode ('wb' or 'ab' option); + if False (default), the data will be written in text mode ('w' or 'a' option). + - read_and_write (bool): If True, the file will be opened with read and write permissions ('r+' option); + if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option). """ mode = "wb" if binary else "w" if not truncate: @@ -220,15 +254,18 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal mode = "r+b" if binary else "r+" with tempfile.NamedTemporaryFile(mode=mode) as tmp_file: - if isinstance(data, list): - tmp_file.writelines(data) - else: - tmp_file.write(data) + if isinstance(data, bytes) and not binary: + data = data.decode(encoding) + elif isinstance(data, str) and binary: + data = data.encode(encoding) + + tmp_file.write(data) tmp_file.flush() - sftp = self.ssh.open_sftp() - sftp.put(tmp_file.name, filename) - sftp.close() + with self.ssh_connect() as ssh: + sftp = ssh.open_sftp() + sftp.put(tmp_file.name, filename) + sftp.close() def touch(self, filename): """ @@ -281,8 +318,29 @@ def get_pid(self): return self.exec_command("echo $$") # Database control - def db_connect(self, dbname, user, password=None, host="localhost", port=5432): - local_port = self.ssh.forward_remote_port(host, port) + def db_connect(self, dbname, user, password=None, host="127.0.0.1", hostname="localhost", port=5432): + """ + Connects to a PostgreSQL database on the remote system. + Args: + - dbname (str): The name of the database to connect to. + - user (str): The username for the database connection. + - password (str, optional): The password for the database connection. Defaults to None. + - host (str, optional): The IP address of the remote system. Defaults to "127.0.0.1". + - hostname (str, optional): The hostname of the remote system. Defaults to "localhost". + - port (int, optional): The port number of the PostgreSQL service. Defaults to 5432. + + This function establishes a connection to a PostgreSQL database on the remote system using the specified + parameters. It returns a connection object that can be used to interact with the database. + """ + transport = self.ssh.get_transport() + local_port = 9090 # or any other available port + + transport.open_channel( + 'direct-tcpip', + (hostname, port), + (host, local_port) + ) + conn = pglib.connect( host=host, port=local_port, @@ -291,3 +349,4 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432): password=password, ) return conn + diff --git a/tests/test_remote.py b/tests/test_remote.py index 47804dfb..0155956c 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -1,76 +1,66 @@ -#!/usr/bin/env python -# coding: utf-8 - -import os -import time - import pytest -from docker import DockerClient -from paramiko import RSAKey -from testgres import RemoteOperations +from testgres.operations.remote_ops import RemoteOperations class TestRemoteOperations: - @pytest.fixture(scope="class", autouse=True) - def setup_class(self): - # Create shared volume - self.volume_path = os.path.abspath("./tmp/ssh_key") - os.makedirs(self.volume_path, exist_ok=True) - - # Generate SSH keys - private_key_path = os.path.join(self.volume_path, "id_rsa") - public_key_path = os.path.join(self.volume_path, "id_rsa.pub") - - private_key = RSAKey.generate(4096) - private_key.write_private_key_file(private_key_path) - - with open(public_key_path, "w") as f: - f.write(f"{private_key.get_name()} {private_key.get_base64()}") - - self.docker = DockerClient.from_env() - self.container = self.docker.containers.run( - "rastasheep/ubuntu-sshd:18.04", - detach=True, - tty=True, - ports={22: 8022}, - ) - - # Wait for the container to start sshd - time.sleep(10) - - yield - - # Stop and remove the container after tests - self.container.stop() - self.container.remove() @pytest.fixture(scope="function", autouse=True) def setup(self): self.operations = RemoteOperations( - host="localhost", - username="root", - ssh_key=os.path.join(self.volume_path, "id_rsa") + host="172.18.0.3", + username="dev", + ssh_key='/home/vika/Desktop/work/probackup/dev-ee-probackup/container_files/postgres/ssh/id_ed25519' ) yield self.operations.__del__() - def test_exec_command(self): + def test_exec_command_success(self): + """ + Test exec_command for successful command execution. + """ cmd = "python3 --version" - response = self.operations.exec_command(cmd) + response = self.operations.exec_command(cmd, wait_exit=True) - assert "Python 3.9" in response + assert b'Python 3.' in response - def test_is_executable(self): - cmd = "python3" + def test_exec_command_failure(self): + """ + Test exec_command for command execution failure. + """ + cmd = "nonexistent_command" + exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True) + + assert error == b'bash: line 1: nonexistent_command: command not found\n' + + def test_is_executable_true(self): + """ + Test is_executable for an existing executable. + """ + cmd = "postgres" response = self.operations.is_executable(cmd) assert response is True - def test_makedirs_and_rmdirs(self): - path = "/test_dir" + def test_is_executable_false(self): + """ + Test is_executable for a non-executable. + """ + cmd = "python" + response = self.operations.is_executable(cmd) + + assert response is False + + def test_makedirs_and_rmdirs_success(self): + """ + Test makedirs and rmdirs for successful directory creation and removal. + """ + cmd = "pwd" + pwd = self.operations.exec_command(cmd, wait_exit=True, encoding='utf-8').strip() + + path = f"{pwd}/test_dir" # Test makedirs self.operations.makedirs(path) @@ -79,3 +69,121 @@ def test_makedirs_and_rmdirs(self): # Test rmdirs self.operations.rmdirs(path) assert not self.operations.path_exists(path) + + def test_makedirs_and_rmdirs_failure(self): + """ + Test makedirs and rmdirs for directory creation and removal failure. + """ + # Try to create a directory in a read-only location + path = "/root/test_dir" + + # Test makedirs + with pytest.raises(Exception): + self.operations.makedirs(path) + + # Test rmdirs + exit_status, result, error = self.operations.rmdirs(path, verbose=True) + assert error == b"rm: cannot remove '/root/test_dir': Permission denied\n" + + def test_listdir(self): + """ + Test listdir for listing directory contents. + """ + path = "/etc" + files = self.operations.listdir(path) + + assert isinstance(files, list) + + def test_path_exists_true(self): + """ + Test path_exists for an existing path. + """ + path = "/etc" + response = self.operations.path_exists(path) + + assert response is True + + def test_path_exists_false(self): + """ + Test path_exists for a non-existing path. + """ + path = "/nonexistent_path" + response = self.operations.path_exists(path) + + assert response is False + + def test_write_text_file(self): + """ + Test write for writing data to a text file. + """ + filename = "/tmp/test_file.txt" + data = "Hello, world!" + + self.operations.write(filename, data) + + response = self.operations.read(filename) + + assert response == data + + def test_write_binary_file(self): + """ + Test write for writing data to a binary file. + """ + filename = "/tmp/test_file.bin" + data = b"\x00\x01\x02\x03" + + self.operations.write(filename, data, binary=True) + + response = self.operations.read(filename, binary=True) + + assert response == data + + def test_read_text_file(self): + """ + Test read for reading data from a text file. + """ + filename = "/etc/hosts" + + response = self.operations.read(filename) + + assert isinstance(response, str) + + def test_read_binary_file(self): + """ + Test read for reading data from a binary file. + """ + filename = "/usr/bin/python3" + + response = self.operations.read(filename, binary=True) + + assert isinstance(response, bytes) + + def test_touch(self): + """ + Test touch for creating a new file or updating access and modification times of an existing file. + """ + filename = "/tmp/test_file.txt" + + self.operations.touch(filename) + + assert self.operations.isfile(filename) + + def test_isfile_true(self): + """ + Test isfile for an existing file. + """ + filename = "/etc/hosts" + + response = self.operations.isfile(filename) + + assert response is True + + def test_isfile_false(self): + """ + Test isfile for a non-existing file. + """ + filename = "/nonexistent_file.txt" + + response = self.operations.isfile(filename) + + assert response is False From 8c373e63b131aa9ddc20b3dd6aa1c350a4c9e347 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Wed, 14 Jun 2023 23:32:31 +0200 Subject: [PATCH 12/23] PBCKP-152 add testgres tests for remote node --- tests/test_simple_remote.py | 1006 +++++++++++++++++++++++++++++++++++ 1 file changed, 1006 insertions(+) create mode 100755 tests/test_simple_remote.py diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py new file mode 100755 index 00000000..179f3ffb --- /dev/null +++ b/tests/test_simple_remote.py @@ -0,0 +1,1006 @@ +#!/usr/bin/env python +# coding: utf-8 + +import os +import re +import subprocess +import tempfile + + +import testgres +import time +import six +import unittest +import psutil + +import logging.config + +from contextlib import contextmanager +from shutil import rmtree + +from testgres.exceptions import \ + InitNodeException, \ + StartNodeException, \ + ExecUtilException, \ + BackupException, \ + QueryException, \ + TimeoutException, \ + TestgresException + +from testgres.config import \ + TestgresConfig, \ + configure_testgres, \ + scoped_config, \ + pop_config + +from testgres import \ + NodeStatus, \ + ProcessType, \ + IsolationLevel, \ + get_new_node + +from testgres import \ + get_bin_path, \ + get_pg_config, \ + get_pg_version + +from testgres import \ + First, \ + Any + +# NOTE: those are ugly imports +from testgres import bound_ports +from testgres.utils import PgVer +from testgres.node import ProcessProxy + + +def pg_version_ge(version): + cur_ver = PgVer(get_pg_version()) + min_ver = PgVer(version) + return cur_ver >= min_ver + + +def util_exists(util): + def good_properties(f): + return (os.path.exists(f) and # noqa: W504 + os.path.isfile(f) and # noqa: W504 + os.access(f, os.X_OK)) # yapf: disable + + # try to resolve it + if good_properties(get_bin_path(util)): + return True + + # check if util is in PATH + for path in os.environ["PATH"].split(os.pathsep): + if good_properties(os.path.join(path, util)): + return True + + +@contextmanager +def removing(f): + try: + yield f + finally: + if os.path.isfile(f): + os.remove(f) + elif os.path.isdir(f): + rmtree(f, ignore_errors=True) + + +def get_remote_node(): + return get_new_node(host='172.18.0.3', username='dev', ssh_key='/home/vika/Desktop/work/probackup/dev-ee-probackup/container_files/postgres/ssh/id_ed25519') + + +class TestgresRemoteTests(unittest.TestCase): + + def test_node_repr(self): + with get_remote_node() as node: + pattern = r"PostgresNode\(name='.+', port=.+, base_dir='.+'\)" + self.assertIsNotNone(re.match(pattern, str(node))) + + def test_custom_init(self): + with get_remote_node() as node: + # enable page checksums + node.init(initdb_params=['-k']).start() + + with get_remote_node() as node: + node.init( + allow_streaming=True, + initdb_params=['--auth-local=reject', '--auth-host=reject']) + + hba_file = os.path.join(node.data_dir, 'pg_hba.conf') + with open(hba_file, 'r') as conf: + lines = conf.readlines() + + # check number of lines + self.assertGreaterEqual(len(lines), 6) + + # there should be no trust entries at all + self.assertFalse(any('trust' in s for s in lines)) + + def test_double_init(self): + with get_remote_node().init() as node: + # can't initialize node more than once + with self.assertRaises(InitNodeException): + node.init() + + def test_init_after_cleanup(self): + with get_remote_node() as node: + node.init().start().execute('select 1') + node.cleanup() + node.init().start().execute('select 1') + + @unittest.skipUnless(util_exists('pg_resetwal'), 'might be missing') + @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') + def test_init_unique_system_id(self): + # this function exists in PostgreSQL 9.6+ + query = 'select system_identifier from pg_control_system()' + + with scoped_config(cache_initdb=False): + with get_remote_node().init().start() as node0: + id0 = node0.execute(query)[0] + + with scoped_config(cache_initdb=True, + cached_initdb_unique=True) as config: + + self.assertTrue(config.cache_initdb) + self.assertTrue(config.cached_initdb_unique) + + # spawn two nodes; ids must be different + with get_remote_node().init().start() as node1, \ + get_remote_node().init().start() as node2: + + id1 = node1.execute(query)[0] + id2 = node2.execute(query)[0] + + # ids must increase + self.assertGreater(id1, id0) + self.assertGreater(id2, id1) + + def test_node_exit(self): + with self.assertRaises(QueryException): + with get_remote_node().init() as node: + base_dir = node.base_dir + node.safe_psql('select 1') + + # we should save the DB for "debugging" + self.assertTrue(os.path.exists(base_dir)) + rmtree(base_dir, ignore_errors=True) + + with get_remote_node().init() as node: + base_dir = node.base_dir + + # should have been removed by default + self.assertFalse(os.path.exists(base_dir)) + + def test_double_start(self): + with get_remote_node().init().start() as node: + # can't start node more than once + node.start() + self.assertTrue(node.is_started) + + def test_uninitialized_start(self): + with get_remote_node() as node: + # node is not initialized yet + with self.assertRaises(StartNodeException): + node.start() + + def test_restart(self): + with get_remote_node() as node: + node.init().start() + + # restart, ok + res = node.execute('select 1') + self.assertEqual(res, [(1, )]) + node.restart() + res = node.execute('select 2') + self.assertEqual(res, [(2, )]) + + # restart, fail + with self.assertRaises(StartNodeException): + node.append_conf('pg_hba.conf', 'DUMMY') + node.restart() + + def test_reload(self): + with get_remote_node() as node: + node.init().start() + + # change client_min_messages and save old value + cmm_old = node.execute('show client_min_messages') + node.append_conf(client_min_messages='DEBUG1') + + # reload config + node.reload() + + # check new value + cmm_new = node.execute('show client_min_messages') + self.assertEqual('debug1', cmm_new[0][0].lower()) + self.assertNotEqual(cmm_old, cmm_new) + + def test_pg_ctl(self): + with get_remote_node() as node: + node.init().start() + + status = node.pg_ctl(['status']) + self.assertTrue('PID' in status) + + def test_status(self): + self.assertTrue(NodeStatus.Running) + self.assertFalse(NodeStatus.Stopped) + self.assertFalse(NodeStatus.Uninitialized) + + # check statuses after each operation + with get_remote_node() as node: + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Uninitialized) + + node.init() + + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Stopped) + + node.start() + + self.assertNotEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Running) + + node.stop() + + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Stopped) + + node.cleanup() + + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Uninitialized) + + def test_psql(self): + with get_remote_node().init().start() as node: + + # check returned values (1 arg) + res = node.psql('select 1') + self.assertEqual(res, (0, b'1\n', b'')) + + # check returned values (2 args) + res = node.psql('postgres', 'select 2') + self.assertEqual(res, (0, b'2\n', b'')) + + # check returned values (named) + res = node.psql(query='select 3', dbname='postgres') + self.assertEqual(res, (0, b'3\n', b'')) + + # check returned values (1 arg) + res = node.safe_psql('select 4') + self.assertEqual(res, b'4\n') + + # check returned values (2 args) + res = node.safe_psql('postgres', 'select 5') + self.assertEqual(res, b'5\n') + + # check returned values (named) + res = node.safe_psql(query='select 6', dbname='postgres') + self.assertEqual(res, b'6\n') + + # check feeding input + node.safe_psql('create table horns (w int)') + node.safe_psql('copy horns from stdin (format csv)', + input=b"1\n2\n3\n\\.\n") + _sum = node.safe_psql('select sum(w) from horns') + self.assertEqual(b'6\n', _sum) + + # check psql's default args, fails + with self.assertRaises(QueryException): + node.psql() + + node.stop() + + # check psql on stopped node, fails + with self.assertRaises(QueryException): + node.safe_psql('select 1') + + def test_transactions(self): + with get_remote_node().init().start() as node: + + with node.connect() as con: + con.begin() + con.execute('create table test(val int)') + con.execute('insert into test values (1)') + con.commit() + + con.begin() + con.execute('insert into test values (2)') + res = con.execute('select * from test order by val asc') + self.assertListEqual(res, [(1, ), (2, )]) + con.rollback() + + con.begin() + res = con.execute('select * from test') + self.assertListEqual(res, [(1, )]) + con.rollback() + + con.begin() + con.execute('drop table test') + con.commit() + + def test_control_data(self): + with get_remote_node() as node: + + # node is not initialized yet + with self.assertRaises(ExecUtilException): + node.get_control_data() + + node.init() + data = node.get_control_data() + + # check returned dict + self.assertIsNotNone(data) + self.assertTrue(any('pg_control' in s for s in data.keys())) + + def test_backup_simple(self): + with get_remote_node() as master: + + # enable streaming for backups + master.init(allow_streaming=True) + + # node must be running + with self.assertRaises(BackupException): + master.backup() + + # it's time to start node + master.start() + + # fill node with some data + master.psql('create table test as select generate_series(1, 4) i') + + with master.backup(xlog_method='stream') as backup: + with backup.spawn_primary().start() as slave: + res = slave.execute('select * from test order by i asc') + self.assertListEqual(res, [(1, ), (2, ), (3, ), (4, )]) + + def test_backup_multiple(self): + with get_remote_node() as node: + node.init(allow_streaming=True).start() + + with node.backup(xlog_method='fetch') as backup1, \ + node.backup(xlog_method='fetch') as backup2: + + self.assertNotEqual(backup1.base_dir, backup2.base_dir) + + with node.backup(xlog_method='fetch') as backup: + with backup.spawn_primary('node1', destroy=False) as node1, \ + backup.spawn_primary('node2', destroy=False) as node2: + + self.assertNotEqual(node1.base_dir, node2.base_dir) + + def test_backup_exhaust(self): + with get_remote_node() as node: + node.init(allow_streaming=True).start() + + with node.backup(xlog_method='fetch') as backup: + + # exhaust backup by creating new node + with backup.spawn_primary(): + pass + + # now let's try to create one more node + with self.assertRaises(BackupException): + backup.spawn_primary() + + def test_backup_wrong_xlog_method(self): + with get_remote_node() as node: + node.init(allow_streaming=True).start() + + with self.assertRaises(BackupException, + msg='Invalid xlog_method "wrong"'): + node.backup(xlog_method='wrong') + + def test_pg_ctl_wait_option(self): + with get_remote_node() as node: + node.init().start(wait=False) + while True: + try: + node.stop(wait=False) + break + except ExecUtilException: + # it's ok to get this exception here since node + # could be not started yet + pass + + def test_replicate(self): + with get_remote_node() as node: + node.init(allow_streaming=True).start() + + with node.replicate().start() as replica: + res = replica.execute('select 1') + self.assertListEqual(res, [(1, )]) + + node.execute('create table test (val int)', commit=True) + + replica.catchup() + + res = node.execute('select * from test') + self.assertListEqual(res, []) + + @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') + def test_synchronous_replication(self): + with get_remote_node() as master: + old_version = not pg_version_ge('9.6') + + master.init(allow_streaming=True).start() + + if not old_version: + master.append_conf('synchronous_commit = remote_apply') + + # create standby + with master.replicate() as standby1, master.replicate() as standby2: + standby1.start() + standby2.start() + + # check formatting + self.assertEqual( + '1 ("{}", "{}")'.format(standby1.name, standby2.name), + str(First(1, (standby1, standby2)))) # yapf: disable + self.assertEqual( + 'ANY 1 ("{}", "{}")'.format(standby1.name, standby2.name), + str(Any(1, (standby1, standby2)))) # yapf: disable + + # set synchronous_standby_names + master.set_synchronous_standbys(First(2, [standby1, standby2])) + master.restart() + + # the following part of the test is only applicable to newer + # versions of PostgresQL + if not old_version: + master.safe_psql('create table abc(a int)') + + # Create a large transaction that will take some time to apply + # on standby to check that it applies synchronously + # (If set synchronous_commit to 'on' or other lower level then + # standby most likely won't catchup so fast and test will fail) + master.safe_psql( + 'insert into abc select generate_series(1, 1000000)') + res = standby1.safe_psql('select count(*) from abc') + self.assertEqual(res, b'1000000\n') + + @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') + def test_logical_replication(self): + with get_remote_node() as node1, get_remote_node() as node2: + node1.init(allow_logical=True) + node1.start() + node2.init().start() + + create_table = 'create table test (a int, b int)' + node1.safe_psql(create_table) + node2.safe_psql(create_table) + + # create publication / create subscription + pub = node1.publish('mypub') + sub = node2.subscribe(pub, 'mysub') + + node1.safe_psql('insert into test values (1, 1), (2, 2)') + + # wait until changes apply on subscriber and check them + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [(1, 1), (2, 2)]) + + # disable and put some new data + sub.disable() + node1.safe_psql('insert into test values (3, 3)') + + # enable and ensure that data successfully transfered + sub.enable() + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [(1, 1), (2, 2), (3, 3)]) + + # Add new tables. Since we added "all tables" to publication + # (default behaviour of publish() method) we don't need + # to explicitely perform pub.add_tables() + create_table = 'create table test2 (c char)' + node1.safe_psql(create_table) + node2.safe_psql(create_table) + sub.refresh() + + # put new data + node1.safe_psql('insert into test2 values (\'a\'), (\'b\')') + sub.catchup() + res = node2.execute('select * from test2') + self.assertListEqual(res, [('a', ), ('b', )]) + + # drop subscription + sub.drop() + pub.drop() + + # create new publication and subscription for specific table + # (ommitting copying data as it's already done) + pub = node1.publish('newpub', tables=['test']) + sub = node2.subscribe(pub, 'newsub', copy_data=False) + + node1.safe_psql('insert into test values (4, 4)') + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [(1, 1), (2, 2), (3, 3), (4, 4)]) + + # explicitely add table + with self.assertRaises(ValueError): + pub.add_tables([]) # fail + pub.add_tables(['test2']) + node1.safe_psql('insert into test2 values (\'c\')') + sub.catchup() + res = node2.execute('select * from test2') + self.assertListEqual(res, [('a', ), ('b', )]) + + @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') + def test_logical_catchup(self): + """ Runs catchup for 100 times to be sure that it is consistent """ + with get_remote_node() as node1, get_remote_node() as node2: + node1.init(allow_logical=True) + node1.start() + node2.init().start() + + create_table = 'create table test (key int primary key, val int); ' + node1.safe_psql(create_table) + node1.safe_psql('alter table test replica identity default') + node2.safe_psql(create_table) + + # create publication / create subscription + sub = node2.subscribe(node1.publish('mypub'), 'mysub') + + for i in range(0, 100): + node1.execute('insert into test values ({0}, {0})'.format(i)) + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [( + i, + i, + )]) + node1.execute('delete from test') + + @unittest.skipIf(pg_version_ge('10'), 'requires <10') + def test_logical_replication_fail(self): + with get_remote_node() as node: + with self.assertRaises(InitNodeException): + node.init(allow_logical=True) + + def test_replication_slots(self): + with get_remote_node() as node: + node.init(allow_streaming=True).start() + + with node.replicate(slot='slot1').start() as replica: + replica.execute('select 1') + + # cannot create new slot with the same name + with self.assertRaises(TestgresException): + node.replicate(slot='slot1') + + def test_incorrect_catchup(self): + with get_remote_node() as node: + node.init(allow_streaming=True).start() + + # node has no master, can't catch up + with self.assertRaises(TestgresException): + node.catchup() + + def test_promotion(self): + with get_remote_node() as master: + master.init().start() + master.safe_psql('create table abc(id serial)') + + with master.replicate().start() as replica: + master.stop() + replica.promote() + + # make standby becomes writable master + replica.safe_psql('insert into abc values (1)') + res = replica.safe_psql('select * from abc') + self.assertEqual(res, b'1\n') + + def test_dump(self): + query_create = 'create table test as select generate_series(1, 2) as val' + query_select = 'select * from test order by val asc' + + with get_remote_node().init().start() as node1: + + node1.execute(query_create) + for format in ['plain', 'custom', 'directory', 'tar']: + with removing(node1.dump(format=format)) as dump: + with get_remote_node().init().start() as node3: + if format == 'directory': + self.assertTrue(os.path.isdir(dump)) + else: + self.assertTrue(os.path.isfile(dump)) + # restore dump + node3.restore(filename=dump) + res = node3.execute(query_select) + self.assertListEqual(res, [(1, ), (2, )]) + + def test_users(self): + with get_remote_node().init().start() as node: + node.psql('create role test_user login') + value = node.safe_psql('select 1', username='test_user') + self.assertEqual(b'1\n', value) + + def test_poll_query_until(self): + with get_remote_node() as node: + node.init().start() + + get_time = 'select extract(epoch from now())' + check_time = 'select extract(epoch from now()) - {} >= 5' + + start_time = node.execute(get_time)[0][0] + node.poll_query_until(query=check_time.format(start_time)) + end_time = node.execute(get_time)[0][0] + + self.assertTrue(end_time - start_time >= 5) + + # check 0 columns + with self.assertRaises(QueryException): + node.poll_query_until( + query='select from pg_catalog.pg_class limit 1') + + # check None, fail + with self.assertRaises(QueryException): + node.poll_query_until(query='create table abc (val int)') + + # check None, ok + node.poll_query_until(query='create table def()', + expected=None) # returns nothing + + # check 0 rows equivalent to expected=None + node.poll_query_until( + query='select * from pg_catalog.pg_class where true = false', + expected=None) + + # check arbitrary expected value, fail + with self.assertRaises(TimeoutException): + node.poll_query_until(query='select 3', + expected=1, + max_attempts=3, + sleep_time=0.01) + + # check arbitrary expected value, ok + node.poll_query_until(query='select 2', expected=2) + + # check timeout + with self.assertRaises(TimeoutException): + node.poll_query_until(query='select 1 > 2', + max_attempts=3, + sleep_time=0.01) + + # check ProgrammingError, fail + with self.assertRaises(testgres.ProgrammingError): + node.poll_query_until(query='dummy1') + + # check ProgrammingError, ok + with self.assertRaises(TimeoutException): + node.poll_query_until(query='dummy2', + max_attempts=3, + sleep_time=0.01, + suppress={testgres.ProgrammingError}) + + # check 1 arg, ok + node.poll_query_until('select true') + + def test_logging(self): + logfile = tempfile.NamedTemporaryFile('w', delete=True) + + log_conf = { + 'version': 1, + 'handlers': { + 'file': { + 'class': 'logging.FileHandler', + 'filename': logfile.name, + 'formatter': 'base_format', + 'level': logging.DEBUG, + }, + }, + 'formatters': { + 'base_format': { + 'format': '%(node)-5s: %(message)s', + }, + }, + 'root': { + 'handlers': ('file', ), + 'level': 'DEBUG', + }, + } + + logging.config.dictConfig(log_conf) + + with scoped_config(use_python_logging=True): + node_name = 'master' + + with get_new_node(name=node_name) as master: + master.init().start() + + # execute a dummy query a few times + for i in range(20): + master.execute('select 1') + time.sleep(0.01) + + # let logging worker do the job + time.sleep(0.1) + + # check that master's port is found + with open(logfile.name, 'r') as log: + lines = log.readlines() + self.assertTrue(any(node_name in s for s in lines)) + + # test logger after stop/start/restart + master.stop() + master.start() + master.restart() + self.assertTrue(master._logger.is_alive()) + + @unittest.skipUnless(util_exists('pgbench'), 'might be missing') + def test_pgbench(self): + with get_remote_node().init().start() as node: + + # initialize pgbench DB and run benchmarks + node.pgbench_init(scale=2, foreign_keys=True, + options=['-q']).pgbench_run(time=2) + + # run TPC-B benchmark + proc = node.pgbench(stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + options=['-T3']) + + out, _ = proc.communicate() + out = out.decode('utf-8') + + self.assertTrue('tps' in out) + + def test_pg_config(self): + # check same instances + a = get_pg_config() + b = get_pg_config() + self.assertEqual(id(a), id(b)) + + # save right before config change + c1 = get_pg_config() + + # modify setting for this scope + with scoped_config(cache_pg_config=False) as config: + + # sanity check for value + self.assertFalse(config.cache_pg_config) + + # save right after config change + c2 = get_pg_config() + + # check different instances after config change + self.assertNotEqual(id(c1), id(c2)) + + # check different instances + a = get_pg_config() + b = get_pg_config() + self.assertNotEqual(id(a), id(b)) + + def test_config_stack(self): + # no such option + with self.assertRaises(TypeError): + configure_testgres(dummy=True) + + # we have only 1 config in stack + with self.assertRaises(IndexError): + pop_config() + + d0 = TestgresConfig.cached_initdb_dir + d1 = 'dummy_abc' + d2 = 'dummy_def' + + with scoped_config(cached_initdb_dir=d1) as c1: + self.assertEqual(c1.cached_initdb_dir, d1) + + with scoped_config(cached_initdb_dir=d2) as c2: + + stack_size = len(testgres.config.config_stack) + + # try to break a stack + with self.assertRaises(TypeError): + with scoped_config(dummy=True): + pass + + self.assertEqual(c2.cached_initdb_dir, d2) + self.assertEqual(len(testgres.config.config_stack), stack_size) + + self.assertEqual(c1.cached_initdb_dir, d1) + + self.assertEqual(TestgresConfig.cached_initdb_dir, d0) + + def test_unix_sockets(self): + with get_remote_node() as node: + node.init(unix_sockets=False, allow_streaming=True) + node.start() + + node.execute('select 1') + node.safe_psql('select 1') + + with node.replicate().start() as r: + r.execute('select 1') + r.safe_psql('select 1') + + def test_auto_name(self): + with get_remote_node().init(allow_streaming=True).start() as m: + with m.replicate().start() as r: + + # check that nodes are running + self.assertTrue(m.status()) + self.assertTrue(r.status()) + + # check their names + self.assertNotEqual(m.name, r.name) + self.assertTrue('testgres' in m.name) + self.assertTrue('testgres' in r.name) + + def test_file_tail(self): + from testgres.utils import file_tail + + s1 = "the quick brown fox jumped over that lazy dog\n" + s2 = "abc\n" + s3 = "def\n" + + with tempfile.NamedTemporaryFile(mode='r+', delete=True) as f: + sz = 0 + while sz < 3 * 8192: + sz += len(s1) + f.write(s1) + f.write(s2) + f.write(s3) + + f.seek(0) + lines = file_tail(f, 3) + self.assertEqual(lines[0], s1) + self.assertEqual(lines[1], s2) + self.assertEqual(lines[2], s3) + + f.seek(0) + lines = file_tail(f, 1) + self.assertEqual(lines[0], s3) + + def test_isolation_levels(self): + with get_remote_node().init().start() as node: + with node.connect() as con: + # string levels + con.begin('Read Uncommitted').commit() + con.begin('Read Committed').commit() + con.begin('Repeatable Read').commit() + con.begin('Serializable').commit() + + # enum levels + con.begin(IsolationLevel.ReadUncommitted).commit() + con.begin(IsolationLevel.ReadCommitted).commit() + con.begin(IsolationLevel.RepeatableRead).commit() + con.begin(IsolationLevel.Serializable).commit() + + # check wrong level + with self.assertRaises(QueryException): + con.begin('Garbage').commit() + + def test_ports_management(self): + # check that no ports have been bound yet + self.assertEqual(len(bound_ports), 0) + + with get_remote_node() as node: + # check that we've just bound a port + self.assertEqual(len(bound_ports), 1) + + # check that bound_ports contains our port + port_1 = list(bound_ports)[0] + port_2 = node.port + self.assertEqual(port_1, port_2) + + # check that port has been freed successfully + self.assertEqual(len(bound_ports), 0) + + def test_exceptions(self): + str(StartNodeException('msg', [('file', 'lines')])) + str(ExecUtilException('msg', 'cmd', 1, 'out')) + str(QueryException('msg', 'query')) + + def test_version_management(self): + a = PgVer('10.0') + b = PgVer('10') + c = PgVer('9.6.5') + d = PgVer('15.0') + e = PgVer('15rc1') + f = PgVer('15beta4') + + self.assertTrue(a == b) + self.assertTrue(b > c) + self.assertTrue(a > c) + self.assertTrue(d > e) + self.assertTrue(e > f) + self.assertTrue(d > f) + + version = get_pg_version() + with get_remote_node() as node: + self.assertTrue(isinstance(version, six.string_types)) + self.assertTrue(isinstance(node.version, PgVer)) + self.assertEqual(node.version, PgVer(version)) + + def test_child_pids(self): + master_processes = [ + ProcessType.AutovacuumLauncher, + ProcessType.BackgroundWriter, + ProcessType.Checkpointer, + ProcessType.StatsCollector, + ProcessType.WalSender, + ProcessType.WalWriter, + ] + + if pg_version_ge('10'): + master_processes.append(ProcessType.LogicalReplicationLauncher) + + repl_processes = [ + ProcessType.Startup, + ProcessType.WalReceiver, + ] + + with get_remote_node().init().start() as master: + + # master node doesn't have a source walsender! + with self.assertRaises(TestgresException): + master.source_walsender + + with master.connect() as con: + self.assertGreater(con.pid, 0) + + with master.replicate().start() as replica: + + # test __str__ method + str(master.child_processes[0]) + + master_pids = master.auxiliary_pids + for ptype in master_processes: + self.assertIn(ptype, master_pids) + + replica_pids = replica.auxiliary_pids + for ptype in repl_processes: + self.assertIn(ptype, replica_pids) + + # there should be exactly 1 source walsender for replica + self.assertEqual(len(master_pids[ProcessType.WalSender]), 1) + pid1 = master_pids[ProcessType.WalSender][0] + pid2 = replica.source_walsender.pid + self.assertEqual(pid1, pid2) + + replica.stop() + + # there should be no walsender after we've stopped replica + with self.assertRaises(TestgresException): + replica.source_walsender + + def test_child_process_dies(self): + # test for FileNotFound exception during child_processes() function + with subprocess.Popen(["sleep", "60"]) as process: + self.assertEqual(process.poll(), None) + # collect list of processes currently running + children = psutil.Process(os.getpid()).children() + # kill a process, so received children dictionary becomes invalid + process.kill() + process.wait() + # try to handle children list -- missing processes will have ptype "ProcessType.Unknown" + [ProcessProxy(p) for p in children] + + +if __name__ == '__main__': + if os.environ.get('ALT_CONFIG'): + suite = unittest.TestSuite() + + # Small subset of tests for alternative configs (PG_BIN or PG_CONFIG) + suite.addTest(TestgresTests('test_pg_config')) + suite.addTest(TestgresTests('test_pg_ctl')) + suite.addTest(TestgresTests('test_psql')) + suite.addTest(TestgresTests('test_replicate')) + + print('Running tests for alternative config:') + for t in suite: + print(t) + print() + + runner = unittest.TextTestRunner() + runner.run(suite) + else: + unittest.main() From 72e6d5d466bb76473b376c9bc4b9f36fa4afbda0 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Sat, 17 Jun 2023 02:08:23 +0200 Subject: [PATCH 13/23] PBCKP-152 fixed test_simple and test_remote --- setup.py | 3 +- testgres/backup.py | 4 +- testgres/cache.py | 7 +- testgres/config.py | 6 + testgres/connection.py | 2 +- testgres/defaults.py | 10 +- testgres/logger.py | 23 ++-- testgres/node.py | 97 ++++++++------- testgres/operations/local_ops.py | 101 +++++++++++----- testgres/operations/os_ops.py | 2 +- testgres/operations/remote_ops.py | 195 +++++++++++++++++++----------- testgres/utils.py | 109 +++-------------- tests/test_remote.py | 23 ++-- tests/test_simple.py | 14 +-- tests/test_simple_remote.py | 61 +++++----- 15 files changed, 353 insertions(+), 304 deletions(-) diff --git a/setup.py b/setup.py index b5162dce..8cb0f70a 100755 --- a/setup.py +++ b/setup.py @@ -13,7 +13,8 @@ "psutil", "packaging", "paramiko", - "fabric" + "fabric", + "sshtunnel" ] # Add compatibility enum class diff --git a/testgres/backup.py b/testgres/backup.py index c0fd6e50..c4cc952b 100644 --- a/testgres/backup.py +++ b/testgres/backup.py @@ -77,7 +77,7 @@ def __init__(self, "-D", data_dir, "-X", xlog_method.value ] # yapf: disable - execute_utility(_params, self.log_file, self.os_ops) + execute_utility(_params, self.log_file) def __enter__(self): return self @@ -139,7 +139,7 @@ def spawn_primary(self, name=None, destroy=True): # Build a new PostgresNode NodeClass = self.original_node.__class__ - with clean_on_error(NodeClass(name=name, base_dir=base_dir)) as node: + with clean_on_error(NodeClass(name=name, base_dir=base_dir, os_ops=self.original_node.os_ops)) as node: # New nodes should always remove dir tree node._should_rm_dirs = True diff --git a/testgres/cache.py b/testgres/cache.py index ef07e976..bf8658c9 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -26,12 +26,11 @@ def cached_initdb(data_dir, logfile=None, params=None, os_ops: OsOperations = Lo """ Perform initdb or use cached node files. """ - testgres_config.os_ops = os_ops - def call_initdb(initdb_dir, log=None): + def call_initdb(initdb_dir, log=logfile): try: _params = [get_bin_path("initdb"), "-D", initdb_dir, "-N"] - execute_utility(_params + (params or []), log, os_ops) + execute_utility(_params + (params or []), log) except ExecUtilException as e: raise_from(InitNodeException("Failed to run initdb"), e) @@ -62,7 +61,7 @@ def call_initdb(initdb_dir, log=None): # XXX: build new WAL segment with our system id _params = [get_bin_path("pg_resetwal"), "-D", data_dir, "-f"] - execute_utility(_params, logfile, os_ops=os_ops) + execute_utility(_params, logfile) except ExecUtilException as e: msg = "Failed to reset WAL for system id" diff --git a/testgres/config.py b/testgres/config.py index b21d8356..b6c43926 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from .consts import TMP_CACHE +from .operations.os_ops import OsOperations from .operations.local_ops import LocalOperations @@ -121,6 +122,11 @@ def copy(self): return copy.copy(self) + @staticmethod + def set_os_ops(os_ops: OsOperations): + testgres_config.os_ops = os_ops + testgres_config.cached_initdb_dir = os_ops.mkdtemp(prefix=TMP_CACHE) + # cached dirs to be removed cached_initdb_dirs = set() diff --git a/testgres/connection.py b/testgres/connection.py index 6725b14f..d28d81bd 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -37,7 +37,7 @@ def __init__(self, # Set default arguments dbname = dbname or default_dbname() - username = username or default_username(node.os_ops) + username = username or default_username() self._node = node diff --git a/testgres/defaults.py b/testgres/defaults.py index 34bcc08b..d77361d7 100644 --- a/testgres/defaults.py +++ b/testgres/defaults.py @@ -2,7 +2,7 @@ import struct import uuid -from .operations.local_ops import LocalOperations +from .config import testgres_config as tconf def default_dbname(): @@ -13,11 +13,11 @@ def default_dbname(): return 'postgres' -def default_username(os_ops=LocalOperations()): +def default_username(): """ Return default username (current user). """ - return os_ops.get_user() + return tconf.os_ops.get_user() def generate_app_name(): @@ -28,7 +28,7 @@ def generate_app_name(): return 'testgres-{}'.format(str(uuid.uuid4())) -def generate_system_id(os_ops=LocalOperations()): +def generate_system_id(): """ Generate a new 64-bit unique system identifier for node. """ @@ -43,7 +43,7 @@ def generate_system_id(os_ops=LocalOperations()): system_id = 0 system_id |= (secs << 32) system_id |= (usecs << 12) - system_id |= (os_ops.get_pid() & 0xFFF) + system_id |= (tconf.os_ops.get_pid() & 0xFFF) # pack ULL in native byte order return struct.pack('=Q', system_id) diff --git a/testgres/logger.py b/testgres/logger.py index abd4d255..59579002 100644 --- a/testgres/logger.py +++ b/testgres/logger.py @@ -5,19 +5,20 @@ import threading import time - # create logger log = logging.getLogger('Testgres') -log.setLevel(logging.DEBUG) -# create console handler and set level to debug -ch = logging.StreamHandler() -ch.setLevel(logging.DEBUG) -# create formatter -formatter = logging.Formatter('\n%(asctime)s - %(name)s[%(levelname)s]: %(message)s') -# add formatter to ch -ch.setFormatter(formatter) -# add ch to logger -log.addHandler(ch) + +if not log.handlers: + log.setLevel(logging.WARN) + # create console handler and set level to debug + ch = logging.StreamHandler() + ch.setLevel(logging.WARN) + # create formatter + formatter = logging.Formatter('\n%(asctime)s - %(name)s[%(levelname)s]: %(message)s') + # add formatter to ch + ch.setFormatter(formatter) + # add ch to logger + log.addHandler(ch) class TestgresLogger(threading.Thread): diff --git a/testgres/node.py b/testgres/node.py index 9d183f96..5ad18ace 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -6,7 +6,6 @@ import threading from queue import Queue -import psutil import time try: @@ -23,7 +22,6 @@ except ImportError: raise ImportError("You must have psycopg2 or pg8000 modules installed") -from shutil import rmtree from six import raise_from, iteritems, text_type from .enums import \ @@ -128,7 +126,7 @@ def __repr__(self): class PostgresNode(object): def __init__(self, name=None, port=None, base_dir=None, - host='127.0.0.1', hostname='localhost', ssh_key=None, username=default_username()): + host='127.0.0.1', hostname='localhost', ssh_key=None, username=default_username(), os_ops=None): """ PostgresNode constructor. @@ -147,15 +145,19 @@ def __init__(self, name=None, port=None, base_dir=None, # basic self.name = name or generate_app_name() - self.port = port or reserve_port() - self.host = host - self.hostname = hostname - self.ssh_key = ssh_key - if hostname != 'localhost' or host != '127.0.0.1': - self.os_ops = RemoteOperations(host, hostname, ssh_key) + if os_ops: + self.os_ops = os_ops + elif ssh_key: + self.os_ops = RemoteOperations(host=host, hostname=hostname, ssh_key=ssh_key, username=username) else: - self.os_ops = LocalOperations(username=username) + self.os_ops = LocalOperations(host=host, hostname=hostname, username=username) + + self.port = self.os_ops.port or reserve_port() + + self.host = self.os_ops.host + self.hostname = self.os_ops.hostname + self.ssh_key = self.os_ops.ssh_key testgres_config.os_ops = self.os_ops # defaults for __exit__() @@ -243,7 +245,7 @@ def child_processes(self): """ # get a list of postmaster's children - children = psutil.Process(self.pid).children() + children = self.os_ops.get_remote_children(self.pid) return [ProcessProxy(p) for p in children] @@ -511,21 +513,18 @@ def get_auth_method(t): # get auth methods auth_local = get_auth_method('local') auth_host = get_auth_method('host') + subnet_base = ".".join(self.os_ops.host.split('.')[:-1] + ['0']) new_lines = [ u"local\treplication\tall\t\t\t{}\n".format(auth_local), u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), - - u"host\treplication\tall\t0.0.0.0/0\t{}\n".format(auth_host), - u"host\tall\tall\t0.0.0.0/0\t{}\n".format(auth_host), - - u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host) + u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host), + u"host\treplication\t{}\t{}/24\t\t{}\n".format(self.os_ops.username, subnet_base, auth_host), + u"host\tall\t{}\t{}/24\t\t{}\n".format(self.os_ops.username, subnet_base, auth_host) ] # yapf: disable # write missing lines - for line in new_lines: - if line not in lines: - self.os_ops.write(hba_conf, line) + self.os_ops.write(hba_conf, new_lines) # overwrite config file self.os_ops.write(postgres_conf, '', truncate=True) @@ -533,7 +532,7 @@ def get_auth_method(t): self.append_conf(fsync=fsync, max_worker_processes=MAX_WORKER_PROCESSES, log_statement=log_statement, - listen_addresses=self.host, + listen_addresses='*', port=self.port) # yapf:disable # common replication settings @@ -598,9 +597,11 @@ def append_conf(self, line='', filename=PG_CONF_FILE, **kwargs): value = 'on' if value else 'off' elif not str(value).replace('.', '', 1).isdigit(): value = "'{}'".format(value) - - # format a new config line - lines.append('{} = {}'.format(option, value)) + if value == '*': + lines.append("{} = '*'".format(option)) + else: + # format a new config line + lines.append('{} = {}'.format(option, value)) config_name = os.path.join(self.data_dir, filename) conf_text = '' @@ -624,7 +625,9 @@ def status(self): "-D", self.data_dir, "status" ] # yapf: disable - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + out = execute_utility(_params, self.utils_log_file) + if 'no server running' in out: + return NodeStatus.Uninitialized return NodeStatus.Running except ExecUtilException as e: @@ -646,7 +649,7 @@ def get_control_data(self): _params += ["-D"] if self._pg_version >= PgVer('9.5') else [] _params += [self.data_dir] - data = execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + data = execute_utility(_params, self.utils_log_file) out_dict = {} @@ -709,8 +712,8 @@ def start(self, params=[], wait=True): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) - except ExecUtilException as e: + execute_utility(_params, self.utils_log_file) + except Exception as e: msg = 'Cannot start node' files = self._collect_special_files() raise_from(StartNodeException(msg, files), e) @@ -740,7 +743,7 @@ def stop(self, params=[], wait=True): "stop" ] + params # yapf: disable - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + execute_utility(_params, self.utils_log_file) self._maybe_stop_logger() self.is_started = False @@ -782,7 +785,7 @@ def restart(self, params=[]): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + execute_utility(_params, self.utils_log_file) except ExecUtilException as e: msg = 'Cannot restart node' files = self._collect_special_files() @@ -809,7 +812,7 @@ def reload(self, params=[]): "reload" ] + params # yapf: disable - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + execute_utility(_params, self.utils_log_file) return self @@ -831,7 +834,7 @@ def promote(self, dbname=None, username=None): "promote" ] # yapf: disable - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + execute_utility(_params, self.utils_log_file) # for versions below 10 `promote` is asynchronous so we need to wait # until it actually becomes writable @@ -866,7 +869,7 @@ def pg_ctl(self, params): "-w" # wait ] + params # yapf: disable - return execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + return execute_utility(_params, self.utils_log_file) def free_port(self): """ @@ -898,7 +901,7 @@ def cleanup(self, max_attempts=3): else: rm_dir = self.data_dir # just data, save logs - rmtree(rm_dir, ignore_errors=True) + self.os_ops.rmdirs(rm_dir, ignore_errors=True) return self @@ -951,7 +954,7 @@ def psql(self, # select query source if query: - psql_params.extend(("-c", query)) + psql_params.extend(("-c", '"{}"'.format(query))) elif filename: psql_params.extend(("-f", filename)) else: @@ -961,7 +964,7 @@ def psql(self, psql_params.append(dbname) # start psql process - status_code, out, err = self.os_ops.exec_command(psql_params, shell=False, verbose=True, input=input) + status_code, out, err = self.os_ops.exec_command(psql_params, verbose=True, input=input) return status_code, out, err @@ -987,13 +990,17 @@ def safe_psql(self, query=None, expect_error=False, **kwargs): # force this setting kwargs['ON_ERROR_STOP'] = 1 - - ret, out, err = self.psql(query=query, **kwargs) + try: + ret, out, err = self.psql(query=query, **kwargs) + except ExecUtilException as e: + ret = e.exit_code + out = e.out + err = e.message if ret: if expect_error: - out = (err or b'').decode('utf-8') + out = err or b'' else: - raise QueryException((err or b'').decode('utf-8'), query) + raise QueryException(err or b'', query) elif expect_error: assert False, f"Exception was expected, but query finished successfully: `{query}` " @@ -1049,7 +1056,7 @@ def tmpfile(): "-F", format.value ] # yapf: disable - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + execute_utility(_params, self.utils_log_file) return filename @@ -1078,7 +1085,7 @@ def restore(self, filename, dbname=None, username=None): # try pg_restore if dump is binary formate, and psql if not try: - execute_utility(_params, self.utils_log_name, os_ops=self.os_ops) + execute_utility(_params, self.utils_log_name) except ExecUtilException: self.psql(filename=filename, dbname=dbname, username=username) @@ -1335,7 +1342,7 @@ def pgbench(self, # Set default arguments dbname = dbname or default_dbname() - username = username or default_username(self.os_ops) + username = username or default_username() _params = [ get_bin_path("pgbench"), @@ -1347,7 +1354,7 @@ def pgbench(self, # should be the last one _params.append(dbname) - proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, shell=False, proc=True) + proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, proc=True) return proc @@ -1387,7 +1394,7 @@ def pgbench_run(self, dbname=None, username=None, options=[], **kwargs): # Set default arguments dbname = dbname or default_dbname() - username = username or default_username(os_ops=self.os_ops) + username = username or default_username() _params = [ get_bin_path("pgbench"), @@ -1410,7 +1417,7 @@ def pgbench_run(self, dbname=None, username=None, options=[], **kwargs): # should be the last one _params.append(dbname) - return execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + return execute_utility(_params, self.utils_log_file) def connect(self, dbname=None, diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index acb10df8..010e3cc0 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -5,32 +5,66 @@ import tempfile from shutil import rmtree +import psutil + +from testgres.exceptions import ExecUtilException from testgres.logger import log from .os_ops import OsOperations from .os_ops import pglib +try: + from shutil import which as find_executable +except ImportError: + from distutils.spawn import find_executable + CMD_TIMEOUT_SEC = 60 class LocalOperations(OsOperations): - def __init__(self, username=None): - super().__init__() + def __init__(self, host='127.0.0.1', hostname='localhost', port=None, username=None): + super().__init__(username) + self.host = host + self.hostname = hostname + self.port = port + self.ssh_key = None self.username = username or self.get_user() # Command execution def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=True, text=False, input=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None): - log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") - # Source global profile file + execute command - try: + """ + Execute a command in a subprocess. + + Args: + - cmd: The command to execute. + - wait_exit: Whether to wait for the subprocess to exit before returning. + - verbose: Whether to return verbose output. + - expect_error: Whether to raise an error if the subprocess exits with an error status. + - encoding: The encoding to use for decoding the subprocess output. + - shell: Whether to use shell when executing the subprocess. + - text: Whether to return str instead of bytes for the subprocess output. + - input: The input to pass to the subprocess. + - stdout: The stdout to use for the subprocess. + - stderr: The stderr to use for the subprocess. + - proc: The process to use for subprocess creation. + :return: The output of the subprocess. + """ + if isinstance(cmd, list): + cmd = " ".join(cmd) + log.debug(f"Executing command: `{cmd}`") + + if os.name == 'nt': + with tempfile.NamedTemporaryFile() as buf: + process = subprocess.Popen(cmd, stdout=buf, stderr=subprocess.STDOUT) + process.communicate() + buf.seek(0) + result = buf.read().decode(encoding) + return result + else: if proc: - return subprocess.Popen(cmd, - shell=shell, - stdin=input or subprocess.PIPE, - stdout=stdout, - stderr=stderr) + return subprocess.Popen(cmd, shell=shell, stdin=input, stdout=stdout, stderr=stderr) process = subprocess.run( cmd, input=input, @@ -43,41 +77,32 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, exit_status = process.returncode result = process.stdout error = process.stderr + found_error = "error" in error.decode(encoding or 'utf-8').lower() + if encoding: + result = result.decode(encoding) + error = error.decode(encoding) if expect_error: raise Exception(result, error) - if exit_status != 0 or "error" in error.lower().decode(encoding or 'utf-8'): # Decode error for comparison - log.error( - f"Problem in executing command: `{cmd}`\nerror: {error.decode(encoding or 'utf-8')}\nexit_code: {exit_status}" - # Decode for logging - ) - + if exit_status != 0 or found_error: + if exit_status == 0: + exit_status = 1 + raise ExecUtilException(message=f'Utility exited with non-zero code. Error `{error}`', + command=cmd, + exit_code=exit_status, + out=result) if verbose: return exit_status, result, error else: return result - except Exception as e: - log.error(f"Unexpected error while executing command `{cmd}`: {e}") - return None - # Environment setup def environ(self, var_name): cmd = f"echo ${var_name}" - return self.exec_command(cmd).strip() + return self.exec_command(cmd, encoding='utf-8').strip() def find_executable(self, executable): - search_paths = self.environ("PATH") - if not search_paths: - return None - - search_paths = search_paths.split(self.pathsep) - for path in search_paths: - remote_file = os.path.join(path, executable) - if self.isfile(remote_file): - return remote_file - - return None + return find_executable(executable) def is_executable(self, file): # Check if the file is executable @@ -157,6 +182,9 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal read_and_write: If True, the file will be opened with read and write permissions ('r+' option); if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) """ + # If it is a bytes str or list + if isinstance(data, bytes) or isinstance(data, list) and all(isinstance(item, bytes) for item in data): + binary = True mode = "wb" if binary else "w" if not truncate: mode = "ab" if binary else "a" @@ -221,6 +249,12 @@ def readlines(self, filename, num_lines=0, binary=False, encoding=None): def isfile(self, remote_file): return os.path.isfile(remote_file) + def isdir(self, dirname): + return os.path.isdir(dirname) + + def remove_file(self, filename): + return os.remove(filename) + # Processes control def kill(self, pid, signal): # Kill the process @@ -231,6 +265,9 @@ def get_pid(self): # Get current process id return os.getpid() + def get_remote_children(self, pid): + return psutil.Process(pid).children() + # Database control def db_connect(self, dbname, user, password=None, host="localhost", port=5432): conn = pglib.connect( diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index 1ee1f869..68925616 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -15,7 +15,7 @@ def __init__(self, username=None): self.username = username # Command execution - def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): + def exec_command(self, cmd, **kwargs): raise NotImplementedError() # Environment setup diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index e2248015..d45614a1 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -2,20 +2,30 @@ import tempfile from typing import Optional +import sshtunnel + import paramiko from paramiko import SSHClient -from logger import log +from testgres.exceptions import ExecUtilException +from testgres.logger import log + from .os_ops import OsOperations from .os_ops import pglib +sshtunnel.SSH_TIMEOUT = 5.0 +sshtunnel.TUNNEL_TIMEOUT = 5.0 + + error_markers = [b'error', b'Permission denied'] class RemoteOperations(OsOperations): - def __init__(self, hostname="localhost", host="127.0.0.1", ssh_key=None, username=None): + def __init__(self, host="127.0.0.1", hostname='localhost', port=None, ssh_key=None, username=None): super().__init__(username) self.host = host + self.hostname = hostname + self.port = port self.ssh_key = ssh_key self.remote = True self.ssh = self.ssh_connect() @@ -57,51 +67,50 @@ def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=Fa Args: - cmd (str): The command to be executed. """ + if self.ssh is None or not self.ssh.get_transport() or not self.ssh.get_transport().is_active(): + self.ssh = self.ssh_connect() + if isinstance(cmd, list): cmd = " ".join(cmd) - try: - if input: - stdin, stdout, stderr = self.ssh.exec_command(cmd) - stdin.write(input.encode("utf-8")) - stdin.flush() - else: - stdin, stdout, stderr = self.ssh.exec_command(cmd) - exit_status = 0 - if wait_exit: - exit_status = stdout.channel.recv_exit_status() - - if encoding: - result = stdout.read().decode(encoding) - error = stderr.read().decode(encoding) - else: - result = stdout.read() - error = stderr.read() + if input: + stdin, stdout, stderr = self.ssh.exec_command(cmd) + stdin.write(input) + stdin.flush() + else: + stdin, stdout, stderr = self.ssh.exec_command(cmd) + exit_status = 0 + if wait_exit: + exit_status = stdout.channel.recv_exit_status() + + if encoding: + result = stdout.read().decode(encoding) + error = stderr.read().decode(encoding) + else: + result = stdout.read() + error = stderr.read() - if expect_error: - raise Exception(result, error) + if expect_error: + raise Exception(result, error) - if encoding: - error_found = exit_status != 0 or any( - marker.decode(encoding) in error for marker in error_markers) - else: - error_found = exit_status != 0 or any( - marker in error for marker in error_markers) - - if error_found: - log.error( - f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}" - ) - if exit_status == 0: - exit_status = 1 - - if verbose: - return exit_status, result, error - else: - return result + if encoding: + error_found = exit_status != 0 or any( + marker.decode(encoding) in error for marker in error_markers) + else: + error_found = exit_status != 0 or any( + marker in error for marker in error_markers) - except Exception as e: - log.error(f"Unexpected error while executing command `{cmd}`: {e}") - return None + if error_found: + if exit_status == 0: + exit_status = 1 + raise ExecUtilException(message=f"Utility exited with non-zero code. Error: {error.decode(encoding or 'utf-8')}", + command=cmd, + exit_code=exit_status, + out=result) + + if verbose: + return exit_status, result, error + else: + return result # Environment setup def environ(self, var_name: str) -> str: @@ -111,7 +120,7 @@ def environ(self, var_name: str) -> str: - var_name (str): The name of the environment variable. """ cmd = f"echo ${var_name}" - return self.exec_command(cmd).strip() + return self.exec_command(cmd, encoding='utf-8').strip() def find_executable(self, executable): search_paths = self.environ("PATH") @@ -142,7 +151,7 @@ def add_to_path(self, new_path): os.environ["PATH"] = f"{new_path}{pathsep}{path}" return pathsep - def set_env(self, var_name: str, var_val: str) -> None: + def set_env(self, var_name: str, var_val: str): """ Set the value of an environment variable. Args: @@ -153,11 +162,11 @@ def set_env(self, var_name: str, var_val: str) -> None: # Get environment variables def get_user(self): - return self.exec_command("echo $USER") + return self.exec_command("echo $USER", encoding='utf-8').strip() def get_name(self): cmd = 'python3 -c "import os; print(os.name)"' - return self.exec_command(cmd).strip() + return self.exec_command(cmd, encoding='utf-8').strip() # Work with dirs def makedirs(self, path, remove_existing=False): @@ -219,10 +228,19 @@ def mkdtemp(self, prefix=None): """ Creates a temporary directory in the remote server. Args: - prefix (str): The prefix of the temporary directory name. + - prefix (str): The prefix of the temporary directory name. """ - temp_dir = self.exec_command(f"mkdtemp -d {prefix}", encoding='utf-8') - return temp_dir.strip() + if prefix: + temp_dir = self.exec_command(f"mktemp -d {prefix}XXXXX", encoding='utf-8') + else: + temp_dir = self.exec_command("mktemp -d", encoding='utf-8') + + if temp_dir: + if not os.path.isabs(temp_dir): + temp_dir = os.path.join('/home', self.username, temp_dir.strip()) + return temp_dir + else: + raise ExecUtilException("Could not create temporary directory.") def mkstemp(self, prefix=None): cmd = f"mktemp {prefix}XXXXXX" @@ -230,6 +248,10 @@ def mkstemp(self, prefix=None): return filename def copytree(self, src, dst): + if not os.path.isabs(dst): + dst = os.path.join('~', dst) + if self.isdir(dst): + raise FileExistsError(f"Directory {dst} already exists.") return self.exec_command(f"cp -r {src} {dst}") # Work with files @@ -253,20 +275,40 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal if read_and_write: mode = "r+b" if binary else "r+" - with tempfile.NamedTemporaryFile(mode=mode) as tmp_file: + with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file: + if not truncate: + with self.ssh_connect() as ssh: + sftp = ssh.open_sftp() + try: + sftp.get(filename, tmp_file.name) + tmp_file.seek(0, os.SEEK_END) + except FileNotFoundError: + pass # File does not exist yet, we'll create it + sftp.close() if isinstance(data, bytes) and not binary: data = data.decode(encoding) elif isinstance(data, str) and binary: data = data.encode(encoding) - - tmp_file.write(data) + if isinstance(data, list): + # ensure each line ends with a newline + data = [s if s.endswith('\n') else s + '\n' for s in data] + tmp_file.writelines(data) + else: + tmp_file.write(data) tmp_file.flush() with self.ssh_connect() as ssh: sftp = ssh.open_sftp() + remote_directory = os.path.dirname(filename) + try: + sftp.stat(remote_directory) + except IOError: + sftp.mkdir(remote_directory) sftp.put(tmp_file.name, filename) sftp.close() + os.remove(tmp_file.name) + def touch(self, filename): """ Create a new file or update the access and modification times of an existing file on the remote server. @@ -307,6 +349,15 @@ def isfile(self, remote_file): result = int(stdout.strip()) return result == 0 + def isdir(self, dirname): + cmd = f"if [ -d {dirname} ]; then echo True; else echo False; fi" + response = self.exec_command(cmd, encoding='utf-8') + return response.strip() == "True" + + def remove_file(self, filename): + cmd = f"rm {filename}" + return self.exec_command(cmd) + # Processes control def kill(self, pid, signal): # Kill the process @@ -317,8 +368,14 @@ def get_pid(self): # Get current process id return self.exec_command("echo $$") + def get_remote_children(self, pid): + command = f"pgrep -P {pid}" + stdin, stdout, stderr = self.ssh.exec_command(command) + children = stdout.readlines() + return [int(child_pid.strip()) for child_pid in children] + # Database control - def db_connect(self, dbname, user, password=None, host="127.0.0.1", hostname="localhost", port=5432): + def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432): """ Connects to a PostgreSQL database on the remote system. Args: @@ -332,21 +389,19 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", hostname="lo This function establishes a connection to a PostgreSQL database on the remote system using the specified parameters. It returns a connection object that can be used to interact with the database. """ - transport = self.ssh.get_transport() - local_port = 9090 # or any other available port - - transport.open_channel( - 'direct-tcpip', - (hostname, port), - (host, local_port) - ) - - conn = pglib.connect( - host=host, - port=local_port, - database=dbname, - user=user, - password=password, - ) - return conn + with sshtunnel.open_tunnel( + (host, 22), # Remote server IP and SSH port + ssh_username=self.username, + ssh_pkey=self.ssh_key, + remote_bind_address=(host, port), # PostgreSQL server IP and PostgreSQL port + local_bind_address=('localhost', port), # Local machine IP and available port + ): + conn = pglib.connect( + host=host, + port=port, + dbname=dbname, + user=user, + password=password + ) + return conn diff --git a/testgres/utils.py b/testgres/utils.py index b72c7da0..73c36e2c 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -3,30 +3,18 @@ from __future__ import division from __future__ import print_function -import io import os import port_for -import subprocess import sys -import tempfile from contextlib import contextmanager from packaging.version import Version -try: - from shutil import which as find_executable -except ImportError: - from distutils.spawn import find_executable from six import iteritems -from fabric import Connection -from .operations.remote_ops import RemoteOperations -from .operations.local_ops import LocalOperations -from .operations.os_ops import OsOperations - -from .config import testgres_config -from .exceptions import ExecUtilException +from .config import testgres_config as tconf +from .logger import log # rows returned by PG_CONFIG _pg_config_data = {} @@ -57,90 +45,34 @@ def release_port(port): bound_ports.discard(port) -def execute_utility(args, logfile=None, os_ops: OsOperations = LocalOperations()): +def execute_utility(args, logfile=None): """ Execute utility (pg_ctl, pg_dump etc). Args: - os_ops: LocalOperations for local node or RemoteOperations for node that connected by ssh. args: utility + arguments (list). logfile: path to file to store stdout and stderr. Returns: stdout of executed utility. """ - - if isinstance(os_ops, RemoteOperations): - conn = Connection( - os_ops.hostname, - connect_kwargs={ - "key_filename": f"{os_ops.ssh_key}", - }, - ) - # TODO skip remote ssh run if we are on the localhost. - # result = conn.run('hostname', hide=True) - # add logger - - cmd = ' '.join(args) - result = conn.run(cmd, hide=True) - - return result - - # run utility - if os.name == 'nt': - # using output to a temporary file in Windows - buf = tempfile.NamedTemporaryFile() - - process = subprocess.Popen( - args, # util + params - stdout=buf, - stderr=subprocess.STDOUT) - process.communicate() - - # get result - buf.file.flush() - buf.file.seek(0) - out = buf.file.read() - buf.close() - else: - process = subprocess.Popen( - args, # util + params - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) - - # get result - out, _ = process.communicate() - - # decode result - out = '' if not out else out.decode('utf-8') - - # format command command = u' '.join(args) + exit_status, out, error = tconf.os_ops.exec_command(command, verbose=True) + # decode result + out = '' if not out else out + if isinstance(out, bytes): + out = out.decode('utf-8') # write new log entry if possible if logfile: try: - with io.open(logfile, 'a') as file_out: - file_out.write(command) - - if out: - # comment-out lines - lines = ('# ' + line for line in out.splitlines(True)) - file_out.write(u'\n') - file_out.writelines(lines) - - file_out.write(u'\n') + tconf.os_ops.write(filename=logfile, data=command, truncate=True) + if out: + # comment-out lines + lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n'] + tconf.os_ops.write(filename=logfile, data=lines) except IOError: - pass - - exit_code = process.returncode - if exit_code: - message = 'Utility exited with non-zero code' - raise ExecUtilException(message=message, - command=command, - exit_code=exit_code, - out=out) - + log.warn(f"Problem with writing to logfile `{logfile}` during run command `{command}`") return out @@ -149,23 +81,22 @@ def get_bin_path(filename): Return absolute path to an executable using PG_BIN or PG_CONFIG. This function does nothing if 'filename' is already absolute. """ - # check if it's already absolute if os.path.isabs(filename): return filename - # try PG_CONFIG + # try PG_CONFIG - get from local machine pg_config = os.environ.get("PG_CONFIG") if pg_config: bindir = get_pg_config()["BINDIR"] return os.path.join(bindir, filename) # try PG_BIN - pg_bin = os.environ.get("PG_BIN") + pg_bin = tconf.os_ops.environ("PG_BIN") if pg_bin: return os.path.join(pg_bin, filename) - pg_config_path = find_executable('pg_config') + pg_config_path = tconf.os_ops.find_executable('pg_config') if pg_config_path: bindir = get_pg_config(pg_config_path)["BINDIR"] return os.path.join(bindir, filename) @@ -181,7 +112,7 @@ def get_pg_config(pg_config_path=None): def cache_pg_config_data(cmd): # execute pg_config and get the output - out = subprocess.check_output([cmd]).decode('utf-8') + out = tconf.os_ops.exec_command(cmd, encoding='utf-8') data = {} for line in out.splitlines(): @@ -196,7 +127,7 @@ def cache_pg_config_data(cmd): return data # drop cache if asked to - if not testgres_config.cache_pg_config: + if not tconf.cache_pg_config: global _pg_config_data _pg_config_data = {} @@ -226,7 +157,7 @@ def get_pg_version(): # get raw version (e.g. postgres (PostgreSQL) 9.5.7) _params = [get_bin_path('postgres'), '--version'] - raw_ver = subprocess.check_output(_params).decode('utf-8') + raw_ver = tconf.os_ops.exec_command(_params, encoding='utf-8') # cook version of PostgreSQL version = raw_ver.strip().split(' ')[-1] \ diff --git a/tests/test_remote.py b/tests/test_remote.py index 0155956c..7bc6b2f1 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -1,6 +1,7 @@ import pytest -from testgres.operations.remote_ops import RemoteOperations +from testgres import ExecUtilException +from testgres import RemoteOperations class TestRemoteOperations: @@ -31,9 +32,11 @@ def test_exec_command_failure(self): Test exec_command for command execution failure. """ cmd = "nonexistent_command" - exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True) - - assert error == b'bash: line 1: nonexistent_command: command not found\n' + try: + exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True) + except ExecUtilException as e: + error = e.message + assert error == 'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n' def test_is_executable_true(self): """ @@ -82,8 +85,11 @@ def test_makedirs_and_rmdirs_failure(self): self.operations.makedirs(path) # Test rmdirs - exit_status, result, error = self.operations.rmdirs(path, verbose=True) - assert error == b"rm: cannot remove '/root/test_dir': Permission denied\n" + try: + exit_status, result, error = self.operations.rmdirs(path, verbose=True) + except ExecUtilException as e: + error = e.message + assert error == "Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n" def test_listdir(self): """ @@ -119,11 +125,12 @@ def test_write_text_file(self): filename = "/tmp/test_file.txt" data = "Hello, world!" + self.operations.write(filename, data, truncate=True) self.operations.write(filename, data) response = self.operations.read(filename) - assert response == data + assert response == data + data def test_write_binary_file(self): """ @@ -132,7 +139,7 @@ def test_write_binary_file(self): filename = "/tmp/test_file.bin" data = b"\x00\x01\x02\x03" - self.operations.write(filename, data, binary=True) + self.operations.write(filename, data, binary=True, truncate=True) response = self.operations.read(filename, binary=True) diff --git a/tests/test_simple.py b/tests/test_simple.py index e8b8abee..2f8ff62b 100755 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -252,34 +252,34 @@ def test_psql(self): # check returned values (1 arg) res = node.psql('select 1') - self.assertEqual(res, (0, b'1\n', b'')) + self.assertEqual((0, b'1\n', b''), res) # check returned values (2 args) res = node.psql('postgres', 'select 2') - self.assertEqual(res, (0, b'2\n', b'')) + self.assertEqual((0, b'2\n', b''), res) # check returned values (named) res = node.psql(query='select 3', dbname='postgres') - self.assertEqual(res, (0, b'3\n', b'')) + self.assertEqual((0, b'3\n', b''), res) # check returned values (1 arg) res = node.safe_psql('select 4') - self.assertEqual(res, b'4\n') + self.assertEqual(b'4\n', res) # check returned values (2 args) res = node.safe_psql('postgres', 'select 5') - self.assertEqual(res, b'5\n') + self.assertEqual(b'5\n', res) # check returned values (named) res = node.safe_psql(query='select 6', dbname='postgres') - self.assertEqual(res, b'6\n') + self.assertEqual(b'6\n', res) # check feeding input node.safe_psql('create table horns (w int)') node.safe_psql('copy horns from stdin (format csv)', input=b"1\n2\n3\n\\.\n") _sum = node.safe_psql('select sum(w) from horns') - self.assertEqual(b'6\n', _sum) + self.assertEqual(_sum, b'6\n') # check psql's default args, fails with self.assertRaises(QueryException): diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 179f3ffb..18c3450a 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -16,7 +16,6 @@ import logging.config from contextlib import contextmanager -from shutil import rmtree from testgres.exceptions import \ InitNodeException, \ @@ -31,13 +30,13 @@ TestgresConfig, \ configure_testgres, \ scoped_config, \ - pop_config + pop_config, testgres_config from testgres import \ NodeStatus, \ ProcessType, \ IsolationLevel, \ - get_new_node + get_new_node, RemoteOperations from testgres import \ get_bin_path, \ @@ -54,6 +53,12 @@ from testgres.node import ProcessProxy +os_ops = RemoteOperations(host='172.18.0.3', + username='dev', + ssh_key='/home/vika/Desktop/work/probackup/dev-ee-probackup/container_files/postgres/ssh/id_ed25519') +testgres_config.set_os_ops(os_ops=os_ops) + + def pg_version_ge(version): cur_ver = PgVer(get_pg_version()) min_ver = PgVer(version) @@ -62,16 +67,16 @@ def pg_version_ge(version): def util_exists(util): def good_properties(f): - return (os.path.exists(f) and # noqa: W504 - os.path.isfile(f) and # noqa: W504 - os.access(f, os.X_OK)) # yapf: disable + return (os_ops.path_exists(f) and # noqa: W504 + os_ops.isfile(f) and # noqa: W504 + os_ops.is_executable(f)) # yapf: disable # try to resolve it if good_properties(get_bin_path(util)): return True # check if util is in PATH - for path in os.environ["PATH"].split(os.pathsep): + for path in os_ops.environ("PATH").split(os_ops.pathsep): if good_properties(os.path.join(path, util)): return True @@ -81,14 +86,15 @@ def removing(f): try: yield f finally: - if os.path.isfile(f): - os.remove(f) - elif os.path.isdir(f): - rmtree(f, ignore_errors=True) + if os_ops.isfile(f): + os_ops.remove_file(f) + + elif os_ops.isdir(f): + os_ops.rmdirs(f, ignore_errors=True) def get_remote_node(): - return get_new_node(host='172.18.0.3', username='dev', ssh_key='/home/vika/Desktop/work/probackup/dev-ee-probackup/container_files/postgres/ssh/id_ed25519') + return get_new_node(host=os_ops.host, username=os_ops.username, ssh_key=os_ops.ssh_key) class TestgresRemoteTests(unittest.TestCase): @@ -109,14 +115,13 @@ def test_custom_init(self): initdb_params=['--auth-local=reject', '--auth-host=reject']) hba_file = os.path.join(node.data_dir, 'pg_hba.conf') - with open(hba_file, 'r') as conf: - lines = conf.readlines() + lines = os_ops.readlines(hba_file) - # check number of lines - self.assertGreaterEqual(len(lines), 6) + # check number of lines + self.assertGreaterEqual(len(lines), 6) - # there should be no trust entries at all - self.assertFalse(any('trust' in s for s in lines)) + # there should be no trust entries at all + self.assertFalse(any('trust' in s for s in lines)) def test_double_init(self): with get_remote_node().init() as node: @@ -164,14 +169,14 @@ def test_node_exit(self): node.safe_psql('select 1') # we should save the DB for "debugging" - self.assertTrue(os.path.exists(base_dir)) - rmtree(base_dir, ignore_errors=True) + self.assertTrue(os_ops.path_exists(base_dir)) + os_ops.rmdirs(base_dir, ignore_errors=True) with get_remote_node().init() as node: base_dir = node.base_dir # should have been removed by default - self.assertFalse(os.path.exists(base_dir)) + self.assertFalse(os_ops.path_exists(base_dir)) def test_double_start(self): with get_remote_node().init().start() as node: @@ -607,9 +612,9 @@ def test_dump(self): with removing(node1.dump(format=format)) as dump: with get_remote_node().init().start() as node3: if format == 'directory': - self.assertTrue(os.path.isdir(dump)) + self.assertTrue(os_ops.isdir(dump)) else: - self.assertTrue(os.path.isfile(dump)) + self.assertTrue(os_ops.isfile(dump)) # restore dump node3.restore(filename=dump) res = node3.execute(query_select) @@ -986,14 +991,14 @@ def test_child_process_dies(self): if __name__ == '__main__': - if os.environ.get('ALT_CONFIG'): + if os_ops.environ('ALT_CONFIG'): suite = unittest.TestSuite() # Small subset of tests for alternative configs (PG_BIN or PG_CONFIG) - suite.addTest(TestgresTests('test_pg_config')) - suite.addTest(TestgresTests('test_pg_ctl')) - suite.addTest(TestgresTests('test_psql')) - suite.addTest(TestgresTests('test_replicate')) + suite.addTest(TestgresRemoteTests('test_pg_config')) + suite.addTest(TestgresRemoteTests('test_pg_ctl')) + suite.addTest(TestgresRemoteTests('test_psql')) + suite.addTest(TestgresRemoteTests('test_replicate')) print('Running tests for alternative config:') for t in suite: From 2c2d2c5cf0eaf6a5a4c21fbf2d589adbdccbdbed Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Thu, 22 Jun 2023 23:14:55 +0200 Subject: [PATCH 14/23] PBCKP-588 test fix test_restore_after_failover --- testgres/node.py | 13 +++-- testgres/operations/local_ops.py | 19 ++----- testgres/operations/os_ops.py | 3 -- testgres/operations/remote_ops.py | 84 +++++++++++++++++++------------ testgres/utils.py | 16 +++--- 5 files changed, 74 insertions(+), 61 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 5ad18ace..2ab17c75 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -625,9 +625,11 @@ def status(self): "-D", self.data_dir, "status" ] # yapf: disable - out = execute_utility(_params, self.utils_log_file) - if 'no server running' in out: + status_code, out, err = execute_utility(_params, self.utils_log_file, verbose=True) + if 'does not exist' in err: return NodeStatus.Uninitialized + elif'no server running' in out: + return NodeStatus.Stopped return NodeStatus.Running except ExecUtilException as e: @@ -712,14 +714,17 @@ def start(self, params=[], wait=True): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file) + exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True) + if 'does not exist' in error: + raise Exception + if 'server started' in out: + self.is_started = True except Exception as e: msg = 'Cannot start node' files = self._collect_special_files() raise_from(StartNodeException(msg, files), e) self._maybe_start_logger() - self.is_started = True return self def stop(self, params=[], wait=True): diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index 010e3cc0..6a26910d 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -52,7 +52,7 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, :return: The output of the subprocess. """ if isinstance(cmd, list): - cmd = " ".join(cmd) + cmd = ' '.join(item.decode('utf-8') if isinstance(item, bytes) else item for item in cmd) log.debug(f"Executing command: `{cmd}`") if os.name == 'nt': @@ -98,8 +98,7 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, # Environment setup def environ(self, var_name): - cmd = f"echo ${var_name}" - return self.exec_command(cmd, encoding='utf-8').strip() + return os.environ.get(var_name) def find_executable(self, executable): return find_executable(executable) @@ -108,17 +107,6 @@ def is_executable(self, file): # Check if the file is executable return os.access(file, os.X_OK) - def add_to_path(self, new_path): - pathsep = self.pathsep - # Check if the directory is already in PATH - path = self.environ("PATH") - if new_path not in path.split(pathsep): - if self.remote: - self.exec_command(f"export PATH={new_path}{pathsep}{path}") - else: - os.environ["PATH"] = f"{new_path}{pathsep}{path}" - return pathsep - def set_env(self, var_name, var_val): # Check if the directory is already in PATH os.environ[var_name] = var_val @@ -128,8 +116,7 @@ def get_user(self): return getpass.getuser() def get_name(self): - cmd = 'python3 -c "import os; print(os.name)"' - return self.exec_command(cmd).strip() + return os.name # Work with dirs def makedirs(self, path, remove_existing=False): diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index 68925616..c3f57653 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -29,9 +29,6 @@ def is_executable(self, file): # Check if the file is executable raise NotImplementedError() - def add_to_path(self, new_path): - raise NotImplementedError() - def set_env(self, var_name, var_val): # Check if the directory is already in PATH raise NotImplementedError() diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index d45614a1..8e94a7fe 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -20,6 +20,22 @@ error_markers = [b'error', b'Permission denied'] +class PsUtilProcessProxy: + def __init__(self, ssh, pid): + self.ssh = ssh + self.pid = pid + + def kill(self): + command = f"kill {self.pid}" + self.ssh.exec_command(command) + + def cmdline(self): + command = f"ps -p {self.pid} -o cmd --no-headers" + stdin, stdout, stderr = self.ssh.exec_command(command) + cmdline = stdout.read().decode('utf-8').strip() + return cmdline.split() + + class RemoteOperations(OsOperations): def __init__(self, host="127.0.0.1", hostname='localhost', port=None, ssh_key=None, username=None): super().__init__(username) @@ -71,7 +87,7 @@ def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=Fa self.ssh = self.ssh_connect() if isinstance(cmd, list): - cmd = " ".join(cmd) + cmd = ' '.join(item.decode('utf-8') if isinstance(item, bytes) else item for item in cmd) if input: stdin, stdout, stderr = self.ssh.exec_command(cmd) stdin.write(input) @@ -140,17 +156,6 @@ def is_executable(self, file): is_exec = self.exec_command(f"test -x {file} && echo OK") return is_exec == b"OK\n" - def add_to_path(self, new_path): - pathsep = self.pathsep - # Check if the directory is already in PATH - path = self.environ("PATH") - if new_path not in path.split(pathsep): - if self.remote: - self.exec_command(f"export PATH={new_path}{pathsep}{path}") - else: - os.environ["PATH"] = f"{new_path}{pathsep}{path}" - return pathsep - def set_env(self, var_name: str, var_val: str): """ Set the value of an environment variable. @@ -243,9 +248,17 @@ def mkdtemp(self, prefix=None): raise ExecUtilException("Could not create temporary directory.") def mkstemp(self, prefix=None): - cmd = f"mktemp {prefix}XXXXXX" - filename = self.exec_command(cmd).strip() - return filename + if prefix: + temp_dir = self.exec_command(f"mktemp {prefix}XXXXX", encoding='utf-8') + else: + temp_dir = self.exec_command("mktemp", encoding='utf-8') + + if temp_dir: + if not os.path.isabs(temp_dir): + temp_dir = os.path.join('/home', self.username, temp_dir.strip()) + return temp_dir + else: + raise ExecUtilException("Could not create temporary directory.") def copytree(self, src, dst): if not os.path.isabs(dst): @@ -291,7 +304,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal data = data.encode(encoding) if isinstance(data, list): # ensure each line ends with a newline - data = [s if s.endswith('\n') else s + '\n' for s in data] + data = [(s if isinstance(s, str) else s.decode('utf-8')).rstrip('\n') + '\n' for s in data] tmp_file.writelines(data) else: tmp_file.write(data) @@ -351,8 +364,8 @@ def isfile(self, remote_file): def isdir(self, dirname): cmd = f"if [ -d {dirname} ]; then echo True; else echo False; fi" - response = self.exec_command(cmd, encoding='utf-8') - return response.strip() == "True" + response = self.exec_command(cmd) + return response.strip() == b"True" def remove_file(self, filename): cmd = f"rm {filename}" @@ -366,16 +379,16 @@ def kill(self, pid, signal): def get_pid(self): # Get current process id - return self.exec_command("echo $$") + return int(self.exec_command("echo $$", encoding='utf-8')) def get_remote_children(self, pid): command = f"pgrep -P {pid}" stdin, stdout, stderr = self.ssh.exec_command(command) children = stdout.readlines() - return [int(child_pid.strip()) for child_pid in children] + return [PsUtilProcessProxy(self.ssh, int(child_pid.strip())) for child_pid in children] # Database control - def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432): + def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, ssh_key=None): """ Connects to a PostgreSQL database on the remote system. Args: @@ -389,19 +402,26 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432): This function establishes a connection to a PostgreSQL database on the remote system using the specified parameters. It returns a connection object that can be used to interact with the database. """ - with sshtunnel.open_tunnel( - (host, 22), # Remote server IP and SSH port - ssh_username=self.username, - ssh_pkey=self.ssh_key, - remote_bind_address=(host, port), # PostgreSQL server IP and PostgreSQL port - local_bind_address=('localhost', port), # Local machine IP and available port - ): + tunnel = sshtunnel.open_tunnel( + (host, 22), # Remote server IP and SSH port + ssh_username=user or self.username, + ssh_pkey=ssh_key or self.ssh_key, + remote_bind_address=(host, port), # PostgreSQL server IP and PostgreSQL port + local_bind_address=('localhost', port) # Local machine IP and available port + ) + + tunnel.start() + + try: conn = pglib.connect( - host=host, - port=port, + host=host, # change to 'localhost' because we're connecting through a local ssh tunnel + port=tunnel.local_bind_port, # use the local bind port set up by the tunnel dbname=dbname, - user=user, + user=user or self.username, password=password ) - return conn + return conn + except Exception as e: + tunnel.stop() + raise e diff --git a/testgres/utils.py b/testgres/utils.py index 73c36e2c..d8321b3e 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -45,7 +45,7 @@ def release_port(port): bound_ports.discard(port) -def execute_utility(args, logfile=None): +def execute_utility(args, logfile=None, verbose=False): """ Execute utility (pg_ctl, pg_dump etc). @@ -56,24 +56,28 @@ def execute_utility(args, logfile=None): Returns: stdout of executed utility. """ - command = u' '.join(args) - exit_status, out, error = tconf.os_ops.exec_command(command, verbose=True) + exit_status, out, error = tconf.os_ops.exec_command(args, verbose=True) # decode result out = '' if not out else out if isinstance(out, bytes): out = out.decode('utf-8') + if isinstance(error, bytes): + error = error.decode('utf-8') # write new log entry if possible if logfile: try: - tconf.os_ops.write(filename=logfile, data=command, truncate=True) + tconf.os_ops.write(filename=logfile, data=args, truncate=True) if out: # comment-out lines lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n'] tconf.os_ops.write(filename=logfile, data=lines) except IOError: - log.warn(f"Problem with writing to logfile `{logfile}` during run command `{command}`") - return out + log.warn(f"Problem with writing to logfile `{logfile}` during run command `{args}`") + if verbose: + return exit_status, out, error + else: + return out def get_bin_path(filename): From 1b4f74aa1f9eb48a7c19a07eba7a6a2083b5c26a Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Fri, 23 Jun 2023 01:55:22 +0200 Subject: [PATCH 15/23] PBCKP-588 test partially fixed test_simple_remote.py 41/43 --- testgres/node.py | 8 +++++--- testgres/operations/remote_ops.py | 30 +++++++++++++++++++++++++----- testgres/pubsub.py | 2 +- tests/test_simple_remote.py | 25 +++++++++++++------------ 4 files changed, 44 insertions(+), 21 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 2ab17c75..7a6e475c 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -519,8 +519,8 @@ def get_auth_method(t): u"local\treplication\tall\t\t\t{}\n".format(auth_local), u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host), - u"host\treplication\t{}\t{}/24\t\t{}\n".format(self.os_ops.username, subnet_base, auth_host), - u"host\tall\t{}\t{}/24\t\t{}\n".format(self.os_ops.username, subnet_base, auth_host) + u"host\treplication\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host), + u"host\tall\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host) ] # yapf: disable # write missing lines @@ -790,7 +790,9 @@ def restart(self, params=[]): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file) + error_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True) + if 'could not start server' in error: + raise ExecUtilException except ExecUtilException as e: msg = 'Cannot restart node' files = self._collect_special_files() diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 8e94a7fe..274a87cf 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -1,5 +1,6 @@ import os import tempfile +import time from typing import Optional import sshtunnel @@ -46,11 +47,29 @@ def __init__(self, host="127.0.0.1", hostname='localhost', port=None, ssh_key=No self.remote = True self.ssh = self.ssh_connect() self.username = username or self.get_user() + self.tunnel = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close_tunnel() + if getattr(self, 'ssh', None): + self.ssh.close() def __del__(self): - if self.ssh: + if getattr(self, 'ssh', None): self.ssh.close() + def close_tunnel(self): + if getattr(self, 'tunnel', None): + self.tunnel.stop(force=True) + start_time = time.time() + while self.tunnel.is_active: + if time.time() - start_time > sshtunnel.TUNNEL_TIMEOUT: + break + time.sleep(0.5) + def ssh_connect(self) -> Optional[SSHClient]: if not self.remote: return None @@ -402,7 +421,8 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s This function establishes a connection to a PostgreSQL database on the remote system using the specified parameters. It returns a connection object that can be used to interact with the database. """ - tunnel = sshtunnel.open_tunnel( + self.close_tunnel() + self.tunnel = sshtunnel.open_tunnel( (host, 22), # Remote server IP and SSH port ssh_username=user or self.username, ssh_pkey=ssh_key or self.ssh_key, @@ -410,12 +430,12 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s local_bind_address=('localhost', port) # Local machine IP and available port ) - tunnel.start() + self.tunnel.start() try: conn = pglib.connect( host=host, # change to 'localhost' because we're connecting through a local ssh tunnel - port=tunnel.local_bind_port, # use the local bind port set up by the tunnel + port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel dbname=dbname, user=user or self.username, password=password @@ -423,5 +443,5 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s return conn except Exception as e: - tunnel.stop() + self.tunnel.stop() raise e diff --git a/testgres/pubsub.py b/testgres/pubsub.py index da85caac..1be673bb 100644 --- a/testgres/pubsub.py +++ b/testgres/pubsub.py @@ -214,4 +214,4 @@ def catchup(self, username=None): username=username or self.pub.username, max_attempts=LOGICAL_REPL_MAX_CATCHUP_ATTEMPTS) except Exception as e: - raise_from(CatchUpException("Failed to catch up", query), e) + raise_from(CatchUpException("Failed to catch up"), e) diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 18c3450a..0b104ff0 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -6,7 +6,6 @@ import subprocess import tempfile - import testgres import time import six @@ -138,6 +137,7 @@ def test_init_after_cleanup(self): @unittest.skipUnless(util_exists('pg_resetwal'), 'might be missing') @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') def test_init_unique_system_id(self): + # FAIL # this function exists in PostgreSQL 9.6+ query = 'select system_identifier from pg_control_system()' @@ -291,7 +291,7 @@ def test_psql(self): node.safe_psql('copy horns from stdin (format csv)', input=b"1\n2\n3\n\\.\n") _sum = node.safe_psql('select sum(w) from horns') - self.assertEqual(b'6\n', _sum) + self.assertEqual(_sum, b'6\n') # check psql's default args, fails with self.assertRaises(QueryException): @@ -688,6 +688,7 @@ def test_poll_query_until(self): node.poll_query_until('select true') def test_logging(self): + # FAIL logfile = tempfile.NamedTemporaryFile('w', delete=True) log_conf = { @@ -747,14 +748,11 @@ def test_pgbench(self): options=['-q']).pgbench_run(time=2) # run TPC-B benchmark - proc = node.pgbench(stdout=subprocess.PIPE, + out = node.pgbench(stdout=subprocess.PIPE, stderr=subprocess.STDOUT, options=['-T3']) - out, _ = proc.communicate() - out = out.decode('utf-8') - - self.assertTrue('tps' in out) + self.assertTrue(b'tps = ' in out) def test_pg_config(self): # check same instances @@ -764,7 +762,6 @@ def test_pg_config(self): # save right before config change c1 = get_pg_config() - # modify setting for this scope with scoped_config(cache_pg_config=False) as config: @@ -819,12 +816,16 @@ def test_unix_sockets(self): node.init(unix_sockets=False, allow_streaming=True) node.start() - node.execute('select 1') - node.safe_psql('select 1') + res_exec = node.execute('select 1') + res_psql = node.safe_psql('select 1') + self.assertEqual(res_exec, [(1,)]) + self.assertEqual(res_psql, b'1\n') with node.replicate().start() as r: - r.execute('select 1') - r.safe_psql('select 1') + res_exec = r.execute('select 1') + res_psql = r.safe_psql('select 1') + self.assertEqual(res_exec, [(1,)]) + self.assertEqual(res_psql, b'1\n') def test_auto_name(self): with get_remote_node().init(allow_streaming=True).start() as m: From 2e916dfb4d044c24c421e80e4cddaa9799d0114d Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 26 Jun 2023 01:30:56 +0200 Subject: [PATCH 16/23] PBCKP-588 fixes after review --- testgres/logger.py | 15 --------------- testgres/node.py | 2 +- testgres/operations/local_ops.py | 4 +--- testgres/operations/os_ops.py | 3 +++ testgres/operations/remote_ops.py | 12 +++++------- testgres/utils.py | 5 ++--- tests/test_remote.py | 2 +- tests/test_simple.py | 16 +++++++++------- tests/test_simple_remote.py | 15 ++++++++------- 9 files changed, 30 insertions(+), 44 deletions(-) diff --git a/testgres/logger.py b/testgres/logger.py index 59579002..b4648f44 100644 --- a/testgres/logger.py +++ b/testgres/logger.py @@ -5,21 +5,6 @@ import threading import time -# create logger -log = logging.getLogger('Testgres') - -if not log.handlers: - log.setLevel(logging.WARN) - # create console handler and set level to debug - ch = logging.StreamHandler() - ch.setLevel(logging.WARN) - # create formatter - formatter = logging.Formatter('\n%(asctime)s - %(name)s[%(levelname)s]: %(message)s') - # add formatter to ch - ch.setFormatter(formatter) - # add ch to logger - log.addHandler(ch) - class TestgresLogger(threading.Thread): """ diff --git a/testgres/node.py b/testgres/node.py index 7a6e475c..5d3dfb72 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -245,7 +245,7 @@ def child_processes(self): """ # get a list of postmaster's children - children = self.os_ops.get_remote_children(self.pid) + children = self.os_ops.get_process_children(self.pid) return [ProcessProxy(p) for p in children] diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index 6a26910d..acd066cf 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -8,7 +8,6 @@ import psutil from testgres.exceptions import ExecUtilException -from testgres.logger import log from .os_ops import OsOperations from .os_ops import pglib @@ -53,7 +52,6 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, """ if isinstance(cmd, list): cmd = ' '.join(item.decode('utf-8') if isinstance(item, bytes) else item for item in cmd) - log.debug(f"Executing command: `{cmd}`") if os.name == 'nt': with tempfile.NamedTemporaryFile() as buf: @@ -252,7 +250,7 @@ def get_pid(self): # Get current process id return os.getpid() - def get_remote_children(self, pid): + def get_process_children(self, pid): return psutil.Process(pid).children() # Database control diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index c3f57653..4b1349b7 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -88,6 +88,9 @@ def get_pid(self): # Get current process id raise NotImplementedError() + def get_process_children(self, pid): + raise NotImplementedError() + # Database control def db_connect(self, dbname, user, password=None, host="localhost", port=5432): raise NotImplementedError() diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 274a87cf..b27ca47d 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -9,7 +9,6 @@ from paramiko import SSHClient from testgres.exceptions import ExecUtilException -from testgres.logger import log from .os_ops import OsOperations from .os_ops import pglib @@ -90,9 +89,9 @@ def _read_ssh_key(self): key = paramiko.RSAKey.from_private_key_file(self.ssh_key) return key except FileNotFoundError: - log.error(f"No such file or directory: '{self.ssh_key}'") + raise ExecUtilException(message=f"No such file or directory: '{self.ssh_key}'") except Exception as e: - log.error(f"An error occurred while reading the ssh key: {e}") + ExecUtilException(message=f"An error occurred while reading the ssh key: {e}") def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=True, text=False, input=None, stdout=None, @@ -400,7 +399,7 @@ def get_pid(self): # Get current process id return int(self.exec_command("echo $$", encoding='utf-8')) - def get_remote_children(self, pid): + def get_process_children(self, pid): command = f"pgrep -P {pid}" stdin, stdout, stderr = self.ssh.exec_command(command) children = stdout.readlines() @@ -414,8 +413,7 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s - dbname (str): The name of the database to connect to. - user (str): The username for the database connection. - password (str, optional): The password for the database connection. Defaults to None. - - host (str, optional): The IP address of the remote system. Defaults to "127.0.0.1". - - hostname (str, optional): The hostname of the remote system. Defaults to "localhost". + - host (str, optional): The IP address of the remote system. Defaults to "localhost". - port (int, optional): The port number of the PostgreSQL service. Defaults to 5432. This function establishes a connection to a PostgreSQL database on the remote system using the specified @@ -444,4 +442,4 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s return conn except Exception as e: self.tunnel.stop() - raise e + raise ExecUtilException("Could not create db tunnel.") diff --git a/testgres/utils.py b/testgres/utils.py index d8321b3e..273c9287 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -12,9 +12,8 @@ from six import iteritems - +from .exceptions import ExecUtilException from .config import testgres_config as tconf -from .logger import log # rows returned by PG_CONFIG _pg_config_data = {} @@ -73,7 +72,7 @@ def execute_utility(args, logfile=None, verbose=False): lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n'] tconf.os_ops.write(filename=logfile, data=lines) except IOError: - log.warn(f"Problem with writing to logfile `{logfile}` during run command `{args}`") + raise ExecUtilException(f"Problem with writing to logfile `{logfile}` during run command `{args}`") if verbose: return exit_status, out, error else: diff --git a/tests/test_remote.py b/tests/test_remote.py index 7bc6b2f1..cdaa6574 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -11,7 +11,7 @@ def setup(self): self.operations = RemoteOperations( host="172.18.0.3", username="dev", - ssh_key='/home/vika/Desktop/work/probackup/dev-ee-probackup/container_files/postgres/ssh/id_ed25519' + ssh_key='../../container_files/postgres/ssh/id_ed25519' ) yield diff --git a/tests/test_simple.py b/tests/test_simple.py index 2f8ff62b..94420b04 100755 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -151,6 +151,8 @@ def test_init_unique_system_id(self): self.assertGreater(id2, id1) def test_node_exit(self): + base_dir = None + with self.assertRaises(QueryException): with get_new_node().init() as node: base_dir = node.base_dir @@ -252,27 +254,27 @@ def test_psql(self): # check returned values (1 arg) res = node.psql('select 1') - self.assertEqual((0, b'1\n', b''), res) + self.assertEqual(res, (0, b'1\n', b'')) # check returned values (2 args) res = node.psql('postgres', 'select 2') - self.assertEqual((0, b'2\n', b''), res) + self.assertEqual(res, (0, b'2\n', b'')) # check returned values (named) res = node.psql(query='select 3', dbname='postgres') - self.assertEqual((0, b'3\n', b''), res) + self.assertEqual(res, (0, b'3\n', b'')) # check returned values (1 arg) res = node.safe_psql('select 4') - self.assertEqual(b'4\n', res) + self.assertEqual(res, b'4\n') # check returned values (2 args) res = node.safe_psql('postgres', 'select 5') - self.assertEqual(b'5\n', res) + self.assertEqual(res, b'5\n') # check returned values (named) res = node.safe_psql(query='select 6', dbname='postgres') - self.assertEqual(b'6\n', res) + self.assertEqual(res, b'6\n') # check feeding input node.safe_psql('create table horns (w int)') @@ -612,7 +614,7 @@ def test_users(self): with get_new_node().init().start() as node: node.psql('create role test_user login') value = node.safe_psql('select 1', username='test_user') - self.assertEqual(b'1\n', value) + self.assertEqual(value, b'1\n') def test_poll_query_until(self): with get_new_node() as node: diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 0b104ff0..f86e623f 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -35,7 +35,8 @@ NodeStatus, \ ProcessType, \ IsolationLevel, \ - get_new_node, RemoteOperations + get_new_node, \ + RemoteOperations from testgres import \ get_bin_path, \ @@ -54,7 +55,7 @@ os_ops = RemoteOperations(host='172.18.0.3', username='dev', - ssh_key='/home/vika/Desktop/work/probackup/dev-ee-probackup/container_files/postgres/ssh/id_ed25519') + ssh_key='../../container_files/postgres/ssh/id_ed25519') testgres_config.set_os_ops(os_ops=os_ops) @@ -92,8 +93,8 @@ def removing(f): os_ops.rmdirs(f, ignore_errors=True) -def get_remote_node(): - return get_new_node(host=os_ops.host, username=os_ops.username, ssh_key=os_ops.ssh_key) +def get_remote_node(name=None): + return get_new_node(name=name, host=os_ops.host, username=os_ops.username, ssh_key=os_ops.ssh_key) class TestgresRemoteTests(unittest.TestCase): @@ -696,7 +697,7 @@ def test_logging(self): 'handlers': { 'file': { 'class': 'logging.FileHandler', - 'filename': logfile.name, + 'filename': logfile, 'formatter': 'base_format', 'level': logging.DEBUG, }, @@ -717,7 +718,7 @@ def test_logging(self): with scoped_config(use_python_logging=True): node_name = 'master' - with get_new_node(name=node_name) as master: + with get_remote_node(name=node_name) as master: master.init().start() # execute a dummy query a few times @@ -729,7 +730,7 @@ def test_logging(self): time.sleep(0.1) # check that master's port is found - with open(logfile.name, 'r') as log: + with open(logfile, 'r') as log: lines = log.readlines() self.assertTrue(any(node_name in s for s in lines)) From 0528541e70a3a3dbe3e019b7bc02552c84e40649 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 26 Jun 2023 02:23:41 +0200 Subject: [PATCH 17/23] PBCKP-588 fixes after review - add ConnectionParams --- testgres/__init__.py | 4 ++-- testgres/api.py | 9 ++++----- testgres/backup.py | 2 +- testgres/node.py | 17 +++++++---------- testgres/operations/local_ops.py | 15 +++++++-------- testgres/operations/os_ops.py | 11 ++++++++--- testgres/operations/remote_ops.py | 31 +++++++++++++------------------ tests/test_remote.py | 11 +++++------ tests/test_simple_remote.py | 12 ++++++------ 9 files changed, 53 insertions(+), 59 deletions(-) diff --git a/testgres/__init__.py b/testgres/__init__.py index 405262dd..ce2636b4 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -46,7 +46,7 @@ First, \ Any -from .operations.os_ops import OsOperations +from .operations.os_ops import OsOperations, ConnectionParams from .operations.local_ops import LocalOperations from .operations.remote_ops import RemoteOperations @@ -60,5 +60,5 @@ "PostgresNode", "NodeApp", "reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version", "First", "Any", - "OsOperations", "LocalOperations", "RemoteOperations" + "OsOperations", "LocalOperations", "RemoteOperations", "ConnectionParams" ] diff --git a/testgres/api.py b/testgres/api.py index b5b76715..8f553529 100644 --- a/testgres/api.py +++ b/testgres/api.py @@ -37,11 +37,10 @@ def get_new_node(name=None, base_dir=None, **kwargs): """ Simply a wrapper around :class:`.PostgresNode` constructor. See :meth:`.PostgresNode.__init__` for details. - For remote connection you can add next parameters: - host='127.0.0.1', - hostname='localhost', - ssh_key=None, - username=default_username() + For remote connection you can add the next parameter: + conn_params = ConnectionParams(host='127.0.0.1', + ssh_key=None, + username=default_username()) """ # NOTE: leave explicit 'name' and 'base_dir' for compatibility return PostgresNode(name=name, base_dir=base_dir, **kwargs) diff --git a/testgres/backup.py b/testgres/backup.py index c4cc952b..a89e214d 100644 --- a/testgres/backup.py +++ b/testgres/backup.py @@ -139,7 +139,7 @@ def spawn_primary(self, name=None, destroy=True): # Build a new PostgresNode NodeClass = self.original_node.__class__ - with clean_on_error(NodeClass(name=name, base_dir=base_dir, os_ops=self.original_node.os_ops)) as node: + with clean_on_error(NodeClass(name=name, base_dir=base_dir, conn_params=self.original_node.os_ops.conn_params)) as node: # New nodes should always remove dir tree node._should_rm_dirs = True diff --git a/testgres/node.py b/testgres/node.py index 5d3dfb72..d12e7324 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -94,6 +94,7 @@ from .backup import NodeBackup +from .operations.os_ops import ConnectionParams from .operations.local_ops import LocalOperations from .operations.remote_ops import RemoteOperations @@ -125,8 +126,7 @@ def __repr__(self): class PostgresNode(object): - def __init__(self, name=None, port=None, base_dir=None, - host='127.0.0.1', hostname='localhost', ssh_key=None, username=default_username(), os_ops=None): + def __init__(self, name=None, port=None, base_dir=None, conn_params: ConnectionParams = ConnectionParams()): """ PostgresNode constructor. @@ -146,17 +146,14 @@ def __init__(self, name=None, port=None, base_dir=None, # basic self.name = name or generate_app_name() - if os_ops: - self.os_ops = os_ops - elif ssh_key: - self.os_ops = RemoteOperations(host=host, hostname=hostname, ssh_key=ssh_key, username=username) + if conn_params.ssh_key: + self.os_ops = RemoteOperations(conn_params) else: - self.os_ops = LocalOperations(host=host, hostname=hostname, username=username) + self.os_ops = LocalOperations(conn_params) - self.port = self.os_ops.port or reserve_port() + self.port = port or reserve_port() self.host = self.os_ops.host - self.hostname = self.os_ops.hostname self.ssh_key = self.os_ops.ssh_key testgres_config.os_ops = self.os_ops @@ -628,7 +625,7 @@ def status(self): status_code, out, err = execute_utility(_params, self.utils_log_file, verbose=True) if 'does not exist' in err: return NodeStatus.Uninitialized - elif'no server running' in out: + elif 'no server running' in out: return NodeStatus.Stopped return NodeStatus.Running diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index acd066cf..bbe6b0d4 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -7,9 +7,9 @@ import psutil -from testgres.exceptions import ExecUtilException +from ..exceptions import ExecUtilException -from .os_ops import OsOperations +from .os_ops import OsOperations, ConnectionParams from .os_ops import pglib try: @@ -21,13 +21,12 @@ class LocalOperations(OsOperations): - def __init__(self, host='127.0.0.1', hostname='localhost', port=None, username=None): - super().__init__(username) - self.host = host - self.hostname = hostname - self.port = port + def __init__(self, conn_params: ConnectionParams = ConnectionParams()): + super().__init__(conn_params.username) + self.conn_params = conn_params + self.host = conn_params.host self.ssh_key = None - self.username = username or self.get_user() + self.username = conn_params.username or self.get_user() # Command execution def exec_command(self, cmd, wait_exit=False, verbose=False, diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index 4b1349b7..9261cacf 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -7,11 +7,16 @@ raise ImportError("You must have psycopg2 or pg8000 modules installed") +class ConnectionParams: + def __init__(self, host='127.0.0.1', ssh_key=None, username=None): + self.host = host + self.ssh_key = ssh_key + self.username = username + + class OsOperations: def __init__(self, username=None): - self.hostname = "localhost" - self.remote = False - self.ssh = None + self.ssh_key = None self.username = username # Command execution diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index b27ca47d..0a90426c 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -8,9 +8,9 @@ import paramiko from paramiko import SSHClient -from testgres.exceptions import ExecUtilException +from ..exceptions import ExecUtilException -from .os_ops import OsOperations +from .os_ops import OsOperations, ConnectionParams from .os_ops import pglib sshtunnel.SSH_TIMEOUT = 5.0 @@ -37,15 +37,13 @@ def cmdline(self): class RemoteOperations(OsOperations): - def __init__(self, host="127.0.0.1", hostname='localhost', port=None, ssh_key=None, username=None): - super().__init__(username) - self.host = host - self.hostname = hostname - self.port = port - self.ssh_key = ssh_key - self.remote = True + def __init__(self, conn_params: ConnectionParams): + super().__init__(conn_params.username) + self.conn_params = conn_params + self.host = conn_params.host + self.ssh_key = conn_params.ssh_key self.ssh = self.ssh_connect() - self.username = username or self.get_user() + self.username = conn_params.username or self.get_user() self.tunnel = None def __enter__(self): @@ -70,14 +68,11 @@ def close_tunnel(self): time.sleep(0.5) def ssh_connect(self) -> Optional[SSHClient]: - if not self.remote: - return None - else: - key = self._read_ssh_key() - ssh = paramiko.SSHClient() - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(self.host, username=self.username, pkey=key) - return ssh + key = self._read_ssh_key() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.host, username=self.username, pkey=key) + return ssh def _read_ssh_key(self): try: diff --git a/tests/test_remote.py b/tests/test_remote.py index cdaa6574..ceb06ee3 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -2,20 +2,19 @@ from testgres import ExecUtilException from testgres import RemoteOperations +from testgres import ConnectionParams class TestRemoteOperations: @pytest.fixture(scope="function", autouse=True) def setup(self): - self.operations = RemoteOperations( - host="172.18.0.3", - username="dev", - ssh_key='../../container_files/postgres/ssh/id_ed25519' - ) + conn_params = ConnectionParams(host="172.18.0.3", + username="dev", + ssh_key='../../container_files/postgres/ssh/id_ed25519') + self.operations = RemoteOperations(conn_params) yield - self.operations.__del__() def test_exec_command_success(self): diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index f86e623f..80cf7674 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -50,12 +50,12 @@ # NOTE: those are ugly imports from testgres import bound_ports from testgres.utils import PgVer -from testgres.node import ProcessProxy +from testgres.node import ProcessProxy, ConnectionParams - -os_ops = RemoteOperations(host='172.18.0.3', - username='dev', - ssh_key='../../container_files/postgres/ssh/id_ed25519') +conn_params = ConnectionParams(host="172.18.0.3", + username="dev", + ssh_key='../../container_files/postgres/ssh/id_ed25519') +os_ops = RemoteOperations(conn_params) testgres_config.set_os_ops(os_ops=os_ops) @@ -94,7 +94,7 @@ def removing(f): def get_remote_node(name=None): - return get_new_node(name=name, host=os_ops.host, username=os_ops.username, ssh_key=os_ops.ssh_key) + return get_new_node(name=name, conn_params=conn_params) class TestgresRemoteTests(unittest.TestCase): From 089ab9b73afbc1fa60c9bab88d0c3e7e52bd5fd4 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 26 Jun 2023 02:37:59 +0200 Subject: [PATCH 18/23] PBCKP-588 fixes after review - remove f-strings --- testgres/connection.py | 2 +- testgres/node.py | 2 +- testgres/operations/local_ops.py | 6 ++-- testgres/operations/remote_ops.py | 58 +++++++++++++++---------------- testgres/utils.py | 2 +- tests/test_simple_remote.py | 5 ++- 6 files changed, 37 insertions(+), 38 deletions(-) diff --git a/testgres/connection.py b/testgres/connection.py index d28d81bd..aeb040ce 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -111,7 +111,7 @@ def execute(self, query, *args): return res except Exception as e: - print(f"Error executing query: {e}") + print("Error executing query: {}".format(e)) return None def close(self): diff --git a/testgres/node.py b/testgres/node.py index d12e7324..0f709b17 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -1006,7 +1006,7 @@ def safe_psql(self, query=None, expect_error=False, **kwargs): else: raise QueryException(err or b'', query) elif expect_error: - assert False, f"Exception was expected, but query finished successfully: `{query}` " + assert False, "Exception was expected, but query finished successfully: `{}` ".format(query) return out diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index bbe6b0d4..c2ee29cd 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -84,7 +84,7 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, if exit_status != 0 or found_error: if exit_status == 0: exit_status = 1 - raise ExecUtilException(message=f'Utility exited with non-zero code. Error `{error}`', + raise ExecUtilException(message='Utility exited with non-zero code. Error `{}`'.format(error), command=cmd, exit_code=exit_status, out=result) @@ -138,7 +138,7 @@ def pathsep(self): elif os_name == "nt": pathsep = ";" else: - raise Exception(f"Unsupported operating system: {os_name}") + raise Exception("Unsupported operating system: {}".format(os_name)) return pathsep def mkdtemp(self, prefix=None): @@ -242,7 +242,7 @@ def remove_file(self, filename): # Processes control def kill(self, pid, signal): # Kill the process - cmd = f"kill -{signal} {pid}" + cmd = "kill -{} {}".format(signal, pid) return self.exec_command(cmd) def get_pid(self): diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 0a90426c..eb996f58 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -26,11 +26,11 @@ def __init__(self, ssh, pid): self.pid = pid def kill(self): - command = f"kill {self.pid}" + command = "kill {}".format(self.pid) self.ssh.exec_command(command) def cmdline(self): - command = f"ps -p {self.pid} -o cmd --no-headers" + command = "ps -p {} -o cmd --no-headers".format(self.pid) stdin, stdout, stderr = self.ssh.exec_command(command) cmdline = stdout.read().decode('utf-8').strip() return cmdline.split() @@ -84,9 +84,9 @@ def _read_ssh_key(self): key = paramiko.RSAKey.from_private_key_file(self.ssh_key) return key except FileNotFoundError: - raise ExecUtilException(message=f"No such file or directory: '{self.ssh_key}'") + raise ExecUtilException(message="No such file or directory: '{}'".format(self.ssh_key)) except Exception as e: - ExecUtilException(message=f"An error occurred while reading the ssh key: {e}") + ExecUtilException(message="An error occurred while reading the ssh key: {}".format(e)) def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=True, text=False, input=None, stdout=None, @@ -131,7 +131,7 @@ def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=Fa if error_found: if exit_status == 0: exit_status = 1 - raise ExecUtilException(message=f"Utility exited with non-zero code. Error: {error.decode(encoding or 'utf-8')}", + raise ExecUtilException(message="Utility exited with non-zero code. Error: {}".format(error.decode(encoding or 'utf-8')), command=cmd, exit_code=exit_status, out=result) @@ -148,7 +148,7 @@ def environ(self, var_name: str) -> str: Args: - var_name (str): The name of the environment variable. """ - cmd = f"echo ${var_name}" + cmd = "echo ${}".format(var_name) return self.exec_command(cmd, encoding='utf-8').strip() def find_executable(self, executable): @@ -166,7 +166,7 @@ def find_executable(self, executable): def is_executable(self, file): # Check if the file is executable - is_exec = self.exec_command(f"test -x {file} && echo OK") + is_exec = self.exec_command("test -x {} && echo OK".format(file)) return is_exec == b"OK\n" def set_env(self, var_name: str, var_val: str): @@ -176,7 +176,7 @@ def set_env(self, var_name: str, var_val: str): - var_name (str): The name of the environment variable. - var_val (str): The value to be set for the environment variable. """ - return self.exec_command(f"export {var_name}={var_val}") + return self.exec_command("export {}={}".format(var_name, var_val)) # Get environment variables def get_user(self): @@ -195,12 +195,12 @@ def makedirs(self, path, remove_existing=False): - remove_existing (bool): If True, the existing directory at the path will be removed. """ if remove_existing: - cmd = f"rm -rf {path} && mkdir -p {path}" + cmd = "rm -rf {} && mkdir -p {}".format(path, path) else: - cmd = f"mkdir -p {path}" + cmd = "mkdir -p {}".format(path) exit_status, result, error = self.exec_command(cmd, verbose=True) if exit_status != 0: - raise Exception(f"Couldn't create dir {path} because of error {error}") + raise Exception("Couldn't create dir {} because of error {}".format(path, error)) return result def rmdirs(self, path, verbose=False, ignore_errors=True): @@ -211,7 +211,7 @@ def rmdirs(self, path, verbose=False, ignore_errors=True): - verbose (bool): If True, return exit status, result, and error. - ignore_errors (bool): If True, do not raise error if directory does not exist. """ - cmd = f"rm -rf {path}" + cmd = "rm -rf {}".format(path) exit_status, result, error = self.exec_command(cmd, verbose=True) if verbose: return exit_status, result, error @@ -224,11 +224,11 @@ def listdir(self, path): Args: path (str): The path to the directory. """ - result = self.exec_command(f"ls {path}") + result = self.exec_command("ls {}".format(path)) return result.splitlines() def path_exists(self, path): - result = self.exec_command(f"test -e {path}; echo $?", encoding='utf-8') + result = self.exec_command("test -e {}; echo $?".format(path), encoding='utf-8') return int(result.strip()) == 0 @property @@ -239,7 +239,7 @@ def pathsep(self): elif os_name == "nt": pathsep = ";" else: - raise Exception(f"Unsupported operating system: {os_name}") + raise Exception("Unsupported operating system: {}".format(os_name)) return pathsep def mkdtemp(self, prefix=None): @@ -249,7 +249,7 @@ def mkdtemp(self, prefix=None): - prefix (str): The prefix of the temporary directory name. """ if prefix: - temp_dir = self.exec_command(f"mktemp -d {prefix}XXXXX", encoding='utf-8') + temp_dir = self.exec_command("mktemp -d {}XXXXX".format(prefix), encoding='utf-8') else: temp_dir = self.exec_command("mktemp -d", encoding='utf-8') @@ -262,7 +262,7 @@ def mkdtemp(self, prefix=None): def mkstemp(self, prefix=None): if prefix: - temp_dir = self.exec_command(f"mktemp {prefix}XXXXX", encoding='utf-8') + temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding='utf-8') else: temp_dir = self.exec_command("mktemp", encoding='utf-8') @@ -277,8 +277,8 @@ def copytree(self, src, dst): if not os.path.isabs(dst): dst = os.path.join('~', dst) if self.isdir(dst): - raise FileExistsError(f"Directory {dst} already exists.") - return self.exec_command(f"cp -r {src} {dst}") + raise FileExistsError("Directory {} already exists.".format(dst)) + return self.exec_command("cp -r {} {}".format(src, dst)) # Work with files def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding='utf-8'): @@ -344,10 +344,10 @@ def touch(self, filename): This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. """ - self.exec_command(f"touch {filename}") + self.exec_command("touch {}".format(filename)) def read(self, filename, binary=False, encoding=None): - cmd = f"cat {filename}" + cmd = "cat {}".format(filename) result = self.exec_command(cmd, encoding=encoding) if not binary and result: @@ -357,9 +357,9 @@ def read(self, filename, binary=False, encoding=None): def readlines(self, filename, num_lines=0, binary=False, encoding=None): if num_lines > 0: - cmd = f"tail -n {num_lines} {filename}" + cmd = "tail -n {} {}".format(num_lines, filename) else: - cmd = f"cat {filename}" + cmd = "cat {}".format(filename) result = self.exec_command(cmd, encoding=encoding) @@ -371,23 +371,23 @@ def readlines(self, filename, num_lines=0, binary=False, encoding=None): return lines def isfile(self, remote_file): - stdout = self.exec_command(f"test -f {remote_file}; echo $?") + stdout = self.exec_command("test -f {}; echo $?".format(remote_file)) result = int(stdout.strip()) return result == 0 def isdir(self, dirname): - cmd = f"if [ -d {dirname} ]; then echo True; else echo False; fi" + cmd = "if [ -d {} ]; then echo True; else echo False; fi".format(dirname) response = self.exec_command(cmd) return response.strip() == b"True" def remove_file(self, filename): - cmd = f"rm {filename}" + cmd = "rm {}".format(filename) return self.exec_command(cmd) # Processes control def kill(self, pid, signal): # Kill the process - cmd = f"kill -{signal} {pid}" + cmd = "kill -{} {}".format(signal, pid) return self.exec_command(cmd) def get_pid(self): @@ -395,7 +395,7 @@ def get_pid(self): return int(self.exec_command("echo $$", encoding='utf-8')) def get_process_children(self, pid): - command = f"pgrep -P {pid}" + command = "pgrep -P {}".format(pid) stdin, stdout, stderr = self.ssh.exec_command(command) children = stdout.readlines() return [PsUtilProcessProxy(self.ssh, int(child_pid.strip())) for child_pid in children] @@ -437,4 +437,4 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s return conn except Exception as e: self.tunnel.stop() - raise ExecUtilException("Could not create db tunnel.") + raise ExecUtilException("Could not create db tunnel. {}".format(e)) diff --git a/testgres/utils.py b/testgres/utils.py index 273c9287..1772d748 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -72,7 +72,7 @@ def execute_utility(args, logfile=None, verbose=False): lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n'] tconf.os_ops.write(filename=logfile, data=lines) except IOError: - raise ExecUtilException(f"Problem with writing to logfile `{logfile}` during run command `{args}`") + raise ExecUtilException("Problem with writing to logfile `{}` during run command `{}`".format(logfile, args)) if verbose: return exit_status, out, error else: diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 80cf7674..448a60ca 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -750,9 +750,8 @@ def test_pgbench(self): # run TPC-B benchmark out = node.pgbench(stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - options=['-T3']) - + stderr=subprocess.STDOUT, + options=['-T3']) self.assertTrue(b'tps = ' in out) def test_pg_config(self): From 190d084a7dbb00f4b844a5d3392194f47fd26073 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Tue, 27 Jun 2023 23:43:18 +0200 Subject: [PATCH 19/23] PBCKP-588 fixes after review - replace subprocess.run on subprocess.Popen --- testgres/operations/local_ops.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index c2ee29cd..fb47194f 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -60,27 +60,26 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, result = buf.read().decode(encoding) return result else: - if proc: - return subprocess.Popen(cmd, shell=shell, stdin=input, stdout=stdout, stderr=stderr) - process = subprocess.run( + process = subprocess.Popen( cmd, - input=input, shell=shell, - text=text, stdout=stdout, stderr=stderr, - timeout=CMD_TIMEOUT_SEC, ) + if proc: + return process + result, error = process.communicate(input) exit_status = process.returncode - result = process.stdout - error = process.stderr + found_error = "error" in error.decode(encoding or 'utf-8').lower() + if encoding: result = result.decode(encoding) error = error.decode(encoding) if expect_error: raise Exception(result, error) + if exit_status != 0 or found_error: if exit_status == 0: exit_status = 1 From 0c26f77db688c57f66bc2029dc92737612e8d736 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Wed, 28 Jun 2023 16:21:12 +0200 Subject: [PATCH 20/23] PBCKP-588 fix failed tests - psql, set_auto_conf --- testgres/node.py | 40 +++++++++++++++++--------- testgres/operations/local_ops.py | 36 ++++++++++++----------- testgres/operations/remote_ops.py | 13 ++++++--- tests/test_remote.py | 15 ++++++---- tests/test_simple_remote.py | 48 ++++++++++++------------------- 5 files changed, 82 insertions(+), 70 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 0f709b17..a146b08d 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -3,6 +3,7 @@ import os import random import signal +import subprocess import threading from queue import Queue @@ -714,14 +715,13 @@ def start(self, params=[], wait=True): exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True) if 'does not exist' in error: raise Exception - if 'server started' in out: - self.is_started = True except Exception as e: msg = 'Cannot start node' files = self._collect_special_files() raise_from(StartNodeException(msg, files), e) self._maybe_start_logger() + self.is_started = True return self def stop(self, params=[], wait=True): @@ -958,7 +958,10 @@ def psql(self, # select query source if query: - psql_params.extend(("-c", '"{}"'.format(query))) + if self.os_ops.remote: + psql_params.extend(("-c", '"{}"'.format(query))) + else: + psql_params.extend(("-c", query)) elif filename: psql_params.extend(("-f", filename)) else: @@ -966,11 +969,20 @@ def psql(self, # should be the last one psql_params.append(dbname) + if not self.os_ops.remote: + # start psql process + process = subprocess.Popen(psql_params, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + # wait until it finishes and get stdout and stderr + out, err = process.communicate(input=input) + return process.returncode, out, err + else: + status_code, out, err = self.os_ops.exec_command(psql_params, verbose=True, input=input) - # start psql process - status_code, out, err = self.os_ops.exec_command(psql_params, verbose=True, input=input) - - return status_code, out, err + return status_code, out, err @method_decorator(positional_args_hack(['dbname', 'query'])) def safe_psql(self, query=None, expect_error=False, **kwargs): @@ -1002,9 +1014,9 @@ def safe_psql(self, query=None, expect_error=False, **kwargs): err = e.message if ret: if expect_error: - out = err or b'' + out = (err or b'').decode('utf-8') else: - raise QueryException(err or b'', query) + raise QueryException((err or b'').decode('utf-8'), query) elif expect_error: assert False, "Exception was expected, but query finished successfully: `{}` ".format(query) @@ -1529,18 +1541,18 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): Defaults to an empty set. """ # parse postgresql.auto.conf - auto_conf_file = os.path.join(self.data_dir, config) - raw_content = self.os_ops.read(auto_conf_file) + path = os.path.join(self.data_dir, config) + lines = self.os_ops.readlines(path) current_options = {} current_directives = [] - for line in raw_content.splitlines(): + for line in lines: # ignore comments if line.startswith('#'): continue - if line == '': + if line.strip() == '': continue if line.startswith('include'): @@ -1570,7 +1582,7 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): for directive in current_directives: auto_conf += directive + "\n" - self.os_ops.write(auto_conf_file, auto_conf) + self.os_ops.write(path, auto_conf, truncate=True) class NodeApp: diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index fb47194f..edd8cde2 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -1,6 +1,7 @@ import getpass import os import shutil +import stat import subprocess import tempfile from shutil import rmtree @@ -8,8 +9,7 @@ import psutil from ..exceptions import ExecUtilException - -from .os_ops import OsOperations, ConnectionParams +from .os_ops import ConnectionParams, OsOperations from .os_ops import pglib try: @@ -18,20 +18,24 @@ from distutils.spawn import find_executable CMD_TIMEOUT_SEC = 60 +error_markers = [b'error', b'Permission denied', b'fatal'] class LocalOperations(OsOperations): - def __init__(self, conn_params: ConnectionParams = ConnectionParams()): - super().__init__(conn_params.username) + def __init__(self, conn_params=None): + if conn_params is None: + conn_params = ConnectionParams() + super(LocalOperations, self).__init__(conn_params.username) self.conn_params = conn_params self.host = conn_params.host self.ssh_key = None + self.remote = False self.username = conn_params.username or self.get_user() # Command execution def exec_command(self, cmd, wait_exit=False, verbose=False, - expect_error=False, encoding=None, shell=True, text=False, - input=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None): + expect_error=False, encoding=None, shell=False, text=False, + input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None): """ Execute a command in a subprocess. @@ -49,9 +53,6 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, - proc: The process to use for subprocess creation. :return: The output of the subprocess. """ - if isinstance(cmd, list): - cmd = ' '.join(item.decode('utf-8') if isinstance(item, bytes) else item for item in cmd) - if os.name == 'nt': with tempfile.NamedTemporaryFile() as buf: process = subprocess.Popen(cmd, stdout=buf, stderr=subprocess.STDOUT) @@ -71,7 +72,7 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, result, error = process.communicate(input) exit_status = process.returncode - found_error = "error" in error.decode(encoding or 'utf-8').lower() + error_found = exit_status != 0 or any(marker in error for marker in error_markers) if encoding: result = result.decode(encoding) @@ -80,7 +81,7 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, if expect_error: raise Exception(result, error) - if exit_status != 0 or found_error: + if exit_status != 0 or error_found: if exit_status == 0: exit_status = 1 raise ExecUtilException(message='Utility exited with non-zero code. Error `{}`'.format(error), @@ -101,7 +102,7 @@ def find_executable(self, executable): def is_executable(self, file): # Check if the file is executable - return os.access(file, os.X_OK) + return os.stat(file).st_mode & stat.S_IXUSR def set_env(self, var_name, var_val): # Check if the directory is already in PATH @@ -116,9 +117,12 @@ def get_name(self): # Work with dirs def makedirs(self, path, remove_existing=False): - if remove_existing and os.path.exists(path): - shutil.rmtree(path) - os.makedirs(path, exist_ok=True) + if remove_existing: + shutil.rmtree(path, ignore_errors=True) + try: + os.makedirs(path) + except FileExistsError: + pass def rmdirs(self, path, ignore_errors=True): return rmtree(path, ignore_errors=ignore_errors) @@ -141,7 +145,7 @@ def pathsep(self): return pathsep def mkdtemp(self, prefix=None): - return tempfile.mkdtemp(prefix=prefix) + return tempfile.mkdtemp(prefix='{}'.format(prefix)) def mkstemp(self, prefix=None): fd, filename = tempfile.mkstemp(prefix=prefix) diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index eb996f58..bdeb423a 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -17,7 +17,7 @@ sshtunnel.TUNNEL_TIMEOUT = 5.0 -error_markers = [b'error', b'Permission denied'] +error_markers = [b'error', b'Permission denied', b'fatal'] class PsUtilProcessProxy: @@ -43,6 +43,7 @@ def __init__(self, conn_params: ConnectionParams): self.host = conn_params.host self.ssh_key = conn_params.ssh_key self.ssh = self.ssh_connect() + self.remote = True self.username = conn_params.username or self.get_user() self.tunnel = None @@ -89,7 +90,7 @@ def _read_ssh_key(self): ExecUtilException(message="An error occurred while reading the ssh key: {}".format(e)) def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, - encoding=None, shell=True, text=False, input=None, stdout=None, + encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None, stderr=None, proc=None): """ Execute a command in the SSH session. @@ -131,7 +132,11 @@ def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=Fa if error_found: if exit_status == 0: exit_status = 1 - raise ExecUtilException(message="Utility exited with non-zero code. Error: {}".format(error.decode(encoding or 'utf-8')), + if encoding: + message = "Utility exited with non-zero code. Error: {}".format(error.decode(encoding)) + else: + message = b"Utility exited with non-zero code. Error: " + error + raise ExecUtilException(message=message, command=cmd, exit_code=exit_status, out=result) @@ -429,7 +434,7 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s conn = pglib.connect( host=host, # change to 'localhost' because we're connecting through a local ssh tunnel port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel - dbname=dbname, + database=dbname, user=user or self.username, password=password ) diff --git a/tests/test_remote.py b/tests/test_remote.py index ceb06ee3..3794349c 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -1,3 +1,5 @@ +import os + import pytest from testgres import ExecUtilException @@ -9,9 +11,10 @@ class TestRemoteOperations: @pytest.fixture(scope="function", autouse=True) def setup(self): - conn_params = ConnectionParams(host="172.18.0.3", - username="dev", - ssh_key='../../container_files/postgres/ssh/id_ed25519') + conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3', + username='dev', + ssh_key=os.getenv( + 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519') self.operations = RemoteOperations(conn_params) yield @@ -35,7 +38,7 @@ def test_exec_command_failure(self): exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True) except ExecUtilException as e: error = e.message - assert error == 'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n' + assert error == b'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n' def test_is_executable_true(self): """ @@ -62,7 +65,7 @@ def test_makedirs_and_rmdirs_success(self): cmd = "pwd" pwd = self.operations.exec_command(cmd, wait_exit=True, encoding='utf-8').strip() - path = f"{pwd}/test_dir" + path = "{}/test_dir".format(pwd) # Test makedirs self.operations.makedirs(path) @@ -88,7 +91,7 @@ def test_makedirs_and_rmdirs_failure(self): exit_status, result, error = self.operations.rmdirs(path, verbose=True) except ExecUtilException as e: error = e.message - assert error == "Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n" + assert error == b"Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n" def test_listdir(self): """ diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 448a60ca..5028bc75 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -52,9 +52,10 @@ from testgres.utils import PgVer from testgres.node import ProcessProxy, ConnectionParams -conn_params = ConnectionParams(host="172.18.0.3", - username="dev", - ssh_key='../../container_files/postgres/ssh/id_ed25519') +conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3', + username='dev', + ssh_key=os.getenv( + 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519') os_ops = RemoteOperations(conn_params) testgres_config.set_os_ops(os_ops=os_ops) @@ -148,14 +149,12 @@ def test_init_unique_system_id(self): with scoped_config(cache_initdb=True, cached_initdb_unique=True) as config: - self.assertTrue(config.cache_initdb) self.assertTrue(config.cached_initdb_unique) # spawn two nodes; ids must be different with get_remote_node().init().start() as node1, \ get_remote_node().init().start() as node2: - id1 = node1.execute(query)[0] id2 = node2.execute(query)[0] @@ -197,10 +196,10 @@ def test_restart(self): # restart, ok res = node.execute('select 1') - self.assertEqual(res, [(1, )]) + self.assertEqual(res, [(1,)]) node.restart() res = node.execute('select 2') - self.assertEqual(res, [(2, )]) + self.assertEqual(res, [(2,)]) # restart, fail with self.assertRaises(StartNodeException): @@ -262,7 +261,6 @@ def test_status(self): def test_psql(self): with get_remote_node().init().start() as node: - # check returned values (1 arg) res = node.psql('select 1') self.assertEqual(res, (0, b'1\n', b'')) @@ -306,7 +304,6 @@ def test_psql(self): def test_transactions(self): with get_remote_node().init().start() as node: - with node.connect() as con: con.begin() con.execute('create table test(val int)') @@ -316,12 +313,12 @@ def test_transactions(self): con.begin() con.execute('insert into test values (2)') res = con.execute('select * from test order by val asc') - self.assertListEqual(res, [(1, ), (2, )]) + self.assertListEqual(res, [(1,), (2,)]) con.rollback() con.begin() res = con.execute('select * from test') - self.assertListEqual(res, [(1, )]) + self.assertListEqual(res, [(1,)]) con.rollback() con.begin() @@ -330,7 +327,6 @@ def test_transactions(self): def test_control_data(self): with get_remote_node() as node: - # node is not initialized yet with self.assertRaises(ExecUtilException): node.get_control_data() @@ -344,7 +340,6 @@ def test_control_data(self): def test_backup_simple(self): with get_remote_node() as master: - # enable streaming for backups master.init(allow_streaming=True) @@ -361,7 +356,7 @@ def test_backup_simple(self): with master.backup(xlog_method='stream') as backup: with backup.spawn_primary().start() as slave: res = slave.execute('select * from test order by i asc') - self.assertListEqual(res, [(1, ), (2, ), (3, ), (4, )]) + self.assertListEqual(res, [(1,), (2,), (3,), (4,)]) def test_backup_multiple(self): with get_remote_node() as node: @@ -369,13 +364,11 @@ def test_backup_multiple(self): with node.backup(xlog_method='fetch') as backup1, \ node.backup(xlog_method='fetch') as backup2: - self.assertNotEqual(backup1.base_dir, backup2.base_dir) with node.backup(xlog_method='fetch') as backup: with backup.spawn_primary('node1', destroy=False) as node1, \ backup.spawn_primary('node2', destroy=False) as node2: - self.assertNotEqual(node1.base_dir, node2.base_dir) def test_backup_exhaust(self): @@ -383,7 +376,6 @@ def test_backup_exhaust(self): node.init(allow_streaming=True).start() with node.backup(xlog_method='fetch') as backup: - # exhaust backup by creating new node with backup.spawn_primary(): pass @@ -418,7 +410,7 @@ def test_replicate(self): with node.replicate().start() as replica: res = replica.execute('select 1') - self.assertListEqual(res, [(1, )]) + self.assertListEqual(res, [(1,)]) node.execute('create table test (val int)', commit=True) @@ -512,7 +504,7 @@ def test_logical_replication(self): node1.safe_psql('insert into test2 values (\'a\'), (\'b\')') sub.catchup() res = node2.execute('select * from test2') - self.assertListEqual(res, [('a', ), ('b', )]) + self.assertListEqual(res, [('a',), ('b',)]) # drop subscription sub.drop() @@ -530,12 +522,12 @@ def test_logical_replication(self): # explicitely add table with self.assertRaises(ValueError): - pub.add_tables([]) # fail + pub.add_tables([]) # fail pub.add_tables(['test2']) node1.safe_psql('insert into test2 values (\'c\')') sub.catchup() res = node2.execute('select * from test2') - self.assertListEqual(res, [('a', ), ('b', )]) + self.assertListEqual(res, [('a',), ('b',)]) @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') def test_logical_catchup(self): @@ -619,7 +611,7 @@ def test_dump(self): # restore dump node3.restore(filename=dump) res = node3.execute(query_select) - self.assertListEqual(res, [(1, ), (2, )]) + self.assertListEqual(res, [(1,), (2,)]) def test_users(self): with get_remote_node().init().start() as node: @@ -651,7 +643,7 @@ def test_poll_query_until(self): # check None, ok node.poll_query_until(query='create table def()', - expected=None) # returns nothing + expected=None) # returns nothing # check 0 rows equivalent to expected=None node.poll_query_until( @@ -697,7 +689,7 @@ def test_logging(self): 'handlers': { 'file': { 'class': 'logging.FileHandler', - 'filename': logfile, + 'filename': logfile.name, 'formatter': 'base_format', 'level': logging.DEBUG, }, @@ -708,7 +700,7 @@ def test_logging(self): }, }, 'root': { - 'handlers': ('file', ), + 'handlers': ('file',), 'level': 'DEBUG', }, } @@ -730,7 +722,7 @@ def test_logging(self): time.sleep(0.1) # check that master's port is found - with open(logfile, 'r') as log: + with open(logfile.name, 'r') as log: lines = log.readlines() self.assertTrue(any(node_name in s for s in lines)) @@ -743,7 +735,6 @@ def test_logging(self): @unittest.skipUnless(util_exists('pgbench'), 'might be missing') def test_pgbench(self): with get_remote_node().init().start() as node: - # initialize pgbench DB and run benchmarks node.pgbench_init(scale=2, foreign_keys=True, options=['-q']).pgbench_run(time=2) @@ -764,7 +755,6 @@ def test_pg_config(self): c1 = get_pg_config() # modify setting for this scope with scoped_config(cache_pg_config=False) as config: - # sanity check for value self.assertFalse(config.cache_pg_config) @@ -796,7 +786,6 @@ def test_config_stack(self): self.assertEqual(c1.cached_initdb_dir, d1) with scoped_config(cached_initdb_dir=d2) as c2: - stack_size = len(testgres.config.config_stack) # try to break a stack @@ -830,7 +819,6 @@ def test_unix_sockets(self): def test_auto_name(self): with get_remote_node().init(allow_streaming=True).start() as m: with m.replicate().start() as r: - # check that nodes are running self.assertTrue(m.status()) self.assertTrue(r.status()) From 0796bc4b334745e59857d2d0a8f8de2d107023a3 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Wed, 26 Jul 2023 09:32:30 +0200 Subject: [PATCH 21/23] PBCKP-152 - test_restore_target_time cut --- testgres/node.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index a146b08d..244f3c1f 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -659,7 +659,7 @@ def get_control_data(self): return out_dict - def slow_start(self, replica=False, dbname='template1', username=default_username()): + def slow_start(self, replica=False, dbname='template1', username=default_username(), max_attempts=0): """ Starts the PostgreSQL instance and then polls the instance until it reaches the expected state (primary or replica). The state is checked @@ -670,6 +670,7 @@ def slow_start(self, replica=False, dbname='template1', username=default_usernam username: replica: If True, waits for the instance to be in recovery (i.e., replica mode). If False, waits for the instance to be in primary mode. Default is False. + max_attempts: """ self.start() @@ -684,7 +685,8 @@ def slow_start(self, replica=False, dbname='template1', username=default_usernam suppress={InternalError, QueryException, ProgrammingError, - OperationalError}) + OperationalError}, + max_attempts=max_attempts) def start(self, params=[], wait=True): """ @@ -719,7 +721,6 @@ def start(self, params=[], wait=True): msg = 'Cannot start node' files = self._collect_special_files() raise_from(StartNodeException(msg, files), e) - self._maybe_start_logger() self.is_started = True return self @@ -1139,9 +1140,9 @@ def poll_query_until(self, # sanity checks assert max_attempts >= 0 assert sleep_time > 0 - attempts = 0 while max_attempts == 0 or attempts < max_attempts: + print(f"Pooling {attempts}") try: res = self.execute(dbname=dbname, query=query, From 0f14034bdef296144016a53bae1e601bc243cc08 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Fri, 28 Jul 2023 09:38:08 +0200 Subject: [PATCH 22/23] PBCKP-152 - node set listen address --- testgres/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testgres/node.py b/testgres/node.py index 244f3c1f..fb259b89 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -530,7 +530,7 @@ def get_auth_method(t): self.append_conf(fsync=fsync, max_worker_processes=MAX_WORKER_PROCESSES, log_statement=log_statement, - listen_addresses='*', + listen_addresses=self.host, port=self.port) # yapf:disable # common replication settings From 12aa7bab9df4a6a2c1e4dca4df5d7f6f17bdba41 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Wed, 2 Aug 2023 00:50:33 +0200 Subject: [PATCH 23/23] Add info about remote mode in README.md --- README.md | 27 +++++++++ testgres/__init__.py | 3 +- testgres/api.py | 12 +++- testgres/node.py | 6 +- testgres/operations/remote_ops.py | 7 ++- testgres/utils.py | 13 ++++- tests/README.md | 29 ++++++++++ tests/test_simple_remote.py | 92 +++++++++++++++---------------- 8 files changed, 130 insertions(+), 59 deletions(-) diff --git a/README.md b/README.md index 6b26ba96..29b974dc 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,33 @@ with testgres.get_new_node().init() as master: Note that `default_conf()` is called by `init()` function; both of them overwrite the configuration file, which means that they should be called before `append_conf()`. +### Remote mode +Testgres supports the creation of PostgreSQL nodes on a remote host. This is useful when you want to run distributed tests involving multiple nodes spread across different machines. + +To use this feature, you need to use the RemoteOperations class. +Here is an example of how you might set this up: + +```python +from testgres import ConnectionParams, RemoteOperations, TestgresConfig, get_remote_node + +# Set up connection params +conn_params = ConnectionParams( + host='your_host', # replace with your host + username='user_name', # replace with your username + ssh_key='path_to_ssh_key' # replace with your SSH key path +) +os_ops = RemoteOperations(conn_params) + +# Add remote testgres config before test +TestgresConfig.set_os_ops(os_ops=os_ops) + +# Proceed with your test +def test_basic_query(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start() + res = node.execute('SELECT 1') + self.assertEqual(res, [(1,)]) +``` ## Authors diff --git a/testgres/__init__.py b/testgres/__init__.py index ce2636b4..b63c7df1 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -1,4 +1,4 @@ -from .api import get_new_node +from .api import get_new_node, get_remote_node from .backup import NodeBackup from .config import \ @@ -52,6 +52,7 @@ __all__ = [ "get_new_node", + "get_remote_node", "NodeBackup", "TestgresConfig", "configure_testgres", "scoped_config", "push_config", "pop_config", "NodeConnection", "DatabaseError", "InternalError", "ProgrammingError", "OperationalError", diff --git a/testgres/api.py b/testgres/api.py index 8f553529..e4b1cdd5 100644 --- a/testgres/api.py +++ b/testgres/api.py @@ -37,10 +37,18 @@ def get_new_node(name=None, base_dir=None, **kwargs): """ Simply a wrapper around :class:`.PostgresNode` constructor. See :meth:`.PostgresNode.__init__` for details. + """ + # NOTE: leave explicit 'name' and 'base_dir' for compatibility + return PostgresNode(name=name, base_dir=base_dir, **kwargs) + + +def get_remote_node(name=None, conn_params=None): + """ + Simply a wrapper around :class:`.PostgresNode` constructor for remote node. + See :meth:`.PostgresNode.__init__` for details. For remote connection you can add the next parameter: conn_params = ConnectionParams(host='127.0.0.1', ssh_key=None, username=default_username()) """ - # NOTE: leave explicit 'name' and 'base_dir' for compatibility - return PostgresNode(name=name, base_dir=base_dir, **kwargs) + return get_new_node(name=name, conn_params=conn_params) diff --git a/testgres/node.py b/testgres/node.py index fb259b89..6483514b 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -146,8 +146,9 @@ def __init__(self, name=None, port=None, base_dir=None, conn_params: ConnectionP # basic self.name = name or generate_app_name() - - if conn_params.ssh_key: + if testgres_config.os_ops: + self.os_ops = testgres_config.os_ops + elif conn_params.ssh_key: self.os_ops = RemoteOperations(conn_params) else: self.os_ops = LocalOperations(conn_params) @@ -157,7 +158,6 @@ def __init__(self, name=None, port=None, base_dir=None, conn_params: ConnectionP self.host = self.os_ops.host self.ssh_key = self.os_ops.ssh_key - testgres_config.os_ops = self.os_ops # defaults for __exit__() self.cleanup_on_good_exit = testgres_config.node_cleanup_on_good_exit self.cleanup_on_bad_exit = testgres_config.node_cleanup_on_bad_exit diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index bdeb423a..6815c7f1 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -17,7 +17,7 @@ sshtunnel.TUNNEL_TIMEOUT = 5.0 -error_markers = [b'error', b'Permission denied', b'fatal'] +error_markers = [b'error', b'Permission denied', b'fatal', b'No such file or directory'] class PsUtilProcessProxy: @@ -203,7 +203,10 @@ def makedirs(self, path, remove_existing=False): cmd = "rm -rf {} && mkdir -p {}".format(path, path) else: cmd = "mkdir -p {}".format(path) - exit_status, result, error = self.exec_command(cmd, verbose=True) + try: + exit_status, result, error = self.exec_command(cmd, verbose=True) + except ExecUtilException as e: + raise Exception("Couldn't create dir {} because of error {}".format(path, e.message)) if exit_status != 0: raise Exception("Couldn't create dir {} because of error {}".format(path, error)) return result diff --git a/testgres/utils.py b/testgres/utils.py index 1772d748..58e18deb 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -87,9 +87,12 @@ def get_bin_path(filename): # check if it's already absolute if os.path.isabs(filename): return filename + if tconf.os_ops.remote: + pg_config = os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG") + else: + # try PG_CONFIG - get from local machine + pg_config = os.environ.get("PG_CONFIG") - # try PG_CONFIG - get from local machine - pg_config = os.environ.get("PG_CONFIG") if pg_config: bindir = get_pg_config()["BINDIR"] return os.path.join(bindir, filename) @@ -139,7 +142,11 @@ def cache_pg_config_data(cmd): return _pg_config_data # try specified pg_config path or PG_CONFIG - pg_config = pg_config_path or os.environ.get("PG_CONFIG") + if tconf.os_ops.remote: + pg_config = pg_config_path or os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG") + else: + # try PG_CONFIG - get from local machine + pg_config = pg_config_path or os.environ.get("PG_CONFIG") if pg_config: return cache_pg_config_data(pg_config) diff --git a/tests/README.md b/tests/README.md index a6d50992..d89efc7e 100644 --- a/tests/README.md +++ b/tests/README.md @@ -27,3 +27,32 @@ export PYTHON_VERSION=3 # or 2 # Run tests ./run_tests.sh ``` + + +#### Remote host tests + +1. Start remote host or docker container +2. Make sure that you run ssh +```commandline +sudo apt-get install openssh-server +sudo systemctl start sshd +``` +3. You need to connect to the remote host at least once to add it to the known hosts file +4. Generate ssh keys +5. Set up params for tests + + +```commandline +conn_params = ConnectionParams( + host='remote_host', + username='username', + ssh_key=/path/to/your/ssh/key' +) +os_ops = RemoteOperations(conn_params) +``` +If you have different path to `PG_CONFIG` on your local and remote host you can set up `PG_CONFIG_REMOTE`, this value will be +using during work with remote host. + +`test_remote` - Tests for RemoteOperations class. + +`test_simple_remote` - Tests that create node and check it. The same as `test_simple`, but for remote node. \ No newline at end of file diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 5028bc75..e8386383 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -35,7 +35,7 @@ NodeStatus, \ ProcessType, \ IsolationLevel, \ - get_new_node, \ + get_remote_node, \ RemoteOperations from testgres import \ @@ -94,23 +94,19 @@ def removing(f): os_ops.rmdirs(f, ignore_errors=True) -def get_remote_node(name=None): - return get_new_node(name=name, conn_params=conn_params) - - class TestgresRemoteTests(unittest.TestCase): def test_node_repr(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: pattern = r"PostgresNode\(name='.+', port=.+, base_dir='.+'\)" self.assertIsNotNone(re.match(pattern, str(node))) def test_custom_init(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: # enable page checksums node.init(initdb_params=['-k']).start() - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init( allow_streaming=True, initdb_params=['--auth-local=reject', '--auth-host=reject']) @@ -125,13 +121,13 @@ def test_custom_init(self): self.assertFalse(any('trust' in s for s in lines)) def test_double_init(self): - with get_remote_node().init() as node: + with get_remote_node(conn_params=conn_params).init() as node: # can't initialize node more than once with self.assertRaises(InitNodeException): node.init() def test_init_after_cleanup(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init().start().execute('select 1') node.cleanup() node.init().start().execute('select 1') @@ -144,7 +140,7 @@ def test_init_unique_system_id(self): query = 'select system_identifier from pg_control_system()' with scoped_config(cache_initdb=False): - with get_remote_node().init().start() as node0: + with get_remote_node(conn_params=conn_params).init().start() as node0: id0 = node0.execute(query)[0] with scoped_config(cache_initdb=True, @@ -153,8 +149,8 @@ def test_init_unique_system_id(self): self.assertTrue(config.cached_initdb_unique) # spawn two nodes; ids must be different - with get_remote_node().init().start() as node1, \ - get_remote_node().init().start() as node2: + with get_remote_node(conn_params=conn_params).init().start() as node1, \ + get_remote_node(conn_params=conn_params).init().start() as node2: id1 = node1.execute(query)[0] id2 = node2.execute(query)[0] @@ -164,7 +160,7 @@ def test_init_unique_system_id(self): def test_node_exit(self): with self.assertRaises(QueryException): - with get_remote_node().init() as node: + with get_remote_node(conn_params=conn_params).init() as node: base_dir = node.base_dir node.safe_psql('select 1') @@ -172,26 +168,26 @@ def test_node_exit(self): self.assertTrue(os_ops.path_exists(base_dir)) os_ops.rmdirs(base_dir, ignore_errors=True) - with get_remote_node().init() as node: + with get_remote_node(conn_params=conn_params).init() as node: base_dir = node.base_dir # should have been removed by default self.assertFalse(os_ops.path_exists(base_dir)) def test_double_start(self): - with get_remote_node().init().start() as node: + with get_remote_node(conn_params=conn_params).init().start() as node: # can't start node more than once node.start() self.assertTrue(node.is_started) def test_uninitialized_start(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: # node is not initialized yet with self.assertRaises(StartNodeException): node.start() def test_restart(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init().start() # restart, ok @@ -207,7 +203,7 @@ def test_restart(self): node.restart() def test_reload(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init().start() # change client_min_messages and save old value @@ -223,7 +219,7 @@ def test_reload(self): self.assertNotEqual(cmm_old, cmm_new) def test_pg_ctl(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init().start() status = node.pg_ctl(['status']) @@ -235,7 +231,7 @@ def test_status(self): self.assertFalse(NodeStatus.Uninitialized) # check statuses after each operation - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: self.assertEqual(node.pid, 0) self.assertEqual(node.status(), NodeStatus.Uninitialized) @@ -260,7 +256,7 @@ def test_status(self): self.assertEqual(node.status(), NodeStatus.Uninitialized) def test_psql(self): - with get_remote_node().init().start() as node: + with get_remote_node(conn_params=conn_params).init().start() as node: # check returned values (1 arg) res = node.psql('select 1') self.assertEqual(res, (0, b'1\n', b'')) @@ -303,7 +299,7 @@ def test_psql(self): node.safe_psql('select 1') def test_transactions(self): - with get_remote_node().init().start() as node: + with get_remote_node(conn_params=conn_params).init().start() as node: with node.connect() as con: con.begin() con.execute('create table test(val int)') @@ -326,7 +322,7 @@ def test_transactions(self): con.commit() def test_control_data(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: # node is not initialized yet with self.assertRaises(ExecUtilException): node.get_control_data() @@ -339,7 +335,7 @@ def test_control_data(self): self.assertTrue(any('pg_control' in s for s in data.keys())) def test_backup_simple(self): - with get_remote_node() as master: + with get_remote_node(conn_params=conn_params) as master: # enable streaming for backups master.init(allow_streaming=True) @@ -359,7 +355,7 @@ def test_backup_simple(self): self.assertListEqual(res, [(1,), (2,), (3,), (4,)]) def test_backup_multiple(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(allow_streaming=True).start() with node.backup(xlog_method='fetch') as backup1, \ @@ -372,7 +368,7 @@ def test_backup_multiple(self): self.assertNotEqual(node1.base_dir, node2.base_dir) def test_backup_exhaust(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(allow_streaming=True).start() with node.backup(xlog_method='fetch') as backup: @@ -385,7 +381,7 @@ def test_backup_exhaust(self): backup.spawn_primary() def test_backup_wrong_xlog_method(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(allow_streaming=True).start() with self.assertRaises(BackupException, @@ -393,7 +389,7 @@ def test_backup_wrong_xlog_method(self): node.backup(xlog_method='wrong') def test_pg_ctl_wait_option(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init().start(wait=False) while True: try: @@ -405,7 +401,7 @@ def test_pg_ctl_wait_option(self): pass def test_replicate(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(allow_streaming=True).start() with node.replicate().start() as replica: @@ -421,7 +417,7 @@ def test_replicate(self): @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') def test_synchronous_replication(self): - with get_remote_node() as master: + with get_remote_node(conn_params=conn_params) as master: old_version = not pg_version_ge('9.6') master.init(allow_streaming=True).start() @@ -462,7 +458,7 @@ def test_synchronous_replication(self): @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') def test_logical_replication(self): - with get_remote_node() as node1, get_remote_node() as node2: + with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2: node1.init(allow_logical=True) node1.start() node2.init().start() @@ -532,7 +528,7 @@ def test_logical_replication(self): @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') def test_logical_catchup(self): """ Runs catchup for 100 times to be sure that it is consistent """ - with get_remote_node() as node1, get_remote_node() as node2: + with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2: node1.init(allow_logical=True) node1.start() node2.init().start() @@ -557,12 +553,12 @@ def test_logical_catchup(self): @unittest.skipIf(pg_version_ge('10'), 'requires <10') def test_logical_replication_fail(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: with self.assertRaises(InitNodeException): node.init(allow_logical=True) def test_replication_slots(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(allow_streaming=True).start() with node.replicate(slot='slot1').start() as replica: @@ -573,7 +569,7 @@ def test_replication_slots(self): node.replicate(slot='slot1') def test_incorrect_catchup(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(allow_streaming=True).start() # node has no master, can't catch up @@ -581,7 +577,7 @@ def test_incorrect_catchup(self): node.catchup() def test_promotion(self): - with get_remote_node() as master: + with get_remote_node(conn_params=conn_params) as master: master.init().start() master.safe_psql('create table abc(id serial)') @@ -598,12 +594,12 @@ def test_dump(self): query_create = 'create table test as select generate_series(1, 2) as val' query_select = 'select * from test order by val asc' - with get_remote_node().init().start() as node1: + with get_remote_node(conn_params=conn_params).init().start() as node1: node1.execute(query_create) for format in ['plain', 'custom', 'directory', 'tar']: with removing(node1.dump(format=format)) as dump: - with get_remote_node().init().start() as node3: + with get_remote_node(conn_params=conn_params).init().start() as node3: if format == 'directory': self.assertTrue(os_ops.isdir(dump)) else: @@ -614,13 +610,13 @@ def test_dump(self): self.assertListEqual(res, [(1,), (2,)]) def test_users(self): - with get_remote_node().init().start() as node: + with get_remote_node(conn_params=conn_params).init().start() as node: node.psql('create role test_user login') value = node.safe_psql('select 1', username='test_user') self.assertEqual(b'1\n', value) def test_poll_query_until(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init().start() get_time = 'select extract(epoch from now())' @@ -734,7 +730,7 @@ def test_logging(self): @unittest.skipUnless(util_exists('pgbench'), 'might be missing') def test_pgbench(self): - with get_remote_node().init().start() as node: + with get_remote_node(conn_params=conn_params).init().start() as node: # initialize pgbench DB and run benchmarks node.pgbench_init(scale=2, foreign_keys=True, options=['-q']).pgbench_run(time=2) @@ -801,7 +797,7 @@ def test_config_stack(self): self.assertEqual(TestgresConfig.cached_initdb_dir, d0) def test_unix_sockets(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(unix_sockets=False, allow_streaming=True) node.start() @@ -817,7 +813,7 @@ def test_unix_sockets(self): self.assertEqual(res_psql, b'1\n') def test_auto_name(self): - with get_remote_node().init(allow_streaming=True).start() as m: + with get_remote_node(conn_params=conn_params).init(allow_streaming=True).start() as m: with m.replicate().start() as r: # check that nodes are running self.assertTrue(m.status()) @@ -854,7 +850,7 @@ def test_file_tail(self): self.assertEqual(lines[0], s3) def test_isolation_levels(self): - with get_remote_node().init().start() as node: + with get_remote_node(conn_params=conn_params).init().start() as node: with node.connect() as con: # string levels con.begin('Read Uncommitted').commit() @@ -876,7 +872,7 @@ def test_ports_management(self): # check that no ports have been bound yet self.assertEqual(len(bound_ports), 0) - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: # check that we've just bound a port self.assertEqual(len(bound_ports), 1) @@ -909,7 +905,7 @@ def test_version_management(self): self.assertTrue(d > f) version = get_pg_version() - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: self.assertTrue(isinstance(version, six.string_types)) self.assertTrue(isinstance(node.version, PgVer)) self.assertEqual(node.version, PgVer(version)) @@ -932,7 +928,7 @@ def test_child_pids(self): ProcessType.WalReceiver, ] - with get_remote_node().init().start() as master: + with get_remote_node(conn_params=conn_params).init().start() as master: # master node doesn't have a source walsender! with self.assertRaises(TestgresException): pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy