Content-Length: 47261 | pFad | http://github.com/postgrespro/testgres/pull/149.patch
thub.com
From 1b10a5499b26cb21d28d2c14feda5d9ce98fe7ce Mon Sep 17 00:00:00 2001
From: vshepard
Date: Fri, 1 Nov 2024 23:30:38 +0100
Subject: [PATCH 1/9] Add ability to skip ssl when connect to PostgresNode
---
testgres/api.py | 4 +-
testgres/node.py | 13 ++---
testgres/operations/local_ops.py | 18 +------
testgres/operations/os_ops.py | 28 ++++++++--
testgres/operations/remote_ops.py | 21 +-------
testgres/utils.py | 4 +-
tests/test_remote.py | 28 ++++++++--
tests/test_simple_remote.py | 89 ++++++++++++++++---------------
8 files changed, 106 insertions(+), 99 deletions(-)
diff --git a/testgres/api.py b/testgres/api.py
index e4b1cdd5..10bfd669 100644
--- a/testgres/api.py
+++ b/testgres/api.py
@@ -42,7 +42,7 @@ def get_new_node(name=None, base_dir=None, **kwargs):
return PostgresNode(name=name, base_dir=base_dir, **kwargs)
-def get_remote_node(name=None, conn_params=None):
+def get_remote_node(name=None):
"""
Simply a wrapper around :class:`.PostgresNode` constructor for remote node.
See :meth:`.PostgresNode.__init__` for details.
@@ -51,4 +51,4 @@ def get_remote_node(name=None, conn_params=None):
ssh_key=None,
username=default_username())
"""
- return get_new_node(name=name, conn_params=conn_params)
+ return get_new_node(name=name)
diff --git a/testgres/node.py b/testgres/node.py
index c8c8c087..78ebd87b 100644
--- a/testgres/node.py
+++ b/testgres/node.py
@@ -126,7 +126,8 @@ def __repr__(self):
class PostgresNode(object):
- def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionParams = ConnectionParams(), bin_dir=None, prefix=None):
+ def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionParams = ConnectionParams(),
+ bin_dir=None, prefix=None):
"""
PostgresNode constructor.
@@ -150,13 +151,9 @@ def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionP
self.name = name or generate_app_name()
if testgres_config.os_ops:
self.os_ops = testgres_config.os_ops
- elif conn_params.ssh_key:
- self.os_ops = RemoteOperations(conn_params)
- else:
- self.os_ops = LocalOperations(conn_params)
self.host = self.os_ops.host
- self.port = port or reserve_port()
+ self.port = port or self.os_ops.port or reserve_port()
self.ssh_key = self.os_ops.ssh_key
@@ -1005,7 +1002,7 @@ def psql(self,
# select query source
if query:
- if self.os_ops.remote:
+ if self.os_ops.conn_params.remote:
psql_params.extend(("-c", '"{}"'.format(query)))
else:
psql_params.extend(("-c", query))
@@ -1016,7 +1013,7 @@ def psql(self,
# should be the last one
psql_params.append(dbname)
- if not self.os_ops.remote:
+ if not self.os_ops.conn_params.remote:
# start psql process
process = subprocess.Popen(psql_params,
stdin=subprocess.PIPE,
diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py
index a0a9926d..796e15c2 100644
--- a/testgres/operations/local_ops.py
+++ b/testgres/operations/local_ops.py
@@ -40,12 +40,7 @@ class LocalOperations(OsOperations):
def __init__(self, conn_params=None):
if conn_params is None:
conn_params = ConnectionParams()
- super(LocalOperations, self).__init__(conn_params.username)
- self.conn_params = conn_params
- self.host = conn_params.host
- self.ssh_key = None
- self.remote = False
- self.username = conn_params.username or getpass.getuser()
+ super(LocalOperations, self).__init__(conn_params)
@staticmethod
def _raise_exec_exception(message, command, exit_code, output):
@@ -305,14 +300,3 @@ def get_pid(self):
def get_process_children(self, pid):
return psutil.Process(pid).children()
-
- # Database control
- def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
- conn = pglib.connect(
- host=host,
- port=port,
- database=dbname,
- user=user,
- password=password,
- )
- return conn
diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py
index 34242040..f027bfef 100644
--- a/testgres/operations/os_ops.py
+++ b/testgres/operations/os_ops.py
@@ -12,11 +12,16 @@
class ConnectionParams:
- def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None):
+ def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None, remote=False, skip_ssl=False):
+ """
+ skip_ssl: if is True, the connection is established without SSL.
+ """
+ self.remote = remote
self.host = host
self.port = port
self.ssh_key = ssh_key
self.username = username
+ self.skip_ssl = skip_ssl
def get_default_encoding():
@@ -26,9 +31,12 @@ def get_default_encoding():
class OsOperations:
- def __init__(self, username=None):
- self.ssh_key = None
- self.username = username or getpass.getuser()
+ def __init__(self, conn_params=ConnectionParams()):
+ self.ssh_key = conn_params.ssh_key
+ self.username = conn_params.username or getpass.getuser()
+ self.host = conn_params.host
+ self.port = conn_params.port
+ self.conn_params = conn_params
# Command execution
def exec_command(self, cmd, **kwargs):
@@ -115,4 +123,14 @@ def get_process_children(self, pid):
# Database control
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
- raise NotImplementedError()
+ ssl_options = {"sslmode": "disable"} if self.conn_params.skip_ssl and 'psycopg2' in globals() else {}
+ conn = pglib.connect(
+ host=host,
+ port=port,
+ database=dbname,
+ user=user,
+ password=password,
+ **({"ssl_context": None} if self.conn_params.skip_ssl and 'pg8000' in globals() else ssl_options)
+ )
+
+ return conn
diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py
index 20095051..24a8b9fe 100644
--- a/testgres/operations/remote_ops.py
+++ b/testgres/operations/remote_ops.py
@@ -37,23 +37,17 @@ def cmdline(self):
class RemoteOperations(OsOperations):
def __init__(self, conn_params: ConnectionParams):
-
if not platform.system().lower() == "linux":
raise EnvironmentError("Remote operations are supported only on Linux!")
+ super().__init__(conn_params)
- super().__init__(conn_params.username)
- self.conn_params = conn_params
- self.host = conn_params.host
- self.port = conn_params.port
- self.ssh_key = conn_params.ssh_key
self.ssh_args = []
if self.ssh_key:
self.ssh_args += ["-i", self.ssh_key]
if self.port:
self.ssh_args += ["-p", self.port]
- self.remote = True
self.username = conn_params.username or getpass.getuser()
- self.ssh_dest = f"{self.username}@{self.host}" if conn_params.username else self.host
+ self.ssh_dest = f"{self.username}@{self.host}" if self.username else self.host
def __enter__(self):
return self
@@ -361,17 +355,6 @@ def get_process_children(self, pid):
else:
raise ExecUtilException(f"Error in getting process children. Error: {result.stderr}")
- # Database control
- def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
- conn = pglib.connect(
- host=host,
- port=port,
- database=dbname,
- user=user,
- password=password,
- )
- return conn
-
def normalize_error(error):
if isinstance(error, bytes):
diff --git a/testgres/utils.py b/testgres/utils.py
index a4ee7877..aa61d270 100644
--- a/testgres/utils.py
+++ b/testgres/utils.py
@@ -97,7 +97,7 @@ def get_bin_path(filename):
# check if it's already absolute
if os.path.isabs(filename):
return filename
- if tconf.os_ops.remote:
+ if tconf.os_ops.conn_params.remote:
pg_config = os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG")
else:
# try PG_CONFIG - get from local machine
@@ -154,7 +154,7 @@ def cache_pg_config_data(cmd):
return _pg_config_data
# try specified pg_config path or PG_CONFIG
- if tconf.os_ops.remote:
+ if tconf.os_ops.conn_params.remote:
pg_config = pg_config_path or os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG")
else:
# try PG_CONFIG - get from local machine
diff --git a/tests/test_remote.py b/tests/test_remote.py
index e0e4a555..bb13108d 100755
--- a/tests/test_remote.py
+++ b/tests/test_remote.py
@@ -2,7 +2,7 @@
import pytest
-from testgres import ExecUtilException
+from testgres import ExecUtilException, get_remote_node, testgres_config
from testgres import RemoteOperations
from testgres import ConnectionParams
@@ -34,7 +34,7 @@ def test_exec_command_failure(self):
exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True)
except ExecUtilException as e:
error = e.message
- assert error == b'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n'
+ assert error == 'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n'
def test_is_executable_true(self):
"""
@@ -87,7 +87,7 @@ def test_makedirs_and_rmdirs_failure(self):
exit_status, result, error = self.operations.rmdirs(path, verbose=True)
except ExecUtilException as e:
error = e.message
- assert error == b"Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n"
+ assert error == "Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n"
def test_listdir(self):
"""
@@ -192,3 +192,25 @@ def test_isfile_false(self):
response = self.operations.isfile(filename)
assert response is False
+
+ def test_skip_ssl(self):
+ conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1',
+ username=os.getenv('USER'),
+ remote=True,
+ skip_ssl=True)
+ os_ops = RemoteOperations(conn_params)
+ testgres_config.set_os_ops(os_ops=os_ops)
+ with get_remote_node().init().start() as node:
+ with node.connect() as con:
+ con.begin()
+ con.execute('create table test(val int)')
+ con.execute('insert into test values (1)')
+ con.commit()
+
+ con.begin()
+ con.execute('insert into test values (2)')
+ res = con.execute('select * from test order by val asc')
+ if isinstance(res, list):
+ res.sort()
+ assert res == [(1,), (2,)]
+
diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py
index d51820ba..0d6f3dd5 100755
--- a/tests/test_simple_remote.py
+++ b/tests/test_simple_remote.py
@@ -96,16 +96,16 @@ def removing(f):
class TestgresRemoteTests(unittest.TestCase):
def test_node_repr(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
pattern = r"PostgresNode\(name='.+', port=.+, base_dir='.+'\)"
self.assertIsNotNone(re.match(pattern, str(node)))
def test_custom_init(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
# enable page checksums
node.init(initdb_params=['-k']).start()
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init(
allow_streaming=True,
initdb_params=['--auth-local=reject', '--auth-host=reject'])
@@ -120,13 +120,13 @@ def test_custom_init(self):
self.assertFalse(any('trust' in s for s in lines))
def test_double_init(self):
- with get_remote_node(conn_params=conn_params).init() as node:
+ with get_remote_node().init() as node:
# can't initialize node more than once
with self.assertRaises(InitNodeException):
node.init()
def test_init_after_cleanup(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init().start().execute('select 1')
node.cleanup()
node.init().start().execute('select 1')
@@ -138,7 +138,7 @@ def test_init_unique_system_id(self):
query = 'select system_identifier from pg_control_system()'
with scoped_config(cache_initdb=False):
- with get_remote_node(conn_params=conn_params).init().start() as node0:
+ with get_remote_node().init().start() as node0:
id0 = node0.execute(query)[0]
with scoped_config(cache_initdb=True,
@@ -147,8 +147,8 @@ def test_init_unique_system_id(self):
self.assertTrue(config.cached_initdb_unique)
# spawn two nodes; ids must be different
- with get_remote_node(conn_params=conn_params).init().start() as node1, \
- get_remote_node(conn_params=conn_params).init().start() as node2:
+ with get_remote_node().init().start() as node1, \
+ get_remote_node().init().start() as node2:
id1 = node1.execute(query)[0]
id2 = node2.execute(query)[0]
@@ -158,7 +158,7 @@ def test_init_unique_system_id(self):
def test_node_exit(self):
with self.assertRaises(QueryException):
- with get_remote_node(conn_params=conn_params).init() as node:
+ with get_remote_node().init() as node:
base_dir = node.base_dir
node.safe_psql('select 1')
@@ -166,26 +166,26 @@ def test_node_exit(self):
self.assertTrue(os_ops.path_exists(base_dir))
os_ops.rmdirs(base_dir, ignore_errors=True)
- with get_remote_node(conn_params=conn_params).init() as node:
+ with get_remote_node().init() as node:
base_dir = node.base_dir
# should have been removed by default
self.assertFalse(os_ops.path_exists(base_dir))
def test_double_start(self):
- with get_remote_node(conn_params=conn_params).init().start() as node:
+ with get_remote_node().init().start() as node:
# can't start node more than once
node.start()
self.assertTrue(node.is_started)
def test_uninitialized_start(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
# node is not initialized yet
with self.assertRaises(StartNodeException):
node.start()
def test_restart(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init().start()
# restart, ok
@@ -201,7 +201,7 @@ def test_restart(self):
node.restart()
def test_reload(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init().start()
# change client_min_messages and save old value
@@ -217,7 +217,7 @@ def test_reload(self):
self.assertNotEqual(cmm_old, cmm_new)
def test_pg_ctl(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init().start()
status = node.pg_ctl(['status'])
@@ -229,7 +229,7 @@ def test_status(self):
self.assertFalse(NodeStatus.Uninitialized)
# check statuses after each operation
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
self.assertEqual(node.pid, 0)
self.assertEqual(node.status(), NodeStatus.Uninitialized)
@@ -254,7 +254,7 @@ def test_status(self):
self.assertEqual(node.status(), NodeStatus.Uninitialized)
def test_psql(self):
- with get_remote_node(conn_params=conn_params).init().start() as node:
+ with get_remote_node().init().start() as node:
# check returned values (1 arg)
res = node.psql('select 1')
self.assertEqual(res, (0, b'1\n', b''))
@@ -297,7 +297,7 @@ def test_psql(self):
node.safe_psql('select 1')
def test_transactions(self):
- with get_remote_node(conn_params=conn_params).init().start() as node:
+ with get_remote_node().init().start() as node:
with node.connect() as con:
con.begin()
con.execute('create table test(val int)')
@@ -320,7 +320,7 @@ def test_transactions(self):
con.commit()
def test_control_data(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
# node is not initialized yet
with self.assertRaises(ExecUtilException):
node.get_control_data()
@@ -333,7 +333,7 @@ def test_control_data(self):
self.assertTrue(any('pg_control' in s for s in data.keys()))
def test_backup_simple(self):
- with get_remote_node(conn_params=conn_params) as master:
+ with get_remote_node() as master:
# enable streaming for backups
master.init(allow_streaming=True)
@@ -353,7 +353,7 @@ def test_backup_simple(self):
self.assertListEqual(res, [(1,), (2,), (3,), (4,)])
def test_backup_multiple(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init(allow_streaming=True).start()
with node.backup(xlog_method='fetch') as backup1, \
@@ -366,7 +366,7 @@ def test_backup_multiple(self):
self.assertNotEqual(node1.base_dir, node2.base_dir)
def test_backup_exhaust(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init(allow_streaming=True).start()
with node.backup(xlog_method='fetch') as backup:
@@ -379,7 +379,7 @@ def test_backup_exhaust(self):
backup.spawn_primary()
def test_backup_wrong_xlog_method(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init(allow_streaming=True).start()
with self.assertRaises(BackupException,
@@ -387,7 +387,7 @@ def test_backup_wrong_xlog_method(self):
node.backup(xlog_method='wrong')
def test_pg_ctl_wait_option(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init().start(wait=False)
while True:
try:
@@ -399,7 +399,7 @@ def test_pg_ctl_wait_option(self):
pass
def test_replicate(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init(allow_streaming=True).start()
with node.replicate().start() as replica:
@@ -415,7 +415,7 @@ def test_replicate(self):
@unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+')
def test_synchronous_replication(self):
- with get_remote_node(conn_params=conn_params) as master:
+ with get_remote_node() as master:
old_version = not pg_version_ge('9.6')
master.init(allow_streaming=True).start()
@@ -456,7 +456,7 @@ def test_synchronous_replication(self):
@unittest.skipUnless(pg_version_ge('10'), 'requires 10+')
def test_logical_replication(self):
- with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2:
+ with get_remote_node() as node1, get_remote_node() as node2:
node1.init(allow_logical=True)
node1.start()
node2.init().start()
@@ -526,7 +526,7 @@ def test_logical_replication(self):
@unittest.skipUnless(pg_version_ge('10'), 'requires 10+')
def test_logical_catchup(self):
""" Runs catchup for 100 times to be sure that it is consistent """
- with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2:
+ with get_remote_node() as node1, get_remote_node() as node2:
node1.init(allow_logical=True)
node1.start()
node2.init().start()
@@ -551,12 +551,12 @@ def test_logical_catchup(self):
@unittest.skipIf(pg_version_ge('10'), 'requires <10')
def test_logical_replication_fail(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
with self.assertRaises(InitNodeException):
node.init(allow_logical=True)
def test_replication_slots(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init(allow_streaming=True).start()
with node.replicate(slot='slot1').start() as replica:
@@ -567,7 +567,7 @@ def test_replication_slots(self):
node.replicate(slot='slot1')
def test_incorrect_catchup(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init(allow_streaming=True).start()
# node has no master, can't catch up
@@ -575,7 +575,7 @@ def test_incorrect_catchup(self):
node.catchup()
def test_promotion(self):
- with get_remote_node(conn_params=conn_params) as master:
+ with get_remote_node() as master:
master.init().start()
master.safe_psql('create table abc(id serial)')
@@ -592,12 +592,12 @@ def test_dump(self):
query_create = 'create table test as select generate_series(1, 2) as val'
query_select = 'select * from test order by val asc'
- with get_remote_node(conn_params=conn_params).init().start() as node1:
+ with get_remote_node().init().start() as node1:
node1.execute(query_create)
for format in ['plain', 'custom', 'directory', 'tar']:
with removing(node1.dump(format=format)) as dump:
- with get_remote_node(conn_params=conn_params).init().start() as node3:
+ with get_remote_node().init().start() as node3:
if format == 'directory':
self.assertTrue(os_ops.isdir(dump))
else:
@@ -608,13 +608,13 @@ def test_dump(self):
self.assertListEqual(res, [(1,), (2,)])
def test_users(self):
- with get_remote_node(conn_params=conn_params).init().start() as node:
+ with get_remote_node().init().start() as node:
node.psql('create role test_user login')
value = node.safe_psql('select 1', username='test_user')
self.assertEqual(b'1\n', value)
def test_poll_query_until(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init().start()
get_time = 'select extract(epoch from now())'
@@ -728,7 +728,7 @@ def test_logging(self):
@unittest.skipUnless(util_exists('pgbench'), 'might be missing')
def test_pgbench(self):
- with get_remote_node(conn_params=conn_params).init().start() as node:
+ with get_remote_node().init().start() as node:
# initialize pgbench DB and run benchmarks
node.pgbench_init(scale=2, foreign_keys=True,
options=['-q']).pgbench_run(time=2)
@@ -796,7 +796,7 @@ def test_config_stack(self):
self.assertEqual(TestgresConfig.cached_initdb_dir, d0)
def test_unix_sockets(self):
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
node.init(unix_sockets=False, allow_streaming=True)
node.start()
@@ -812,7 +812,7 @@ def test_unix_sockets(self):
self.assertEqual(res_psql, b'1\n')
def test_auto_name(self):
- with get_remote_node(conn_params=conn_params).init(allow_streaming=True).start() as m:
+ with get_remote_node().init(allow_streaming=True).start() as m:
with m.replicate().start() as r:
# check that nodes are running
self.assertTrue(m.status())
@@ -849,7 +849,7 @@ def test_file_tail(self):
self.assertEqual(lines[0], s3)
def test_isolation_levels(self):
- with get_remote_node(conn_params=conn_params).init().start() as node:
+ with get_remote_node().init().start() as node:
with node.connect() as con:
# string levels
con.begin('Read Uncommitted').commit()
@@ -871,7 +871,7 @@ def test_ports_management(self):
# check that no ports have been bound yet
self.assertEqual(len(bound_ports), 0)
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
# check that we've just bound a port
self.assertEqual(len(bound_ports), 1)
@@ -904,7 +904,7 @@ def test_version_management(self):
self.assertTrue(d > f)
version = get_pg_version()
- with get_remote_node(conn_params=conn_params) as node:
+ with get_remote_node() as node:
self.assertTrue(isinstance(version, six.string_types))
self.assertTrue(isinstance(node.version, PgVer))
self.assertEqual(node.version, PgVer(version))
@@ -922,12 +922,15 @@ def test_child_pids(self):
if pg_version_ge('10'):
master_processes.append(ProcessType.LogicalReplicationLauncher)
+ if pg_version_ge('14'):
+ master_processes.remove(ProcessType.StatsCollector)
+
repl_processes = [
ProcessType.Startup,
ProcessType.WalReceiver,
]
- with get_remote_node(conn_params=conn_params).init().start() as master:
+ with get_remote_node().init().start() as master:
# master node doesn't have a source walsender!
with self.assertRaises(TestgresException):
From 104a127152cc4d1e4aad5013d54050294a29eec5 Mon Sep 17 00:00:00 2001
From: vshepard
Date: Fri, 1 Nov 2024 23:45:34 +0100
Subject: [PATCH 2/9] Don't reserve a new port if port was set up
---
testgres/node.py | 88 ++++++++++++++++++++++++++++++------------------
1 file changed, 56 insertions(+), 32 deletions(-)
diff --git a/testgres/node.py b/testgres/node.py
index 78ebd87b..3028e1bc 100644
--- a/testgres/node.py
+++ b/testgres/node.py
@@ -96,7 +96,6 @@
from .operations.os_ops import ConnectionParams
from .operations.local_ops import LocalOperations
-from .operations.remote_ops import RemoteOperations
InternalError = pglib.InternalError
ProgrammingError = pglib.ProgrammingError
@@ -487,7 +486,7 @@ def init(self, initdb_params=None, cached=True, **kwargs):
os_ops=self.os_ops,
params=initdb_params,
bin_path=self.bin_dir,
- cached=False)
+ cached=cached)
# initialize default config files
self.default_conf(**kwargs)
@@ -717,9 +716,9 @@ def slow_start(self, replica=False, dbname='template1', username=None, max_attem
OperationalError},
max_attempts=max_attempts)
- def start(self, params=[], wait=True):
+ def start(self, params=None, wait: bool = True) -> 'PostgresNode':
"""
- Starts the PostgreSQL node using pg_ctl if node has not been started.
+ Starts the PostgreSQL node using pg_ctl if the node has not been started.
By default, it waits for the operation to complete before returning.
Optionally, it can return immediately without waiting for the start operation
to complete by setting the `wait` parameter to False.
@@ -731,45 +730,62 @@ def start(self, params=[], wait=True):
Returns:
This instance of :class:`.PostgresNode`.
"""
+ if params is None:
+ params = []
if self.is_started:
return self
_params = [
- self._get_bin_path("pg_ctl"),
- "-D", self.data_dir,
- "-l", self.pg_log_file,
- "-w" if wait else '-W', # --wait or --no-wait
- "start"
- ] + params # yapf: disable
+ self._get_bin_path("pg_ctl"),
+ "-D", self.data_dir,
+ "-l", self.pg_log_file,
+ "-w" if wait else '-W', # --wait or --no-wait
+ "start"
+ ] + params # yapf: disable
- startup_retries = 5
- while True:
+ max_retries = 5
+ sleep_interval = 5 # seconds
+
+ for attempt in range(max_retries):
try:
exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True)
if error and 'does not exist' in error:
raise Exception
+ break # Exit the loop if successful
except Exception as e:
- files = self._collect_special_files()
- if any(len(file) > 1 and 'Is another postmaster already '
- 'running on port' in file[1].decode() for
- file in files):
- logging.warning("Detected an issue with connecting to port {0}. "
- "Trying another port after a 5-second sleep...".format(self.port))
- self.port = reserve_port()
- options = {'port': str(self.port)}
- self.set_auto_conf(options)
- startup_retries -= 1
- time.sleep(5)
- continue
-
- msg = 'Cannot start node'
- raise_from(StartNodeException(msg, files), e)
- break
+ if self._handle_port_conflict():
+ if attempt < max_retries - 1:
+ logging.info(f"Retrying start operation (Attempt {attempt + 2}/{max_retries})...")
+ time.sleep(sleep_interval)
+ continue
+ else:
+ logging.error("Reached maximum retry attempts. Unable to start node.")
+ raise StartNodeException("Cannot start node after multiple attempts",
+ self._collect_special_files()) from e
+ raise StartNodeException("Cannot start node", self._collect_special_files()) from e
+
self._maybe_start_logger()
self.is_started = True
return self
- def stop(self, params=[], wait=True):
+ def _handle_port_conflict(self) -> bool:
+ """
+ Checks for a port conflict and attempts to resolve it by changing the port.
+ Returns True if the port was changed, False otherwise.
+ """
+ files = self._collect_special_files()
+ if any(len(file) > 1 and 'Is another postmaster already running on port' in file[1].decode() for file in files):
+ logging.warning(f"Port conflict detected on port {self.port}.")
+ if self._should_free_port:
+ logging.warning("Port reservation skipped due to _should_free_port setting.")
+ return False
+ self.port = reserve_port()
+ self.set_auto_conf({'port': str(self.port)})
+ logging.info(f"Port changed to {self.port}.")
+ return True
+ return False
+
+ def stop(self, params=None, wait=True):
"""
Stops the PostgreSQL node using pg_ctl if the node has been started.
@@ -780,6 +796,8 @@ def stop(self, params=[], wait=True):
Returns:
This instance of :class:`.PostgresNode`.
"""
+ if params is None:
+ params = []
if not self.is_started:
return self
@@ -812,7 +830,7 @@ def kill(self, someone=None):
os.kill(self.auxiliary_pids[someone][0], sig)
self.is_started = False
- def restart(self, params=[]):
+ def restart(self, params=None):
"""
Restart this node using pg_ctl.
@@ -823,6 +841,8 @@ def restart(self, params=[]):
This instance of :class:`.PostgresNode`.
"""
+ if params is None:
+ params = []
_params = [
self._get_bin_path("pg_ctl"),
"-D", self.data_dir,
@@ -844,7 +864,7 @@ def restart(self, params=[]):
return self
- def reload(self, params=[]):
+ def reload(self, params=None):
"""
Asynchronously reload config files using pg_ctl.
@@ -855,6 +875,8 @@ def reload(self, params=[]):
This instance of :class:`.PostgresNode`.
"""
+ if params is None:
+ params = []
_params = [
self._get_bin_path("pg_ctl"),
"-D", self.data_dir,
@@ -1587,7 +1609,7 @@ def pgbench_table_checksums(self, dbname="postgres",
return {(table, self.table_checksum(table, dbname))
for table in pgbench_tables}
- def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}):
+ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options=None):
"""
Update or remove configuration options in the specified configuration file,
updates the options specified in the options dictionary, removes any options
@@ -1603,6 +1625,8 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}):
Defaults to an empty set.
"""
# parse postgresql.auto.conf
+ if rm_options is None:
+ rm_options = {}
path = os.path.join(self.data_dir, config)
lines = self.os_ops.readlines(path)
From 19ef23f130cb0c534aa4128faf80237c6c0c5666 Mon Sep 17 00:00:00 2001
From: vshepard
Date: Mon, 18 Nov 2024 11:39:50 +0100
Subject: [PATCH 3/9] Fix flake8 style
---
testgres/node.py | 12 +++++-------
testgres/operations/local_ops.py | 3 +--
testgres/operations/remote_ops.py | 9 ---------
3 files changed, 6 insertions(+), 18 deletions(-)
diff --git a/testgres/node.py b/testgres/node.py
index 3028e1bc..d580e49e 100644
--- a/testgres/node.py
+++ b/testgres/node.py
@@ -735,13 +735,11 @@ def start(self, params=None, wait: bool = True) -> 'PostgresNode':
if self.is_started:
return self
- _params = [
- self._get_bin_path("pg_ctl"),
- "-D", self.data_dir,
- "-l", self.pg_log_file,
- "-w" if wait else '-W', # --wait or --no-wait
- "start"
- ] + params # yapf: disable
+ _params = [self._get_bin_path("pg_ctl"),
+ "-D", self.data_dir,
+ "-l", self.pg_log_file,
+ "-w" if wait else '-W', # --wait or --no-wait
+ "start"] + params # yapf: disable
max_retries = 5
sleep_interval = 5 # seconds
diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py
index 796e15c2..7d3a99eb 100644
--- a/testgres/operations/local_ops.py
+++ b/testgres/operations/local_ops.py
@@ -1,4 +1,3 @@
-import getpass
import logging
import os
import shutil
@@ -10,7 +9,7 @@
import psutil
from ..exceptions import ExecUtilException
-from .os_ops import ConnectionParams, OsOperations, pglib, get_default_encoding
+from .os_ops import ConnectionParams, OsOperations, get_default_encoding
try:
from shutil import which as find_executable
diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py
index 24a8b9fe..ed88d1e4 100644
--- a/testgres/operations/remote_ops.py
+++ b/testgres/operations/remote_ops.py
@@ -4,15 +4,6 @@
import subprocess
import tempfile
-# we support both pg8000 and psycopg2
-try:
- import psycopg2 as pglib
-except ImportError:
- try:
- import pg8000 as pglib
- except ImportError:
- raise ImportError("You must have psycopg2 or pg8000 modules installed")
-
from ..exceptions import ExecUtilException
from .os_ops import OsOperations, ConnectionParams, get_default_encoding
From 43c91c06b886830ee0d0b1970595e9121c671b5b Mon Sep 17 00:00:00 2001
From: vshepard
Date: Mon, 18 Nov 2024 11:40:04 +0100
Subject: [PATCH 4/9] Fix test_the_same_port
---
tests/test_remote.py | 1 -
tests/test_simple.py | 9 ++++-----
2 files changed, 4 insertions(+), 6 deletions(-)
diff --git a/tests/test_remote.py b/tests/test_remote.py
index bb13108d..bab08f93 100755
--- a/tests/test_remote.py
+++ b/tests/test_remote.py
@@ -213,4 +213,3 @@ def test_skip_ssl(self):
if isinstance(res, list):
res.sort()
assert res == [(1,), (2,)]
-
diff --git a/tests/test_simple.py b/tests/test_simple.py
index 8f85a23b..068dbca2 100644
--- a/tests/test_simple.py
+++ b/tests/test_simple.py
@@ -1039,10 +1039,9 @@ def test_parse_pg_version(self):
def test_the_same_port(self):
with get_new_node() as node:
node.init().start()
-
- with get_new_node() as node2:
- node2.port = node.port
- node2.init().start()
+ with get_new_node() as node2:
+ node2.port = node.port
+ node2.init().start()
def test_make_simple_with_bin_dir(self):
with get_new_node() as node:
@@ -1059,7 +1058,7 @@ def test_make_simple_with_bin_dir(self):
wrong_bin_dir.slow_start()
raise RuntimeError("Error was expected.") # We should not reach this
except FileNotFoundError:
- pass # Expected error
+ pass # Expected error
if __name__ == '__main__':
From 1e8d91280bd2a42462b18e69c25ba1db7b984663 Mon Sep 17 00:00:00 2001
From: vshepard
Date: Mon, 18 Nov 2024 14:48:08 +0100
Subject: [PATCH 5/9] Fix failed test_ports_management
---
testgres/node.py | 5 +++--
tests/test_simple.py | 1 +
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/testgres/node.py b/testgres/node.py
index d580e49e..53de3163 100644
--- a/testgres/node.py
+++ b/testgres/node.py
@@ -152,8 +152,6 @@ def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionP
self.os_ops = testgres_config.os_ops
self.host = self.os_ops.host
- self.port = port or self.os_ops.port or reserve_port()
-
self.ssh_key = self.os_ops.ssh_key
# defaults for __exit__()
@@ -161,6 +159,8 @@ def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionP
self.cleanup_on_bad_exit = testgres_config.node_cleanup_on_bad_exit
self.shutdown_max_attempts = 3
+ self.port = port or self.os_ops.port or reserve_port()
+
# NOTE: for compatibility
self.utils_log_name = self.utils_log_file
self.pg_log_name = self.pg_log_file
@@ -810,6 +810,7 @@ def stop(self, params=None, wait=True):
self._maybe_stop_logger()
self.is_started = False
+ release_port(self.port)
return self
def kill(self, someone=None):
diff --git a/tests/test_simple.py b/tests/test_simple.py
index 068dbca2..b8c07958 100644
--- a/tests/test_simple.py
+++ b/tests/test_simple.py
@@ -1052,6 +1052,7 @@ def test_make_simple_with_bin_dir(self):
correct_bin_dir = app.make_simple(base_dir=node.base_dir, bin_dir=bin_dir)
correct_bin_dir.slow_start()
correct_bin_dir.safe_psql("SELECT 1;")
+ correct_bin_dir.stop()
try:
wrong_bin_dir = app.make_empty(base_dir=node.base_dir, bin_dir="wrong/path")
From f1d28b44952a4aa1cf6eff4b0e1e4761b35e89c9 Mon Sep 17 00:00:00 2001
From: vshepard
Date: Mon, 18 Nov 2024 18:14:01 +0100
Subject: [PATCH 6/9] Add env variable TESTGRES_SKIP_SSL
---
testgres/operations/os_ops.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py
index f027bfef..5dfd7841 100644
--- a/testgres/operations/os_ops.py
+++ b/testgres/operations/os_ops.py
@@ -1,5 +1,6 @@
import getpass
import locale
+import os
import sys
try:
@@ -12,15 +13,17 @@
class ConnectionParams:
- def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None, remote=False, skip_ssl=False):
+ def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None, remote=False, skip_ssl=None):
"""
- skip_ssl: if is True, the connection is established without SSL.
+ skip_ssl: if is True, the connection to database is established without SSL.
"""
self.remote = remote
self.host = host
self.port = port
self.ssh_key = ssh_key
self.username = username
+ if skip_ssl is None:
+ skip_ssl = os.getenv("TESTGRES_SKIP_SSL", False)
self.skip_ssl = skip_ssl
From e729c2f2fecc0a63fef96c50cfcf4624e66901ef Mon Sep 17 00:00:00 2001
From: vshepard
Date: Mon, 18 Nov 2024 20:00:12 +0100
Subject: [PATCH 7/9] Fix sys.modules instead globals
---
testgres/operations/os_ops.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py
index 5dfd7841..1e1575ce 100644
--- a/testgres/operations/os_ops.py
+++ b/testgres/operations/os_ops.py
@@ -126,14 +126,14 @@ def get_process_children(self, pid):
# Database control
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
- ssl_options = {"sslmode": "disable"} if self.conn_params.skip_ssl and 'psycopg2' in globals() else {}
+ ssl_options = {"sslmode": "disable"} if self.conn_params.skip_ssl and 'psycopg2' in sys.modules else {}
conn = pglib.connect(
host=host,
port=port,
database=dbname,
user=user,
password=password,
- **({"ssl_context": None} if self.conn_params.skip_ssl and 'pg8000' in globals() else ssl_options)
+ **({"ssl_context": None} if self.conn_params.skip_ssl and 'pg8000' in sys.modules else ssl_options)
)
return conn
From dc4b4c3dbf26e1d019b7fe0395a2fb8f933b0c38 Mon Sep 17 00:00:00 2001
From: vshepard
Date: Mon, 18 Nov 2024 20:16:26 +0100
Subject: [PATCH 8/9] Move _get_ssl_options in separate function
---
testgres/operations/os_ops.py | 15 +++++++++++++--
tests/test_simple.py | 8 +++++---
2 files changed, 18 insertions(+), 5 deletions(-)
diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py
index 1e1575ce..b77762df 100644
--- a/testgres/operations/os_ops.py
+++ b/testgres/operations/os_ops.py
@@ -124,16 +124,27 @@ def get_pid(self):
def get_process_children(self, pid):
raise NotImplementedError()
+ def _get_ssl_options(self):
+ """
+ Determine the SSL options based on available modules.
+ """
+ if self.conn_params.skip_ssl:
+ if 'psycopg2' in sys.modules:
+ return {"sslmode": "disable"}
+ elif 'pg8000' in sys.modules:
+ return {"ssl_context": None}
+ return {}
+
# Database control
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
- ssl_options = {"sslmode": "disable"} if self.conn_params.skip_ssl and 'psycopg2' in sys.modules else {}
+ ssl_options = self._get_ssl_options()
conn = pglib.connect(
host=host,
port=port,
database=dbname,
user=user,
password=password,
- **({"ssl_context": None} if self.conn_params.skip_ssl and 'pg8000' in sys.modules else ssl_options)
+ **ssl_options
)
return conn
diff --git a/tests/test_simple.py b/tests/test_simple.py
index b8c07958..9cc48e7e 100644
--- a/tests/test_simple.py
+++ b/tests/test_simple.py
@@ -1039,9 +1039,11 @@ def test_parse_pg_version(self):
def test_the_same_port(self):
with get_new_node() as node:
node.init().start()
- with get_new_node() as node2:
- node2.port = node.port
- node2.init().start()
+ with get_new_node() as node2:
+ node2.port = node.port
+ # _should_free_port is true if port was set up manually
+ node2._should_free_port = False
+ node2.init().start()
def test_make_simple_with_bin_dir(self):
with get_new_node() as node:
From fa6d7519018de0d94a3474d332d5f049cd7889ea Mon Sep 17 00:00:00 2001
From: "d.kovalenko"
Date: Tue, 24 Dec 2024 18:21:41 +0300
Subject: [PATCH 9/9] [BUG FIX]
TestgresRemoteTests.test_safe_psql__expect_error is corrected
get_remote_node() must be called without any parameters.
---
tests/test_simple_remote.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py
index 8d700e52..da671a5d 100755
--- a/tests/test_simple_remote.py
+++ b/tests/test_simple_remote.py
@@ -297,7 +297,7 @@ def test_psql(self):
node.safe_psql('select 1')
def test_safe_psql__expect_error(self):
- with get_remote_node(conn_params=conn_params).init().start() as node:
+ with get_remote_node().init().start() as node:
err = node.safe_psql('select_or_not_select 1', expect_error=True)
self.assertTrue(type(err) == str) # noqa: E721
self.assertIn('select_or_not_select', err)
--- a PPN by Garber Painting Akron. With Image Size Reduction included!Fetched URL: http://github.com/postgrespro/testgres/pull/149.patch
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy