diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index e6042e31..0c8e0c87 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -180,6 +180,14 @@ def pre_exec(self): if priority is not None: self._dbapi_connection.connection.request_priority = priority + transaction_tag = self.execution_options.get("transaction_tag") + if transaction_tag: + self._dbapi_connection.connection.transaction_tag = transaction_tag + + request_tag = self.execution_options.get("request_tag") + if request_tag: + self.cursor.request_tag = request_tag + def fire_sequence(self, seq, type_): """Builds a statement for fetching next value of the sequence.""" return self._execute_scalar( diff --git a/test/mockserver_tests/tags_model.py b/test/mockserver_tests/tags_model.py new file mode 100644 index 00000000..9965dbf0 --- /dev/null +++ b/test/mockserver_tests/tags_model.py @@ -0,0 +1,28 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import String, BigInteger +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class Singer(Base): + __tablename__ = "singers" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + name: Mapped[str] = mapped_column(String) diff --git a/test/mockserver_tests/test_basics.py b/test/mockserver_tests/test_basics.py index f895c9e4..29bffa82 100644 --- a/test/mockserver_tests/test_basics.py +++ b/test/mockserver_tests/test_basics.py @@ -74,8 +74,12 @@ def test_sqlalchemy_select1(self): with engine.connect().execution_options( isolation_level="AUTOCOMMIT" ) as connection: - results = connection.execute(select(1)).fetchall() + results = connection.execute( + select(1).execution_options(request_tag="my-tag") + ).fetchall() self.verify_select1(results) + request: ExecuteSqlRequest = self.spanner_service.requests[1] + eq_("my-tag", request.request_options.request_tag) def test_sqlalchemy_select_now(self): now = datetime.datetime.now(datetime.UTC) diff --git a/test/mockserver_tests/test_tags.py b/test/mockserver_tests/test_tags.py new file mode 100644 index 00000000..c422bc5e --- /dev/null +++ b/test/mockserver_tests/test_tags.py @@ -0,0 +1,139 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import create_engine, select +from sqlalchemy.orm import Session +from sqlalchemy.testing import eq_, is_instance_of +from google.cloud.spanner_v1 import ( + FixedSizePool, + BatchCreateSessionsRequest, + ExecuteSqlRequest, + BeginTransactionRequest, + CommitRequest, +) +from test.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_update_count, +) +from test.mockserver_tests.mock_server_test_base import add_result +import google.cloud.spanner_v1.types.type as spanner_type +import google.cloud.spanner_v1.types.result_set as result_set + + +class TestStaleReads(MockServerTestBase): + def test_request_tag(self): + from test.mockserver_tests.tags_model import Singer + + add_singer_query_result("SELECT singers.id, singers.name \n" + "FROM singers") + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + ) + + with Session(engine.execution_options(read_only=True)) as session: + # Execute two queries in a read-only transaction. + session.scalars( + select(Singer).execution_options(request_tag="my-tag-1") + ).all() + session.scalars( + select(Singer).execution_options(request_tag="my-tag-2") + ).all() + + # Verify the requests that we got. + requests = self.spanner_service.requests + eq_(4, len(requests)) + is_instance_of(requests[0], BatchCreateSessionsRequest) + is_instance_of(requests[1], BeginTransactionRequest) + is_instance_of(requests[2], ExecuteSqlRequest) + is_instance_of(requests[3], ExecuteSqlRequest) + # Verify that we got a request tag for the queries. + eq_("my-tag-1", requests[2].request_options.request_tag) + eq_("my-tag-2", requests[3].request_options.request_tag) + + def test_transaction_tag(self): + from test.mockserver_tests.tags_model import Singer + + add_singer_query_result("SELECT singers.id, singers.name\n" + "FROM singers") + add_update_count("INSERT INTO singers (id, name) VALUES (@a0, @a1)", 1) + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + ) + + with Session( + engine.execution_options(transaction_tag="my-transaction-tag") + ) as session: + # Execute a query and an insert statement in a read/write transaction. + session.scalars( + select(Singer).execution_options(request_tag="my-tag-1") + ).all() + session.add(Singer(id=1, name="Some Singer")) + session.commit() + + # Verify the requests that we got. + requests = self.spanner_service.requests + eq_(5, len(requests)) + is_instance_of(requests[0], BatchCreateSessionsRequest) + is_instance_of(requests[1], BeginTransactionRequest) + is_instance_of(requests[2], ExecuteSqlRequest) + is_instance_of(requests[3], ExecuteSqlRequest) + is_instance_of(requests[4], CommitRequest) + for request in requests[2:]: + eq_("my-transaction-tag", request.request_options.transaction_tag) + + +def add_singer_query_result(sql: str): + result = result_set.ResultSet( + dict( + metadata=result_set.ResultSetMetadata( + dict( + row_type=spanner_type.StructType( + dict( + fields=[ + spanner_type.StructType.Field( + dict( + name="singers_id", + type=spanner_type.Type( + dict(code=spanner_type.TypeCode.INT64) + ), + ) + ), + spanner_type.StructType.Field( + dict( + name="singers_name", + type=spanner_type.Type( + dict(code=spanner_type.TypeCode.STRING) + ), + ) + ), + ] + ) + ) + ) + ), + ) + ) + result.rows.extend( + [ + ( + "1", + "Jane Doe", + ), + ( + "2", + "John Doe", + ), + ] + ) + add_result(sql, result) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy