Skip to content

Commit 3453c52

Browse files
committed
Add support for asynchronous working
1 parent 1444759 commit 3453c52

File tree

2 files changed

+335
-293
lines changed

2 files changed

+335
-293
lines changed

testgres/testgres.py

Lines changed: 100 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,10 @@
4545
from enum import Enum
4646
from distutils.version import LooseVersion
4747

48-
# Try to use psycopg2 by default. If psycopg2 isn't available then use
49-
# pg8000 which is slower but much more portable because uses only
50-
# pure-Python code
5148
try:
52-
import psycopg2 as pglib
49+
import asyncpg as pglib
5350
except ImportError:
54-
try:
55-
import pg8000 as pglib
56-
except ImportError:
57-
raise ImportError("You must have psycopg2 or pg8000 modules installed")
51+
raise ImportError("You must have asyncpg module installed")
5852

5953
# ports used by nodes
6054
bound_ports = set()
@@ -193,26 +187,34 @@ def __init__(self,
193187
password=None):
194188

195189
# Use default user if not specified
196-
username = username or default_username()
197-
190+
self.username = username or default_username()
191+
self.dbname = dbname
192+
self.host = host
193+
self.password = password
198194
self.parent_node = parent_node
195+
self.connection = None
196+
self.current_transaction = None
199197

200-
self.connection = pglib.connect(
201-
database=dbname,
202-
user=username,
203-
port=parent_node.port,
204-
host=host,
205-
password=password)
198+
async def init_connection(self):
199+
if self.connection:
200+
return
206201

207-
self.cursor = self.connection.cursor()
202+
self.connection = await pglib.connect(
203+
database=self.dbname,
204+
user=self.username,
205+
port=self.parent_node.port,
206+
host=self.host,
207+
password=self.password)
208208

209-
def __enter__(self):
209+
async def __aenter__(self):
210210
return self
211211

212-
def __exit__(self, type, value, traceback):
213-
self.close()
212+
async def __aexit__(self, type, value, traceback):
213+
await self.close()
214+
215+
async def begin(self, isolation_level=IsolationLevel.ReadCommitted):
216+
await self.init_connection()
214217

215-
def begin(self, isolation_level=IsolationLevel.ReadCommitted):
216218
# yapf: disable
217219
levels = [
218220
'read uncommitted',
@@ -245,37 +247,45 @@ def begin(self, isolation_level=IsolationLevel.ReadCommitted):
245247

246248
# Set isolation level
247249
cmd = 'SET TRANSACTION ISOLATION LEVEL {}'
248-
self.cursor.execute(cmd.format(isolation_level))
250+
self.current_transaction = self.connection.transaction()
251+
await self.current_transaction.start()
252+
await self.connection.execute(cmd.format(isolation_level))
249253

250254
return self
251255

252-
def commit(self):
253-
self.connection.commit()
256+
async def commit(self):
257+
if not self.current_transaction:
258+
raise QueryException("transaction is not started")
254259

255-
return self
256-
257-
def rollback(self):
258-
self.connection.rollback()
260+
await self.current_transaction.commit()
261+
self.current_transaction = None
259262

260-
return self
263+
async def rollback(self):
264+
if not self.current_transaction:
265+
raise QueryException("transaction is not started")
261266

262-
def execute(self, query, *args):
263-
self.cursor.execute(query, args)
267+
await self.current_transaction.rollback()
268+
self.current_transaction = None
264269

265-
try:
266-
res = self.cursor.fetchall()
267-
268-
# pg8000 might return tuples
269-
if isinstance(res, tuple):
270-
res = [tuple(t) for t in res]
270+
async def execute(self, query, *args):
271+
await self.init_connection()
272+
if self.current_transaction:
273+
return await self.connection.execute(query, *args)
274+
else:
275+
async with self.connection.transaction():
276+
return await self.connection.execute(query, *args)
271277

272-
return res
273-
except Exception:
274-
return None
278+
async def fetch(self, query, *args):
279+
await self.init_connection()
280+
if self.current_transaction:
281+
return await self.connection.fetch(query, *args)
282+
else:
283+
async with self.connection.transaction():
284+
return await self.connection.fetch(query, *args)
275285

276-
def close(self):
277-
self.cursor.close()
278-
self.connection.close()
286+
async def close(self):
287+
if self.connection:
288+
await self.connection.close()
279289

280290

281291
class NodeBackup(object):
@@ -943,7 +953,7 @@ def restore(self, dbname, filename, username=None):
943953

944954
self.psql(dbname=dbname, filename=filename, username=username)
945955

946-
def poll_query_until(self,
956+
async def poll_query_until(self,
947957
dbname,
948958
query,
949959
username=None,
@@ -973,41 +983,54 @@ def poll_query_until(self,
973983

974984
attempts = 0
975985
while max_attempts == 0 or attempts < max_attempts:
976-
try:
977-
res = self.execute(dbname=dbname,
978-
query=query,
979-
username=username,
980-
commit=True)
981-
982-
if expected is None and res is None:
983-
return # done
986+
res = await self.fetch(dbname=dbname,
987+
query=query,
988+
username=username,
989+
commit=True)
984990

985-
if res is None:
986-
raise QueryException('Query returned None')
991+
if expected is None and res is None:
992+
return # done
987993

988-
if len(res) == 0:
989-
raise QueryException('Query returned 0 rows')
994+
if res is None:
995+
raise QueryException('Query returned None')
990996

991-
if len(res[0]) == 0:
992-
raise QueryException('Query returned 0 columns')
997+
if len(res) == 0:
998+
raise QueryException('Query returned 0 rows')
993999

994-
if res[0][0]:
995-
return # done
1000+
if len(res[0]) == 0:
1001+
raise QueryException('Query returned 0 columns')
9961002

997-
except pglib.ProgrammingError as e:
998-
if raise_programming_error:
999-
raise e
1000-
1001-
except pglib.InternalError as e:
1002-
if raise_internal_error:
1003-
raise e
1003+
if res[0][0]:
1004+
return # done
10041005

10051006
time.sleep(sleep_time)
10061007
attempts += 1
10071008

10081009
raise TimeoutException('Query timeout')
10091010

1010-
def execute(self, dbname, query, username=None, commit=True):
1011+
async def execute(self, dbname, query, username=None, commit=True):
1012+
"""
1013+
Execute a query
1014+
1015+
Args:
1016+
dbname: database name to connect to.
1017+
query: query to be executed.
1018+
username: database user name.
1019+
commit: should we commit this query?
1020+
1021+
Returns:
1022+
A list of tuples representing rows.
1023+
"""
1024+
1025+
async with self.connect(dbname, username) as node_con:
1026+
if commit:
1027+
await node_con.begin()
1028+
1029+
await node_con.execute(query)
1030+
if commit:
1031+
await node_con.commit()
1032+
1033+
async def fetch(self, dbname, query, username=None, commit=True):
10111034
"""
10121035
Execute a query and return all rows as list.
10131036
@@ -1021,10 +1044,13 @@ def execute(self, dbname, query, username=None, commit=True):
10211044
A list of tuples representing rows.
10221045
"""
10231046

1024-
with self.connect(dbname, username) as node_con:
1025-
res = node_con.execute(query)
1047+
async with self.connect(dbname, username) as node_con:
1048+
if commit:
1049+
await node_con.begin()
1050+
1051+
res = await node_con.fetch(query)
10261052
if commit:
1027-
node_con.commit()
1053+
await node_con.commit()
10281054
return res
10291055

10301056
def backup(self, username=None, xlog_method=DEFAULT_XLOG_METHOD):
@@ -1059,7 +1085,7 @@ def replicate(self, name, username=None,
10591085
backup = self.backup(username=username, xlog_method=xlog_method)
10601086
return backup.spawn_replica(name, use_logging=use_logging)
10611087

1062-
def catchup(self, username=None):
1088+
async def catchup(self, username=None):
10631089
"""
10641090
Wait until async replica catches up with its master.
10651091
"""
@@ -1080,8 +1106,8 @@ def catchup(self, username=None):
10801106
raise CatchUpException("Master node is not specified")
10811107

10821108
try:
1083-
lsn = master.execute('postgres', poll_lsn)[0][0]
1084-
self.poll_query_until(dbname='postgres',
1109+
lsn = (await master.fetch('postgres', poll_lsn))[0][0]
1110+
await self.poll_query_until(dbname='postgres',
10851111
username=username,
10861112
query=wait_lsn.format(lsn),
10871113
max_attempts=0) # infinite

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy