diff --git a/README.md b/README.md index 6b26ba96..29b974dc 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,33 @@ with testgres.get_new_node().init() as master: Note that `default_conf()` is called by `init()` function; both of them overwrite the configuration file, which means that they should be called before `append_conf()`. +### Remote mode +Testgres supports the creation of PostgreSQL nodes on a remote host. This is useful when you want to run distributed tests involving multiple nodes spread across different machines. + +To use this feature, you need to use the RemoteOperations class. +Here is an example of how you might set this up: + +```python +from testgres import ConnectionParams, RemoteOperations, TestgresConfig, get_remote_node + +# Set up connection params +conn_params = ConnectionParams( + host='your_host', # replace with your host + username='user_name', # replace with your username + ssh_key='path_to_ssh_key' # replace with your SSH key path +) +os_ops = RemoteOperations(conn_params) + +# Add remote testgres config before test +TestgresConfig.set_os_ops(os_ops=os_ops) + +# Proceed with your test +def test_basic_query(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start() + res = node.execute('SELECT 1') + self.assertEqual(res, [(1,)]) +``` ## Authors diff --git a/setup.py b/setup.py index 6d0c2256..8cb0f70a 100755 --- a/setup.py +++ b/setup.py @@ -12,6 +12,9 @@ "six>=1.9.0", "psutil", "packaging", + "paramiko", + "fabric", + "sshtunnel" ] # Add compatibility enum class @@ -27,9 +30,9 @@ readme = f.read() setup( - version='1.8.9', + version='1.9.0', name='testgres', - packages=['testgres'], + packages=['testgres', 'testgres.operations'], description='Testing utility for PostgreSQL and its extensions', url='https://github.com/postgrespro/testgres', long_description=readme, diff --git a/testgres/__init__.py b/testgres/__init__.py index 1b33ba3b..b63c7df1 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -1,4 +1,4 @@ -from .api import get_new_node +from .api import get_new_node, get_remote_node from .backup import NodeBackup from .config import \ @@ -46,8 +46,13 @@ First, \ Any +from .operations.os_ops import OsOperations, ConnectionParams +from .operations.local_ops import LocalOperations +from .operations.remote_ops import RemoteOperations + __all__ = [ "get_new_node", + "get_remote_node", "NodeBackup", "TestgresConfig", "configure_testgres", "scoped_config", "push_config", "pop_config", "NodeConnection", "DatabaseError", "InternalError", "ProgrammingError", "OperationalError", @@ -56,4 +61,5 @@ "PostgresNode", "NodeApp", "reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version", "First", "Any", + "OsOperations", "LocalOperations", "RemoteOperations", "ConnectionParams" ] diff --git a/testgres/api.py b/testgres/api.py index e90cf7bd..e4b1cdd5 100644 --- a/testgres/api.py +++ b/testgres/api.py @@ -40,3 +40,15 @@ def get_new_node(name=None, base_dir=None, **kwargs): """ # NOTE: leave explicit 'name' and 'base_dir' for compatibility return PostgresNode(name=name, base_dir=base_dir, **kwargs) + + +def get_remote_node(name=None, conn_params=None): + """ + Simply a wrapper around :class:`.PostgresNode` constructor for remote node. + See :meth:`.PostgresNode.__init__` for details. + For remote connection you can add the next parameter: + conn_params = ConnectionParams(host='127.0.0.1', + ssh_key=None, + username=default_username()) + """ + return get_new_node(name=name, conn_params=conn_params) diff --git a/testgres/backup.py b/testgres/backup.py index a725a1df..a89e214d 100644 --- a/testgres/backup.py +++ b/testgres/backup.py @@ -2,9 +2,7 @@ import os -from shutil import rmtree, copytree from six import raise_from -from tempfile import mkdtemp from .enums import XLogMethod @@ -15,8 +13,6 @@ PG_CONF_FILE, \ BACKUP_LOG_FILE -from .defaults import default_username - from .exceptions import BackupException from .utils import \ @@ -47,7 +43,7 @@ def __init__(self, username: database user name. xlog_method: none | fetch | stream (see docs) """ - + self.os_ops = node.os_ops if not node.status(): raise BackupException('Node must be running') @@ -60,8 +56,8 @@ def __init__(self, raise BackupException(msg) # Set default arguments - username = username or default_username() - base_dir = base_dir or mkdtemp(prefix=TMP_BACKUP) + username = username or self.os_ops.get_user() + base_dir = base_dir or self.os_ops.mkdtemp(prefix=TMP_BACKUP) # public self.original_node = node @@ -107,14 +103,14 @@ def _prepare_dir(self, destroy): available = not destroy if available: - dest_base_dir = mkdtemp(prefix=TMP_NODE) + dest_base_dir = self.os_ops.mkdtemp(prefix=TMP_NODE) data1 = os.path.join(self.base_dir, DATA_DIR) data2 = os.path.join(dest_base_dir, DATA_DIR) try: # Copy backup to new data dir - copytree(data1, data2) + self.os_ops.copytree(data1, data2) except Exception as e: raise_from(BackupException('Failed to copy files'), e) else: @@ -143,7 +139,7 @@ def spawn_primary(self, name=None, destroy=True): # Build a new PostgresNode NodeClass = self.original_node.__class__ - with clean_on_error(NodeClass(name=name, base_dir=base_dir)) as node: + with clean_on_error(NodeClass(name=name, base_dir=base_dir, conn_params=self.original_node.os_ops.conn_params)) as node: # New nodes should always remove dir tree node._should_rm_dirs = True @@ -185,4 +181,4 @@ def cleanup(self): if self._available: self._available = False - rmtree(self.base_dir, ignore_errors=True) + self.os_ops.rmdirs(self.base_dir, ignore_errors=True) diff --git a/testgres/cache.py b/testgres/cache.py index c3cd9971..bf8658c9 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -1,9 +1,7 @@ # coding: utf-8 -import io import os -from shutil import copytree from six import raise_from from .config import testgres_config @@ -20,12 +18,16 @@ get_bin_path, \ execute_utility +from .operations.local_ops import LocalOperations +from .operations.os_ops import OsOperations -def cached_initdb(data_dir, logfile=None, params=None): + +def cached_initdb(data_dir, logfile=None, params=None, os_ops: OsOperations = LocalOperations()): """ Perform initdb or use cached node files. """ - def call_initdb(initdb_dir, log=None): + + def call_initdb(initdb_dir, log=logfile): try: _params = [get_bin_path("initdb"), "-D", initdb_dir, "-N"] execute_utility(_params + (params or []), log) @@ -39,13 +41,14 @@ def call_initdb(initdb_dir, log=None): cached_data_dir = testgres_config.cached_initdb_dir # Initialize cached initdb - if not os.path.exists(cached_data_dir) or \ - not os.listdir(cached_data_dir): + + if not os_ops.path_exists(cached_data_dir) or \ + not os_ops.listdir(cached_data_dir): call_initdb(cached_data_dir) try: # Copy cached initdb to current data dir - copytree(cached_data_dir, data_dir) + os_ops.copytree(cached_data_dir, data_dir) # Assign this node a unique system id if asked to if testgres_config.cached_initdb_unique: @@ -53,8 +56,8 @@ def call_initdb(initdb_dir, log=None): # Some users might rely upon unique system ids, but # our initdb caching mechanism breaks this contract. pg_control = os.path.join(data_dir, XLOG_CONTROL_FILE) - with io.open(pg_control, "r+b") as f: - f.write(generate_system_id()) # overwrite id + system_id = generate_system_id() + os_ops.write(pg_control, system_id, truncate=True, binary=True, read_and_write=True) # XXX: build new WAL segment with our system id _params = [get_bin_path("pg_resetwal"), "-D", data_dir, "-f"] diff --git a/testgres/config.py b/testgres/config.py index cfcdadc2..b6c43926 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -5,10 +5,10 @@ import tempfile from contextlib import contextmanager -from shutil import rmtree -from tempfile import mkdtemp from .consts import TMP_CACHE +from .operations.os_ops import OsOperations +from .operations.local_ops import LocalOperations class GlobalConfig(object): @@ -43,6 +43,9 @@ class GlobalConfig(object): _cached_initdb_dir = None """ underlying class attribute for cached_initdb_dir property """ + + os_ops = LocalOperations() + """ OsOperation object that allows work on remote host """ @property def cached_initdb_dir(self): """ path to a temp directory for cached initdb. """ @@ -54,6 +57,7 @@ def cached_initdb_dir(self, value): if value: cached_initdb_dirs.add(value) + return testgres_config.cached_initdb_dir @property def temp_dir(self): @@ -118,6 +122,11 @@ def copy(self): return copy.copy(self) + @staticmethod + def set_os_ops(os_ops: OsOperations): + testgres_config.os_ops = os_ops + testgres_config.cached_initdb_dir = os_ops.mkdtemp(prefix=TMP_CACHE) + # cached dirs to be removed cached_initdb_dirs = set() @@ -135,7 +144,7 @@ def copy(self): @atexit.register def _rm_cached_initdb_dirs(): for d in cached_initdb_dirs: - rmtree(d, ignore_errors=True) + testgres_config.os_ops.rmdirs(d, ignore_errors=True) def push_config(**options): @@ -198,4 +207,4 @@ def configure_testgres(**options): # NOTE: assign initial cached dir for initdb -testgres_config.cached_initdb_dir = mkdtemp(prefix=TMP_CACHE) +testgres_config.cached_initdb_dir = testgres_config.os_ops.mkdtemp(prefix=TMP_CACHE) diff --git a/testgres/connection.py b/testgres/connection.py index ee2a2128..aeb040ce 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -41,11 +41,11 @@ def __init__(self, self._node = node - self._connection = pglib.connect(database=dbname, - user=username, - password=password, - host=node.host, - port=node.port) + self._connection = node.os_ops.db_connect(dbname=dbname, + user=username, + password=password, + host=node.host, + port=node.port) self._connection.autocommit = autocommit self._cursor = self.connection.cursor() @@ -103,16 +103,15 @@ def rollback(self): def execute(self, query, *args): self.cursor.execute(query, args) - try: res = self.cursor.fetchall() - # pg8000 might return tuples if isinstance(res, tuple): res = [tuple(t) for t in res] return res - except Exception: + except Exception as e: + print("Error executing query: {}".format(e)) return None def close(self): diff --git a/testgres/defaults.py b/testgres/defaults.py index 8d5b892e..d77361d7 100644 --- a/testgres/defaults.py +++ b/testgres/defaults.py @@ -1,9 +1,9 @@ import datetime -import getpass -import os import struct import uuid +from .config import testgres_config as tconf + def default_dbname(): """ @@ -17,8 +17,7 @@ def default_username(): """ Return default username (current user). """ - - return getpass.getuser() + return tconf.os_ops.get_user() def generate_app_name(): @@ -44,7 +43,7 @@ def generate_system_id(): system_id = 0 system_id |= (secs << 32) system_id |= (usecs << 12) - system_id |= (os.getpid() & 0xFFF) + system_id |= (tconf.os_ops.get_pid() & 0xFFF) # pack ULL in native byte order return struct.pack('=Q', system_id) diff --git a/testgres/node.py b/testgres/node.py index 659a62f8..6483514b 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -1,18 +1,14 @@ # coding: utf-8 -import io import os import random -import shutil import signal +import subprocess import threading from queue import Queue -import psutil -import subprocess import time - try: from collections.abc import Iterable except ImportError: @@ -27,9 +23,7 @@ except ImportError: raise ImportError("You must have psycopg2 or pg8000 modules installed") -from shutil import rmtree from six import raise_from, iteritems, text_type -from tempfile import mkstemp, mkdtemp from .enums import \ NodeStatus, \ @@ -93,7 +87,6 @@ eprint, \ get_bin_path, \ get_pg_version, \ - file_tail, \ reserve_port, \ release_port, \ execute_utility, \ @@ -102,6 +95,10 @@ from .backup import NodeBackup +from .operations.os_ops import ConnectionParams +from .operations.local_ops import LocalOperations +from .operations.remote_ops import RemoteOperations + InternalError = pglib.InternalError ProgrammingError = pglib.ProgrammingError OperationalError = pglib.OperationalError @@ -130,7 +127,7 @@ def __repr__(self): class PostgresNode(object): - def __init__(self, name=None, port=None, base_dir=None): + def __init__(self, name=None, port=None, base_dir=None, conn_params: ConnectionParams = ConnectionParams()): """ PostgresNode constructor. @@ -148,10 +145,19 @@ def __init__(self, name=None, port=None, base_dir=None): self._master = None # basic - self.host = '127.0.0.1' self.name = name or generate_app_name() + if testgres_config.os_ops: + self.os_ops = testgres_config.os_ops + elif conn_params.ssh_key: + self.os_ops = RemoteOperations(conn_params) + else: + self.os_ops = LocalOperations(conn_params) + self.port = port or reserve_port() + self.host = self.os_ops.host + self.ssh_key = self.os_ops.ssh_key + # defaults for __exit__() self.cleanup_on_good_exit = testgres_config.node_cleanup_on_good_exit self.cleanup_on_bad_exit = testgres_config.node_cleanup_on_bad_exit @@ -195,8 +201,9 @@ def pid(self): if self.status(): pid_file = os.path.join(self.data_dir, PG_PID_FILE) - with io.open(pid_file) as f: - return int(f.readline()) + lines = self.os_ops.readlines(pid_file) + pid = int(lines[0]) if lines else None + return pid # for clarity return 0 @@ -236,7 +243,7 @@ def child_processes(self): """ # get a list of postmaster's children - children = psutil.Process(self.pid).children() + children = self.os_ops.get_process_children(self.pid) return [ProcessProxy(p) for p in children] @@ -274,11 +281,11 @@ def master(self): @property def base_dir(self): if not self._base_dir: - self._base_dir = mkdtemp(prefix=TMP_NODE) + self._base_dir = self.os_ops.mkdtemp(prefix=TMP_NODE) # NOTE: it's safe to create a new dir - if not os.path.exists(self._base_dir): - os.makedirs(self._base_dir) + if not self.os_ops.path_exists(self._base_dir): + self.os_ops.makedirs(self._base_dir) return self._base_dir @@ -287,8 +294,8 @@ def logs_dir(self): path = os.path.join(self.base_dir, LOGS_DIR) # NOTE: it's safe to create a new dir - if not os.path.exists(path): - os.makedirs(path) + if not self.os_ops.path_exists(path): + self.os_ops.makedirs(path) return path @@ -365,9 +372,7 @@ def _create_recovery_conf(self, username, slot=None): # Since 12 recovery.conf had disappeared if self.version >= PgVer('12'): signal_name = os.path.join(self.data_dir, "standby.signal") - # cross-python touch(). It is vulnerable to races, but who cares? - with open(signal_name, 'a'): - os.utime(signal_name, None) + self.os_ops.touch(signal_name) else: line += "standby_mode=on\n" @@ -425,19 +430,14 @@ def _collect_special_files(self): for f, num_lines in files: # skip missing files - if not os.path.exists(f): + if not self.os_ops.path_exists(f): continue - with io.open(f, "rb") as _f: - if num_lines > 0: - # take last N lines of file - lines = b''.join(file_tail(_f, num_lines)).decode('utf-8') - else: - # read whole file - lines = _f.read().decode('utf-8') + file_lines = self.os_ops.readlines(f, num_lines, binary=True, encoding=None) + lines = b''.join(file_lines) - # fill list - result.append((f, lines)) + # fill list + result.append((f, lines)) return result @@ -456,9 +456,11 @@ def init(self, initdb_params=None, **kwargs): """ # initialize this PostgreSQL node - cached_initdb(data_dir=self.data_dir, - logfile=self.utils_log_file, - params=initdb_params) + cached_initdb( + data_dir=self.data_dir, + logfile=self.utils_log_file, + os_ops=self.os_ops, + params=initdb_params) # initialize default config files self.default_conf(**kwargs) @@ -489,43 +491,41 @@ def default_conf(self, hba_conf = os.path.join(self.data_dir, HBA_CONF_FILE) # filter lines in hba file - with io.open(hba_conf, "r+") as conf: - # get rid of comments and blank lines - lines = [ - s for s in conf.readlines() - if len(s.strip()) > 0 and not s.startswith('#') - ] - - # write filtered lines - conf.seek(0) - conf.truncate() - conf.writelines(lines) - - # replication-related settings - if allow_streaming: - # get auth method for host or local users - def get_auth_method(t): - return next((s.split()[-1] - for s in lines if s.startswith(t)), 'trust') - - # get auth methods - auth_local = get_auth_method('local') - auth_host = get_auth_method('host') - - new_lines = [ - u"local\treplication\tall\t\t\t{}\n".format(auth_local), - u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), - u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host) - ] # yapf: disable - - # write missing lines - for line in new_lines: - if line not in lines: - conf.write(line) + # get rid of comments and blank lines + hba_conf_file = self.os_ops.readlines(hba_conf) + lines = [ + s for s in hba_conf_file + if len(s.strip()) > 0 and not s.startswith('#') + ] + + # write filtered lines + self.os_ops.write(hba_conf, lines, truncate=True) + + # replication-related settings + if allow_streaming: + # get auth method for host or local users + def get_auth_method(t): + return next((s.split()[-1] + for s in lines if s.startswith(t)), 'trust') + + # get auth methods + auth_local = get_auth_method('local') + auth_host = get_auth_method('host') + subnet_base = ".".join(self.os_ops.host.split('.')[:-1] + ['0']) + + new_lines = [ + u"local\treplication\tall\t\t\t{}\n".format(auth_local), + u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), + u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host), + u"host\treplication\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host), + u"host\tall\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host) + ] # yapf: disable + + # write missing lines + self.os_ops.write(hba_conf, new_lines) # overwrite config file - with io.open(postgres_conf, "w") as conf: - conf.truncate() + self.os_ops.write(postgres_conf, '', truncate=True) self.append_conf(fsync=fsync, max_worker_processes=MAX_WORKER_PROCESSES, @@ -595,15 +595,17 @@ def append_conf(self, line='', filename=PG_CONF_FILE, **kwargs): value = 'on' if value else 'off' elif not str(value).replace('.', '', 1).isdigit(): value = "'{}'".format(value) - - # format a new config line - lines.append('{} = {}'.format(option, value)) + if value == '*': + lines.append("{} = '*'".format(option)) + else: + # format a new config line + lines.append('{} = {}'.format(option, value)) config_name = os.path.join(self.data_dir, filename) - with io.open(config_name, 'a') as conf: - for line in lines: - conf.write(text_type(line)) - conf.write(text_type('\n')) + conf_text = '' + for line in lines: + conf_text += text_type(line) + '\n' + self.os_ops.write(config_name, conf_text) return self @@ -621,7 +623,11 @@ def status(self): "-D", self.data_dir, "status" ] # yapf: disable - execute_utility(_params, self.utils_log_file) + status_code, out, err = execute_utility(_params, self.utils_log_file, verbose=True) + if 'does not exist' in err: + return NodeStatus.Uninitialized + elif 'no server running' in out: + return NodeStatus.Stopped return NodeStatus.Running except ExecUtilException as e: @@ -653,7 +659,7 @@ def get_control_data(self): return out_dict - def slow_start(self, replica=False, dbname='template1', username=default_username()): + def slow_start(self, replica=False, dbname='template1', username=default_username(), max_attempts=0): """ Starts the PostgreSQL instance and then polls the instance until it reaches the expected state (primary or replica). The state is checked @@ -664,6 +670,7 @@ def slow_start(self, replica=False, dbname='template1', username=default_usernam username: replica: If True, waits for the instance to be in recovery (i.e., replica mode). If False, waits for the instance to be in primary mode. Default is False. + max_attempts: """ self.start() @@ -678,7 +685,8 @@ def slow_start(self, replica=False, dbname='template1', username=default_usernam suppress={InternalError, QueryException, ProgrammingError, - OperationalError}) + OperationalError}, + max_attempts=max_attempts) def start(self, params=[], wait=True): """ @@ -706,12 +714,13 @@ def start(self, params=[], wait=True): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file) - except ExecUtilException as e: + exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True) + if 'does not exist' in error: + raise Exception + except Exception as e: msg = 'Cannot start node' files = self._collect_special_files() raise_from(StartNodeException(msg, files), e) - self._maybe_start_logger() self.is_started = True return self @@ -779,7 +788,9 @@ def restart(self, params=[]): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file) + error_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True) + if 'could not start server' in error: + raise ExecUtilException except ExecUtilException as e: msg = 'Cannot restart node' files = self._collect_special_files() @@ -895,7 +906,7 @@ def cleanup(self, max_attempts=3): else: rm_dir = self.data_dir # just data, save logs - rmtree(rm_dir, ignore_errors=True) + self.os_ops.rmdirs(rm_dir, ignore_errors=True) return self @@ -948,7 +959,10 @@ def psql(self, # select query source if query: - psql_params.extend(("-c", query)) + if self.os_ops.remote: + psql_params.extend(("-c", '"{}"'.format(query))) + else: + psql_params.extend(("-c", query)) elif filename: psql_params.extend(("-f", filename)) else: @@ -956,16 +970,20 @@ def psql(self, # should be the last one psql_params.append(dbname) + if not self.os_ops.remote: + # start psql process + process = subprocess.Popen(psql_params, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + # wait until it finishes and get stdout and stderr + out, err = process.communicate(input=input) + return process.returncode, out, err + else: + status_code, out, err = self.os_ops.exec_command(psql_params, verbose=True, input=input) - # start psql process - process = subprocess.Popen(psql_params, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - - # wait until it finishes and get stdout and stderr - out, err = process.communicate(input=input) - return process.returncode, out, err + return status_code, out, err @method_decorator(positional_args_hack(['dbname', 'query'])) def safe_psql(self, query=None, expect_error=False, **kwargs): @@ -989,15 +1007,19 @@ def safe_psql(self, query=None, expect_error=False, **kwargs): # force this setting kwargs['ON_ERROR_STOP'] = 1 - - ret, out, err = self.psql(query=query, **kwargs) + try: + ret, out, err = self.psql(query=query, **kwargs) + except ExecUtilException as e: + ret = e.exit_code + out = e.out + err = e.message if ret: if expect_error: out = (err or b'').decode('utf-8') else: raise QueryException((err or b'').decode('utf-8'), query) elif expect_error: - assert False, f"Exception was expected, but query finished successfully: `{query}` " + assert False, "Exception was expected, but query finished successfully: `{}` ".format(query) return out @@ -1031,10 +1053,9 @@ def dump(self, # Generate tmpfile or tmpdir def tmpfile(): if format == DumpFormat.Directory: - fname = mkdtemp(prefix=TMP_DUMP) + fname = self.os_ops.mkdtemp(prefix=TMP_DUMP) else: - fd, fname = mkstemp(prefix=TMP_DUMP) - os.close(fd) + fname = self.os_ops.mkstemp(prefix=TMP_DUMP) return fname # Set default arguments @@ -1119,9 +1140,9 @@ def poll_query_until(self, # sanity checks assert max_attempts >= 0 assert sleep_time > 0 - attempts = 0 while max_attempts == 0 or attempts < max_attempts: + print(f"Pooling {attempts}") try: res = self.execute(dbname=dbname, query=query, @@ -1350,7 +1371,7 @@ def pgbench(self, # should be the last one _params.append(dbname) - proc = subprocess.Popen(_params, stdout=stdout, stderr=stderr) + proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, proc=True) return proc @@ -1523,18 +1544,16 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): # parse postgresql.auto.conf path = os.path.join(self.data_dir, config) - with open(path, 'r') as f: - raw_content = f.read() - + lines = self.os_ops.readlines(path) current_options = {} current_directives = [] - for line in raw_content.splitlines(): + for line in lines: # ignore comments if line.startswith('#'): continue - if line == '': + if line.strip() == '': continue if line.startswith('include'): @@ -1564,22 +1583,22 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): for directive in current_directives: auto_conf += directive + "\n" - with open(path, 'wt') as f: - f.write(auto_conf) + self.os_ops.write(path, auto_conf, truncate=True) class NodeApp: - def __init__(self, test_path, nodes_to_cleanup): + def __init__(self, test_path, nodes_to_cleanup, os_ops=LocalOperations()): self.test_path = test_path self.nodes_to_cleanup = nodes_to_cleanup + self.os_ops = os_ops def make_empty( self, base_dir=None): real_base_dir = os.path.join(self.test_path, base_dir) - shutil.rmtree(real_base_dir, ignore_errors=True) - os.makedirs(real_base_dir) + self.os_ops.rmdirs(real_base_dir, ignore_errors=True) + self.os_ops.makedirs(real_base_dir) node = PostgresNode(base_dir=real_base_dir) node.should_rm_dirs = True @@ -1602,27 +1621,24 @@ def make_simple( initdb_params=initdb_params, allow_streaming=set_replication) # set major version - with open(os.path.join(node.data_dir, 'PG_VERSION')) as f: - node.major_version_str = str(f.read().rstrip()) - node.major_version = float(node.major_version_str) - - # Sane default parameters - options = {} - options['max_connections'] = 100 - options['shared_buffers'] = '10MB' - options['fsync'] = 'off' - - options['wal_level'] = 'logical' - options['hot_standby'] = 'off' - - options['log_line_prefix'] = '%t [%p]: [%l-1] ' - options['log_statement'] = 'none' - options['log_duration'] = 'on' - options['log_min_duration_statement'] = 0 - options['log_connections'] = 'on' - options['log_disconnections'] = 'on' - options['restart_after_crash'] = 'off' - options['autovacuum'] = 'off' + pg_version_file = self.os_ops.read(os.path.join(node.data_dir, 'PG_VERSION')) + node.major_version_str = str(pg_version_file.rstrip()) + node.major_version = float(node.major_version_str) + + # Set default parameters + options = {'max_connections': 100, + 'shared_buffers': '10MB', + 'fsync': 'off', + 'wal_level': 'logical', + 'hot_standby': 'off', + 'log_line_prefix': '%t [%p]: [%l-1] ', + 'log_statement': 'none', + 'log_duration': 'on', + 'log_min_duration_statement': 0, + 'log_connections': 'on', + 'log_disconnections': 'on', + 'restart_after_crash': 'off', + 'autovacuum': 'off'} # Allow replication in pg_hba.conf if set_replication: diff --git a/testgres/operations/__init__.py b/testgres/operations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py new file mode 100644 index 00000000..89071282 --- /dev/null +++ b/testgres/operations/local_ops.py @@ -0,0 +1,269 @@ +import getpass +import os +import shutil +import stat +import subprocess +import tempfile + +import psutil + +from ..exceptions import ExecUtilException +from .os_ops import ConnectionParams, OsOperations +from .os_ops import pglib + +try: + from shutil import which as find_executable + from shutil import rmtree +except ImportError: + from distutils.spawn import find_executable + from distutils import rmtree + + +CMD_TIMEOUT_SEC = 60 +error_markers = [b'error', b'Permission denied', b'fatal'] + + +class LocalOperations(OsOperations): + def __init__(self, conn_params=None): + if conn_params is None: + conn_params = ConnectionParams() + super(LocalOperations, self).__init__(conn_params.username) + self.conn_params = conn_params + self.host = conn_params.host + self.ssh_key = None + self.remote = False + self.username = conn_params.username or self.get_user() + + # Command execution + def exec_command(self, cmd, wait_exit=False, verbose=False, + expect_error=False, encoding=None, shell=False, text=False, + input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None): + """ + Execute a command in a subprocess. + + Args: + - cmd: The command to execute. + - wait_exit: Whether to wait for the subprocess to exit before returning. + - verbose: Whether to return verbose output. + - expect_error: Whether to raise an error if the subprocess exits with an error status. + - encoding: The encoding to use for decoding the subprocess output. + - shell: Whether to use shell when executing the subprocess. + - text: Whether to return str instead of bytes for the subprocess output. + - input: The input to pass to the subprocess. + - stdout: The stdout to use for the subprocess. + - stderr: The stderr to use for the subprocess. + - proc: The process to use for subprocess creation. + :return: The output of the subprocess. + """ + if os.name == 'nt': + with tempfile.NamedTemporaryFile() as buf: + process = subprocess.Popen(cmd, stdout=buf, stderr=subprocess.STDOUT) + process.communicate() + buf.seek(0) + result = buf.read().decode(encoding) + return result + else: + process = subprocess.Popen( + cmd, + shell=shell, + stdout=stdout, + stderr=stderr, + ) + if proc: + return process + result, error = process.communicate(input) + exit_status = process.returncode + + error_found = exit_status != 0 or any(marker in error for marker in error_markers) + + if encoding: + result = result.decode(encoding) + error = error.decode(encoding) + + if expect_error: + raise Exception(result, error) + + if exit_status != 0 or error_found: + if exit_status == 0: + exit_status = 1 + raise ExecUtilException(message='Utility exited with non-zero code. Error `{}`'.format(error), + command=cmd, + exit_code=exit_status, + out=result) + if verbose: + return exit_status, result, error + else: + return result + + # Environment setup + def environ(self, var_name): + return os.environ.get(var_name) + + def find_executable(self, executable): + return find_executable(executable) + + def is_executable(self, file): + # Check if the file is executable + return os.stat(file).st_mode & stat.S_IXUSR + + def set_env(self, var_name, var_val): + # Check if the directory is already in PATH + os.environ[var_name] = var_val + + # Get environment variables + def get_user(self): + return getpass.getuser() + + def get_name(self): + return os.name + + # Work with dirs + def makedirs(self, path, remove_existing=False): + if remove_existing: + shutil.rmtree(path, ignore_errors=True) + try: + os.makedirs(path) + except FileExistsError: + pass + + def rmdirs(self, path, ignore_errors=True): + return rmtree(path, ignore_errors=ignore_errors) + + def listdir(self, path): + return os.listdir(path) + + def path_exists(self, path): + return os.path.exists(path) + + @property + def pathsep(self): + os_name = self.get_name() + if os_name == "posix": + pathsep = ":" + elif os_name == "nt": + pathsep = ";" + else: + raise Exception("Unsupported operating system: {}".format(os_name)) + return pathsep + + def mkdtemp(self, prefix=None): + return tempfile.mkdtemp(prefix='{}'.format(prefix)) + + def mkstemp(self, prefix=None): + fd, filename = tempfile.mkstemp(prefix=prefix) + os.close(fd) # Close the file descriptor immediately after creating the file + return filename + + def copytree(self, src, dst): + return shutil.copytree(src, dst) + + # Work with files + def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + """ + Write data to a file locally + Args: + filename: The file path where the data will be written. + data: The data to be written to the file. + truncate: If True, the file will be truncated before writing ('w' or 'wb' option); + if False (default), data will be appended ('a' or 'ab' option). + binary: If True, the data will be written in binary mode ('wb' or 'ab' option); + if False (default), the data will be written in text mode ('w' or 'a' option). + read_and_write: If True, the file will be opened with read and write permissions ('r+' option); + if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) + """ + # If it is a bytes str or list + if isinstance(data, bytes) or isinstance(data, list) and all(isinstance(item, bytes) for item in data): + binary = True + mode = "wb" if binary else "w" + if not truncate: + mode = "ab" if binary else "a" + if read_and_write: + mode = "r+b" if binary else "r+" + + with open(filename, mode) as file: + if isinstance(data, list): + file.writelines(data) + else: + file.write(data) + + def touch(self, filename): + """ + Create a new file or update the access and modification times of an existing file. + Args: + filename (str): The name of the file to touch. + + This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. + """ + # cross-python touch(). It is vulnerable to races, but who cares? + with open(filename, "a"): + os.utime(filename, None) + + def read(self, filename, encoding=None): + with open(filename, "r", encoding=encoding) as file: + return file.read() + + def readlines(self, filename, num_lines=0, binary=False, encoding=None): + """ + Read lines from a local file. + If num_lines is greater than 0, only the last num_lines lines will be read. + """ + assert num_lines >= 0 + mode = 'rb' if binary else 'r' + if num_lines == 0: + with open(filename, mode, encoding=encoding) as file: # open in binary mode + return file.readlines() + + else: + bufsize = 8192 + buffers = 1 + + with open(filename, mode, encoding=encoding) as file: # open in binary mode + file.seek(0, os.SEEK_END) + end_pos = file.tell() + + while True: + offset = max(0, end_pos - bufsize * buffers) + file.seek(offset, os.SEEK_SET) + pos = file.tell() + lines = file.readlines() + cur_lines = len(lines) + + if cur_lines >= num_lines or pos == 0: + return lines[-num_lines:] # get last num_lines from lines + + buffers = int( + buffers * max(2, int(num_lines / max(cur_lines, 1))) + ) # Adjust buffer size + + def isfile(self, remote_file): + return os.path.isfile(remote_file) + + def isdir(self, dirname): + return os.path.isdir(dirname) + + def remove_file(self, filename): + return os.remove(filename) + + # Processes control + def kill(self, pid, signal): + # Kill the process + cmd = "kill -{} {}".format(signal, pid) + return self.exec_command(cmd) + + def get_pid(self): + # Get current process id + return os.getpid() + + def get_process_children(self, pid): + return psutil.Process(pid).children() + + # Database control + def db_connect(self, dbname, user, password=None, host="localhost", port=5432): + conn = pglib.connect( + host=host, + port=port, + database=dbname, + user=user, + password=password, + ) + return conn diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py new file mode 100644 index 00000000..9261cacf --- /dev/null +++ b/testgres/operations/os_ops.py @@ -0,0 +1,101 @@ +try: + import psycopg2 as pglib # noqa: F401 +except ImportError: + try: + import pg8000 as pglib # noqa: F401 + except ImportError: + raise ImportError("You must have psycopg2 or pg8000 modules installed") + + +class ConnectionParams: + def __init__(self, host='127.0.0.1', ssh_key=None, username=None): + self.host = host + self.ssh_key = ssh_key + self.username = username + + +class OsOperations: + def __init__(self, username=None): + self.ssh_key = None + self.username = username + + # Command execution + def exec_command(self, cmd, **kwargs): + raise NotImplementedError() + + # Environment setup + def environ(self, var_name): + raise NotImplementedError() + + def find_executable(self, executable): + raise NotImplementedError() + + def is_executable(self, file): + # Check if the file is executable + raise NotImplementedError() + + def set_env(self, var_name, var_val): + # Check if the directory is already in PATH + raise NotImplementedError() + + # Get environment variables + def get_user(self): + raise NotImplementedError() + + def get_name(self): + raise NotImplementedError() + + # Work with dirs + def makedirs(self, path, remove_existing=False): + raise NotImplementedError() + + def rmdirs(self, path, ignore_errors=True): + raise NotImplementedError() + + def listdir(self, path): + raise NotImplementedError() + + def path_exists(self, path): + raise NotImplementedError() + + @property + def pathsep(self): + raise NotImplementedError() + + def mkdtemp(self, prefix=None): + raise NotImplementedError() + + def copytree(self, src, dst): + raise NotImplementedError() + + # Work with files + def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + raise NotImplementedError() + + def touch(self, filename): + raise NotImplementedError() + + def read(self, filename): + raise NotImplementedError() + + def readlines(self, filename): + raise NotImplementedError() + + def isfile(self, remote_file): + raise NotImplementedError() + + # Processes control + def kill(self, pid, signal): + # Kill the process + raise NotImplementedError() + + def get_pid(self): + # Get current process id + raise NotImplementedError() + + def get_process_children(self, pid): + raise NotImplementedError() + + # Database control + def db_connect(self, dbname, user, password=None, host="localhost", port=5432): + raise NotImplementedError() diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py new file mode 100644 index 00000000..6815c7f1 --- /dev/null +++ b/testgres/operations/remote_ops.py @@ -0,0 +1,448 @@ +import os +import tempfile +import time +from typing import Optional + +import sshtunnel + +import paramiko +from paramiko import SSHClient + +from ..exceptions import ExecUtilException + +from .os_ops import OsOperations, ConnectionParams +from .os_ops import pglib + +sshtunnel.SSH_TIMEOUT = 5.0 +sshtunnel.TUNNEL_TIMEOUT = 5.0 + + +error_markers = [b'error', b'Permission denied', b'fatal', b'No such file or directory'] + + +class PsUtilProcessProxy: + def __init__(self, ssh, pid): + self.ssh = ssh + self.pid = pid + + def kill(self): + command = "kill {}".format(self.pid) + self.ssh.exec_command(command) + + def cmdline(self): + command = "ps -p {} -o cmd --no-headers".format(self.pid) + stdin, stdout, stderr = self.ssh.exec_command(command) + cmdline = stdout.read().decode('utf-8').strip() + return cmdline.split() + + +class RemoteOperations(OsOperations): + def __init__(self, conn_params: ConnectionParams): + super().__init__(conn_params.username) + self.conn_params = conn_params + self.host = conn_params.host + self.ssh_key = conn_params.ssh_key + self.ssh = self.ssh_connect() + self.remote = True + self.username = conn_params.username or self.get_user() + self.tunnel = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close_tunnel() + if getattr(self, 'ssh', None): + self.ssh.close() + + def __del__(self): + if getattr(self, 'ssh', None): + self.ssh.close() + + def close_tunnel(self): + if getattr(self, 'tunnel', None): + self.tunnel.stop(force=True) + start_time = time.time() + while self.tunnel.is_active: + if time.time() - start_time > sshtunnel.TUNNEL_TIMEOUT: + break + time.sleep(0.5) + + def ssh_connect(self) -> Optional[SSHClient]: + key = self._read_ssh_key() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.host, username=self.username, pkey=key) + return ssh + + def _read_ssh_key(self): + try: + with open(self.ssh_key, "r") as f: + key_data = f.read() + if "BEGIN OPENSSH PRIVATE KEY" in key_data: + key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) + else: + key = paramiko.RSAKey.from_private_key_file(self.ssh_key) + return key + except FileNotFoundError: + raise ExecUtilException(message="No such file or directory: '{}'".format(self.ssh_key)) + except Exception as e: + ExecUtilException(message="An error occurred while reading the ssh key: {}".format(e)) + + def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, + encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None, + stderr=None, proc=None): + """ + Execute a command in the SSH session. + Args: + - cmd (str): The command to be executed. + """ + if self.ssh is None or not self.ssh.get_transport() or not self.ssh.get_transport().is_active(): + self.ssh = self.ssh_connect() + + if isinstance(cmd, list): + cmd = ' '.join(item.decode('utf-8') if isinstance(item, bytes) else item for item in cmd) + if input: + stdin, stdout, stderr = self.ssh.exec_command(cmd) + stdin.write(input) + stdin.flush() + else: + stdin, stdout, stderr = self.ssh.exec_command(cmd) + exit_status = 0 + if wait_exit: + exit_status = stdout.channel.recv_exit_status() + + if encoding: + result = stdout.read().decode(encoding) + error = stderr.read().decode(encoding) + else: + result = stdout.read() + error = stderr.read() + + if expect_error: + raise Exception(result, error) + + if encoding: + error_found = exit_status != 0 or any( + marker.decode(encoding) in error for marker in error_markers) + else: + error_found = exit_status != 0 or any( + marker in error for marker in error_markers) + + if error_found: + if exit_status == 0: + exit_status = 1 + if encoding: + message = "Utility exited with non-zero code. Error: {}".format(error.decode(encoding)) + else: + message = b"Utility exited with non-zero code. Error: " + error + raise ExecUtilException(message=message, + command=cmd, + exit_code=exit_status, + out=result) + + if verbose: + return exit_status, result, error + else: + return result + + # Environment setup + def environ(self, var_name: str) -> str: + """ + Get the value of an environment variable. + Args: + - var_name (str): The name of the environment variable. + """ + cmd = "echo ${}".format(var_name) + return self.exec_command(cmd, encoding='utf-8').strip() + + def find_executable(self, executable): + search_paths = self.environ("PATH") + if not search_paths: + return None + + search_paths = search_paths.split(self.pathsep) + for path in search_paths: + remote_file = os.path.join(path, executable) + if self.isfile(remote_file): + return remote_file + + return None + + def is_executable(self, file): + # Check if the file is executable + is_exec = self.exec_command("test -x {} && echo OK".format(file)) + return is_exec == b"OK\n" + + def set_env(self, var_name: str, var_val: str): + """ + Set the value of an environment variable. + Args: + - var_name (str): The name of the environment variable. + - var_val (str): The value to be set for the environment variable. + """ + return self.exec_command("export {}={}".format(var_name, var_val)) + + # Get environment variables + def get_user(self): + return self.exec_command("echo $USER", encoding='utf-8').strip() + + def get_name(self): + cmd = 'python3 -c "import os; print(os.name)"' + return self.exec_command(cmd, encoding='utf-8').strip() + + # Work with dirs + def makedirs(self, path, remove_existing=False): + """ + Create a directory in the remote server. + Args: + - path (str): The path to the directory to be created. + - remove_existing (bool): If True, the existing directory at the path will be removed. + """ + if remove_existing: + cmd = "rm -rf {} && mkdir -p {}".format(path, path) + else: + cmd = "mkdir -p {}".format(path) + try: + exit_status, result, error = self.exec_command(cmd, verbose=True) + except ExecUtilException as e: + raise Exception("Couldn't create dir {} because of error {}".format(path, e.message)) + if exit_status != 0: + raise Exception("Couldn't create dir {} because of error {}".format(path, error)) + return result + + def rmdirs(self, path, verbose=False, ignore_errors=True): + """ + Remove a directory in the remote server. + Args: + - path (str): The path to the directory to be removed. + - verbose (bool): If True, return exit status, result, and error. + - ignore_errors (bool): If True, do not raise error if directory does not exist. + """ + cmd = "rm -rf {}".format(path) + exit_status, result, error = self.exec_command(cmd, verbose=True) + if verbose: + return exit_status, result, error + else: + return result + + def listdir(self, path): + """ + List all files and directories in a directory. + Args: + path (str): The path to the directory. + """ + result = self.exec_command("ls {}".format(path)) + return result.splitlines() + + def path_exists(self, path): + result = self.exec_command("test -e {}; echo $?".format(path), encoding='utf-8') + return int(result.strip()) == 0 + + @property + def pathsep(self): + os_name = self.get_name() + if os_name == "posix": + pathsep = ":" + elif os_name == "nt": + pathsep = ";" + else: + raise Exception("Unsupported operating system: {}".format(os_name)) + return pathsep + + def mkdtemp(self, prefix=None): + """ + Creates a temporary directory in the remote server. + Args: + - prefix (str): The prefix of the temporary directory name. + """ + if prefix: + temp_dir = self.exec_command("mktemp -d {}XXXXX".format(prefix), encoding='utf-8') + else: + temp_dir = self.exec_command("mktemp -d", encoding='utf-8') + + if temp_dir: + if not os.path.isabs(temp_dir): + temp_dir = os.path.join('/home', self.username, temp_dir.strip()) + return temp_dir + else: + raise ExecUtilException("Could not create temporary directory.") + + def mkstemp(self, prefix=None): + if prefix: + temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding='utf-8') + else: + temp_dir = self.exec_command("mktemp", encoding='utf-8') + + if temp_dir: + if not os.path.isabs(temp_dir): + temp_dir = os.path.join('/home', self.username, temp_dir.strip()) + return temp_dir + else: + raise ExecUtilException("Could not create temporary directory.") + + def copytree(self, src, dst): + if not os.path.isabs(dst): + dst = os.path.join('~', dst) + if self.isdir(dst): + raise FileExistsError("Directory {} already exists.".format(dst)) + return self.exec_command("cp -r {} {}".format(src, dst)) + + # Work with files + def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding='utf-8'): + """ + Write data to a file on a remote host + + Args: + - filename (str): The file path where the data will be written. + - data (bytes or str): The data to be written to the file. + - truncate (bool): If True, the file will be truncated before writing ('w' or 'wb' option); + if False (default), data will be appended ('a' or 'ab' option). + - binary (bool): If True, the data will be written in binary mode ('wb' or 'ab' option); + if False (default), the data will be written in text mode ('w' or 'a' option). + - read_and_write (bool): If True, the file will be opened with read and write permissions ('r+' option); + if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option). + """ + mode = "wb" if binary else "w" + if not truncate: + mode = "ab" if binary else "a" + if read_and_write: + mode = "r+b" if binary else "r+" + + with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file: + if not truncate: + with self.ssh_connect() as ssh: + sftp = ssh.open_sftp() + try: + sftp.get(filename, tmp_file.name) + tmp_file.seek(0, os.SEEK_END) + except FileNotFoundError: + pass # File does not exist yet, we'll create it + sftp.close() + if isinstance(data, bytes) and not binary: + data = data.decode(encoding) + elif isinstance(data, str) and binary: + data = data.encode(encoding) + if isinstance(data, list): + # ensure each line ends with a newline + data = [(s if isinstance(s, str) else s.decode('utf-8')).rstrip('\n') + '\n' for s in data] + tmp_file.writelines(data) + else: + tmp_file.write(data) + tmp_file.flush() + + with self.ssh_connect() as ssh: + sftp = ssh.open_sftp() + remote_directory = os.path.dirname(filename) + try: + sftp.stat(remote_directory) + except IOError: + sftp.mkdir(remote_directory) + sftp.put(tmp_file.name, filename) + sftp.close() + + os.remove(tmp_file.name) + + def touch(self, filename): + """ + Create a new file or update the access and modification times of an existing file on the remote server. + + Args: + filename (str): The name of the file to touch. + + This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. + """ + self.exec_command("touch {}".format(filename)) + + def read(self, filename, binary=False, encoding=None): + cmd = "cat {}".format(filename) + result = self.exec_command(cmd, encoding=encoding) + + if not binary and result: + result = result.decode(encoding or 'utf-8') + + return result + + def readlines(self, filename, num_lines=0, binary=False, encoding=None): + if num_lines > 0: + cmd = "tail -n {} {}".format(num_lines, filename) + else: + cmd = "cat {}".format(filename) + + result = self.exec_command(cmd, encoding=encoding) + + if not binary and result: + lines = result.decode(encoding or 'utf-8').splitlines() + else: + lines = result.splitlines() + + return lines + + def isfile(self, remote_file): + stdout = self.exec_command("test -f {}; echo $?".format(remote_file)) + result = int(stdout.strip()) + return result == 0 + + def isdir(self, dirname): + cmd = "if [ -d {} ]; then echo True; else echo False; fi".format(dirname) + response = self.exec_command(cmd) + return response.strip() == b"True" + + def remove_file(self, filename): + cmd = "rm {}".format(filename) + return self.exec_command(cmd) + + # Processes control + def kill(self, pid, signal): + # Kill the process + cmd = "kill -{} {}".format(signal, pid) + return self.exec_command(cmd) + + def get_pid(self): + # Get current process id + return int(self.exec_command("echo $$", encoding='utf-8')) + + def get_process_children(self, pid): + command = "pgrep -P {}".format(pid) + stdin, stdout, stderr = self.ssh.exec_command(command) + children = stdout.readlines() + return [PsUtilProcessProxy(self.ssh, int(child_pid.strip())) for child_pid in children] + + # Database control + def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, ssh_key=None): + """ + Connects to a PostgreSQL database on the remote system. + Args: + - dbname (str): The name of the database to connect to. + - user (str): The username for the database connection. + - password (str, optional): The password for the database connection. Defaults to None. + - host (str, optional): The IP address of the remote system. Defaults to "localhost". + - port (int, optional): The port number of the PostgreSQL service. Defaults to 5432. + + This function establishes a connection to a PostgreSQL database on the remote system using the specified + parameters. It returns a connection object that can be used to interact with the database. + """ + self.close_tunnel() + self.tunnel = sshtunnel.open_tunnel( + (host, 22), # Remote server IP and SSH port + ssh_username=user or self.username, + ssh_pkey=ssh_key or self.ssh_key, + remote_bind_address=(host, port), # PostgreSQL server IP and PostgreSQL port + local_bind_address=('localhost', port) # Local machine IP and available port + ) + + self.tunnel.start() + + try: + conn = pglib.connect( + host=host, # change to 'localhost' because we're connecting through a local ssh tunnel + port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel + database=dbname, + user=user or self.username, + password=password + ) + + return conn + except Exception as e: + self.tunnel.stop() + raise ExecUtilException("Could not create db tunnel. {}".format(e)) diff --git a/testgres/pubsub.py b/testgres/pubsub.py index da85caac..1be673bb 100644 --- a/testgres/pubsub.py +++ b/testgres/pubsub.py @@ -214,4 +214,4 @@ def catchup(self, username=None): username=username or self.pub.username, max_attempts=LOGICAL_REPL_MAX_CATCHUP_ATTEMPTS) except Exception as e: - raise_from(CatchUpException("Failed to catch up", query), e) + raise_from(CatchUpException("Failed to catch up"), e) diff --git a/testgres/utils.py b/testgres/utils.py index 9760908d..5e12eba9 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -3,24 +3,18 @@ from __future__ import division from __future__ import print_function -import io import os import port_for -import subprocess import sys -import tempfile from contextlib import contextmanager from packaging.version import Version, InvalidVersion import re -try: - from shutil import which as find_executable -except ImportError: - from distutils.spawn import find_executable + from six import iteritems -from .config import testgres_config from .exceptions import ExecUtilException +from .config import testgres_config as tconf # rows returned by PG_CONFIG _pg_config_data = {} @@ -58,7 +52,7 @@ def release_port(port): bound_ports.discard(port) -def execute_utility(args, logfile=None): +def execute_utility(args, logfile=None, verbose=False): """ Execute utility (pg_ctl, pg_dump etc). @@ -69,63 +63,28 @@ def execute_utility(args, logfile=None): Returns: stdout of executed utility. """ - - # run utility - if os.name == 'nt': - # using output to a temporary file in Windows - buf = tempfile.NamedTemporaryFile() - - process = subprocess.Popen( - args, # util + params - stdout=buf, - stderr=subprocess.STDOUT) - process.communicate() - - # get result - buf.file.flush() - buf.file.seek(0) - out = buf.file.read() - buf.close() - else: - process = subprocess.Popen( - args, # util + params - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) - - # get result - out, _ = process.communicate() - + exit_status, out, error = tconf.os_ops.exec_command(args, verbose=True) # decode result - out = '' if not out else out.decode('utf-8') - - # format command - command = u' '.join(args) + out = '' if not out else out + if isinstance(out, bytes): + out = out.decode('utf-8') + if isinstance(error, bytes): + error = error.decode('utf-8') # write new log entry if possible if logfile: try: - with io.open(logfile, 'a') as file_out: - file_out.write(command) - - if out: - # comment-out lines - lines = ('# ' + line for line in out.splitlines(True)) - file_out.write(u'\n') - file_out.writelines(lines) - - file_out.write(u'\n') + tconf.os_ops.write(filename=logfile, data=args, truncate=True) + if out: + # comment-out lines + lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n'] + tconf.os_ops.write(filename=logfile, data=lines) except IOError: - pass - - exit_code = process.returncode - if exit_code: - message = 'Utility exited with non-zero code' - raise ExecUtilException(message=message, - command=command, - exit_code=exit_code, - out=out) - - return out + raise ExecUtilException("Problem with writing to logfile `{}` during run command `{}`".format(logfile, args)) + if verbose: + return exit_status, out, error + else: + return out def get_bin_path(filename): @@ -133,23 +92,25 @@ def get_bin_path(filename): Return absolute path to an executable using PG_BIN or PG_CONFIG. This function does nothing if 'filename' is already absolute. """ - # check if it's already absolute if os.path.isabs(filename): return filename + if tconf.os_ops.remote: + pg_config = os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG") + else: + # try PG_CONFIG - get from local machine + pg_config = os.environ.get("PG_CONFIG") - # try PG_CONFIG - pg_config = os.environ.get("PG_CONFIG") if pg_config: bindir = get_pg_config()["BINDIR"] return os.path.join(bindir, filename) # try PG_BIN - pg_bin = os.environ.get("PG_BIN") + pg_bin = tconf.os_ops.environ("PG_BIN") if pg_bin: return os.path.join(pg_bin, filename) - pg_config_path = find_executable('pg_config') + pg_config_path = tconf.os_ops.find_executable('pg_config') if pg_config_path: bindir = get_pg_config(pg_config_path)["BINDIR"] return os.path.join(bindir, filename) @@ -160,11 +121,12 @@ def get_bin_path(filename): def get_pg_config(pg_config_path=None): """ Return output of pg_config (provided that it is installed). - NOTE: this fuction caches the result by default (see GlobalConfig). + NOTE: this function caches the result by default (see GlobalConfig). """ + def cache_pg_config_data(cmd): # execute pg_config and get the output - out = subprocess.check_output([cmd]).decode('utf-8') + out = tconf.os_ops.exec_command(cmd, encoding='utf-8') data = {} for line in out.splitlines(): @@ -179,7 +141,7 @@ def cache_pg_config_data(cmd): return data # drop cache if asked to - if not testgres_config.cache_pg_config: + if not tconf.cache_pg_config: global _pg_config_data _pg_config_data = {} @@ -188,7 +150,11 @@ def cache_pg_config_data(cmd): return _pg_config_data # try specified pg_config path or PG_CONFIG - pg_config = pg_config_path or os.environ.get("PG_CONFIG") + if tconf.os_ops.remote: + pg_config = pg_config_path or os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG") + else: + # try PG_CONFIG - get from local machine + pg_config = pg_config_path or os.environ.get("PG_CONFIG") if pg_config: return cache_pg_config_data(pg_config) @@ -209,7 +175,7 @@ def get_pg_version(): # get raw version (e.g. postgres (PostgreSQL) 9.5.7) _params = [get_bin_path('postgres'), '--version'] - raw_ver = subprocess.check_output(_params).decode('utf-8') + raw_ver = tconf.os_ops.exec_command(_params, encoding='utf-8') # cook version of PostgreSQL version = raw_ver.strip().split(' ')[-1] \ diff --git a/tests/README.md b/tests/README.md index a6d50992..d89efc7e 100644 --- a/tests/README.md +++ b/tests/README.md @@ -27,3 +27,32 @@ export PYTHON_VERSION=3 # or 2 # Run tests ./run_tests.sh ``` + + +#### Remote host tests + +1. Start remote host or docker container +2. Make sure that you run ssh +```commandline +sudo apt-get install openssh-server +sudo systemctl start sshd +``` +3. You need to connect to the remote host at least once to add it to the known hosts file +4. Generate ssh keys +5. Set up params for tests + + +```commandline +conn_params = ConnectionParams( + host='remote_host', + username='username', + ssh_key=/path/to/your/ssh/key' +) +os_ops = RemoteOperations(conn_params) +``` +If you have different path to `PG_CONFIG` on your local and remote host you can set up `PG_CONFIG_REMOTE`, this value will be +using during work with remote host. + +`test_remote` - Tests for RemoteOperations class. + +`test_simple_remote` - Tests that create node and check it. The same as `test_simple`, but for remote node. \ No newline at end of file diff --git a/tests/test_remote.py b/tests/test_remote.py new file mode 100755 index 00000000..3794349c --- /dev/null +++ b/tests/test_remote.py @@ -0,0 +1,198 @@ +import os + +import pytest + +from testgres import ExecUtilException +from testgres import RemoteOperations +from testgres import ConnectionParams + + +class TestRemoteOperations: + + @pytest.fixture(scope="function", autouse=True) + def setup(self): + conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3', + username='dev', + ssh_key=os.getenv( + 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519') + self.operations = RemoteOperations(conn_params) + + yield + self.operations.__del__() + + def test_exec_command_success(self): + """ + Test exec_command for successful command execution. + """ + cmd = "python3 --version" + response = self.operations.exec_command(cmd, wait_exit=True) + + assert b'Python 3.' in response + + def test_exec_command_failure(self): + """ + Test exec_command for command execution failure. + """ + cmd = "nonexistent_command" + try: + exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True) + except ExecUtilException as e: + error = e.message + assert error == b'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n' + + def test_is_executable_true(self): + """ + Test is_executable for an existing executable. + """ + cmd = "postgres" + response = self.operations.is_executable(cmd) + + assert response is True + + def test_is_executable_false(self): + """ + Test is_executable for a non-executable. + """ + cmd = "python" + response = self.operations.is_executable(cmd) + + assert response is False + + def test_makedirs_and_rmdirs_success(self): + """ + Test makedirs and rmdirs for successful directory creation and removal. + """ + cmd = "pwd" + pwd = self.operations.exec_command(cmd, wait_exit=True, encoding='utf-8').strip() + + path = "{}/test_dir".format(pwd) + + # Test makedirs + self.operations.makedirs(path) + assert self.operations.path_exists(path) + + # Test rmdirs + self.operations.rmdirs(path) + assert not self.operations.path_exists(path) + + def test_makedirs_and_rmdirs_failure(self): + """ + Test makedirs and rmdirs for directory creation and removal failure. + """ + # Try to create a directory in a read-only location + path = "/root/test_dir" + + # Test makedirs + with pytest.raises(Exception): + self.operations.makedirs(path) + + # Test rmdirs + try: + exit_status, result, error = self.operations.rmdirs(path, verbose=True) + except ExecUtilException as e: + error = e.message + assert error == b"Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n" + + def test_listdir(self): + """ + Test listdir for listing directory contents. + """ + path = "/etc" + files = self.operations.listdir(path) + + assert isinstance(files, list) + + def test_path_exists_true(self): + """ + Test path_exists for an existing path. + """ + path = "/etc" + response = self.operations.path_exists(path) + + assert response is True + + def test_path_exists_false(self): + """ + Test path_exists for a non-existing path. + """ + path = "/nonexistent_path" + response = self.operations.path_exists(path) + + assert response is False + + def test_write_text_file(self): + """ + Test write for writing data to a text file. + """ + filename = "/tmp/test_file.txt" + data = "Hello, world!" + + self.operations.write(filename, data, truncate=True) + self.operations.write(filename, data) + + response = self.operations.read(filename) + + assert response == data + data + + def test_write_binary_file(self): + """ + Test write for writing data to a binary file. + """ + filename = "/tmp/test_file.bin" + data = b"\x00\x01\x02\x03" + + self.operations.write(filename, data, binary=True, truncate=True) + + response = self.operations.read(filename, binary=True) + + assert response == data + + def test_read_text_file(self): + """ + Test read for reading data from a text file. + """ + filename = "/etc/hosts" + + response = self.operations.read(filename) + + assert isinstance(response, str) + + def test_read_binary_file(self): + """ + Test read for reading data from a binary file. + """ + filename = "/usr/bin/python3" + + response = self.operations.read(filename, binary=True) + + assert isinstance(response, bytes) + + def test_touch(self): + """ + Test touch for creating a new file or updating access and modification times of an existing file. + """ + filename = "/tmp/test_file.txt" + + self.operations.touch(filename) + + assert self.operations.isfile(filename) + + def test_isfile_true(self): + """ + Test isfile for an existing file. + """ + filename = "/etc/hosts" + + response = self.operations.isfile(filename) + + assert response is True + + def test_isfile_false(self): + """ + Test isfile for a non-existing file. + """ + filename = "/nonexistent_file.txt" + + response = self.operations.isfile(filename) + + assert response is False diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py new file mode 100755 index 00000000..e8386383 --- /dev/null +++ b/tests/test_simple_remote.py @@ -0,0 +1,996 @@ +#!/usr/bin/env python +# coding: utf-8 + +import os +import re +import subprocess +import tempfile + +import testgres +import time +import six +import unittest +import psutil + +import logging.config + +from contextlib import contextmanager + +from testgres.exceptions import \ + InitNodeException, \ + StartNodeException, \ + ExecUtilException, \ + BackupException, \ + QueryException, \ + TimeoutException, \ + TestgresException + +from testgres.config import \ + TestgresConfig, \ + configure_testgres, \ + scoped_config, \ + pop_config, testgres_config + +from testgres import \ + NodeStatus, \ + ProcessType, \ + IsolationLevel, \ + get_remote_node, \ + RemoteOperations + +from testgres import \ + get_bin_path, \ + get_pg_config, \ + get_pg_version + +from testgres import \ + First, \ + Any + +# NOTE: those are ugly imports +from testgres import bound_ports +from testgres.utils import PgVer +from testgres.node import ProcessProxy, ConnectionParams + +conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3', + username='dev', + ssh_key=os.getenv( + 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519') +os_ops = RemoteOperations(conn_params) +testgres_config.set_os_ops(os_ops=os_ops) + + +def pg_version_ge(version): + cur_ver = PgVer(get_pg_version()) + min_ver = PgVer(version) + return cur_ver >= min_ver + + +def util_exists(util): + def good_properties(f): + return (os_ops.path_exists(f) and # noqa: W504 + os_ops.isfile(f) and # noqa: W504 + os_ops.is_executable(f)) # yapf: disable + + # try to resolve it + if good_properties(get_bin_path(util)): + return True + + # check if util is in PATH + for path in os_ops.environ("PATH").split(os_ops.pathsep): + if good_properties(os.path.join(path, util)): + return True + + +@contextmanager +def removing(f): + try: + yield f + finally: + if os_ops.isfile(f): + os_ops.remove_file(f) + + elif os_ops.isdir(f): + os_ops.rmdirs(f, ignore_errors=True) + + +class TestgresRemoteTests(unittest.TestCase): + + def test_node_repr(self): + with get_remote_node(conn_params=conn_params) as node: + pattern = r"PostgresNode\(name='.+', port=.+, base_dir='.+'\)" + self.assertIsNotNone(re.match(pattern, str(node))) + + def test_custom_init(self): + with get_remote_node(conn_params=conn_params) as node: + # enable page checksums + node.init(initdb_params=['-k']).start() + + with get_remote_node(conn_params=conn_params) as node: + node.init( + allow_streaming=True, + initdb_params=['--auth-local=reject', '--auth-host=reject']) + + hba_file = os.path.join(node.data_dir, 'pg_hba.conf') + lines = os_ops.readlines(hba_file) + + # check number of lines + self.assertGreaterEqual(len(lines), 6) + + # there should be no trust entries at all + self.assertFalse(any('trust' in s for s in lines)) + + def test_double_init(self): + with get_remote_node(conn_params=conn_params).init() as node: + # can't initialize node more than once + with self.assertRaises(InitNodeException): + node.init() + + def test_init_after_cleanup(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start().execute('select 1') + node.cleanup() + node.init().start().execute('select 1') + + @unittest.skipUnless(util_exists('pg_resetwal'), 'might be missing') + @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') + def test_init_unique_system_id(self): + # FAIL + # this function exists in PostgreSQL 9.6+ + query = 'select system_identifier from pg_control_system()' + + with scoped_config(cache_initdb=False): + with get_remote_node(conn_params=conn_params).init().start() as node0: + id0 = node0.execute(query)[0] + + with scoped_config(cache_initdb=True, + cached_initdb_unique=True) as config: + self.assertTrue(config.cache_initdb) + self.assertTrue(config.cached_initdb_unique) + + # spawn two nodes; ids must be different + with get_remote_node(conn_params=conn_params).init().start() as node1, \ + get_remote_node(conn_params=conn_params).init().start() as node2: + id1 = node1.execute(query)[0] + id2 = node2.execute(query)[0] + + # ids must increase + self.assertGreater(id1, id0) + self.assertGreater(id2, id1) + + def test_node_exit(self): + with self.assertRaises(QueryException): + with get_remote_node(conn_params=conn_params).init() as node: + base_dir = node.base_dir + node.safe_psql('select 1') + + # we should save the DB for "debugging" + self.assertTrue(os_ops.path_exists(base_dir)) + os_ops.rmdirs(base_dir, ignore_errors=True) + + with get_remote_node(conn_params=conn_params).init() as node: + base_dir = node.base_dir + + # should have been removed by default + self.assertFalse(os_ops.path_exists(base_dir)) + + def test_double_start(self): + with get_remote_node(conn_params=conn_params).init().start() as node: + # can't start node more than once + node.start() + self.assertTrue(node.is_started) + + def test_uninitialized_start(self): + with get_remote_node(conn_params=conn_params) as node: + # node is not initialized yet + with self.assertRaises(StartNodeException): + node.start() + + def test_restart(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start() + + # restart, ok + res = node.execute('select 1') + self.assertEqual(res, [(1,)]) + node.restart() + res = node.execute('select 2') + self.assertEqual(res, [(2,)]) + + # restart, fail + with self.assertRaises(StartNodeException): + node.append_conf('pg_hba.conf', 'DUMMY') + node.restart() + + def test_reload(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start() + + # change client_min_messages and save old value + cmm_old = node.execute('show client_min_messages') + node.append_conf(client_min_messages='DEBUG1') + + # reload config + node.reload() + + # check new value + cmm_new = node.execute('show client_min_messages') + self.assertEqual('debug1', cmm_new[0][0].lower()) + self.assertNotEqual(cmm_old, cmm_new) + + def test_pg_ctl(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start() + + status = node.pg_ctl(['status']) + self.assertTrue('PID' in status) + + def test_status(self): + self.assertTrue(NodeStatus.Running) + self.assertFalse(NodeStatus.Stopped) + self.assertFalse(NodeStatus.Uninitialized) + + # check statuses after each operation + with get_remote_node(conn_params=conn_params) as node: + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Uninitialized) + + node.init() + + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Stopped) + + node.start() + + self.assertNotEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Running) + + node.stop() + + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Stopped) + + node.cleanup() + + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Uninitialized) + + def test_psql(self): + with get_remote_node(conn_params=conn_params).init().start() as node: + # check returned values (1 arg) + res = node.psql('select 1') + self.assertEqual(res, (0, b'1\n', b'')) + + # check returned values (2 args) + res = node.psql('postgres', 'select 2') + self.assertEqual(res, (0, b'2\n', b'')) + + # check returned values (named) + res = node.psql(query='select 3', dbname='postgres') + self.assertEqual(res, (0, b'3\n', b'')) + + # check returned values (1 arg) + res = node.safe_psql('select 4') + self.assertEqual(res, b'4\n') + + # check returned values (2 args) + res = node.safe_psql('postgres', 'select 5') + self.assertEqual(res, b'5\n') + + # check returned values (named) + res = node.safe_psql(query='select 6', dbname='postgres') + self.assertEqual(res, b'6\n') + + # check feeding input + node.safe_psql('create table horns (w int)') + node.safe_psql('copy horns from stdin (format csv)', + input=b"1\n2\n3\n\\.\n") + _sum = node.safe_psql('select sum(w) from horns') + self.assertEqual(_sum, b'6\n') + + # check psql's default args, fails + with self.assertRaises(QueryException): + node.psql() + + node.stop() + + # check psql on stopped node, fails + with self.assertRaises(QueryException): + node.safe_psql('select 1') + + def test_transactions(self): + with get_remote_node(conn_params=conn_params).init().start() as node: + with node.connect() as con: + con.begin() + con.execute('create table test(val int)') + con.execute('insert into test values (1)') + con.commit() + + con.begin() + con.execute('insert into test values (2)') + res = con.execute('select * from test order by val asc') + self.assertListEqual(res, [(1,), (2,)]) + con.rollback() + + con.begin() + res = con.execute('select * from test') + self.assertListEqual(res, [(1,)]) + con.rollback() + + con.begin() + con.execute('drop table test') + con.commit() + + def test_control_data(self): + with get_remote_node(conn_params=conn_params) as node: + # node is not initialized yet + with self.assertRaises(ExecUtilException): + node.get_control_data() + + node.init() + data = node.get_control_data() + + # check returned dict + self.assertIsNotNone(data) + self.assertTrue(any('pg_control' in s for s in data.keys())) + + def test_backup_simple(self): + with get_remote_node(conn_params=conn_params) as master: + # enable streaming for backups + master.init(allow_streaming=True) + + # node must be running + with self.assertRaises(BackupException): + master.backup() + + # it's time to start node + master.start() + + # fill node with some data + master.psql('create table test as select generate_series(1, 4) i') + + with master.backup(xlog_method='stream') as backup: + with backup.spawn_primary().start() as slave: + res = slave.execute('select * from test order by i asc') + self.assertListEqual(res, [(1,), (2,), (3,), (4,)]) + + def test_backup_multiple(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(allow_streaming=True).start() + + with node.backup(xlog_method='fetch') as backup1, \ + node.backup(xlog_method='fetch') as backup2: + self.assertNotEqual(backup1.base_dir, backup2.base_dir) + + with node.backup(xlog_method='fetch') as backup: + with backup.spawn_primary('node1', destroy=False) as node1, \ + backup.spawn_primary('node2', destroy=False) as node2: + self.assertNotEqual(node1.base_dir, node2.base_dir) + + def test_backup_exhaust(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(allow_streaming=True).start() + + with node.backup(xlog_method='fetch') as backup: + # exhaust backup by creating new node + with backup.spawn_primary(): + pass + + # now let's try to create one more node + with self.assertRaises(BackupException): + backup.spawn_primary() + + def test_backup_wrong_xlog_method(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(allow_streaming=True).start() + + with self.assertRaises(BackupException, + msg='Invalid xlog_method "wrong"'): + node.backup(xlog_method='wrong') + + def test_pg_ctl_wait_option(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start(wait=False) + while True: + try: + node.stop(wait=False) + break + except ExecUtilException: + # it's ok to get this exception here since node + # could be not started yet + pass + + def test_replicate(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(allow_streaming=True).start() + + with node.replicate().start() as replica: + res = replica.execute('select 1') + self.assertListEqual(res, [(1,)]) + + node.execute('create table test (val int)', commit=True) + + replica.catchup() + + res = node.execute('select * from test') + self.assertListEqual(res, []) + + @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') + def test_synchronous_replication(self): + with get_remote_node(conn_params=conn_params) as master: + old_version = not pg_version_ge('9.6') + + master.init(allow_streaming=True).start() + + if not old_version: + master.append_conf('synchronous_commit = remote_apply') + + # create standby + with master.replicate() as standby1, master.replicate() as standby2: + standby1.start() + standby2.start() + + # check formatting + self.assertEqual( + '1 ("{}", "{}")'.format(standby1.name, standby2.name), + str(First(1, (standby1, standby2)))) # yapf: disable + self.assertEqual( + 'ANY 1 ("{}", "{}")'.format(standby1.name, standby2.name), + str(Any(1, (standby1, standby2)))) # yapf: disable + + # set synchronous_standby_names + master.set_synchronous_standbys(First(2, [standby1, standby2])) + master.restart() + + # the following part of the test is only applicable to newer + # versions of PostgresQL + if not old_version: + master.safe_psql('create table abc(a int)') + + # Create a large transaction that will take some time to apply + # on standby to check that it applies synchronously + # (If set synchronous_commit to 'on' or other lower level then + # standby most likely won't catchup so fast and test will fail) + master.safe_psql( + 'insert into abc select generate_series(1, 1000000)') + res = standby1.safe_psql('select count(*) from abc') + self.assertEqual(res, b'1000000\n') + + @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') + def test_logical_replication(self): + with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2: + node1.init(allow_logical=True) + node1.start() + node2.init().start() + + create_table = 'create table test (a int, b int)' + node1.safe_psql(create_table) + node2.safe_psql(create_table) + + # create publication / create subscription + pub = node1.publish('mypub') + sub = node2.subscribe(pub, 'mysub') + + node1.safe_psql('insert into test values (1, 1), (2, 2)') + + # wait until changes apply on subscriber and check them + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [(1, 1), (2, 2)]) + + # disable and put some new data + sub.disable() + node1.safe_psql('insert into test values (3, 3)') + + # enable and ensure that data successfully transfered + sub.enable() + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [(1, 1), (2, 2), (3, 3)]) + + # Add new tables. Since we added "all tables" to publication + # (default behaviour of publish() method) we don't need + # to explicitely perform pub.add_tables() + create_table = 'create table test2 (c char)' + node1.safe_psql(create_table) + node2.safe_psql(create_table) + sub.refresh() + + # put new data + node1.safe_psql('insert into test2 values (\'a\'), (\'b\')') + sub.catchup() + res = node2.execute('select * from test2') + self.assertListEqual(res, [('a',), ('b',)]) + + # drop subscription + sub.drop() + pub.drop() + + # create new publication and subscription for specific table + # (ommitting copying data as it's already done) + pub = node1.publish('newpub', tables=['test']) + sub = node2.subscribe(pub, 'newsub', copy_data=False) + + node1.safe_psql('insert into test values (4, 4)') + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [(1, 1), (2, 2), (3, 3), (4, 4)]) + + # explicitely add table + with self.assertRaises(ValueError): + pub.add_tables([]) # fail + pub.add_tables(['test2']) + node1.safe_psql('insert into test2 values (\'c\')') + sub.catchup() + res = node2.execute('select * from test2') + self.assertListEqual(res, [('a',), ('b',)]) + + @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') + def test_logical_catchup(self): + """ Runs catchup for 100 times to be sure that it is consistent """ + with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2: + node1.init(allow_logical=True) + node1.start() + node2.init().start() + + create_table = 'create table test (key int primary key, val int); ' + node1.safe_psql(create_table) + node1.safe_psql('alter table test replica identity default') + node2.safe_psql(create_table) + + # create publication / create subscription + sub = node2.subscribe(node1.publish('mypub'), 'mysub') + + for i in range(0, 100): + node1.execute('insert into test values ({0}, {0})'.format(i)) + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [( + i, + i, + )]) + node1.execute('delete from test') + + @unittest.skipIf(pg_version_ge('10'), 'requires <10') + def test_logical_replication_fail(self): + with get_remote_node(conn_params=conn_params) as node: + with self.assertRaises(InitNodeException): + node.init(allow_logical=True) + + def test_replication_slots(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(allow_streaming=True).start() + + with node.replicate(slot='slot1').start() as replica: + replica.execute('select 1') + + # cannot create new slot with the same name + with self.assertRaises(TestgresException): + node.replicate(slot='slot1') + + def test_incorrect_catchup(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(allow_streaming=True).start() + + # node has no master, can't catch up + with self.assertRaises(TestgresException): + node.catchup() + + def test_promotion(self): + with get_remote_node(conn_params=conn_params) as master: + master.init().start() + master.safe_psql('create table abc(id serial)') + + with master.replicate().start() as replica: + master.stop() + replica.promote() + + # make standby becomes writable master + replica.safe_psql('insert into abc values (1)') + res = replica.safe_psql('select * from abc') + self.assertEqual(res, b'1\n') + + def test_dump(self): + query_create = 'create table test as select generate_series(1, 2) as val' + query_select = 'select * from test order by val asc' + + with get_remote_node(conn_params=conn_params).init().start() as node1: + + node1.execute(query_create) + for format in ['plain', 'custom', 'directory', 'tar']: + with removing(node1.dump(format=format)) as dump: + with get_remote_node(conn_params=conn_params).init().start() as node3: + if format == 'directory': + self.assertTrue(os_ops.isdir(dump)) + else: + self.assertTrue(os_ops.isfile(dump)) + # restore dump + node3.restore(filename=dump) + res = node3.execute(query_select) + self.assertListEqual(res, [(1,), (2,)]) + + def test_users(self): + with get_remote_node(conn_params=conn_params).init().start() as node: + node.psql('create role test_user login') + value = node.safe_psql('select 1', username='test_user') + self.assertEqual(b'1\n', value) + + def test_poll_query_until(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start() + + get_time = 'select extract(epoch from now())' + check_time = 'select extract(epoch from now()) - {} >= 5' + + start_time = node.execute(get_time)[0][0] + node.poll_query_until(query=check_time.format(start_time)) + end_time = node.execute(get_time)[0][0] + + self.assertTrue(end_time - start_time >= 5) + + # check 0 columns + with self.assertRaises(QueryException): + node.poll_query_until( + query='select from pg_catalog.pg_class limit 1') + + # check None, fail + with self.assertRaises(QueryException): + node.poll_query_until(query='create table abc (val int)') + + # check None, ok + node.poll_query_until(query='create table def()', + expected=None) # returns nothing + + # check 0 rows equivalent to expected=None + node.poll_query_until( + query='select * from pg_catalog.pg_class where true = false', + expected=None) + + # check arbitrary expected value, fail + with self.assertRaises(TimeoutException): + node.poll_query_until(query='select 3', + expected=1, + max_attempts=3, + sleep_time=0.01) + + # check arbitrary expected value, ok + node.poll_query_until(query='select 2', expected=2) + + # check timeout + with self.assertRaises(TimeoutException): + node.poll_query_until(query='select 1 > 2', + max_attempts=3, + sleep_time=0.01) + + # check ProgrammingError, fail + with self.assertRaises(testgres.ProgrammingError): + node.poll_query_until(query='dummy1') + + # check ProgrammingError, ok + with self.assertRaises(TimeoutException): + node.poll_query_until(query='dummy2', + max_attempts=3, + sleep_time=0.01, + suppress={testgres.ProgrammingError}) + + # check 1 arg, ok + node.poll_query_until('select true') + + def test_logging(self): + # FAIL + logfile = tempfile.NamedTemporaryFile('w', delete=True) + + log_conf = { + 'version': 1, + 'handlers': { + 'file': { + 'class': 'logging.FileHandler', + 'filename': logfile.name, + 'formatter': 'base_format', + 'level': logging.DEBUG, + }, + }, + 'formatters': { + 'base_format': { + 'format': '%(node)-5s: %(message)s', + }, + }, + 'root': { + 'handlers': ('file',), + 'level': 'DEBUG', + }, + } + + logging.config.dictConfig(log_conf) + + with scoped_config(use_python_logging=True): + node_name = 'master' + + with get_remote_node(name=node_name) as master: + master.init().start() + + # execute a dummy query a few times + for i in range(20): + master.execute('select 1') + time.sleep(0.01) + + # let logging worker do the job + time.sleep(0.1) + + # check that master's port is found + with open(logfile.name, 'r') as log: + lines = log.readlines() + self.assertTrue(any(node_name in s for s in lines)) + + # test logger after stop/start/restart + master.stop() + master.start() + master.restart() + self.assertTrue(master._logger.is_alive()) + + @unittest.skipUnless(util_exists('pgbench'), 'might be missing') + def test_pgbench(self): + with get_remote_node(conn_params=conn_params).init().start() as node: + # initialize pgbench DB and run benchmarks + node.pgbench_init(scale=2, foreign_keys=True, + options=['-q']).pgbench_run(time=2) + + # run TPC-B benchmark + out = node.pgbench(stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + options=['-T3']) + self.assertTrue(b'tps = ' in out) + + def test_pg_config(self): + # check same instances + a = get_pg_config() + b = get_pg_config() + self.assertEqual(id(a), id(b)) + + # save right before config change + c1 = get_pg_config() + # modify setting for this scope + with scoped_config(cache_pg_config=False) as config: + # sanity check for value + self.assertFalse(config.cache_pg_config) + + # save right after config change + c2 = get_pg_config() + + # check different instances after config change + self.assertNotEqual(id(c1), id(c2)) + + # check different instances + a = get_pg_config() + b = get_pg_config() + self.assertNotEqual(id(a), id(b)) + + def test_config_stack(self): + # no such option + with self.assertRaises(TypeError): + configure_testgres(dummy=True) + + # we have only 1 config in stack + with self.assertRaises(IndexError): + pop_config() + + d0 = TestgresConfig.cached_initdb_dir + d1 = 'dummy_abc' + d2 = 'dummy_def' + + with scoped_config(cached_initdb_dir=d1) as c1: + self.assertEqual(c1.cached_initdb_dir, d1) + + with scoped_config(cached_initdb_dir=d2) as c2: + stack_size = len(testgres.config.config_stack) + + # try to break a stack + with self.assertRaises(TypeError): + with scoped_config(dummy=True): + pass + + self.assertEqual(c2.cached_initdb_dir, d2) + self.assertEqual(len(testgres.config.config_stack), stack_size) + + self.assertEqual(c1.cached_initdb_dir, d1) + + self.assertEqual(TestgresConfig.cached_initdb_dir, d0) + + def test_unix_sockets(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(unix_sockets=False, allow_streaming=True) + node.start() + + res_exec = node.execute('select 1') + res_psql = node.safe_psql('select 1') + self.assertEqual(res_exec, [(1,)]) + self.assertEqual(res_psql, b'1\n') + + with node.replicate().start() as r: + res_exec = r.execute('select 1') + res_psql = r.safe_psql('select 1') + self.assertEqual(res_exec, [(1,)]) + self.assertEqual(res_psql, b'1\n') + + def test_auto_name(self): + with get_remote_node(conn_params=conn_params).init(allow_streaming=True).start() as m: + with m.replicate().start() as r: + # check that nodes are running + self.assertTrue(m.status()) + self.assertTrue(r.status()) + + # check their names + self.assertNotEqual(m.name, r.name) + self.assertTrue('testgres' in m.name) + self.assertTrue('testgres' in r.name) + + def test_file_tail(self): + from testgres.utils import file_tail + + s1 = "the quick brown fox jumped over that lazy dog\n" + s2 = "abc\n" + s3 = "def\n" + + with tempfile.NamedTemporaryFile(mode='r+', delete=True) as f: + sz = 0 + while sz < 3 * 8192: + sz += len(s1) + f.write(s1) + f.write(s2) + f.write(s3) + + f.seek(0) + lines = file_tail(f, 3) + self.assertEqual(lines[0], s1) + self.assertEqual(lines[1], s2) + self.assertEqual(lines[2], s3) + + f.seek(0) + lines = file_tail(f, 1) + self.assertEqual(lines[0], s3) + + def test_isolation_levels(self): + with get_remote_node(conn_params=conn_params).init().start() as node: + with node.connect() as con: + # string levels + con.begin('Read Uncommitted').commit() + con.begin('Read Committed').commit() + con.begin('Repeatable Read').commit() + con.begin('Serializable').commit() + + # enum levels + con.begin(IsolationLevel.ReadUncommitted).commit() + con.begin(IsolationLevel.ReadCommitted).commit() + con.begin(IsolationLevel.RepeatableRead).commit() + con.begin(IsolationLevel.Serializable).commit() + + # check wrong level + with self.assertRaises(QueryException): + con.begin('Garbage').commit() + + def test_ports_management(self): + # check that no ports have been bound yet + self.assertEqual(len(bound_ports), 0) + + with get_remote_node(conn_params=conn_params) as node: + # check that we've just bound a port + self.assertEqual(len(bound_ports), 1) + + # check that bound_ports contains our port + port_1 = list(bound_ports)[0] + port_2 = node.port + self.assertEqual(port_1, port_2) + + # check that port has been freed successfully + self.assertEqual(len(bound_ports), 0) + + def test_exceptions(self): + str(StartNodeException('msg', [('file', 'lines')])) + str(ExecUtilException('msg', 'cmd', 1, 'out')) + str(QueryException('msg', 'query')) + + def test_version_management(self): + a = PgVer('10.0') + b = PgVer('10') + c = PgVer('9.6.5') + d = PgVer('15.0') + e = PgVer('15rc1') + f = PgVer('15beta4') + + self.assertTrue(a == b) + self.assertTrue(b > c) + self.assertTrue(a > c) + self.assertTrue(d > e) + self.assertTrue(e > f) + self.assertTrue(d > f) + + version = get_pg_version() + with get_remote_node(conn_params=conn_params) as node: + self.assertTrue(isinstance(version, six.string_types)) + self.assertTrue(isinstance(node.version, PgVer)) + self.assertEqual(node.version, PgVer(version)) + + def test_child_pids(self): + master_processes = [ + ProcessType.AutovacuumLauncher, + ProcessType.BackgroundWriter, + ProcessType.Checkpointer, + ProcessType.StatsCollector, + ProcessType.WalSender, + ProcessType.WalWriter, + ] + + if pg_version_ge('10'): + master_processes.append(ProcessType.LogicalReplicationLauncher) + + repl_processes = [ + ProcessType.Startup, + ProcessType.WalReceiver, + ] + + with get_remote_node(conn_params=conn_params).init().start() as master: + + # master node doesn't have a source walsender! + with self.assertRaises(TestgresException): + master.source_walsender + + with master.connect() as con: + self.assertGreater(con.pid, 0) + + with master.replicate().start() as replica: + + # test __str__ method + str(master.child_processes[0]) + + master_pids = master.auxiliary_pids + for ptype in master_processes: + self.assertIn(ptype, master_pids) + + replica_pids = replica.auxiliary_pids + for ptype in repl_processes: + self.assertIn(ptype, replica_pids) + + # there should be exactly 1 source walsender for replica + self.assertEqual(len(master_pids[ProcessType.WalSender]), 1) + pid1 = master_pids[ProcessType.WalSender][0] + pid2 = replica.source_walsender.pid + self.assertEqual(pid1, pid2) + + replica.stop() + + # there should be no walsender after we've stopped replica + with self.assertRaises(TestgresException): + replica.source_walsender + + def test_child_process_dies(self): + # test for FileNotFound exception during child_processes() function + with subprocess.Popen(["sleep", "60"]) as process: + self.assertEqual(process.poll(), None) + # collect list of processes currently running + children = psutil.Process(os.getpid()).children() + # kill a process, so received children dictionary becomes invalid + process.kill() + process.wait() + # try to handle children list -- missing processes will have ptype "ProcessType.Unknown" + [ProcessProxy(p) for p in children] + + +if __name__ == '__main__': + if os_ops.environ('ALT_CONFIG'): + suite = unittest.TestSuite() + + # Small subset of tests for alternative configs (PG_BIN or PG_CONFIG) + suite.addTest(TestgresRemoteTests('test_pg_config')) + suite.addTest(TestgresRemoteTests('test_pg_ctl')) + suite.addTest(TestgresRemoteTests('test_psql')) + suite.addTest(TestgresRemoteTests('test_replicate')) + + print('Running tests for alternative config:') + for t in suite: + print(t) + print() + + runner = unittest.TextTestRunner() + runner.run(suite) + else: + unittest.main()
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: