Skip to content

Commit 3ac0f91

Browse files
authored
feat: add SQL statement for begin transaction isolation level (#1331)
* feat: add SQL statement for egin transaction isolation level Adds an additional option to the `begin [transaction]` SQL statement to specify the isolation level of that transaction. The following format is now supported: ``` {begin | start} [transaction] [isolation level {repeatable read | serializable}] ``` * test: add test for invalid isolation level
1 parent b3c259d commit 3ac0f91

File tree

5 files changed

+186
-3
lines changed

5 files changed

+186
-3
lines changed

google/cloud/spanner_dbapi/client_side_statement_executor.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import TYPE_CHECKING
14+
from typing import TYPE_CHECKING, Union
15+
from google.cloud.spanner_v1 import TransactionOptions
1516

1617
if TYPE_CHECKING:
1718
from google.cloud.spanner_dbapi.cursor import Cursor
@@ -58,7 +59,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
5859
connection.commit()
5960
return None
6061
if statement_type == ClientSideStatementType.BEGIN:
61-
connection.begin()
62+
connection.begin(isolation_level=_get_isolation_level(parsed_statement))
6263
return None
6364
if statement_type == ClientSideStatementType.ROLLBACK:
6465
connection.rollback()
@@ -121,3 +122,19 @@ def _get_streamed_result_set(column_name, type_code, column_values):
121122
column_values_pb.append(_make_value_pb(column_value))
122123
result_set.values.extend(column_values_pb)
123124
return StreamedResultSet(iter([result_set]))
125+
126+
127+
def _get_isolation_level(
128+
statement: ParsedStatement,
129+
) -> Union[TransactionOptions.IsolationLevel, None]:
130+
if (
131+
statement.client_side_statement_params is None
132+
or len(statement.client_side_statement_params) == 0
133+
):
134+
return None
135+
level = statement.client_side_statement_params[0]
136+
if not isinstance(level, str) or level == "":
137+
return None
138+
# Replace (duplicate) whitespaces in the string with an underscore.
139+
level = "_".join(level.split()).upper()
140+
return TransactionOptions.IsolationLevel[level]

google/cloud/spanner_dbapi/client_side_statement_parser.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
Statement,
2222
)
2323

24-
RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(\s+TRANSACTION)?\s*$", re.IGNORECASE)
24+
RE_BEGIN = re.compile(
25+
r"^\s*(?:BEGIN|START)(?:\s+TRANSACTION)?(?:\s+ISOLATION\s+LEVEL\s+(REPEATABLE\s+READ|SERIALIZABLE))?\s*$",
26+
re.IGNORECASE,
27+
)
2528
RE_COMMIT = re.compile(r"^\s*(COMMIT)(\s+TRANSACTION)?\s*$", re.IGNORECASE)
2629
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(\s+TRANSACTION)?\s*$", re.IGNORECASE)
2730
RE_SHOW_COMMIT_TIMESTAMP = re.compile(
@@ -68,6 +71,10 @@ def parse_stmt(query):
6871
elif RE_START_BATCH_DML.match(query):
6972
client_side_statement_type = ClientSideStatementType.START_BATCH_DML
7073
elif RE_BEGIN.match(query):
74+
match = re.search(RE_BEGIN, query)
75+
isolation_level = match.group(1)
76+
if isolation_level is not None:
77+
client_side_statement_params.append(isolation_level)
7178
client_side_statement_type = ClientSideStatementType.BEGIN
7279
elif RE_RUN_BATCH.match(query):
7380
client_side_statement_type = ClientSideStatementType.RUN_BATCH

tests/mockserver_tests/test_dbapi_isolation_level.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from google.api_core.exceptions import Unknown
1516
from google.cloud.spanner_dbapi import Connection
1617
from google.cloud.spanner_v1 import (
1718
BeginTransactionRequest,
@@ -117,3 +118,33 @@ def test_transaction_isolation_level(self):
117118
self.assertEqual(1, len(begin_requests))
118119
self.assertEqual(begin_requests[0].options.isolation_level, level)
119120
MockServerTestBase.spanner_service.clear_requests()
121+
122+
def test_begin_isolation_level(self):
123+
connection = Connection(self.instance, self.database)
124+
for level in [
125+
TransactionOptions.IsolationLevel.REPEATABLE_READ,
126+
TransactionOptions.IsolationLevel.SERIALIZABLE,
127+
]:
128+
isolation_level_name = level.name.replace("_", " ")
129+
with connection.cursor() as cursor:
130+
cursor.execute(f"begin isolation level {isolation_level_name}")
131+
cursor.execute(
132+
"insert into singers (id, name) values (1, 'Some Singer')"
133+
)
134+
self.assertEqual(1, cursor.rowcount)
135+
connection.commit()
136+
begin_requests = list(
137+
filter(
138+
lambda msg: isinstance(msg, BeginTransactionRequest),
139+
self.spanner_service.requests,
140+
)
141+
)
142+
self.assertEqual(1, len(begin_requests))
143+
self.assertEqual(begin_requests[0].options.isolation_level, level)
144+
MockServerTestBase.spanner_service.clear_requests()
145+
146+
def test_begin_invalid_isolation_level(self):
147+
connection = Connection(self.instance, self.database)
148+
with connection.cursor() as cursor:
149+
with self.assertRaises(Unknown):
150+
cursor.execute("begin isolation level does_not_exist")
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2025 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from google.cloud.spanner_dbapi.client_side_statement_executor import (
18+
_get_isolation_level,
19+
)
20+
from google.cloud.spanner_dbapi.parse_utils import classify_statement
21+
from google.cloud.spanner_v1 import TransactionOptions
22+
23+
24+
class TestParseUtils(unittest.TestCase):
25+
def test_get_isolation_level(self):
26+
self.assertIsNone(_get_isolation_level(classify_statement("begin")))
27+
self.assertEqual(
28+
TransactionOptions.IsolationLevel.SERIALIZABLE,
29+
_get_isolation_level(
30+
classify_statement("begin isolation level serializable")
31+
),
32+
)
33+
self.assertEqual(
34+
TransactionOptions.IsolationLevel.SERIALIZABLE,
35+
_get_isolation_level(
36+
classify_statement(
37+
"begin transaction isolation level serializable "
38+
)
39+
),
40+
)
41+
self.assertEqual(
42+
TransactionOptions.IsolationLevel.REPEATABLE_READ,
43+
_get_isolation_level(
44+
classify_statement("begin isolation level repeatable read")
45+
),
46+
)
47+
self.assertEqual(
48+
TransactionOptions.IsolationLevel.REPEATABLE_READ,
49+
_get_isolation_level(
50+
classify_statement(
51+
"begin transaction isolation level repeatable read "
52+
)
53+
),
54+
)

tests/unit/spanner_dbapi/test_parse_utils.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,28 @@ def test_classify_stmt(self):
6363
("commit", StatementType.CLIENT_SIDE),
6464
("begin", StatementType.CLIENT_SIDE),
6565
("start", StatementType.CLIENT_SIDE),
66+
("begin isolation level serializable", StatementType.CLIENT_SIDE),
67+
("start isolation level serializable", StatementType.CLIENT_SIDE),
68+
("begin isolation level repeatable read", StatementType.CLIENT_SIDE),
69+
("start isolation level repeatable read", StatementType.CLIENT_SIDE),
6670
("begin transaction", StatementType.CLIENT_SIDE),
6771
("start transaction", StatementType.CLIENT_SIDE),
72+
(
73+
"begin transaction isolation level serializable",
74+
StatementType.CLIENT_SIDE,
75+
),
76+
(
77+
"start transaction isolation level serializable",
78+
StatementType.CLIENT_SIDE,
79+
),
80+
(
81+
"begin transaction isolation level repeatable read",
82+
StatementType.CLIENT_SIDE,
83+
),
84+
(
85+
"start transaction isolation level repeatable read",
86+
StatementType.CLIENT_SIDE,
87+
),
6888
("rollback", StatementType.CLIENT_SIDE),
6989
(" commit TRANSACTION ", StatementType.CLIENT_SIDE),
7090
(" rollback TRANSACTION ", StatementType.CLIENT_SIDE),
@@ -84,6 +104,16 @@ def test_classify_stmt(self):
84104
("udpate table set col2=1 where col1 = 2", StatementType.UNKNOWN),
85105
("begin foo", StatementType.UNKNOWN),
86106
("begin transaction foo", StatementType.UNKNOWN),
107+
("begin transaction isolation level", StatementType.UNKNOWN),
108+
("begin transaction repeatable read", StatementType.UNKNOWN),
109+
(
110+
"begin transaction isolation level repeatable read foo",
111+
StatementType.UNKNOWN,
112+
),
113+
(
114+
"begin transaction isolation level unspecified",
115+
StatementType.UNKNOWN,
116+
),
87117
("commit foo", StatementType.UNKNOWN),
88118
("commit transaction foo", StatementType.UNKNOWN),
89119
("rollback foo", StatementType.UNKNOWN),
@@ -100,6 +130,50 @@ def test_classify_stmt(self):
100130
classify_statement(query).statement_type, want_class, query
101131
)
102132

133+
def test_begin_isolation_level(self):
134+
parsed_statement = classify_statement("begin")
135+
self.assertEqual(
136+
parsed_statement,
137+
ParsedStatement(
138+
StatementType.CLIENT_SIDE,
139+
Statement("begin"),
140+
ClientSideStatementType.BEGIN,
141+
[],
142+
),
143+
)
144+
parsed_statement = classify_statement("begin isolation level serializable")
145+
self.assertEqual(
146+
parsed_statement,
147+
ParsedStatement(
148+
StatementType.CLIENT_SIDE,
149+
Statement("begin isolation level serializable"),
150+
ClientSideStatementType.BEGIN,
151+
["serializable"],
152+
),
153+
)
154+
parsed_statement = classify_statement("begin isolation level repeatable read")
155+
self.assertEqual(
156+
parsed_statement,
157+
ParsedStatement(
158+
StatementType.CLIENT_SIDE,
159+
Statement("begin isolation level repeatable read"),
160+
ClientSideStatementType.BEGIN,
161+
["repeatable read"],
162+
),
163+
)
164+
parsed_statement = classify_statement(
165+
"begin isolation level repeatable read "
166+
)
167+
self.assertEqual(
168+
parsed_statement,
169+
ParsedStatement(
170+
StatementType.CLIENT_SIDE,
171+
Statement("begin isolation level repeatable read"),
172+
ClientSideStatementType.BEGIN,
173+
["repeatable read"],
174+
),
175+
)
176+
103177
def test_partition_query_classify_stmt(self):
104178
parsed_statement = classify_statement(
105179
" PARTITION SELECT s.SongName FROM Songs AS s "

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy