Content-Length: 57470 | pFad | http://github.com/postgrespro/testgres/pull/99.patch
thub.com
From edb5708ac4c2aafac9717911ee0ac7f3ea30e5df Mon Sep 17 00:00:00 2001
From: vshepard
Date: Thu, 14 Dec 2023 14:21:55 +0100
Subject: [PATCH 1/4] Fix initdb error on Windows
---
setup.py | 6 +-
testgres/operations/local_ops.py | 100 +++++++++++++++++++++++++-----
testgres/operations/os_ops.py | 6 ++
testgres/operations/remote_ops.py | 34 +++++-----
testgres/utils.py | 48 ++++++++++++--
tests/test_simple.py | 41 +++++++++---
6 files changed, 184 insertions(+), 51 deletions(-)
mode change 100755 => 100644 tests/test_simple.py
diff --git a/setup.py b/setup.py
index 16d4c300..9a01bf16 100755
--- a/setup.py
+++ b/setup.py
@@ -27,7 +27,7 @@
readme = f.read()
setup(
- version='1.9.2',
+ version='1.9.3',
name='testgres',
packages=['testgres', 'testgres.operations'],
description='Testing utility for PostgreSQL and its extensions',
@@ -35,8 +35,8 @@
long_description=readme,
long_description_content_type='text/markdown',
license='PostgreSQL',
- author='Ildar Musin',
- author_email='zildermann@gmail.com',
+ author='Postgres Professional',
+ author_email='testgres@postgrespro.ru',
keywords=['test', 'testing', 'postgresql'],
install_requires=install_requires,
classifiers=[],
diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py
index 36b14058..8006b6f1 100644
--- a/testgres/operations/local_ops.py
+++ b/testgres/operations/local_ops.py
@@ -8,8 +8,7 @@
import psutil
from ..exceptions import ExecUtilException
-from .os_ops import ConnectionParams, OsOperations
-from .os_ops import pglib
+from .os_ops import ConnectionParams, OsOperations, pglib, get_default_encoding
try:
from shutil import which as find_executable
@@ -22,6 +21,12 @@
error_markers = [b'error', b'Permission denied', b'fatal']
+def has_errors(output):
+ if isinstance(output, str):
+ output = output.encode(get_default_encoding())
+ return any(marker in output for marker in error_markers)
+
+
class LocalOperations(OsOperations):
def __init__(self, conn_params=None):
if conn_params is None:
@@ -33,7 +38,38 @@ def __init__(self, conn_params=None):
self.remote = False
self.username = conn_params.username or self.get_user()
- # Command execution
+ @staticmethod
+ def _run_command(cmd, shell, input, timeout, encoding, temp_file=None):
+ """Execute a command and return the process."""
+ if temp_file is not None:
+ stdout = temp_file
+ stderr = subprocess.STDOUT
+ else:
+ stdout = subprocess.PIPE
+ stderr = subprocess.PIPE
+
+ process = subprocess.Popen(
+ cmd,
+ shell=shell,
+ stdin=subprocess.PIPE if input is not None else None,
+ stdout=stdout,
+ stderr=stderr,
+ )
+
+ try:
+ return process.communicate(input=input.encode(encoding) if input else None, timeout=timeout), process
+ except subprocess.TimeoutExpired:
+ process.kill()
+ raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
+
+ @staticmethod
+ def _raise_exec_exception(message, command, exit_code, output):
+ """Raise an ExecUtilException."""
+ raise ExecUtilException(message=message.format(output),
+ command=command,
+ exit_code=exit_code,
+ out=output)
+
def exec_command(self, cmd, wait_exit=False, verbose=False,
expect_error=False, encoding=None, shell=False, text=False,
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
@@ -56,16 +92,15 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
:return: The output of the subprocess.
"""
if os.name == 'nt':
- with tempfile.NamedTemporaryFile() as buf:
- process = subprocess.Popen(cmd, stdout=buf, stderr=subprocess.STDOUT)
- process.communicate()
- buf.seek(0)
- result = buf.read().decode(encoding)
- return result
+ return self._exec_command_windows(cmd, wait_exit=wait_exit, verbose=verbose,
+ expect_error=expect_error, encoding=encoding, shell=shell, text=text,
+ input=input, stdin=stdin, stdout=stdout, stderr=stderr,
+ get_process=get_process, timeout=timeout)
else:
process = subprocess.Popen(
cmd,
shell=shell,
+ stdin=stdin,
stdout=stdout,
stderr=stderr,
)
@@ -79,7 +114,7 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
exit_status = process.returncode
- error_found = exit_status != 0 or any(marker in error for marker in error_markers)
+ error_found = exit_status != 0 or has_errors(error)
if encoding:
result = result.decode(encoding)
@@ -91,15 +126,50 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
if exit_status != 0 or error_found:
if exit_status == 0:
exit_status = 1
- raise ExecUtilException(message='Utility exited with non-zero code. Error `{}`'.format(error),
- command=cmd,
- exit_code=exit_status,
- out=result)
+ self._raise_exec_exception('Utility exited with non-zero code. Error `{}`', cmd, exit_status, result)
if verbose:
return exit_status, result, error
else:
return result
+ @staticmethod
+ def _process_output(process, encoding, temp_file=None):
+ """Process the output of a command."""
+ if temp_file is not None:
+ temp_file.seek(0)
+ output = temp_file.read()
+ else:
+ output = process.stdout.read()
+
+ if encoding:
+ output = output.decode(encoding)
+
+ return output
+
+ def _exec_command_windows(self, cmd, wait_exit=False, verbose=False,
+ expect_error=False, encoding=None, shell=False, text=False,
+ input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+ get_process=None, timeout=None):
+ with tempfile.NamedTemporaryFile(mode='w+b') as temp_file:
+ _, process = self._run_command(cmd, shell, input, timeout, encoding, temp_file)
+ if get_process:
+ return process
+ output = self._process_output(process, encoding, temp_file)
+
+ if process.returncode != 0 or has_errors(output):
+ if process.returncode == 0:
+ process.returncode = 1
+ if expect_error:
+ if verbose:
+ return process.returncode, output, output
+ else:
+ return output
+ else:
+ self._raise_exec_exception('Utility exited with non-zero code. Error `{}`', cmd, process.returncode,
+ output)
+
+ return (process.returncode, output, output) if verbose else output
+
# Environment setup
def environ(self, var_name):
return os.environ.get(var_name)
@@ -210,7 +280,7 @@ def read(self, filename, encoding=None, binary=False):
if binary:
return content
if isinstance(content, bytes):
- return content.decode(encoding or 'utf-8')
+ return content.decode(encoding or get_default_encoding())
return content
def readlines(self, filename, num_lines=0, binary=False, encoding=None):
diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py
index 9261cacf..6ee07170 100644
--- a/testgres/operations/os_ops.py
+++ b/testgres/operations/os_ops.py
@@ -1,3 +1,5 @@
+import locale
+
try:
import psycopg2 as pglib # noqa: F401
except ImportError:
@@ -14,6 +16,10 @@ def __init__(self, host='127.0.0.1', ssh_key=None, username=None):
self.username = username
+def get_default_encoding():
+ return locale.getdefaultlocale()[1] or 'UTF-8'
+
+
class OsOperations:
def __init__(self, username=None):
self.ssh_key = None
diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py
index 0a545834..41f643ae 100644
--- a/testgres/operations/remote_ops.py
+++ b/testgres/operations/remote_ops.py
@@ -1,4 +1,3 @@
-import locale
import logging
import os
import subprocess
@@ -15,12 +14,7 @@
raise ImportError("You must have psycopg2 or pg8000 modules installed")
from ..exceptions import ExecUtilException
-
-from .os_ops import OsOperations, ConnectionParams
-
-ConsoleEncoding = locale.getdefaultlocale()[1]
-if not ConsoleEncoding:
- ConsoleEncoding = 'UTF-8'
+from .os_ops import OsOperations, ConnectionParams, get_default_encoding
error_markers = [b'error', b'Permission denied', b'fatal', b'No such file or directory']
@@ -36,7 +30,7 @@ def kill(self):
def cmdline(self):
command = "ps -p {} -o cmd --no-headers".format(self.pid)
- stdin, stdout, stderr = self.ssh.exec_command(command, verbose=True, encoding=ConsoleEncoding)
+ stdin, stdout, stderr = self.ssh.exec_command(command, verbose=True, encoding=get_default_encoding())
cmdline = stdout.strip()
return cmdline.split()
@@ -145,7 +139,7 @@ def environ(self, var_name: str) -> str:
- var_name (str): The name of the environment variable.
"""
cmd = "echo ${}".format(var_name)
- return self.exec_command(cmd, encoding=ConsoleEncoding).strip()
+ return self.exec_command(cmd, encoding=get_default_encoding()).strip()
def find_executable(self, executable):
search_paths = self.environ("PATH")
@@ -176,11 +170,11 @@ def set_env(self, var_name: str, var_val: str):
# Get environment variables
def get_user(self):
- return self.exec_command("echo $USER", encoding=ConsoleEncoding).strip()
+ return self.exec_command("echo $USER", encoding=get_default_encoding()).strip()
def get_name(self):
cmd = 'python3 -c "import os; print(os.name)"'
- return self.exec_command(cmd, encoding=ConsoleEncoding).strip()
+ return self.exec_command(cmd, encoding=get_default_encoding()).strip()
# Work with dirs
def makedirs(self, path, remove_existing=False):
@@ -227,7 +221,7 @@ def listdir(self, path):
return result.splitlines()
def path_exists(self, path):
- result = self.exec_command("test -e {}; echo $?".format(path), encoding=ConsoleEncoding)
+ result = self.exec_command("test -e {}; echo $?".format(path), encoding=get_default_encoding())
return int(result.strip()) == 0
@property
@@ -264,9 +258,9 @@ def mkdtemp(self, prefix=None):
def mkstemp(self, prefix=None):
if prefix:
- temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding=ConsoleEncoding)
+ temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding=get_default_encoding())
else:
- temp_dir = self.exec_command("mktemp", encoding=ConsoleEncoding)
+ temp_dir = self.exec_command("mktemp", encoding=get_default_encoding())
if temp_dir:
if not os.path.isabs(temp_dir):
@@ -283,7 +277,9 @@ def copytree(self, src, dst):
return self.exec_command("cp -r {} {}".format(src, dst))
# Work with files
- def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding=ConsoleEncoding):
+ def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding=None):
+ if not encoding:
+ encoding = get_default_encoding()
mode = "wb" if binary else "w"
if not truncate:
mode = "ab" if binary else "a"
@@ -302,7 +298,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
data = data.encode(encoding)
if isinstance(data, list):
- data = [(s if isinstance(s, str) else s.decode(ConsoleEncoding)).rstrip('\n') + '\n' for s in data]
+ data = [(s if isinstance(s, str) else s.decode(get_default_encoding())).rstrip('\n') + '\n' for s in data]
tmp_file.writelines(data)
else:
tmp_file.write(data)
@@ -334,7 +330,7 @@ def read(self, filename, binary=False, encoding=None):
result = self.exec_command(cmd, encoding=encoding)
if not binary and result:
- result = result.decode(encoding or ConsoleEncoding)
+ result = result.decode(encoding or get_default_encoding())
return result
@@ -347,7 +343,7 @@ def readlines(self, filename, num_lines=0, binary=False, encoding=None):
result = self.exec_command(cmd, encoding=encoding)
if not binary and result:
- lines = result.decode(encoding or ConsoleEncoding).splitlines()
+ lines = result.decode(encoding or get_default_encoding()).splitlines()
else:
lines = result.splitlines()
@@ -375,7 +371,7 @@ def kill(self, pid, signal):
def get_pid(self):
# Get current process id
- return int(self.exec_command("echo $$", encoding=ConsoleEncoding))
+ return int(self.exec_command("echo $$", encoding=get_default_encoding()))
def get_process_children(self, pid):
command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"pgrep -P {pid}"]
diff --git a/testgres/utils.py b/testgres/utils.py
index db75fadc..24f25b11 100644
--- a/testgres/utils.py
+++ b/testgres/utils.py
@@ -4,13 +4,16 @@
from __future__ import print_function
import os
-import port_for
+import random
+import socket
+
import sys
from contextlib import contextmanager
from packaging.version import Version, InvalidVersion
import re
+from port_for import PortForException
from six import iteritems
from .exceptions import ExecUtilException
@@ -37,13 +40,49 @@ def reserve_port():
"""
Generate a new port and add it to 'bound_ports'.
"""
-
- port = port_for.select_random(exclude_ports=bound_ports)
+ port = select_random(exclude_ports=bound_ports)
bound_ports.add(port)
return port
+def select_random(
+ ports=None,
+ exclude_ports=None,
+) -> int:
+ """
+ Return random unused port number.
+ Standard function from port_for does not work on Windows because of error
+ 'port_for.exceptions.PortForException: Can't select a port'
+ We should update it.
+ """
+ if ports is None:
+ ports = set(range(1024, 65535))
+
+ if exclude_ports is None:
+ exclude_ports = set()
+
+ ports.difference_update(set(exclude_ports))
+
+ sampled_ports = random.sample(tuple(ports), min(len(ports), 100))
+
+ for port in sampled_ports:
+ if is_port_free(port):
+ return port
+
+ raise PortForException("Can't select a port")
+
+
+def is_port_free(port: int) -> bool:
+ """Check if a port is free to use."""
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ try:
+ s.bind(("", port))
+ return True
+ except OSError:
+ return False
+
+
def release_port(port):
"""
Free port provided by reserve_port().
@@ -80,7 +119,8 @@ def execute_utility(args, logfile=None, verbose=False):
lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n']
tconf.os_ops.write(filename=logfile, data=lines)
except IOError:
- raise ExecUtilException("Problem with writing to logfile `{}` during run command `{}`".format(logfile, args))
+ raise ExecUtilException(
+ "Problem with writing to logfile `{}` during run command `{}`".format(logfile, args))
if verbose:
return exit_status, out, error
else:
diff --git a/tests/test_simple.py b/tests/test_simple.py
old mode 100755
new mode 100644
index 45c28a21..4b4ab7ef
--- a/tests/test_simple.py
+++ b/tests/test_simple.py
@@ -74,6 +74,24 @@ def good_properties(f):
return True
+def rm_carriage_returns(out):
+ """
+ In Windows we have additional '\r' symbols in output.
+ Let's get rid of them.
+ """
+ if os.name == 'nt':
+ if isinstance(out, (int, float, complex)):
+ return out
+ elif isinstance(out, tuple):
+ return tuple(rm_carriage_returns(item) for item in out)
+ elif isinstance(out, bytes):
+ return out.replace(b'\r', b'')
+ else:
+ return out.replace('\r', '')
+ else:
+ return out
+
+
@contextmanager
def removing(f):
try:
@@ -254,34 +272,34 @@ def test_psql(self):
# check returned values (1 arg)
res = node.psql('select 1')
- self.assertEqual(res, (0, b'1\n', b''))
+ self.assertEqual(rm_carriage_returns(res), (0, b'1\n', b''))
# check returned values (2 args)
res = node.psql('postgres', 'select 2')
- self.assertEqual(res, (0, b'2\n', b''))
+ self.assertEqual(rm_carriage_returns(res), (0, b'2\n', b''))
# check returned values (named)
res = node.psql(query='select 3', dbname='postgres')
- self.assertEqual(res, (0, b'3\n', b''))
+ self.assertEqual(rm_carriage_returns(res), (0, b'3\n', b''))
# check returned values (1 arg)
res = node.safe_psql('select 4')
- self.assertEqual(res, b'4\n')
+ self.assertEqual(rm_carriage_returns(res), b'4\n')
# check returned values (2 args)
res = node.safe_psql('postgres', 'select 5')
- self.assertEqual(res, b'5\n')
+ self.assertEqual(rm_carriage_returns(res), b'5\n')
# check returned values (named)
res = node.safe_psql(query='select 6', dbname='postgres')
- self.assertEqual(res, b'6\n')
+ self.assertEqual(rm_carriage_returns(res), b'6\n')
# check feeding input
node.safe_psql('create table horns (w int)')
node.safe_psql('copy horns from stdin (format csv)',
input=b"1\n2\n3\n\\.\n")
_sum = node.safe_psql('select sum(w) from horns')
- self.assertEqual(_sum, b'6\n')
+ self.assertEqual(rm_carriage_returns(_sum), b'6\n')
# check psql's default args, fails
with self.assertRaises(QueryException):
@@ -455,7 +473,7 @@ def test_synchronous_replication(self):
master.safe_psql(
'insert into abc select generate_series(1, 1000000)')
res = standby1.safe_psql('select count(*) from abc')
- self.assertEqual(res, b'1000000\n')
+ self.assertEqual(rm_carriage_returns(res), b'1000000\n')
@unittest.skipUnless(pg_version_ge('10'), 'requires 10+')
def test_logical_replication(self):
@@ -589,7 +607,7 @@ def test_promotion(self):
# make standby becomes writable master
replica.safe_psql('insert into abc values (1)')
res = replica.safe_psql('select * from abc')
- self.assertEqual(res, b'1\n')
+ self.assertEqual(rm_carriage_returns(res), b'1\n')
def test_dump(self):
query_create = 'create table test as select generate_series(1, 2) as val'
@@ -614,6 +632,7 @@ def test_users(self):
with get_new_node().init().start() as node:
node.psql('create role test_user login')
value = node.safe_psql('select 1', username='test_user')
+ value = rm_carriage_returns(value)
self.assertEqual(value, b'1\n')
def test_poll_query_until(self):
@@ -977,7 +996,9 @@ def test_child_pids(self):
def test_child_process_dies(self):
# test for FileNotFound exception during child_processes() function
- with subprocess.Popen(["sleep", "60"]) as process:
+ cmd = ["timeout", "60"] if os.name == 'nt' else ["sleep", "60"]
+
+ with subprocess.Popen(cmd, shell=True) as process: # shell=True might be needed on Windows
self.assertEqual(process.poll(), None)
# collect list of processes currently running
children = psutil.Process(os.getpid()).children()
From 036b924230f9e1100f8b234ada531c0031bf2c4d Mon Sep 17 00:00:00 2001
From: vshepard
Date: Sun, 17 Dec 2023 17:59:56 +0100
Subject: [PATCH 2/4] Fix initdb error on Windows - fix pgbench
---
testgres/operations/local_ops.py | 17 +++++-------
testgres/operations/remote_ops.py | 21 ++++++++-------
testgres/utils.py | 44 +++----------------------------
tests/test_remote.py | 9 +++----
tests/test_simple.py | 2 ++
tests/test_simple_remote.py | 7 +++--
6 files changed, 31 insertions(+), 69 deletions(-)
diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py
index 8006b6f1..924a9287 100644
--- a/testgres/operations/local_ops.py
+++ b/testgres/operations/local_ops.py
@@ -55,12 +55,7 @@ def _run_command(cmd, shell, input, timeout, encoding, temp_file=None):
stdout=stdout,
stderr=stderr,
)
-
- try:
- return process.communicate(input=input.encode(encoding) if input else None, timeout=timeout), process
- except subprocess.TimeoutExpired:
- process.kill()
- raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
+ return process
@staticmethod
def _raise_exec_exception(message, command, exit_code, output):
@@ -151,10 +146,12 @@ def _exec_command_windows(self, cmd, wait_exit=False, verbose=False,
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
get_process=None, timeout=None):
with tempfile.NamedTemporaryFile(mode='w+b') as temp_file:
- _, process = self._run_command(cmd, shell, input, timeout, encoding, temp_file)
- if get_process:
- return process
- output = self._process_output(process, encoding, temp_file)
+ process = self._run_command(cmd, shell, input, timeout, encoding, temp_file)
+ try:
+ output = process.communicate(input=input.encode(encoding) if input else None, timeout=timeout)
+ except subprocess.TimeoutExpired:
+ process.kill()
+ raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
if process.returncode != 0 or has_errors(output):
if process.returncode == 0:
diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py
index 41f643ae..01251e1c 100644
--- a/testgres/operations/remote_ops.py
+++ b/testgres/operations/remote_ops.py
@@ -45,6 +45,10 @@ def __init__(self, conn_params: ConnectionParams):
self.conn_params = conn_params
self.host = conn_params.host
self.ssh_key = conn_params.ssh_key
+ if self.ssh_key:
+ self.ssh_cmd = ["-i", self.ssh_key]
+ else:
+ self.ssh_cmd = []
self.remote = True
self.username = conn_params.username or self.get_user()
self.add_known_host(self.host)
@@ -91,9 +95,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
"""
ssh_cmd = []
if isinstance(cmd, str):
- ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key, cmd]
+ ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + [cmd]
elif isinstance(cmd, list):
- ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key] + cmd
+ ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + cmd
process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if get_process:
return process
@@ -242,9 +246,9 @@ def mkdtemp(self, prefix=None):
- prefix (str): The prefix of the temporary directory name.
"""
if prefix:
- command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
+ command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
else:
- command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", "mktemp -d"]
+ command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", "mktemp -d"]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
@@ -288,7 +292,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file:
if not truncate:
- scp_cmd = ['scp', '-i', self.ssh_key, f"{self.username}@{self.host}:{filename}", tmp_file.name]
+ scp_cmd = ['scp'] + self.ssh_cmd + [f"{self.username}@{self.host}:{filename}", tmp_file.name]
subprocess.run(scp_cmd, check=False) # The file might not exist yet
tmp_file.seek(0, os.SEEK_END)
@@ -304,12 +308,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
tmp_file.write(data)
tmp_file.flush()
-
- scp_cmd = ['scp', '-i', self.ssh_key, tmp_file.name, f"{self.username}@{self.host}:{filename}"]
+ scp_cmd = ['scp'] + self.ssh_cmd + [tmp_file.name, f"{self.username}@{self.host}:{filename}"]
subprocess.run(scp_cmd, check=True)
remote_directory = os.path.dirname(filename)
- mkdir_cmd = ['ssh', '-i', self.ssh_key, f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"]
+ mkdir_cmd = ['ssh'] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"]
subprocess.run(mkdir_cmd, check=True)
os.remove(tmp_file.name)
@@ -374,7 +377,7 @@ def get_pid(self):
return int(self.exec_command("echo $$", encoding=get_default_encoding()))
def get_process_children(self, pid):
- command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"pgrep -P {pid}"]
+ command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"pgrep -P {pid}"]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
diff --git a/testgres/utils.py b/testgres/utils.py
index 24f25b11..7ce84fea 100644
--- a/testgres/utils.py
+++ b/testgres/utils.py
@@ -4,8 +4,6 @@
from __future__ import print_function
import os
-import random
-import socket
import sys
@@ -13,9 +11,9 @@
from packaging.version import Version, InvalidVersion
import re
-from port_for import PortForException
from six import iteritems
+from helpers.port_manager import PortManager
from .exceptions import ExecUtilException
from .config import testgres_config as tconf
@@ -40,49 +38,13 @@ def reserve_port():
"""
Generate a new port and add it to 'bound_ports'.
"""
- port = select_random(exclude_ports=bound_ports)
+ port_mng = PortManager()
+ port = port_mng.find_free_port(exclude_ports=bound_ports)
bound_ports.add(port)
return port
-def select_random(
- ports=None,
- exclude_ports=None,
-) -> int:
- """
- Return random unused port number.
- Standard function from port_for does not work on Windows because of error
- 'port_for.exceptions.PortForException: Can't select a port'
- We should update it.
- """
- if ports is None:
- ports = set(range(1024, 65535))
-
- if exclude_ports is None:
- exclude_ports = set()
-
- ports.difference_update(set(exclude_ports))
-
- sampled_ports = random.sample(tuple(ports), min(len(ports), 100))
-
- for port in sampled_ports:
- if is_port_free(port):
- return port
-
- raise PortForException("Can't select a port")
-
-
-def is_port_free(port: int) -> bool:
- """Check if a port is free to use."""
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- try:
- s.bind(("", port))
- return True
- except OSError:
- return False
-
-
def release_port(port):
"""
Free port provided by reserve_port().
diff --git a/tests/test_remote.py b/tests/test_remote.py
index 2e0f0676..e0e4a555 100755
--- a/tests/test_remote.py
+++ b/tests/test_remote.py
@@ -11,10 +11,9 @@ class TestRemoteOperations:
@pytest.fixture(scope="function", autouse=True)
def setup(self):
- conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3',
- username='dev',
- ssh_key=os.getenv(
- 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519')
+ conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1',
+ username=os.getenv('USER'),
+ ssh_key=os.getenv('RDBMS_TESTPOOL_SSHKEY'))
self.operations = RemoteOperations(conn_params)
def test_exec_command_success(self):
@@ -41,7 +40,7 @@ def test_is_executable_true(self):
"""
Test is_executable for an existing executable.
"""
- cmd = "postgres"
+ cmd = os.getenv('PG_CONFIG')
response = self.operations.is_executable(cmd)
assert response is True
diff --git a/tests/test_simple.py b/tests/test_simple.py
index 4b4ab7ef..8e3abf1c 100644
--- a/tests/test_simple.py
+++ b/tests/test_simple.py
@@ -763,6 +763,8 @@ def test_pgbench(self):
out, _ = proc.communicate()
out = out.decode('utf-8')
+ proc.stdout.close()
+
self.assertTrue('tps' in out)
def test_pg_config(self):
diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py
index 1042f3c4..d51820ba 100755
--- a/tests/test_simple_remote.py
+++ b/tests/test_simple_remote.py
@@ -52,10 +52,9 @@
from testgres.utils import PgVer
from testgres.node import ProcessProxy, ConnectionParams
-conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3',
- username='dev',
- ssh_key=os.getenv(
- 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519')
+conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1',
+ username=os.getenv('USER'),
+ ssh_key=os.getenv('RDBMS_TESTPOOL_SSHKEY'))
os_ops = RemoteOperations(conn_params)
testgres_config.set_os_ops(os_ops=os_ops)
From 05cd996c9864c4da57d416cc2ef137728af070b2 Mon Sep 17 00:00:00 2001
From: vshepard
Date: Sun, 17 Dec 2023 17:59:56 +0100
Subject: [PATCH 3/4] Fix initdb error on Windows - fix pgbench
---
testgres/helpers/__init__.py | 0
testgres/helpers/port_manager.py | 40 ++++++++++++++++++++++++++++
testgres/operations/local_ops.py | 17 +++++-------
testgres/operations/remote_ops.py | 21 ++++++++-------
testgres/utils.py | 44 +++----------------------------
tests/test_remote.py | 9 +++----
tests/test_simple.py | 2 ++
tests/test_simple_remote.py | 7 +++--
8 files changed, 71 insertions(+), 69 deletions(-)
create mode 100644 testgres/helpers/__init__.py
create mode 100644 testgres/helpers/port_manager.py
diff --git a/testgres/helpers/__init__.py b/testgres/helpers/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/testgres/helpers/port_manager.py b/testgres/helpers/port_manager.py
new file mode 100644
index 00000000..6afdf8a9
--- /dev/null
+++ b/testgres/helpers/port_manager.py
@@ -0,0 +1,40 @@
+import socket
+import random
+from typing import Set, Iterable, Optional
+
+
+class PortForException(Exception):
+ pass
+
+
+class PortManager:
+ def __init__(self, ports_range=(1024, 65535)):
+ self.ports_range = ports_range
+
+ @staticmethod
+ def is_port_free(port: int) -> bool:
+ """Check if a port is free to use."""
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ try:
+ s.bind(("", port))
+ return True
+ except OSError:
+ return False
+
+ def find_free_port(self, ports: Optional[Set[int]] = None, exclude_ports: Optional[Iterable[int]] = None) -> int:
+ """Return a random unused port number."""
+ if ports is None:
+ ports = set(range(1024, 65535))
+
+ if exclude_ports is None:
+ exclude_ports = set()
+
+ ports.difference_update(set(exclude_ports))
+
+ sampled_ports = random.sample(tuple(ports), min(len(ports), 100))
+
+ for port in sampled_ports:
+ if self.is_port_free(port):
+ return port
+
+ raise PortForException("Can't select a port")
diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py
index 8006b6f1..924a9287 100644
--- a/testgres/operations/local_ops.py
+++ b/testgres/operations/local_ops.py
@@ -55,12 +55,7 @@ def _run_command(cmd, shell, input, timeout, encoding, temp_file=None):
stdout=stdout,
stderr=stderr,
)
-
- try:
- return process.communicate(input=input.encode(encoding) if input else None, timeout=timeout), process
- except subprocess.TimeoutExpired:
- process.kill()
- raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
+ return process
@staticmethod
def _raise_exec_exception(message, command, exit_code, output):
@@ -151,10 +146,12 @@ def _exec_command_windows(self, cmd, wait_exit=False, verbose=False,
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
get_process=None, timeout=None):
with tempfile.NamedTemporaryFile(mode='w+b') as temp_file:
- _, process = self._run_command(cmd, shell, input, timeout, encoding, temp_file)
- if get_process:
- return process
- output = self._process_output(process, encoding, temp_file)
+ process = self._run_command(cmd, shell, input, timeout, encoding, temp_file)
+ try:
+ output = process.communicate(input=input.encode(encoding) if input else None, timeout=timeout)
+ except subprocess.TimeoutExpired:
+ process.kill()
+ raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
if process.returncode != 0 or has_errors(output):
if process.returncode == 0:
diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py
index 41f643ae..01251e1c 100644
--- a/testgres/operations/remote_ops.py
+++ b/testgres/operations/remote_ops.py
@@ -45,6 +45,10 @@ def __init__(self, conn_params: ConnectionParams):
self.conn_params = conn_params
self.host = conn_params.host
self.ssh_key = conn_params.ssh_key
+ if self.ssh_key:
+ self.ssh_cmd = ["-i", self.ssh_key]
+ else:
+ self.ssh_cmd = []
self.remote = True
self.username = conn_params.username or self.get_user()
self.add_known_host(self.host)
@@ -91,9 +95,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
"""
ssh_cmd = []
if isinstance(cmd, str):
- ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key, cmd]
+ ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + [cmd]
elif isinstance(cmd, list):
- ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key] + cmd
+ ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + cmd
process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if get_process:
return process
@@ -242,9 +246,9 @@ def mkdtemp(self, prefix=None):
- prefix (str): The prefix of the temporary directory name.
"""
if prefix:
- command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
+ command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
else:
- command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", "mktemp -d"]
+ command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", "mktemp -d"]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
@@ -288,7 +292,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file:
if not truncate:
- scp_cmd = ['scp', '-i', self.ssh_key, f"{self.username}@{self.host}:{filename}", tmp_file.name]
+ scp_cmd = ['scp'] + self.ssh_cmd + [f"{self.username}@{self.host}:{filename}", tmp_file.name]
subprocess.run(scp_cmd, check=False) # The file might not exist yet
tmp_file.seek(0, os.SEEK_END)
@@ -304,12 +308,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
tmp_file.write(data)
tmp_file.flush()
-
- scp_cmd = ['scp', '-i', self.ssh_key, tmp_file.name, f"{self.username}@{self.host}:{filename}"]
+ scp_cmd = ['scp'] + self.ssh_cmd + [tmp_file.name, f"{self.username}@{self.host}:{filename}"]
subprocess.run(scp_cmd, check=True)
remote_directory = os.path.dirname(filename)
- mkdir_cmd = ['ssh', '-i', self.ssh_key, f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"]
+ mkdir_cmd = ['ssh'] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"]
subprocess.run(mkdir_cmd, check=True)
os.remove(tmp_file.name)
@@ -374,7 +377,7 @@ def get_pid(self):
return int(self.exec_command("echo $$", encoding=get_default_encoding()))
def get_process_children(self, pid):
- command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"pgrep -P {pid}"]
+ command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"pgrep -P {pid}"]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
diff --git a/testgres/utils.py b/testgres/utils.py
index 24f25b11..7ce84fea 100644
--- a/testgres/utils.py
+++ b/testgres/utils.py
@@ -4,8 +4,6 @@
from __future__ import print_function
import os
-import random
-import socket
import sys
@@ -13,9 +11,9 @@
from packaging.version import Version, InvalidVersion
import re
-from port_for import PortForException
from six import iteritems
+from helpers.port_manager import PortManager
from .exceptions import ExecUtilException
from .config import testgres_config as tconf
@@ -40,49 +38,13 @@ def reserve_port():
"""
Generate a new port and add it to 'bound_ports'.
"""
- port = select_random(exclude_ports=bound_ports)
+ port_mng = PortManager()
+ port = port_mng.find_free_port(exclude_ports=bound_ports)
bound_ports.add(port)
return port
-def select_random(
- ports=None,
- exclude_ports=None,
-) -> int:
- """
- Return random unused port number.
- Standard function from port_for does not work on Windows because of error
- 'port_for.exceptions.PortForException: Can't select a port'
- We should update it.
- """
- if ports is None:
- ports = set(range(1024, 65535))
-
- if exclude_ports is None:
- exclude_ports = set()
-
- ports.difference_update(set(exclude_ports))
-
- sampled_ports = random.sample(tuple(ports), min(len(ports), 100))
-
- for port in sampled_ports:
- if is_port_free(port):
- return port
-
- raise PortForException("Can't select a port")
-
-
-def is_port_free(port: int) -> bool:
- """Check if a port is free to use."""
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- try:
- s.bind(("", port))
- return True
- except OSError:
- return False
-
-
def release_port(port):
"""
Free port provided by reserve_port().
diff --git a/tests/test_remote.py b/tests/test_remote.py
index 2e0f0676..e0e4a555 100755
--- a/tests/test_remote.py
+++ b/tests/test_remote.py
@@ -11,10 +11,9 @@ class TestRemoteOperations:
@pytest.fixture(scope="function", autouse=True)
def setup(self):
- conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3',
- username='dev',
- ssh_key=os.getenv(
- 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519')
+ conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1',
+ username=os.getenv('USER'),
+ ssh_key=os.getenv('RDBMS_TESTPOOL_SSHKEY'))
self.operations = RemoteOperations(conn_params)
def test_exec_command_success(self):
@@ -41,7 +40,7 @@ def test_is_executable_true(self):
"""
Test is_executable for an existing executable.
"""
- cmd = "postgres"
+ cmd = os.getenv('PG_CONFIG')
response = self.operations.is_executable(cmd)
assert response is True
diff --git a/tests/test_simple.py b/tests/test_simple.py
index 4b4ab7ef..8e3abf1c 100644
--- a/tests/test_simple.py
+++ b/tests/test_simple.py
@@ -763,6 +763,8 @@ def test_pgbench(self):
out, _ = proc.communicate()
out = out.decode('utf-8')
+ proc.stdout.close()
+
self.assertTrue('tps' in out)
def test_pg_config(self):
diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py
index 1042f3c4..d51820ba 100755
--- a/tests/test_simple_remote.py
+++ b/tests/test_simple_remote.py
@@ -52,10 +52,9 @@
from testgres.utils import PgVer
from testgres.node import ProcessProxy, ConnectionParams
-conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3',
- username='dev',
- ssh_key=os.getenv(
- 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519')
+conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1',
+ username=os.getenv('USER'),
+ ssh_key=os.getenv('RDBMS_TESTPOOL_SSHKEY'))
os_ops = RemoteOperations(conn_params)
testgres_config.set_os_ops(os_ops=os_ops)
From 901a639b9171d646928578f5d015098c1352e095 Mon Sep 17 00:00:00 2001
From: Viktoriia Shepard
Date: Mon, 18 Dec 2023 00:35:07 +0000
Subject: [PATCH 4/4] Refactoring local_ops.py
---
setup.py | 2 +-
testgres/__init__.py | 4 +-
testgres/node.py | 8 +-
testgres/operations/local_ops.py | 175 +++++++++++--------------------
testgres/operations/os_ops.py | 2 +-
5 files changed, 68 insertions(+), 123 deletions(-)
diff --git a/setup.py b/setup.py
index 9a01bf16..e0287659 100755
--- a/setup.py
+++ b/setup.py
@@ -29,7 +29,7 @@
setup(
version='1.9.3',
name='testgres',
- packages=['testgres', 'testgres.operations'],
+ packages=['testgres', 'testgres.operations', 'testgres.helpers'],
description='Testing utility for PostgreSQL and its extensions',
url='https://github.com/postgrespro/testgres',
long_description=readme,
diff --git a/testgres/__init__.py b/testgres/__init__.py
index 383daf2d..8d0e38c6 100644
--- a/testgres/__init__.py
+++ b/testgres/__init__.py
@@ -52,6 +52,8 @@
from .operations.local_ops import LocalOperations
from .operations.remote_ops import RemoteOperations
+from .helpers.port_manager import PortManager
+
__all__ = [
"get_new_node",
"get_remote_node",
@@ -62,6 +64,6 @@
"XLogMethod", "IsolationLevel", "NodeStatus", "ProcessType", "DumpFormat",
"PostgresNode", "NodeApp",
"reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version",
- "First", "Any",
+ "First", "Any", "PortManager",
"OsOperations", "LocalOperations", "RemoteOperations", "ConnectionParams"
]
diff --git a/testgres/node.py b/testgres/node.py
index 52e6d2ee..20cf4264 100644
--- a/testgres/node.py
+++ b/testgres/node.py
@@ -623,8 +623,8 @@ def status(self):
"-D", self.data_dir,
"status"
] # yapf: disable
- status_code, out, err = execute_utility(_params, self.utils_log_file, verbose=True)
- if 'does not exist' in err:
+ status_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True)
+ if error and 'does not exist' in error:
return NodeStatus.Uninitialized
elif 'no server running' in out:
return NodeStatus.Stopped
@@ -717,7 +717,7 @@ def start(self, params=[], wait=True):
try:
exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True)
- if 'does not exist' in error:
+ if error and 'does not exist' in error:
raise Exception
except Exception as e:
msg = 'Cannot start node'
@@ -791,7 +791,7 @@ def restart(self, params=[]):
try:
error_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True)
- if 'could not start server' in error:
+ if error and 'could not start server' in error:
raise ExecUtilException
except ExecUtilException as e:
msg = 'Cannot restart node'
diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py
index 91576867..93ebf012 100644
--- a/testgres/operations/local_ops.py
+++ b/testgres/operations/local_ops.py
@@ -22,9 +22,11 @@
def has_errors(output):
- if isinstance(output, str):
- output = output.encode(get_default_encoding())
- return any(marker in output for marker in error_markers)
+ if output:
+ if isinstance(output, str):
+ output = output.encode(get_default_encoding())
+ return any(marker in output for marker in error_markers)
+ return False
class LocalOperations(OsOperations):
@@ -38,32 +40,6 @@ def __init__(self, conn_params=None):
self.remote = False
self.username = conn_params.username or self.get_user()
- @staticmethod
- def _run_command(cmd, shell, input, stdin, stdout, stderr, timeout, encoding, temp_file=None, get_process=None):
- """Execute a command and return the process."""
- if temp_file is not None:
- stdout = stdout or temp_file
- stderr = stderr or subprocess.STDOUT
- else:
- stdout = stdout or subprocess.PIPE
- stderr = stderr or subprocess.PIPE
-
- process = subprocess.Popen(
- cmd,
- shell=shell,
- stdin=stdin or subprocess.PIPE if input is not None else None,
- stdout=stdout,
- stderr=stderr,
- )
-
- if get_process:
- return None, process
- try:
- return process.communicate(input=input.encode(encoding) if input else None, timeout=timeout), process
- except subprocess.TimeoutExpired:
- process.kill()
- raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
-
@staticmethod
def _raise_exec_exception(message, command, exit_code, output):
"""Raise an ExecUtilException."""
@@ -72,105 +48,72 @@ def _raise_exec_exception(message, command, exit_code, output):
exit_code=exit_code,
out=output)
- def exec_command(self, cmd, wait_exit=False, verbose=False,
- expect_error=False, encoding=None, shell=False, text=False,
- input=None, stdin=None, stdout=None, stderr=None,
- get_process=None, timeout=None):
- """
- Execute a command in a subprocess.
-
- Args:
- - cmd: The command to execute.
- - wait_exit: Whether to wait for the subprocess to exit before returning.
- - verbose: Whether to return verbose output.
- - expect_error: Whether to raise an error if the subprocess exits with an error status.
- - encoding: The encoding to use for decoding the subprocess output.
- - shell: Whether to use shell when executing the subprocess.
- - text: Whether to return str instead of bytes for the subprocess output.
- - input: The input to pass to the subprocess.
- - stdout: The stdout to use for the subprocess.
- - stderr: The stderr to use for the subprocess.
- - proc: The process to use for subprocess creation.
- :return: The output of the subprocess.
- """
- if os.name == 'nt':
- return self._exec_command_windows(cmd, wait_exit=wait_exit, verbose=verbose,
- expect_error=expect_error, encoding=encoding, shell=shell, text=text,
- input=input, stdin=stdin, stdout=stdout, stderr=stderr,
- get_process=get_process, timeout=timeout)
- else:
+ @staticmethod
+ def _process_output(encoding, temp_file_path):
+ """Process the output of a command from a temporary file."""
+ with open(temp_file_path, 'rb') as temp_file:
+ output = temp_file.read()
+ if encoding:
+ output = output.decode(encoding)
+ return output, None # In Windows stderr writing in stdout
+
+ def _run_command(self, cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding):
+ """Execute a command and return the process and its output."""
+ if os.name == 'nt' and stdout is None: # Windows
+ with tempfile.NamedTemporaryFile(mode='w+b', delete=False) as temp_file:
+ stdout = temp_file
+ stderr = subprocess.STDOUT
+ process = subprocess.Popen(
+ cmd,
+ shell=shell,
+ stdin=stdin or subprocess.PIPE if input is not None else None,
+ stdout=stdout,
+ stderr=stderr,
+ )
+ if get_process:
+ return process, None, None
+ temp_file_path = temp_file.name
+
+ # Wait process finished
+ process.wait()
+
+ output, error = self._process_output(encoding, temp_file_path)
+ return process, output, error
+ else: # Other OS
process = subprocess.Popen(
cmd,
shell=shell,
- stdin=stdin,
- stdout=stdout,
- stderr=stderr,
+ stdin=stdin or subprocess.PIPE if input is not None else None,
+ stdout=stdout or subprocess.PIPE,
+ stderr=stderr or subprocess.PIPE,
)
if get_process:
- return process
-
+ return process, None, None
try:
- result, error = process.communicate(input, timeout=timeout)
+ output, error = process.communicate(input=input.encode(encoding) if input else None, timeout=timeout)
+ if encoding:
+ output = output.decode(encoding)
+ error = error.decode(encoding)
+ return process, output, error
except subprocess.TimeoutExpired:
process.kill()
raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
- exit_status = process.returncode
- error_found = exit_status != 0 or has_errors(error)
-
- if encoding:
- result = result.decode(encoding)
- error = error.decode(encoding)
-
- if expect_error:
- raise Exception(result, error)
-
- if exit_status != 0 or error_found:
- if exit_status == 0:
- exit_status = 1
- self._raise_exec_exception('Utility exited with non-zero code. Error `{}`', cmd, exit_status, result)
- if verbose:
- return exit_status, result, error
- else:
- return result
+ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=False,
+ text=False, input=None, stdin=None, stdout=None, stderr=None, get_process=False, timeout=None):
+ """
+ Execute a command in a subprocess and handle the output based on the provided parameters.
+ """
+ process, output, error = self._run_command(cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding)
+ if get_process:
+ return process
+ if process.returncode != 0 or (has_errors(error) and not expect_error):
+ self._raise_exec_exception('Utility exited with non-zero code. Error `{}`', cmd, process.returncode, error)
- @staticmethod
- def _process_output(process, encoding, temp_file=None):
- """Process the output of a command."""
- if temp_file is not None:
- temp_file.seek(0)
- output = temp_file.read()
+ if verbose:
+ return process.returncode, output, error
else:
- output = process.stdout.read()
-
- if encoding:
- output = output.decode(encoding)
-
- return output
-
- def _exec_command_windows(self, cmd, wait_exit=False, verbose=False,
- expect_error=False, encoding=None, shell=False, text=False,
- input=None, stdin=None, stdout=None, stderr=None,
- get_process=None, timeout=None):
- with tempfile.NamedTemporaryFile(mode='w+b') as temp_file:
- _, process = self._run_command(cmd, shell, input, stdin, stdout, stderr, timeout, encoding, temp_file, get_process)
- if get_process:
- return process
- result = self._process_output(process, encoding, temp_file)
-
- if process.returncode != 0 or has_errors(result):
- if process.returncode == 0:
- process.returncode = 1
- if expect_error:
- if verbose:
- return process.returncode, result, result
- else:
- return result
- else:
- self._raise_exec_exception('Utility exited with non-zero code. Error `{}`', cmd, process.returncode,
- result)
-
- return (process.returncode, result, result) if verbose else result
+ return output
# Environment setup
def environ(self, var_name):
diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py
index 6ee07170..dd6613cf 100644
--- a/testgres/operations/os_ops.py
+++ b/testgres/operations/os_ops.py
@@ -81,7 +81,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
def touch(self, filename):
raise NotImplementedError()
- def read(self, filename):
+ def read(self, filename, encoding, binary):
raise NotImplementedError()
def readlines(self, filename):
--- a PPN by Garber Painting Akron. With Image Size Reduction included!Fetched URL: http://github.com/postgrespro/testgres/pull/99.patch
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy