Skip to content

Commit 263ff9c

Browse files
author
ВашÐViktoria Shepard
committed
Remove sshtunnel
1 parent 46eb92a commit 263ff9c

File tree

6 files changed

+68
-59
lines changed

6 files changed

+68
-59
lines changed

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
"port-for>=0.4",
1212
"six>=1.9.0",
1313
"psutil",
14-
"packaging",
15-
"sshtunnel"
14+
"packaging"
1615
]
1716

1817
# Add compatibility enum class

testgres/exceptions.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,16 @@ def __str__(self):
3232
if self.out:
3333
msg.append(u'----\n{}'.format(self.out))
3434

35-
return six.text_type('\n').join(msg)
35+
return self.convert_and_join(msg)
36+
37+
@staticmethod
38+
def convert_and_join(msg_list):
39+
# Convert each byte element in the list to str
40+
str_list = [six.text_type(item, 'utf-8') if isinstance(item, bytes) else six.text_type(item) for item in
41+
msg_list]
42+
43+
# Join the list into a single string with the specified delimiter
44+
return six.text_type('\n').join(str_list)
3645

3746

3847
@six.python_2_unicode_compatible

testgres/node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1371,7 +1371,7 @@ def pgbench(self,
13711371
# should be the last one
13721372
_params.append(dbname)
13731373

1374-
proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, proc=True)
1374+
proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, get_process=True)
13751375

13761376
return proc
13771377

testgres/operations/local_ops.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from distutils.spawn import find_executable
1919
from distutils import rmtree
2020

21-
2221
CMD_TIMEOUT_SEC = 60
2322
error_markers = [b'error', b'Permission denied', b'fatal']
2423

@@ -37,7 +36,8 @@ def __init__(self, conn_params=None):
3736
# Command execution
3837
def exec_command(self, cmd, wait_exit=False, verbose=False,
3938
expect_error=False, encoding=None, shell=False, text=False,
40-
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None):
39+
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
40+
get_process=None, timeout=None):
4141
"""
4242
Execute a command in a subprocess.
4343
@@ -69,9 +69,14 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
6969
stdout=stdout,
7070
stderr=stderr,
7171
)
72-
if proc:
72+
if get_process:
7373
return process
74-
result, error = process.communicate(input)
74+
75+
try:
76+
result, error = process.communicate(input, timeout=timeout)
77+
except subprocess.TimeoutExpired:
78+
process.kill()
79+
raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
7580
exit_status = process.returncode
7681

7782
error_found = exit_status != 0 or any(marker in error for marker in error_markers)

testgres/operations/remote_ops.py

Lines changed: 43 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
import os
44
import subprocess
55
import tempfile
6-
import time
76

8-
import sshtunnel
7+
# we support both pg8000 and psycopg2
8+
try:
9+
import psycopg2 as pglib
10+
except ImportError:
11+
try:
12+
import pg8000 as pglib
13+
except ImportError:
14+
raise ImportError("You must have psycopg2 or pg8000 modules installed")
915

1016
from ..exceptions import ExecUtilException
1117

1218
from .os_ops import OsOperations, ConnectionParams
13-
from .os_ops import pglib
14-
15-
sshtunnel.SSH_TIMEOUT = 5.0
16-
sshtunnel.TUNNEL_TIMEOUT = 5.0
1719

1820
ConsoleEncoding = locale.getdefaultlocale()[1]
1921
if not ConsoleEncoding:
@@ -50,21 +52,28 @@ def __init__(self, conn_params: ConnectionParams):
5052
self.remote = True
5153
self.username = conn_params.username or self.get_user()
5254
self.add_known_host(self.host)
55+
self.tunnel_process = None
5356

5457
def __enter__(self):
5558
return self
5659

5760
def __exit__(self, exc_type, exc_val, exc_tb):
58-
self.close_tunnel()
61+
self.close_ssh_tunnel()
5962

60-
def close_tunnel(self):
61-
if getattr(self, 'tunnel', None):
62-
self.tunnel.stop(force=True)
63-
start_time = time.time()
64-
while self.tunnel.is_active:
65-
if time.time() - start_time > sshtunnel.TUNNEL_TIMEOUT:
66-
break
67-
time.sleep(0.5)
63+
def establish_ssh_tunnel(self, local_port, remote_port):
64+
"""
65+
Establish an SSH tunnel from a local port to a remote PostgreSQL port.
66+
"""
67+
ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"]
68+
self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300)
69+
70+
def close_ssh_tunnel(self):
71+
if hasattr(self, 'tunnel_process'):
72+
self.tunnel_process.terminate()
73+
self.tunnel_process.wait()
74+
del self.tunnel_process
75+
else:
76+
print("No active tunnel to close.")
6877

6978
def add_known_host(self, host):
7079
cmd = 'ssh-keyscan -H %s >> /home/%s/.ssh/known_hosts' % (host, os.getlogin())
@@ -78,21 +87,29 @@ def add_known_host(self, host):
7887
raise ExecUtilException(message="Failed to add %s to known_hosts. Error: %s" % (host, str(e)), command=cmd,
7988
exit_code=e.returncode, out=e.stderr)
8089

81-
def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False,
90+
def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
8291
encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None,
83-
stderr=None, proc=None):
92+
stderr=None, get_process=None, timeout=None):
8493
"""
8594
Execute a command in the SSH session.
8695
Args:
8796
- cmd (str): The command to be executed.
8897
"""
98+
ssh_cmd = []
8999
if isinstance(cmd, str):
90100
ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key, cmd]
91101
elif isinstance(cmd, list):
92102
ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key] + cmd
93103
process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
104+
if get_process:
105+
return process
106+
107+
try:
108+
result, error = process.communicate(input, timeout=timeout)
109+
except subprocess.TimeoutExpired:
110+
process.kill()
111+
raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
94112

95-
result, error = process.communicate(input)
96113
exit_status = process.returncode
97114

98115
if encoding:
@@ -372,41 +389,19 @@ def get_process_children(self, pid):
372389
raise ExecUtilException(f"Error in getting process children. Error: {result.stderr}")
373390

374391
# Database control
375-
def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, ssh_key=None):
392+
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
376393
"""
377-
Connects to a PostgreSQL database on the remote system.
378-
Args:
379-
- dbname (str): The name of the database to connect to.
380-
- user (str): The username for the database connection.
381-
- password (str, optional): The password for the database connection. Defaults to None.
382-
- host (str, optional): The IP address of the remote system. Defaults to "localhost".
383-
- port (int, optional): The port number of the PostgreSQL service. Defaults to 5432.
384-
385-
This function establishes a connection to a PostgreSQL database on the remote system using the specified
386-
parameters. It returns a connection object that can be used to interact with the database.
394+
Established SSH tunnel and Connects to a PostgreSQL
387395
"""
388-
self.close_tunnel()
389-
self.tunnel = sshtunnel.open_tunnel(
390-
(self.host, 22), # Remote server IP and SSH port
391-
ssh_username=self.username,
392-
ssh_pkey=self.ssh_key,
393-
remote_bind_address=(self.host, port), # PostgreSQL server IP and PostgreSQL port
394-
local_bind_address=('localhost', 0)
395-
# Local machine IP and available port (0 means it will pick any available port)
396-
)
397-
self.tunnel.start()
398-
396+
self.establish_ssh_tunnel(local_port=port, remote_port=5432)
399397
try:
400-
# Use localhost and self.tunnel.local_bind_port to connect
401398
conn = pglib.connect(
402-
host='localhost', # Connect to localhost
403-
port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel
399+
host=host,
400+
port=port,
404401
database=dbname,
405-
user=user or self.username,
406-
password=password
402+
user=user,
403+
password=password,
407404
)
408-
409405
return conn
410406
except Exception as e:
411-
self.tunnel.stop()
412-
raise ExecUtilException("Could not create db tunnel. {}".format(e))
407+
raise Exception(f"Could not connect to the database. Error: {e}")

tests/test_simple_remote.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -735,9 +735,10 @@ def test_pgbench(self):
735735
options=['-q']).pgbench_run(time=2)
736736

737737
# run TPC-B benchmark
738-
out = node.pgbench(stdout=subprocess.PIPE,
739-
stderr=subprocess.STDOUT,
740-
options=['-T3'])
738+
proc = node.pgbench(stdout=subprocess.PIPE,
739+
stderr=subprocess.STDOUT,
740+
options=['-T3'])
741+
out = proc.communicate()[0]
741742
self.assertTrue(b'tps = ' in out)
742743

743744
def test_pg_config(self):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy