diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml index 40bf9973..d8a1bbca 100644 --- a/.github/.OwlBot.lock.yaml +++ b/.github/.OwlBot.lock.yaml @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,5 +13,5 @@ # limitations under the License. docker: image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest - digest: sha256:230f7fe8a0d2ed81a519cfc15c6bb11c5b46b9fb449b8b1219b3771bcb520ad2 -# created: 2023-12-09T15:16:25.430769578Z + digest: sha256:5ea6d0ab82c956b50962f91d94e206d3921537ae5fe1549ec5326381d8905cfa +# created: 2024-01-15T16:32:08.142785673Z diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml index 68af1253..38bd545c 100644 --- a/.github/sync-repo-settings.yaml +++ b/.github/sync-repo-settings.yaml @@ -14,6 +14,8 @@ branchProtectionRules: - 'Samples - Python 3.8' - 'Samples - Python 3.9' - 'Samples - Python 3.10' + - 'Samples - Python 3.11' + - 'Samples - Python 3.12' permissionRules: - team: actools-python permission: admin diff --git a/.kokoro/requirements.txt b/.kokoro/requirements.txt index e5c1ffca..bb3d6ca3 100644 --- a/.kokoro/requirements.txt +++ b/.kokoro/requirements.txt @@ -263,9 +263,9 @@ jeepney==0.8.0 \ # via # keyring # secretstorage -jinja2==3.1.2 \ - --hash=sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852 \ - --hash=sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61 +jinja2==3.1.3 \ + --hash=sha256:7d6d50dd97d52cbc355597bd845fabfbac3f551e1f99619e39a35ce8c370b5fa \ + --hash=sha256:ac8bd6544d4bb2c9792bf3a159e80bba8fda7f07e81bc3aed565432d5925ba90 # via gcp-releasetool keyring==24.2.0 \ --hash=sha256:4901caaf597bfd3bbd78c9a0c7c4c29fcd8310dab2cffefe749e916b6527acd6 \ diff --git a/AUTHORS b/AUTHORS index 5daa663b..fc5345ee 100644 --- a/AUTHORS +++ b/AUTHORS @@ -19,6 +19,7 @@ Maksym Voitko Maxim Zudilov (mxmzdlv) Maxime Beauchemin (mistercrunch) Romain Rigaux +Sharoon Thomas (sharoonthomas) Sumedh Sakdeo Tim Swast (tswast) Vince Broz diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c3b7ca7..2f98741e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,30 @@ Older versions of this project were distributed as [pybigquery][0]. [2]: https://pypi.org/project/pybigquery/#history +## [1.11.0.dev3](https://github.com/googleapis/python-bigquery-sqlalchemy/compare/v1.9.0...v1.11.0.dev3) (2024-02-20) + + +### Bug Fixes + +* Fix grouping sets, rollup and cube rendering issue ([#1019](https://github.com/googleapis/python-bigquery-sqlalchemy/pull/1019)) +* Add more grouping sets/rollup/cube tests ([#1029](https://github.com/googleapis/python-bigquery-sqlalchemy/pull/1029)) + +## [1.11.0.dev2](https://github.com/googleapis/python-bigquery-sqlalchemy/compare/v1.9.0...v1.11.0.dev2) (2024-02-01) + +## [1.11.0.dev1](https://github.com/googleapis/python-bigquery-sqlalchemy/compare/v1.9.0...v1.11.0.dev1) (2024-01-30) + + +### Bug Fixes + +* Fix coverage test issues in SQLAlchemy migration ([#987](https://github.com/googleapis/python-bigquery-sqlalchemy/pull/987)) +* Cleanup test_sqlalchemy_dialect file for readability ([#1018](https://github.com/googleapis/python-bigquery-sqlalchemy/pull/1018)) + +## [1.11.0.dev0](https://github.com/googleapis/python-bigquery-sqlalchemy/compare/v1.9.0...v1.11.0.dev0) (2024-01-25) + + +### Features + +* Drop support for SQLAlchemy versions 1.2 and 1.3, maintain support for 1.4 and add support for 2.0 ([#920](https://github.com/googleapis/python-bigquery-sqlalchemy/pull/920)) ## [1.9.0](https://github.com/googleapis/python-bigquery-sqlalchemy/compare/v1.8.0...v1.9.0) (2023-12-10) diff --git a/README.rst b/README.rst index a2036289..b6693abb 100644 --- a/README.rst +++ b/README.rst @@ -35,7 +35,8 @@ In order to use this library, you first need to go through the following steps: .. _Setup Authentication.: https://googleapis.dev/python/google-api-core/latest/auth.html .. note:: - This library is only compatible with SQLAlchemy versions < 2.0.0 + This library is a prerelease to gauge compatiblity with SQLAlchemy + versions >= 1.4.16 and < 2.1 Installation ------------ @@ -108,7 +109,8 @@ SQLAlchemy from sqlalchemy.schema import * engine = create_engine('bigquery://project') table = Table('dataset.table', MetaData(bind=engine), autoload=True) - print(select([func.count('*')], from_obj=table).scalar()) + print(select([func.count('*')], from_obj=table().scalar()) + Project ^^^^^^^ @@ -281,7 +283,7 @@ If you need additional control, you can supply a BigQuery client of your own: engine = create_engine( 'bigquery://some-project/some-dataset?user_supplied_client=True', - connect_args={'client': custom_bq_client}, + connect_args={'client': custom_bq_client}, ) @@ -292,7 +294,12 @@ To add metadata to a table: .. code-block:: python - table = Table('mytable', ..., bigquery_description='my table description', bigquery_friendly_name='my table friendly name') + table = Table('mytable', ..., + bigquery_description='my table description', + bigquery_friendly_name='my table friendly name', + bigquery_default_rounding_mode="ROUND_HALF_EVEN", + bigquery_expiration_timestamp=datetime.datetime.fromisoformat("2038-01-01T00:00:00+00:00"), + ) To add metadata to a column: @@ -300,6 +307,52 @@ To add metadata to a column: Column('mycolumn', doc='my column description') +To create a clustered table: + +.. code-block:: python + + table = Table('mytable', ..., bigquery_clustering_fields=["a", "b", "c"]) + +To create a time-unit column-partitioned table: + +.. code-block:: python + + from google.cloud import bigquery + + table = Table('mytable', ..., + bigquery_time_partitioning=bigquery.TimePartitioning( + field="mytimestamp", + type_="MONTH", + expiration_ms=1000 * 60 * 60 * 24 * 30 * 6, # 6 months + ), + bigquery_require_partition_filter=True, + ) + +To create an ingestion-time partitioned table: + +.. code-block:: python + + from google.cloud import bigquery + + table = Table('mytable', ..., + bigquery_time_partitioning=bigquery.TimePartitioning(), + bigquery_require_partition_filter=True, + ) + +To create an integer-range partitioned table + +.. code-block:: python + + from google.cloud import bigquery + + table = Table('mytable', ..., + bigquery_range_partitioning=bigquery.RangePartitioning( + field="zipcode", + range_=bigquery.PartitionRange(start=0, end=100000, interval=10), + ), + bigquery_require_partition_filter=True, + ) + Threading and Multiprocessing ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/dev_requirements.txt b/dev_requirements.txt index ddc53054..1798fab5 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -2,4 +2,4 @@ sqlalchemy>=2.0.15,<2.1.0 google-cloud-bigquery>=1.6.0 pytest===6.2.5 pytest-flake8===1.1.0 # versions 1.1.1 and above require pytest 7 -pytz==2023.3 +pytz==2023.3.post1 diff --git a/noxfile.py b/noxfile.py index e31f32c5..fc0ed6c5 100644 --- a/noxfile.py +++ b/noxfile.py @@ -369,7 +369,7 @@ def compliance(session): session.skip("Compliance tests were not found") session.install("--pre", "grpcio") - session.install("--pre", "--no-deps", "--upgrade", "sqlalchemy<2.0.0") + session.install("--pre", "--no-deps", "--upgrade", "sqlalchemy>=1.4.16,<2.1") session.install( "mock", "pytest", @@ -394,10 +394,11 @@ def compliance(session): f"--junitxml=compliance_{session.python}_sponge_log.xml", "--reruns=3", "--reruns-delay=60", - "--only-rerun=403 Exceeded rate limits", - "--only-rerun=409 Already Exists", - "--only-rerun=404 Not found", - "--only-rerun=400 Cannot execute DML over a non-existent table", + "--only-rerun=Exceeded rate limits", + "--only-rerun=Already Exists", + "--only-rerun=Not found", + "--only-rerun=Cannot execute DML over a non-existent table", + "--only-rerun=Job exceeded rate limits", system_test_folder_path, *session.posargs, # To suppress the "Deprecated API features detected!" warning when @@ -427,7 +428,16 @@ def docs(session): session.install("-e", ".") session.install( - "sphinx==4.0.1", + # We need to pin to specific versions of the `sphinxcontrib-*` packages + # which still support sphinx 4.x. + # See https://github.com/googleapis/sphinx-docfx-yaml/issues/344 + # and https://github.com/googleapis/sphinx-docfx-yaml/issues/345. + "sphinxcontrib-applehelp==1.0.4", + "sphinxcontrib-devhelp==1.0.2", + "sphinxcontrib-htmlhelp==2.0.1", + "sphinxcontrib-qthelp==1.0.3", + "sphinxcontrib-serializinghtml==1.1.5", + "sphinx==4.5.0", "alabaster", "geoalchemy2", "shapely", @@ -455,6 +465,15 @@ def docfx(session): session.install("-e", ".") session.install( + # We need to pin to specific versions of the `sphinxcontrib-*` packages + # which still support sphinx 4.x. + # See https://github.com/googleapis/sphinx-docfx-yaml/issues/344 + # and https://github.com/googleapis/sphinx-docfx-yaml/issues/345. + "sphinxcontrib-applehelp==1.0.4", + "sphinxcontrib-devhelp==1.0.2", + "sphinxcontrib-htmlhelp==2.0.1", + "sphinxcontrib-qthelp==1.0.3", + "sphinxcontrib-serializinghtml==1.1.5", "gcp-sphinx-docfx-yaml", "alabaster", "geoalchemy2", @@ -524,7 +543,7 @@ def prerelease_deps(session): prerel_deps = [ "protobuf", - "sqlalchemy<2.0.0", + "sqlalchemy>=1.4.16,<2.1", # dependency of grpc "six", "googleapis-common-protos", diff --git a/owlbot.py b/owlbot.py index 22678c8b..4dfec18d 100644 --- a/owlbot.py +++ b/owlbot.py @@ -42,14 +42,17 @@ system_test_extras=extras, system_test_extras_by_python=extras_by_python, ) -s.move(templated_files, excludes=[ - # sqlalchemy-bigquery was originally licensed MIT - "LICENSE", - "docs/multiprocessing.rst", - # exclude gh actions as credentials are needed for tests - ".github/workflows", - "README.rst", -]) +s.move( + templated_files, + excludes=[ + # sqlalchemy-bigquery was originally licensed MIT + "LICENSE", + "docs/multiprocessing.rst", + # exclude gh actions as credentials are needed for tests + ".github/workflows", + "README.rst", + ], +) # ---------------------------------------------------------------------------- # Fixup files @@ -59,7 +62,7 @@ [".coveragerc"], "google/cloud/__init__.py", "sqlalchemy_bigquery/requirements.py", - ) +) s.replace( ["noxfile.py"], @@ -75,12 +78,14 @@ s.replace( - ["noxfile.py"], "--cov=google", "--cov=sqlalchemy_bigquery", + ["noxfile.py"], + "--cov=google", + "--cov=sqlalchemy_bigquery", ) s.replace( - ["noxfile.py"], + ["noxfile.py"], "\+ SYSTEM_TEST_EXTRAS", "", ) @@ -88,35 +93,34 @@ s.replace( ["noxfile.py"], - '''"protobuf", - # dependency of grpc''', - '''"protobuf", - "sqlalchemy<2.0.0", - # dependency of grpc''', + """"protobuf", + # dependency of grpc""", + """"protobuf", + "sqlalchemy>=1.4.16,<2.1", + # dependency of grpc""", ) s.replace( ["noxfile.py"], r"def default\(session\)", - "def default(session, install_extras=True)", + "def default(session, install_extras=True)", ) - - def place_before(path, text, *before_text, escape=None): replacement = "\n".join(before_text) + "\n" + text if escape: for c in escape: - text = text.replace(c, '\\' + c) + text = text.replace(c, "\\" + c) s.replace([path], text, replacement) + place_before( "noxfile.py", "SYSTEM_TEST_PYTHON_VERSIONS=", "", - "# We're using two Python versions to test with sqlalchemy 1.3 and 1.4.", + "# We're using two Python versions to test with sqlalchemy>=1.4.16", ) place_before( @@ -126,7 +130,7 @@ def place_before(path, text, *before_text, escape=None): ) -install_logic = ''' +install_logic = """ if install_extras and session.python in ["3.11", "3.12"]: install_target = ".[geography,alembic,tests,bqstorage]" elif install_extras: @@ -134,7 +138,7 @@ def place_before(path, text, *before_text, escape=None): else: install_target = "." session.install("-e", install_target, "-c", constraints_path) -''' +""" place_before( "noxfile.py", @@ -163,7 +167,7 @@ def compliance(session): session.skip("Compliance tests were not found") session.install("--pre", "grpcio") - session.install("--pre", "--no-deps", "--upgrade", "sqlalchemy<2.0.0") + session.install("--pre", "--no-deps", "--upgrade", "sqlalchemy>=1.4.16,<2.1") session.install( "mock", "pytest", @@ -188,10 +192,11 @@ def compliance(session): f"--junitxml=compliance_{session.python}_sponge_log.xml", "--reruns=3", "--reruns-delay=60", - "--only-rerun=403 Exceeded rate limits", - "--only-rerun=409 Already Exists", - "--only-rerun=404 Not found", - "--only-rerun=400 Cannot execute DML over a non-existent table", + "--only-rerun=Exceeded rate limits", + "--only-rerun=Already Exists", + "--only-rerun=Not found", + "--only-rerun=Cannot execute DML over a non-existent table", + "--only-rerun=Job exceeded rate limits", system_test_folder_path, *session.posargs, # To suppress the "Deprecated API features detected!" warning when @@ -205,12 +210,11 @@ def compliance(session): ''' place_before( - "noxfile.py", - "@nox.session(python=DEFAULT_PYTHON_VERSION)\n" - "def cover(session):", - compliance, - escape="()", - ) + "noxfile.py", + "@nox.session(python=DEFAULT_PYTHON_VERSION)\n" "def cover(session):", + compliance, + escape="()", +) s.replace(["noxfile.py"], '"alabaster"', '"alabaster", "geoalchemy2", "shapely"') @@ -266,11 +270,10 @@ def system_noextras(session): place_before( "noxfile.py", - "@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS[-1])\n" - "def compliance(session):", + "@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS[-1])\n" "def compliance(session):", system_noextras, escape="()[]", - ) +) # Add DB config for SQLAlchemy dialect test suite. @@ -287,7 +290,7 @@ def system_noextras(session): [tool:pytest] addopts= --tb native -v -r fxX -p no:warnings python_files=tests/*test_*.py -""" +""", ) # ---------------------------------------------------------------------------- @@ -298,7 +301,7 @@ def system_noextras(session): python.py_samples(skip_readmes=True) s.replace( - ["./samples/snippets/noxfile.py"], + ["./samples/snippets/noxfile.py"], """session.install\("-e", _get_repo_root\(\)\)""", """session.install("-e", _get_repo_root()) else: diff --git a/samples/snippets/requirements-test.txt b/samples/snippets/requirements-test.txt index 90537d96..552292b6 100644 --- a/samples/snippets/requirements-test.txt +++ b/samples/snippets/requirements-test.txt @@ -1,16 +1,16 @@ attrs==23.1.0 -click==8.1.6 -google-auth==2.22.0 -google-cloud-testutils==1.3.3 +click==8.1.7 +google-auth==2.25.2 +google-cloud-testutils==1.4.0 iniconfig==2.0.0 -packaging==23.1 -pluggy==1.2.0 +packaging==23.2 +pluggy==1.3.0 py==1.11.0 -pyasn1==0.5.0 +pyasn1==0.5.1 pyasn1-modules==0.3.0 pyparsing==3.1.1 pytest===6.2.5 rsa==4.9 six==1.16.0 toml==0.10.2 -typing-extensions==4.7.1 +typing-extensions==4.9.0 diff --git a/samples/snippets/requirements.txt b/samples/snippets/requirements.txt index b15bf2cb..f011f19c 100644 --- a/samples/snippets/requirements.txt +++ b/samples/snippets/requirements.txt @@ -1,33 +1,33 @@ -alembic==1.11.2 -certifi==2023.7.22 -charset-normalizer==3.2.0 -geoalchemy2==0.14.1 -google-api-core[grpc]==2.11.1 -google-auth==2.22.0 -google-cloud-bigquery==3.11.4 -google-cloud-core==2.3.3 +alembic==1.13.0 +certifi==2023.11.17 +charset-normalizer==3.3.2 +geoalchemy2==0.14.2 +google-api-core[grpc]==2.15.0 +google-auth==2.25.2 +google-cloud-bigquery==3.14.1 +google-cloud-core==2.4.1 google-crc32c==1.5.0 -google-resumable-media==2.5.0 -googleapis-common-protos==1.60.0 -greenlet==3.0.1 -grpcio==1.59.0 -grpcio-status==1.57.0 -idna==3.4 -importlib-resources==6.0.1; python_version >= '3.8' -mako==1.2.4 +google-resumable-media==2.7.0 +googleapis-common-protos==1.62.0 +greenlet==3.0.2 +grpcio==1.60.0 +grpcio-status==1.60.0 +idna==3.6 +importlib-resources==6.1.1; python_version >= '3.8' +mako==1.3.0 markupsafe==2.1.3 -packaging==23.1 -proto-plus==1.22.3 -protobuf==4.24.0 -pyasn1==0.5.0 +packaging==23.2 +proto-plus==1.23.0 +protobuf==4.25.1 +pyasn1==0.5.1 pyasn1-modules==0.3.0 pyparsing==3.1.1 python-dateutil==2.8.2 -pytz==2023.3 +pytz==2023.3.post1 requests==2.31.0 rsa==4.9 shapely==2.0.2 six==1.16.0 -sqlalchemy===1.4.27 -typing-extensions==4.7.1 -urllib3==1.26.18 +sqlalchemy==1.4.16 +typing-extensions==4.9.0 +urllib3==2.1.0 diff --git a/setup.py b/setup.py index e035c518..31565afa 100644 --- a/setup.py +++ b/setup.py @@ -99,9 +99,9 @@ def readme(): # Until this issue is closed # https://github.com/googleapis/google-cloud-python/issues/10566 "google-auth>=1.25.0,<3.0.0dev", # Work around pip wack. - "google-cloud-bigquery>=2.25.2,<4.0.0dev", + "google-cloud-bigquery>=3.3.6,<4.0.0dev", "packaging", - "sqlalchemy>=1.2.0,<2.0.0dev", + "sqlalchemy>=1.4.16,<2.1", ], extras_require=extras, python_requires=">=3.8, <3.13", diff --git a/sqlalchemy_bigquery/_struct.py b/sqlalchemy_bigquery/_struct.py index fc551c12..309d1080 100644 --- a/sqlalchemy_bigquery/_struct.py +++ b/sqlalchemy_bigquery/_struct.py @@ -17,20 +17,14 @@ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -import packaging.version import sqlalchemy.sql.default_comparator import sqlalchemy.sql.sqltypes import sqlalchemy.types from . import base -sqlalchemy_1_4_or_more = packaging.version.parse( - sqlalchemy.__version__ -) >= packaging.version.parse("1.4") - -if sqlalchemy_1_4_or_more: - import sqlalchemy.sql.coercions - import sqlalchemy.sql.roles +import sqlalchemy.sql.coercions +import sqlalchemy.sql.roles def _get_subtype_col_spec(type_): @@ -103,34 +97,20 @@ def _setup_getitem(self, name): def __getattr__(self, name): if name.lower() in self.expr.type._STRUCT_byname: return self[name] + else: + raise AttributeError(name) comparator_factory = Comparator -# In the implementations of _field_index below, we're stealing from -# the JSON type implementation, but the code to steal changed in -# 1.4. :/ - -if sqlalchemy_1_4_or_more: - - def _field_index(self, name, operator): - return sqlalchemy.sql.coercions.expect( - sqlalchemy.sql.roles.BinaryElementRole, - name, - expr=self.expr, - operator=operator, - bindparam_type=sqlalchemy.types.String(), - ) - -else: - - def _field_index(self, name, operator): - return sqlalchemy.sql.default_comparator._check_literal( - self.expr, - operator, - name, - bindparam_type=sqlalchemy.types.String(), - ) +def _field_index(self, name, operator): + return sqlalchemy.sql.coercions.expect( + sqlalchemy.sql.roles.BinaryElementRole, + name, + expr=self.expr, + operator=operator, + bindparam_type=sqlalchemy.types.String(), + ) def struct_getitem_op(a, b): diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 5297f223..e80f2891 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -19,6 +19,7 @@ """Integration between SQLAlchemy and BigQuery.""" +import datetime from decimal import Decimal import random import operator @@ -27,7 +28,11 @@ from google import auth import google.api_core.exceptions from google.cloud.bigquery import dbapi -from google.cloud.bigquery.table import TableReference +from google.cloud.bigquery.table import ( + RangePartitioning, + TableReference, + TimePartitioning, +) from google.api_core.exceptions import NotFound import packaging.version import sqlalchemy @@ -35,7 +40,7 @@ import sqlalchemy.sql.functions import sqlalchemy.sql.sqltypes import sqlalchemy.sql.type_api -from sqlalchemy.exc import NoSuchTableError +from sqlalchemy.exc import NoSuchTableError, NoSuchColumnError from sqlalchemy import util from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.compiler import ( @@ -158,7 +163,7 @@ def get_insert_default(self, column): # pragma: NO COVER """, flags=re.IGNORECASE | re.VERBOSE, ) - def __distribute_types_to_expanded_placeholders(self, m): + def __distribute_types_to_expanded_placeholders(self, m): # pragma: NO COVER # If we have an in parameter, it sometimes gets expaned to 0 or more # parameters and we need to move the type marker to each # parameter. @@ -169,6 +174,8 @@ def __distribute_types_to_expanded_placeholders(self, m): # suffixes refect that when an array parameter is expanded, # numeric suffixes are added. For example, a placeholder like # `%(foo)s` gets expaneded to `%(foo_0)s, `%(foo_1)s, ...`. + + # Coverage: despite our best efforts, never recognized this segment of code as being tested. placeholders, type_ = m.groups() if placeholders: placeholders = placeholders.replace(")", f":{type_})") @@ -214,7 +221,7 @@ def visit_table_valued_alias(self, element, **kw): # For example, given SQLAlchemy code: # # print( - # select([func.unnest(foo.c.objects).alias('foo_objects').column]) + # select(func.unnest(foo.c.objects).alias('foo_objects').column) # .compile(engine)) # # Left to it's own devices, SQLAlchemy would outout: @@ -269,6 +276,14 @@ def _known_tables(self): if table is not None: known_tables.add(table.name) + # If we have the table in the `from` of our parent, do not add the alias + # as this will add the table twice and cause an implicit JOIN for that + # table on itself + asfrom_froms = self.stack[-1].get("asfrom_froms", []) + for from_ in asfrom_froms: + if isinstance(from_, Table): + known_tables.add(from_.name) + return known_tables def visit_column( @@ -323,7 +338,14 @@ def visit_label(self, *args, within_group_by=False, **kwargs): # Flag set in the group_by_clause method. Works around missing # equivalent to supports_simple_order_by_label for group by. if within_group_by: - kwargs["render_label_as_label"] = args[0] + column_label = args[0] + sql_keywords = {"GROUPING SETS", "ROLLUP", "CUBE"} + for keyword in sql_keywords: + if keyword in str(column_label): + break + else: # for/else always happens unless break gets called + kwargs["render_label_as_label"] = column_label + return super(BigQueryCompiler, self).visit_label(*args, **kwargs) def group_by_clause(self, select, **kw): @@ -343,11 +365,7 @@ def group_by_clause(self, select, **kw): __sqlalchemy_version_info = packaging.version.parse(sqlalchemy.__version__) - __expanding_text = ( - "EXPANDING" - if __sqlalchemy_version_info < packaging.version.parse("1.4") - else "POSTCOMPILE" - ) + __expanding_text = "POSTCOMPILE" # https://github.com/sqlalchemy/sqlalchemy/commit/f79df12bd6d99b8f6f09d4bf07722638c4b4c159 __expanding_conflict = ( @@ -375,9 +393,6 @@ def visit_in_op_binary(self, binary, operator_, **kw): self._generate_generic_binary(binary, " IN ", **kw) ) - def visit_empty_set_expr(self, element_types): - return "" - def visit_not_in_op_binary(self, binary, operator, **kw): return ( "(" @@ -387,8 +402,6 @@ def visit_not_in_op_binary(self, binary, operator, **kw): + ")" ) - visit_notin_op_binary = visit_not_in_op_binary # before 1.4 - ############################################################################ ############################################################################ @@ -411,8 +424,8 @@ def visit_contains_op_binary(self, binary, operator, **kw): self._maybe_reescape(binary), operator, **kw ) - def visit_notcontains_op_binary(self, binary, operator, **kw): - return super(BigQueryCompiler, self).visit_notcontains_op_binary( + def visit_not_contains_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_not_contains_op_binary( self._maybe_reescape(binary), operator, **kw ) @@ -421,8 +434,8 @@ def visit_startswith_op_binary(self, binary, operator, **kw): self._maybe_reescape(binary), operator, **kw ) - def visit_notstartswith_op_binary(self, binary, operator, **kw): - return super(BigQueryCompiler, self).visit_notstartswith_op_binary( + def visit_not_startswith_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_not_startswith_op_binary( self._maybe_reescape(binary), operator, **kw ) @@ -431,8 +444,8 @@ def visit_endswith_op_binary(self, binary, operator, **kw): self._maybe_reescape(binary), operator, **kw ) - def visit_notendswith_op_binary(self, binary, operator, **kw): - return super(BigQueryCompiler, self).visit_notendswith_op_binary( + def visit_not_endswith_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_not_endswith_op_binary( self._maybe_reescape(binary), operator, **kw ) @@ -497,7 +510,8 @@ def visit_bindparam( # here, because then we can't do a recompile later (e.g., first # print the statment, then execute it). See issue #357. # - if getattr(bindparam, "expand_op", None) is not None: + # Coverage: despite our best efforts, never recognized this segment of code as being tested. + if getattr(bindparam, "expand_op", None) is not None: # pragma: NO COVER assert bindparam.expand_op.__name__.endswith("in_op") # in in bindparam = bindparam._clone(maintain_key=True) bindparam.expanding = False @@ -623,16 +637,23 @@ def visit_NUMERIC(self, type_, **kw): class BigQueryDDLCompiler(DDLCompiler): + option_datatype_mapping = { + "friendly_name": str, + "expiration_timestamp": datetime.datetime, + "require_partition_filter": bool, + "default_rounding_mode": str, + } + # BigQuery has no support for foreign keys. - def visit_foreign_key_constraint(self, constraint): + def visit_foreign_key_constraint(self, constraint, **kw): return None # BigQuery has no support for primary keys. - def visit_primary_key_constraint(self, constraint): + def visit_primary_key_constraint(self, constraint, **kw): return None # BigQuery has no support for unique constraints. - def visit_unique_constraint(self, constraint): + def visit_unique_constraint(self, constraint, **kw): return None def get_column_specification(self, column, **kwargs): @@ -646,38 +667,257 @@ def get_column_specification(self, column, **kwargs): return colspec def post_create_table(self, table): + """ + Constructs additional SQL clauses for table creation in BigQuery. + + This function processes the BigQuery dialect-specific options and generates SQL clauses for partitioning, + clustering, and other table options. + + Args: + table (Table): The SQLAlchemy Table object for which the SQL is being generated. + + Returns: + str: A string composed of SQL clauses for time partitioning, clustering, and other BigQuery specific + options, each separated by a newline. Returns an empty string if no such options are specified. + + Raises: + TypeError: If the time_partitioning option is not a `TimePartitioning` object or if the clustering_fields option is not a list. + NoSuchColumnError: If any field specified in clustering_fields does not exist in the table. + """ + bq_opts = table.dialect_options["bigquery"] - opts = [] - if ("description" in bq_opts) or table.comment: - description = process_string_literal( - bq_opts.get("description", table.comment) + options = {} + clauses = [] + + if ( + bq_opts.get("time_partitioning") is not None + and bq_opts.get("range_partitioning") is not None + ): + raise ValueError( + "biquery_time_partitioning and bigquery_range_partitioning" + " dialect options are mutually exclusive." ) - opts.append(f"description={description}") - if "friendly_name" in bq_opts: - opts.append( - "friendly_name={}".format( - process_string_literal(bq_opts["friendly_name"]) + if (time_partitioning := bq_opts.get("time_partitioning")) is not None: + self._raise_for_type( + "time_partitioning", + time_partitioning, + TimePartitioning, + ) + + if time_partitioning.expiration_ms: + _24hours = 1000 * 60 * 60 * 24 + options["partition_expiration_days"] = ( + time_partitioning.expiration_ms / _24hours ) + + partition_by_clause = self._process_time_partitioning( + table, + time_partitioning, ) - if opts: - return "\nOPTIONS({})".format(", ".join(opts)) + clauses.append(partition_by_clause) - return "" + if (range_partitioning := bq_opts.get("range_partitioning")) is not None: + self._raise_for_type( + "range_partitioning", + range_partitioning, + RangePartitioning, + ) + + partition_by_clause = self._process_range_partitioning( + table, + range_partitioning, + ) + + clauses.append(partition_by_clause) + + if (clustering_fields := bq_opts.get("clustering_fields")) is not None: + self._raise_for_type("clustering_fields", clustering_fields, list) + + for field in clustering_fields: + if field not in table.c: + raise NoSuchColumnError(field) + + clauses.append(f"CLUSTER BY {', '.join(clustering_fields)}") - def visit_set_table_comment(self, create): + if ("description" in bq_opts) or table.comment: + description = bq_opts.get("description", table.comment) + self._validate_option_value_type("description", description) + options["description"] = description + + for option in self.option_datatype_mapping: + if option in bq_opts: + options[option] = bq_opts.get(option) + + if options: + individual_option_statements = [ + "{}={}".format(k, self._process_option_value(v)) + for (k, v) in options.items() + if self._validate_option_value_type(k, v) + ] + clauses.append(f"OPTIONS({', '.join(individual_option_statements)})") + + return " " + "\n".join(clauses) + + def visit_set_table_comment(self, create, **kw): table_name = self.preparer.format_table(create.element) description = self.sql_compiler.render_literal_value( create.element.comment, sqlalchemy.sql.sqltypes.String() ) return f"ALTER TABLE {table_name} SET OPTIONS(description={description})" - def visit_drop_table_comment(self, drop): + def visit_drop_table_comment(self, drop, **kw): table_name = self.preparer.format_table(drop.element) return f"ALTER TABLE {table_name} SET OPTIONS(description=null)" + def _validate_option_value_type(self, option: str, value): + """ + Validates the type of the given option value against the expected data type. + + Args: + option (str): The name of the option to be validated. + value: The value of the dialect option whose type is to be checked. The type of this parameter + is dynamic and is verified against the expected type in `self.option_datatype_mapping`. + + Returns: + bool: True if the type of the value matches the expected type, or if the option is not found in + `self.option_datatype_mapping`. + + Raises: + TypeError: If the type of the provided value does not match the expected type as defined in + `self.option_datatype_mapping`. + """ + if option in self.option_datatype_mapping: + self._raise_for_type( + option, + value, + self.option_datatype_mapping[option], + ) + + return True + + def _raise_for_type(self, option, value, expected_type): + if type(value) is not expected_type: + raise TypeError( + f"bigquery_{option} dialect option accepts only {expected_type}," + f" provided {repr(value)}" + ) + + def _process_time_partitioning( + self, table: Table, time_partitioning: TimePartitioning + ): + """ + Generates a SQL 'PARTITION BY' clause for partitioning a table by a date or timestamp. + + Args: + - table (Table): The SQLAlchemy table object representing the BigQuery table to be partitioned. + - time_partitioning (TimePartitioning): The time partitioning details, + including the field to be used for partitioning. + + Returns: + - str: A SQL 'PARTITION BY' clause that uses either TIMESTAMP_TRUNC or DATE_TRUNC to + partition data on the specified field. + + Example: + - Given a table with a TIMESTAMP type column 'event_timestamp' and setting + 'time_partitioning.field' to 'event_timestamp', the function returns + "PARTITION BY TIMESTAMP_TRUNC(event_timestamp, DAY)". + """ + field = "_PARTITIONDATE" + trunc_fn = "DATE_TRUNC" + + if time_partitioning.field is not None: + field = time_partitioning.field + if isinstance( + table.columns[time_partitioning.field].type, + sqlalchemy.sql.sqltypes.TIMESTAMP, + ): + trunc_fn = "TIMESTAMP_TRUNC" + + return f"PARTITION BY {trunc_fn}({field}, {time_partitioning.type_})" + + def _process_range_partitioning( + self, table: Table, range_partitioning: RangePartitioning + ): + """ + Generates a SQL 'PARTITION BY' clause for partitioning a table by a range of integers. + + Args: + - table (Table): The SQLAlchemy table object representing the BigQuery table to be partitioned. + - range_partitioning (RangePartitioning): The RangePartitioning object containing the + partitioning field, range start, range end, and interval. + + Returns: + - str: A SQL string for range partitioning using RANGE_BUCKET and GENERATE_ARRAY functions. + + Raises: + - AttributeError: If the partitioning field is not defined. + - ValueError: If the partitioning field (i.e. column) data type is not an integer. + - TypeError: If the partitioning range start/end values are not integers. + + Example: + "PARTITION BY RANGE_BUCKET(zipcode, GENERATE_ARRAY(0, 100000, 10))" + """ + if range_partitioning.field is None: + raise AttributeError( + "bigquery_range_partitioning expects field to be defined" + ) + + if not isinstance( + table.columns[range_partitioning.field].type, + sqlalchemy.sql.sqltypes.INT, + ): + raise ValueError( + "bigquery_range_partitioning expects field (i.e. column) data type to be INTEGER" + ) + + range_ = range_partitioning.range_ + + if not isinstance(range_.start, int): + raise TypeError( + "bigquery_range_partitioning expects range_.start to be an int," + f" provided {repr(range_.start)}" + ) + + if not isinstance(range_.end, int): + raise TypeError( + "bigquery_range_partitioning expects range_.end to be an int," + f" provided {repr(range_.end)}" + ) + + default_interval = 1 + + return f"PARTITION BY RANGE_BUCKET({range_partitioning.field}, GENERATE_ARRAY({range_.start}, {range_.end}, {range_.interval or default_interval}))" + + def _process_option_value(self, value): + """ + Transforms the given option value into a literal representation suitable for SQL queries in BigQuery. + + Args: + value: The value to be transformed. + + Returns: + The processed value in a format suitable for inclusion in a SQL query. + + Raises: + NotImplementedError: When there is no transformation registered for a data type. + """ + option_casting = { + # Mapping from option type to its casting method + str: lambda x: process_string_literal(x), + int: lambda x: x, + float: lambda x: x, + bool: lambda x: "true" if x else "false", + datetime.datetime: lambda x: BQTimestamp.process_timestamp_literal(x), + } + + if (option_cast := option_casting.get(type(value))) is not None: + return option_cast(value) + + raise NotImplementedError(f"No transformation registered for {repr(value)}") + def process_string_literal(value): return repr(value.replace("%", "%%")) @@ -791,6 +1031,14 @@ def __init__( @classmethod def dbapi(cls): + """ + Use `import_dbapi()` instead. + Maintained for backward compatibility. + """ + return dbapi + + @classmethod + def import_dbapi(cls): return dbapi @staticmethod @@ -963,7 +1211,21 @@ def _get_table(self, connection, table_name, schema=None): raise NoSuchTableError(table_name) return table - def has_table(self, connection, table_name, schema=None): + def has_table(self, connection, table_name, schema=None, **kw): + """Checks whether a table exists in BigQuery. + + Args: + connection (google.cloud.bigquery.client.Client): The client + object used to interact with BigQuery. + table_name (str): The name of the table to check for. + schema (str, optional): The name of the schema to which the table + belongs. Defaults to the default schema. + **kw (dict): Any extra keyword arguments will be ignored. + + Returns: + bool: True if the table exists, False otherwise. + + """ try: self._get_table(connection, table_name, schema) return True @@ -989,25 +1251,8 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw): return {"constrained_columns": []} def get_indexes(self, connection, table_name, schema=None, **kw): - table = self._get_table(connection, table_name, schema) - indexes = [] - if table.time_partitioning: - indexes.append( - { - "name": "partition", - "column_names": [table.time_partitioning.field], - "unique": False, - } - ) - if table.clustering_fields: - indexes.append( - { - "name": "clustering", - "column_names": table.clustering_fields, - "unique": False, - } - ) - return indexes + # BigQuery has no support for indexes. + return [] def get_schema_names(self, connection, **kw): if isinstance(connection, Engine): @@ -1034,10 +1279,6 @@ def do_rollback(self, dbapi_connection): # BigQuery has no support for transactions. pass - def _check_unicode_returns(self, connection, additional_tests=None): - # requests gives back Unicode strings - return True - def get_view_definition(self, connection, view_name, schema=None, **kw): if isinstance(connection, Engine): connection = connection.connect() @@ -1057,7 +1298,13 @@ def __init__(self, *args, **kwargs): raise TypeError("The unnest function requires a single argument.") arg = args[0] if isinstance(arg, sqlalchemy.sql.expression.ColumnElement): - if not isinstance(arg.type, sqlalchemy.sql.sqltypes.ARRAY): + if not ( + isinstance(arg.type, sqlalchemy.sql.sqltypes.ARRAY) + or ( + hasattr(arg.type, "impl") + and isinstance(arg.type.impl, sqlalchemy.sql.sqltypes.ARRAY) + ) + ): raise TypeError("The argument to unnest must have an ARRAY type.") self.type = arg.type.item_type super().__init__(*args, **kwargs) diff --git a/sqlalchemy_bigquery/requirements.py b/sqlalchemy_bigquery/requirements.py index 90cc08db..118e3946 100644 --- a/sqlalchemy_bigquery/requirements.py +++ b/sqlalchemy_bigquery/requirements.py @@ -136,6 +136,11 @@ def schemas(self): return unsupported() + @property + def array_type(self): + """Target database must support array_type""" + return supported() + @property def implicit_default_schema(self): """target system has a strong concept of 'default' schema that can diff --git a/sqlalchemy_bigquery/version.py b/sqlalchemy_bigquery/version.py index f15b4f67..265cef18 100644 --- a/sqlalchemy_bigquery/version.py +++ b/sqlalchemy_bigquery/version.py @@ -17,4 +17,4 @@ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -__version__ = "1.9.0" +__version__ = "1.11.0.dev3" diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt deleted file mode 100644 index 1d0a1b72..00000000 --- a/testing/constraints-3.7.txt +++ /dev/null @@ -1,12 +0,0 @@ -# This constraints file is used to check that lower bounds -# are correct in setup.py -# List *all* library dependencies and extras in this file. -# Pin the version to the lower bound. -# -# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", -sqlalchemy==1.2.0 -google-auth==1.25.0 -google-cloud-bigquery==3.3.6 -google-cloud-bigquery-storage==2.0.0 -google-api-core==1.31.5 -pyarrow==3.0.0 diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt index 4884f96a..667a747d 100644 --- a/testing/constraints-3.8.txt +++ b/testing/constraints-3.8.txt @@ -1 +1,13 @@ -sqlalchemy==1.3.24 +# This constraints file is used to check that lower bounds +# are correct in setup.py +# List *all* library dependencies and extras in this file. +# Pin the version to the lower bound. +# +# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", +sqlalchemy==1.4.16 +google-auth==1.25.0 +google-cloud-bigquery==3.3.6 +google-cloud-bigquery-storage==2.0.0 +google-api-core==1.31.5 +grpcio==1.47.0 +pyarrow==3.0.0 diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index 77dc823a..e69de29b 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -1 +0,0 @@ -sqlalchemy>=1.4.13,<2.0.0 diff --git a/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py index a79f2818..5420bf32 100644 --- a/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py +++ b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py @@ -18,6 +18,7 @@ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import datetime +import decimal import mock import packaging.version import pytest @@ -27,45 +28,203 @@ import sqlalchemy.testing.suite.test_types import sqlalchemy.sql.sqltypes -from sqlalchemy.testing import util +from sqlalchemy.testing import util, config from sqlalchemy.testing.assertions import eq_ -from sqlalchemy.testing.suite import config, select, exists +from sqlalchemy.testing.suite import select, exists from sqlalchemy.testing.suite import * # noqa +from sqlalchemy.testing.suite import Integer, Table, Column, String, bindparam, testing from sqlalchemy.testing.suite import ( - ComponentReflectionTest as _ComponentReflectionTest, CTETest as _CTETest, ExistsTest as _ExistsTest, + FetchLimitOffsetTest as _FetchLimitOffsetTest, + DifficultParametersTest as _DifficultParametersTest, + DistinctOnTest, + HasIndexTest, + IdentityAutoincrementTest, InsertBehaviorTest as _InsertBehaviorTest, LongNameBlowoutTest, + PostCompileParamsTest, QuotedNameArgumentTest, SimpleUpdateDeleteTest as _SimpleUpdateDeleteTest, TimestampMicrosecondsTest as _TimestampMicrosecondsTest, ) +from sqlalchemy.testing.suite.test_types import ( + ArrayTest, +) -if packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"): - from sqlalchemy.testing.suite import LimitOffsetTest as _LimitOffsetTest +from sqlalchemy.testing.suite.test_reflection import ( + BizarroCharacterFKResolutionTest, + ComponentReflectionTest, + HasTableTest, +) - class LimitOffsetTest(_LimitOffsetTest): - @pytest.mark.skip("BigQuery doesn't allow an offset without a limit.") - def test_simple_offset(self): - pass +if packaging.version.parse(sqlalchemy.__version__) >= packaging.version.parse("2.0"): + import uuid + from sqlalchemy.sql import type_coerce + from sqlalchemy.testing.suite import ( + TrueDivTest as _TrueDivTest, + IntegerTest as _IntegerTest, + NumericTest as _NumericTest, + StringTest as _StringTest, + UuidTest as _UuidTest, + ) - test_bound_offset = test_simple_offset + class DifficultParametersTest(_DifficultParametersTest): + """There are some parameters that don't work with bigquery that were removed from this test""" + + tough_parameters = testing.combinations( + ("boring",), + ("per cent",), + ("per % cent",), + ("%percent",), + ("col:ons",), + ("_starts_with_underscore",), + ("more :: %colons%",), + ("_name",), + ("___name",), + ("42numbers",), + ("percent%signs",), + ("has spaces",), + ("1param",), + ("1col:on",), + argnames="paramname", + ) - class TimestampMicrosecondsTest(_TimestampMicrosecondsTest): - data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396, tzinfo=pytz.UTC) + @tough_parameters + @config.requirements.unusual_column_name_characters + def test_round_trip_same_named_column(self, paramname, connection, metadata): + name = paramname - def test_literal(self): - # The base tests doesn't set up the literal properly, because - # it doesn't pass its datatype to `literal`. + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column(name, String(50), nullable=False), + ) - def literal(value): - assert value == self.data - return sqlalchemy.sql.elements.literal(value, self.datatype) + # table is created + t.create(connection) - with mock.patch("sqlalchemy.testing.suite.test_types.literal", literal): - super(TimestampMicrosecondsTest, self).test_literal() + # automatic param generated by insert + connection.execute(t.insert().values({"id": 1, name: "some name"})) + + # automatic param generated by criteria, plus selecting the column + stmt = select(t.c[name]).where(t.c[name] == "some name") + + eq_(connection.scalar(stmt), "some name") + + # use the name in a param explicitly + stmt = select(t.c[name]).where(t.c[name] == bindparam(name)) + + row = connection.execute(stmt, {name: "some name"}).first() + + # name works as the key from cursor.description + eq_(row._mapping[name], "some name") + + # use expanding IN + stmt = select(t.c[name]).where( + t.c[name].in_(["some name", "some other_name"]) + ) + + row = connection.execute(stmt).first() + + @testing.fixture + def multirow_fixture(self, metadata, connection): + mytable = Table( + "mytable", + metadata, + Column("myid", Integer), + Column("name", String(50)), + Column("desc", String(50)), + ) + + mytable.create(connection) + + connection.execute( + mytable.insert(), + [ + {"myid": 1, "name": "a", "desc": "a_desc"}, + {"myid": 2, "name": "b", "desc": "b_desc"}, + {"myid": 3, "name": "c", "desc": "c_desc"}, + {"myid": 4, "name": "d", "desc": "d_desc"}, + ], + ) + yield mytable + + @tough_parameters + def test_standalone_bindparam_escape( + self, paramname, connection, multirow_fixture + ): + tbl1 = multirow_fixture + stmt = select(tbl1.c.myid).where( + tbl1.c.name == bindparam(paramname, value="x") + ) + res = connection.scalar(stmt, {paramname: "c"}) + eq_(res, 3) + + @tough_parameters + def test_standalone_bindparam_escape_expanding( + self, paramname, connection, multirow_fixture + ): + tbl1 = multirow_fixture + stmt = ( + select(tbl1.c.myid) + .where(tbl1.c.name.in_(bindparam(paramname, value=["a", "b"]))) + .order_by(tbl1.c.myid) + ) + + res = connection.scalars(stmt, {paramname: ["d", "a"]}).all() + eq_(res, [1, 4]) + + # BQ has no autoinc and client-side defaults can't work for select + del _IntegerTest.test_huge_int_auto_accommodation + + class NumericTest(_NumericTest): + """Added a where clause for BQ compatibility.""" + + @testing.fixture + def do_numeric_test(self, metadata, connection): + def run(type_, input_, output, filter_=None, check_scale=False): + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + connection.execute(t.insert(), [{"x": x} for x in input_]) + + result = {row[0] for row in connection.execute(t.select())} + output = set(output) + if filter_: + result = {filter_(x) for x in result} + output = {filter_(x) for x in output} + eq_(result, output) + if check_scale: + eq_([str(x) for x in result], [str(x) for x in output]) + + where_expr = True + + connection.execute(t.delete().where(where_expr)) + + if type_.asdecimal: + test_value = decimal.Decimal("2.9") + add_value = decimal.Decimal("37.12") + else: + test_value = 2.9 + add_value = 37.12 + + connection.execute(t.insert(), {"x": test_value}) + assert_we_are_a_number = connection.scalar( + select(type_coerce(t.c.x + add_value, type_)) + ) + eq_( + round(assert_we_are_a_number, 3), + round(test_value + add_value, 3), + ) + + return run + + class TimestampMicrosecondsTest(_TimestampMicrosecondsTest): + """BQ has no support for BQ util.text_type""" + + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396, tzinfo=pytz.UTC) def test_select_direct(self, connection): # This func added because this test was failing when passed the @@ -82,44 +241,243 @@ def literal(value, type_=None): with mock.patch("sqlalchemy.testing.suite.test_types.literal", literal): super(TimestampMicrosecondsTest, self).test_select_direct(connection) -else: - from sqlalchemy.testing.suite import ( - FetchLimitOffsetTest as _FetchLimitOffsetTest, - RowCountTest as _RowCountTest, + def test_round_trip_executemany(self, connection): + unicode_table = self.tables.unicode_table + connection.execute( + unicode_table.insert(), + [{"id": i, "unicode_data": self.data} for i in range(3)], + ) + + rows = connection.execute(select(unicode_table.c.unicode_data)).fetchall() + eq_(rows, [(self.data,) for i in range(3)]) + for row in rows: + assert isinstance(row[0], str) + + sqlalchemy.testing.suite.test_types._UnicodeFixture.test_round_trip_executemany = ( + test_round_trip_executemany ) - class FetchLimitOffsetTest(_FetchLimitOffsetTest): - @pytest.mark.skip("BigQuery doesn't allow an offset without a limit.") - def test_simple_offset(self): + class TrueDivTest(_TrueDivTest): + @pytest.mark.skip("BQ rounds based on datatype") + def test_floordiv_integer(self): pass - test_bound_offset = test_simple_offset - test_expr_offset = test_simple_offset_zero = test_simple_offset + @pytest.mark.skip("BQ rounds based on datatype") + def test_floordiv_integer_bound(self): + pass + + class SimpleUpdateDeleteTest(_SimpleUpdateDeleteTest): + """The base tests fail if operations return rows for some reason.""" + + def test_update(self): + t = self.tables.plain_pk + connection = config.db.connect() + # In SQLAlchemy 2.0, the datatype changed to dict in the following function. + r = connection.execute(t.update().where(t.c.id == 2), dict(data="d2_new")) + assert not r.is_insert + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self): + t = self.tables.plain_pk + connection = config.db.connect() + r = connection.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + class StringTest(_StringTest): + """Added a where clause for BQ compatibility""" + + def test_dont_truncate_rightside( + self, metadata, connection, expr=None, expected=None + ): + t = Table( + "t", + metadata, + Column("x", String(2)), + Column("id", Integer, primary_key=True), + ) + t.create(connection) + connection.connection.commit() + connection.execute( + t.insert(), + [{"x": "AB", "id": 1}, {"x": "BC", "id": 2}, {"x": "AC", "id": 3}], + ) + combinations = [("%B%", ["AB", "BC"]), ("A%C", ["AC"]), ("A%C%Z", [])] + + for args in combinations: + eq_( + connection.scalars(select(t.c.x).where(t.c.x.like(args[0]))).all(), + args[1], + ) + + class UuidTest(_UuidTest): + """BQ needs to pass in UUID as a string""" + + @classmethod + def define_tables(cls, metadata): + Table( + "uuid_table", + metadata, + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), + Column("uuid_data", String), # Use native UUID for primary data + Column( + "uuid_text_data", String, nullable=True + ), # Optional text representation + Column("uuid_data_nonnative", String), + Column("uuid_text_data_nonnative", String), + ) + + def test_uuid_round_trip(self, connection): + data = str(uuid.uuid4()) + uuid_table = self.tables.uuid_table + + connection.execute( + uuid_table.insert(), + {"id": 1, "uuid_data": data, "uuid_data_nonnative": data}, + ) + row = connection.execute( + select(uuid_table.c.uuid_data, uuid_table.c.uuid_data_nonnative).where( + uuid_table.c.uuid_data == data, + uuid_table.c.uuid_data_nonnative == data, + ) + ).first() + eq_(row, (data, data)) - # The original test is missing an order by. + def test_uuid_text_round_trip(self, connection): + data = str(uuid.uuid4()) + uuid_table = self.tables.uuid_table - # Also, note that sqlalchemy union is a union distinct, not a - # union all. This test caught that were were getting that wrong. - def test_limit_render_multiple_times(self, connection): - table = self.tables.some_table - stmt = select(table.c.id).order_by(table.c.id).limit(1).scalar_subquery() + connection.execute( + uuid_table.insert(), + { + "id": 1, + "uuid_text_data": data, + "uuid_text_data_nonnative": data, + }, + ) + row = connection.execute( + select( + uuid_table.c.uuid_text_data, + uuid_table.c.uuid_text_data_nonnative, + ).where( + uuid_table.c.uuid_text_data == data, + uuid_table.c.uuid_text_data_nonnative == data, + ) + ).first() + eq_((row[0].lower(), row[1].lower()), (data, data)) + + def test_literal_uuid(self, literal_round_trip): + data = str(uuid.uuid4()) + literal_round_trip(String(), [data], [data]) + + def test_literal_text(self, literal_round_trip): + data = str(uuid.uuid4()) + literal_round_trip( + String(), + [data], + [data], + filter_=lambda x: x.lower(), + ) - u = sqlalchemy.union(select(stmt), select(stmt)).subquery().select() + def test_literal_nonnative_uuid(self, literal_round_trip): + data = str(uuid.uuid4()) + literal_round_trip(String(), [data], [data]) + + def test_literal_nonnative_text(self, literal_round_trip): + data = str(uuid.uuid4()) + literal_round_trip( + String(), + [data], + [data], + filter_=lambda x: x.lower(), + ) - self._assert_result( - connection, - u, - [(1,)], + @testing.requires.insert_returning + def test_uuid_returning(self, connection): + data = str(uuid.uuid4()) + str_data = str(data) + uuid_table = self.tables.uuid_table + + result = connection.execute( + uuid_table.insert().returning( + uuid_table.c.uuid_data, + uuid_table.c.uuid_text_data, + uuid_table.c.uuid_data_nonnative, + uuid_table.c.uuid_text_data_nonnative, + ), + { + "id": 1, + "uuid_data": data, + "uuid_text_data": str_data, + "uuid_data_nonnative": data, + "uuid_text_data_nonnative": str_data, + }, ) + row = result.first() + + eq_(row, (data, str_data, data, str_data)) + +else: + from sqlalchemy.testing.suite import ( + RowCountTest as _RowCountTest, + ) del DifficultParametersTest # exercises column names illegal in BQ - del DistinctOnTest # expects unquoted table names. - del HasIndexTest # BQ doesn't do the indexes that SQLA is loooking for. - del IdentityAutoincrementTest # BQ doesn't do autoincrement - # This test makes makes assertions about generated sql and trips - # over the backquotes that we add everywhere. XXX Why do we do that? - del PostCompileParamsTest + class RowCountTest(_RowCountTest): + """""" + + @classmethod + def insert_data(cls, connection): + cls.data = data = [ + ("Angela", "A"), + ("Andrew", "A"), + ("Anand", "A"), + ("Bob", "B"), + ("Bobette", "B"), + ("Buffy", "B"), + ("Charlie", "C"), + ("Cynthia", "C"), + ("Chris", "C"), + ] + + employees_table = cls.tables.employees + connection.execute( + employees_table.insert(), + [ + {"employee_id": i, "name": n, "department": d} + for i, (n, d) in enumerate(data) + ], + ) + + class SimpleUpdateDeleteTest(_SimpleUpdateDeleteTest): + """The base tests fail if operations return rows for some reason.""" + + def test_update(self): + t = self.tables.plain_pk + r = config.db.execute(t.update().where(t.c.id == 2), data="d2_new") + assert not r.is_insert + + eq_( + config.db.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self): + t = self.tables.plain_pk + r = config.db.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + eq_( + config.db.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) class TimestampMicrosecondsTest(_TimestampMicrosecondsTest): data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396, tzinfo=pytz.UTC) @@ -171,40 +529,14 @@ def test_round_trip_executemany(self, connection): test_round_trip_executemany ) - class RowCountTest(_RowCountTest): - @classmethod - def insert_data(cls, connection): - cls.data = data = [ - ("Angela", "A"), - ("Andrew", "A"), - ("Anand", "A"), - ("Bob", "B"), - ("Bobette", "B"), - ("Buffy", "B"), - ("Charlie", "C"), - ("Cynthia", "C"), - ("Chris", "C"), - ] - - employees_table = cls.tables.employees - connection.execute( - employees_table.insert(), - [ - {"employee_id": i, "name": n, "department": d} - for i, (n, d) in enumerate(data) - ], - ) - - -# Quotes aren't allowed in BigQuery table names. -del QuotedNameArgumentTest +class CTETest(_CTETest): + @pytest.mark.skip("Can't use CTEs with insert") + def test_insert_from_select_round_trip(self): + pass -class InsertBehaviorTest(_InsertBehaviorTest): - @pytest.mark.skip( - "BQ has no autoinc and client-side defaults can't work for select." - ) - def test_insert_from_select_autoinc(cls): + @pytest.mark.skip("Recusive CTEs aren't supported.") + def test_select_recursive_round_trip(self): pass @@ -220,7 +552,7 @@ def test_select_exists(self, connection): stuff = self.tables.stuff eq_( connection.execute( - select([stuff.c.id]).where( + select(stuff.c.id).where( and_( stuff.c.id == 1, exists().where(stuff.c.data == "some data"), @@ -234,58 +566,71 @@ def test_select_exists_false(self, connection): stuff = self.tables.stuff eq_( connection.execute( - select([stuff.c.id]).where(exists().where(stuff.c.data == "no data")) + select(stuff.c.id).where(exists().where(stuff.c.data == "no data")) ).fetchall(), [], ) -# This test requires features (indexes, primary keys, etc., that BigQuery doesn't have. -del LongNameBlowoutTest - +class FetchLimitOffsetTest(_FetchLimitOffsetTest): + @pytest.mark.skip("BigQuery doesn't allow an offset without a limit.") + def test_simple_offset(self): + pass -class SimpleUpdateDeleteTest(_SimpleUpdateDeleteTest): - """The base tests fail if operations return rows for some reason.""" + test_bound_offset = test_simple_offset + test_expr_offset = test_simple_offset_zero = test_simple_offset + test_limit_offset_nobinds = test_simple_offset # TODO figure out + # how to prevent this from failing + # The original test is missing an order by. - def test_update(self): - t = self.tables.plain_pk - r = config.db.execute(t.update().where(t.c.id == 2), data="d2_new") - assert not r.is_insert - # assert not r.returns_rows + # Also, note that sqlalchemy union is a union distinct, not a + # union all. This test caught that were were getting that wrong. + def test_limit_render_multiple_times(self, connection): + table = self.tables.some_table + stmt = select(table.c.id).order_by(table.c.id).limit(1).scalar_subquery() - eq_( - config.db.execute(t.select().order_by(t.c.id)).fetchall(), - [(1, "d1"), (2, "d2_new"), (3, "d3")], - ) + u = sqlalchemy.union(select(stmt), select(stmt)).subquery().select() - def test_delete(self): - t = self.tables.plain_pk - r = config.db.execute(t.delete().where(t.c.id == 2)) - assert not r.is_insert - # assert not r.returns_rows - eq_( - config.db.execute(t.select().order_by(t.c.id)).fetchall(), - [(1, "d1"), (3, "d3")], + self._assert_result( + connection, + u, + [(1,)], ) -class CTETest(_CTETest): - @pytest.mark.skip("Can't use CTEs with insert") - def test_insert_from_select_round_trip(self): - pass - - @pytest.mark.skip("Recusive CTEs aren't supported.") - def test_select_recursive_round_trip(self): +class InsertBehaviorTest(_InsertBehaviorTest): + @pytest.mark.skip( + "BQ has no autoinc and client-side defaults can't work for select." + ) + def test_insert_from_select_autoinc(cls): pass - -class ComponentReflectionTest(_ComponentReflectionTest): - @pytest.mark.skip("Big query types don't track precision, length, etc.") - def course_grained_types(): + @pytest.mark.skip( + "BQ has no autoinc and client-side defaults can't work for select." + ) + def test_no_results_for_non_returning_insert(cls): pass - test_numeric_reflection = test_varchar_reflection = course_grained_types - @pytest.mark.skip("BQ doesn't have indexes (in the way these tests expect).") - def test_get_indexes(self): - pass +del ComponentReflectionTest # Multiple tests re: CHECK CONSTRAINTS, etc which +# BQ does not support +# class ComponentReflectionTest(_ComponentReflectionTest): +# @pytest.mark.skip("Big query types don't track precision, length, etc.") +# def course_grained_types(): +# pass + +# test_numeric_reflection = test_varchar_reflection = course_grained_types + +# @pytest.mark.skip("BQ doesn't have indexes (in the way these tests expect).") +# def test_get_indexes(self): +# pass + +del ArrayTest # only appears to apply to postgresql +del BizarroCharacterFKResolutionTest +del HasTableTest.test_has_table_cache # TODO confirm whether BQ has table caching +del DistinctOnTest # expects unquoted table names. +del HasIndexTest # BQ doesn't do the indexes that SQLA is loooking for. +del IdentityAutoincrementTest # BQ doesn't do autoincrement +del LongNameBlowoutTest # Requires features (indexes, primary keys, etc., that BigQuery doesn't have. +del PostCompileParamsTest # BQ adds backticks to bind parameters, causing failure of tests TODO: fix this? +del QuotedNameArgumentTest # Quotes aren't allowed in BigQuery table names. diff --git a/tests/system/test__struct.py b/tests/system/test__struct.py index bb7958c9..69d2ba76 100644 --- a/tests/system/test__struct.py +++ b/tests/system/test__struct.py @@ -54,7 +54,7 @@ def test_struct(engine, bigquery_dataset, metadata): ) ) - assert list(conn.execute(sqlalchemy.select([table]))) == [ + assert list(conn.execute(sqlalchemy.select(table))) == [ ( { "name": "bob", @@ -62,16 +62,16 @@ def test_struct(engine, bigquery_dataset, metadata): }, ) ] - assert list(conn.execute(sqlalchemy.select([table.c.person.NAME]))) == [("bob",)] - assert list(conn.execute(sqlalchemy.select([table.c.person.children[0]]))) == [ + assert list(conn.execute(sqlalchemy.select(table.c.person.NAME))) == [("bob",)] + assert list(conn.execute(sqlalchemy.select(table.c.person.children[0]))) == [ ({"name": "billy", "bdate": datetime.date(2020, 1, 1)},) ] - assert list( - conn.execute(sqlalchemy.select([table.c.person.children[0].bdate])) - ) == [(datetime.date(2020, 1, 1),)] + assert list(conn.execute(sqlalchemy.select(table.c.person.children[0].bdate))) == [ + (datetime.date(2020, 1, 1),) + ] assert list( conn.execute( - sqlalchemy.select([table]).where(table.c.person.children[0].NAME == "billy") + sqlalchemy.select(table).where(table.c.person.children[0].NAME == "billy") ) ) == [ ( @@ -84,7 +84,7 @@ def test_struct(engine, bigquery_dataset, metadata): assert ( list( conn.execute( - sqlalchemy.select([table]).where( + sqlalchemy.select(table).where( table.c.person.children[0].NAME == "sally" ) ) @@ -99,21 +99,22 @@ def test_complex_literals_pr_67(engine, bigquery_dataset, metadata): # Simple select example: table_name = f"{bigquery_dataset}.test_comples_literals_pr_67" - engine.execute( - f""" - create table {table_name} as ( - select 'a' as id, - struct(1 as x__count, 2 as y__count, 3 as z__count) as dimensions + with engine.connect() as conn: + conn.execute( + sqlalchemy.text( + f""" + create table {table_name} as ( + select 'a' as id, + struct(1 as x__count, 2 as y__count, 3 as z__count) as dimensions + ) + """ ) - """ - ) + ) table = sqlalchemy.Table(table_name, metadata, autoload_with=engine) got = str( - sqlalchemy.select([(table.c.dimensions.x__count + 5).label("c")]).compile( - engine - ) + sqlalchemy.select((table.c.dimensions.x__count + 5).label("c")).compile(engine) ) want = ( f"SELECT (`{table_name}`.`dimensions`.x__count) + %(param_1:INT64)s AS `c` \n" @@ -149,9 +150,11 @@ def test_unnest_and_struct_access_233(engine, bigquery_dataset, metadata): conn.execute( mock_table.insert(), - dict(mock_id="x"), - dict(mock_id="y"), - dict(mock_id="z"), + [ + dict(mock_id="x"), + dict(mock_id="y"), + dict(mock_id="z"), + ], ) conn.execute( another_mock_table.insert(), diff --git a/tests/system/test_alembic.py b/tests/system/test_alembic.py index 1948a19a..30308c68 100644 --- a/tests/system/test_alembic.py +++ b/tests/system/test_alembic.py @@ -23,7 +23,7 @@ from sqlalchemy import Column, DateTime, Integer, String, Numeric import google.api_core.exceptions -from google.cloud.bigquery import SchemaField +from google.cloud.bigquery import SchemaField, TimePartitioning alembic = pytest.importorskip("alembic") @@ -138,15 +138,12 @@ def test_alembic_scenario(alembic_table): op.drop_table("accounts") assert alembic_table("accounts") is None - op.execute( - """ - create table transactions( - account INT64 NOT NULL, - transaction_time DATETIME NOT NULL, - amount NUMERIC(11, 2) NOT NULL - ) - partition by DATE(transaction_time) - """ + op.create_table( + "transactions", + Column("account", Integer, nullable=False), + Column("transaction_time", DateTime(), nullable=False), + Column("amount", Numeric(11, 2), nullable=False), + bigquery_time_partitioning=TimePartitioning(field="transaction_time"), ) op.alter_column("transactions", "amount", nullable=True) diff --git a/tests/system/test_geography.py b/tests/system/test_geography.py index 7189eebb..c04748af 100644 --- a/tests/system/test_geography.py +++ b/tests/system/test_geography.py @@ -74,7 +74,7 @@ def test_geoalchemy2_core(bigquery_dataset): from sqlalchemy.sql import select assert sorted( - (r.name, r.geog.desc[:4]) for r in conn.execute(select([lake_table])) + (r.name, r.geog.desc[:4]) for r in conn.execute(select(lake_table)) ) == [("Garde", "0103"), ("Majeur", "0103"), ("Orta", "0103")] # Spatial query @@ -82,26 +82,32 @@ def test_geoalchemy2_core(bigquery_dataset): from sqlalchemy import func [[result]] = conn.execute( - select([lake_table.c.name], func.ST_Contains(lake_table.c.geog, "POINT(4 1)")) + select(lake_table.c.name).where( + func.ST_Contains(lake_table.c.geog, "POINT(4 1)") + ) ) assert result == "Orta" assert sorted( (r.name, int(r.area)) for r in conn.execute( - select([lake_table.c.name, lake_table.c.geog.ST_AREA().label("area")]) + select(lake_table.c.name, lake_table.c.geog.ST_AREA().label("area")) ) ) == [("Garde", 49452374328), ("Majeur", 12364036567), ("Orta", 111253664228)] # Extra: Make sure we can save a retrieved value back: - [[geog]] = conn.execute(select([lake_table.c.geog], lake_table.c.name == "Garde")) + [[geog]] = conn.execute( + select(lake_table.c.geog).where(lake_table.c.name == "Garde") + ) conn.execute(lake_table.insert().values(name="test", geog=geog)) assert ( int( list( conn.execute( - select([lake_table.c.geog.st_area()], lake_table.c.name == "test") + select(lake_table.c.geog.st_area()).where( + lake_table.c.name == "test" + ) ) )[0][0] ) @@ -122,7 +128,9 @@ def test_geoalchemy2_core(bigquery_dataset): int( list( conn.execute( - select([lake_table.c.geog.st_area()], lake_table.c.name == "test2") + select(lake_table.c.geog.st_area()).where( + lake_table.c.name == "test2" + ) ) )[0][0] ) diff --git a/tests/system/test_sqlalchemy_bigquery.py b/tests/system/test_sqlalchemy_bigquery.py index 62b534ff..457a8ea8 100644 --- a/tests/system/test_sqlalchemy_bigquery.py +++ b/tests/system/test_sqlalchemy_bigquery.py @@ -22,6 +22,8 @@ import datetime import decimal +from google.cloud.bigquery import TimePartitioning + from sqlalchemy.engine import create_engine from sqlalchemy.schema import Table, MetaData, Column from sqlalchemy.ext.declarative import declarative_base @@ -155,24 +157,22 @@ def engine_with_location(): @pytest.fixture(scope="session") def table(engine, bigquery_dataset): - return Table(f"{bigquery_dataset}.sample", MetaData(bind=engine), autoload=True) + return Table(f"{bigquery_dataset}.sample", MetaData(), autoload_with=engine) @pytest.fixture(scope="session") def table_using_test_dataset(engine_using_test_dataset): - return Table("sample", MetaData(bind=engine_using_test_dataset), autoload=True) + return Table("sample", MetaData(), autoload_with=engine_using_test_dataset) @pytest.fixture(scope="session") def table_one_row(engine, bigquery_dataset): - return Table( - f"{bigquery_dataset}.sample_one_row", MetaData(bind=engine), autoload=True - ) + return Table(f"{bigquery_dataset}.sample_one_row", MetaData(), autoload_with=engine) @pytest.fixture(scope="session") def table_dml(engine, bigquery_empty_table): - return Table(bigquery_empty_table, MetaData(bind=engine), autoload=True) + return Table(bigquery_empty_table, MetaData(), autoload_with=engine) @pytest.fixture(scope="session") @@ -214,7 +214,7 @@ def query(table): .label("outer") ) query = ( - select([col1, col2, col3]) + select(col1, col2, col3) .where(col1 < "2017-01-01 00:00:00") .group_by(col1) .order_by(col2) @@ -225,37 +225,47 @@ def query(table): def test_engine_with_dataset(engine_using_test_dataset, bigquery_dataset): - rows = engine_using_test_dataset.execute("SELECT * FROM sample_one_row").fetchall() - assert list(rows[0]) == ONE_ROW_CONTENTS + with engine_using_test_dataset.connect() as conn: + rows = conn.execute(sqlalchemy.text("SELECT * FROM sample_one_row")).fetchall() + assert list(rows[0]) == ONE_ROW_CONTENTS - table_one_row = Table( - "sample_one_row", MetaData(bind=engine_using_test_dataset), autoload=True - ) - rows = table_one_row.select(use_labels=True).execute().fetchall() - assert list(rows[0]) == ONE_ROW_CONTENTS_EXPANDED + table_one_row = Table( + "sample_one_row", MetaData(), autoload_with=engine_using_test_dataset + ) + rows = conn.execute( + table_one_row.select().set_label_style( + sqlalchemy.LABEL_STYLE_TABLENAME_PLUS_COL + ) + ).fetchall() + assert list(rows[0]) == ONE_ROW_CONTENTS_EXPANDED - table_one_row = Table( - f"{bigquery_dataset}.sample_one_row", - MetaData(bind=engine_using_test_dataset), - autoload=True, - ) - rows = table_one_row.select(use_labels=True).execute().fetchall() - # verify that we are pulling from the specifically-named dataset, - # instead of pulling from the default dataset of the engine (which - # does not have this table at all) - assert list(rows[0]) == ONE_ROW_CONTENTS_EXPANDED + table_one_row = Table( + f"{bigquery_dataset}.sample_one_row", + MetaData(), + autoload_with=engine_using_test_dataset, + ) + rows = conn.execute( + table_one_row.select().set_label_style( + sqlalchemy.LABEL_STYLE_TABLENAME_PLUS_COL + ) + ).fetchall() + # verify that we are pulling from the specifically-named dataset, + # instead of pulling from the default dataset of the engine (which + # does not have this table at all) + assert list(rows[0]) == ONE_ROW_CONTENTS_EXPANDED def test_dataset_location( engine_with_location, bigquery_dataset, bigquery_regional_dataset ): - rows = engine_with_location.execute( - f"SELECT * FROM {bigquery_regional_dataset}.sample_one_row" - ).fetchall() - assert list(rows[0]) == ONE_ROW_CONTENTS + with engine_with_location.connect() as conn: + rows = conn.execute( + sqlalchemy.text(f"SELECT * FROM {bigquery_regional_dataset}.sample_one_row") + ).fetchall() + assert list(rows[0]) == ONE_ROW_CONTENTS -def test_reflect_select(table, table_using_test_dataset): +def test_reflect_select(table, engine_using_test_dataset, table_using_test_dataset): for table in [table, table_using_test_dataset]: assert table.comment == "A sample table containing most data types." @@ -276,61 +286,73 @@ def test_reflect_select(table, table_using_test_dataset): assert isinstance(table.c["nested_record.record.name"].type, types.String) assert isinstance(table.c.array.type, types.ARRAY) - # Force unique column labels using `use_labels` below to deal - # with BQ sometimes complaining about duplicate column names - # when a destination table is specified, even though no - # destination table is specified. When this test was written, - # `use_labels` was forced by the dialect. - rows = table.select(use_labels=True).execute().fetchall() - assert len(rows) == 1000 + with engine_using_test_dataset.connect() as conn: + rows = conn.execute( + table.select().set_label_style( + sqlalchemy.LABEL_STYLE_TABLENAME_PLUS_COL + ) + ).fetchall() + assert len(rows) == 1000 def test_content_from_raw_queries(engine, bigquery_dataset): - rows = engine.execute(f"SELECT * FROM {bigquery_dataset}.sample_one_row").fetchall() - assert list(rows[0]) == ONE_ROW_CONTENTS + with engine.connect() as conn: + rows = conn.execute( + sqlalchemy.text(f"SELECT * FROM {bigquery_dataset}.sample_one_row") + ).fetchall() + assert list(rows[0]) == ONE_ROW_CONTENTS def test_record_content_from_raw_queries(engine, bigquery_dataset): - rows = engine.execute( - f"SELECT record.name FROM {bigquery_dataset}.sample_one_row" - ).fetchall() - assert rows[0][0] == "John Doe" + with engine.connect() as conn: + rows = conn.execute( + sqlalchemy.text( + f"SELECT record.name FROM {bigquery_dataset}.sample_one_row" + ) + ).fetchall() + assert rows[0][0] == "John Doe" def test_content_from_reflect(engine, table_one_row): - rows = table_one_row.select(use_labels=True).execute().fetchall() - assert list(rows[0]) == ONE_ROW_CONTENTS_EXPANDED + with engine.connect() as conn: + rows = conn.execute( + table_one_row.select().set_label_style( + sqlalchemy.LABEL_STYLE_TABLENAME_PLUS_COL + ) + ).fetchall() + assert list(rows[0]) == ONE_ROW_CONTENTS_EXPANDED def test_unicode(engine, table_one_row): unicode_str = "白人看不懂" - returned_str = sqlalchemy.select( - [expression.bindparam("好", unicode_str)], - from_obj=table_one_row, - ).scalar() + with engine.connect() as conn: + returned_str = conn.execute( + sqlalchemy.select(expression.bindparam("好", unicode_str)).select_from( + table_one_row + ) + ).scalar() assert returned_str == unicode_str def test_reflect_select_shared_table(engine): one_row = Table( - "bigquery-public-data.samples.natality", MetaData(bind=engine), autoload=True + "bigquery-public-data.samples.natality", MetaData(), autoload_with=engine ) - row = one_row.select().limit(1).execute().first() - assert len(row) >= 1 + with engine.connect() as conn: + row = conn.execute(one_row.select().limit(1)).first() + assert len(row) >= 1 def test_reflect_table_does_not_exist(engine, bigquery_dataset): with pytest.raises(NoSuchTableError): Table( f"{bigquery_dataset}.table_does_not_exist", - MetaData(bind=engine), - autoload=True, + MetaData(), + autoload_with=engine, ) assert ( - Table( - f"{bigquery_dataset}.table_does_not_exist", MetaData(bind=engine) - ).exists() + sqlalchemy.inspect(engine).has_table(f"{bigquery_dataset}.table_does_not_exist") is False ) @@ -339,18 +361,18 @@ def test_reflect_dataset_does_not_exist(engine): with pytest.raises(NoSuchTableError): Table( "dataset_does_not_exist.table_does_not_exist", - MetaData(bind=engine), - autoload=True, + MetaData(), + autoload_with=engine, ) def test_tables_list(engine, engine_using_test_dataset, bigquery_dataset): - tables = engine.table_names() + tables = sqlalchemy.inspect(engine).get_table_names() assert f"{bigquery_dataset}.sample" in tables assert f"{bigquery_dataset}.sample_one_row" in tables assert f"{bigquery_dataset}.sample_view" not in tables - tables = engine_using_test_dataset.table_names() + tables = sqlalchemy.inspect(engine_using_test_dataset).get_table_names() assert "sample" in tables assert "sample_one_row" in tables assert "sample_view" not in tables @@ -377,13 +399,13 @@ def test_nested_labels(engine, table): sqlalchemy.func.sum(col.label("inner")).label("outer") ).over(), sqlalchemy.func.sum( - sqlalchemy.case([[sqlalchemy.literal(True), col.label("inner")]]).label( + sqlalchemy.case((sqlalchemy.literal(True), col.label("inner"))).label( "outer" ) ), sqlalchemy.func.sum( sqlalchemy.func.sum( - sqlalchemy.case([[sqlalchemy.literal(True), col.label("inner")]]).label( + sqlalchemy.case((sqlalchemy.literal(True), col.label("inner"))).label( "middle" ) ).label("outer") @@ -410,7 +432,7 @@ def test_session_query( col_concat, func.avg(table.c.integer), func.sum( - case([(table.c.boolean == sqlalchemy.literal(True), 1)], else_=0) + case((table.c.boolean == sqlalchemy.literal(True), 1), else_=0) ), ) .group_by(table.c.string, col_concat) @@ -443,13 +465,14 @@ def test_custom_expression( ): """GROUP BY clause should use labels instead of expressions""" q = query(table) - result = engine.execute(q).fetchall() - assert len(result) > 0 + with engine.connect() as conn: + result = conn.execute(q).fetchall() + assert len(result) > 0 q = query(table_using_test_dataset) - result = engine_using_test_dataset.execute(q).fetchall() - - assert len(result) > 0 + with engine_using_test_dataset.connect() as conn: + result = conn.execute(q).fetchall() + assert len(result) > 0 def test_compiled_query_literal_binds( @@ -457,15 +480,17 @@ def test_compiled_query_literal_binds( ): q = query(table) compiled = q.compile(engine, compile_kwargs={"literal_binds": True}) - result = engine.execute(compiled).fetchall() - assert len(result) > 0 + with engine.connect() as conn: + result = conn.execute(compiled).fetchall() + assert len(result) > 0 q = query(table_using_test_dataset) compiled = q.compile( engine_using_test_dataset, compile_kwargs={"literal_binds": True} ) - result = engine_using_test_dataset.execute(compiled).fetchall() - assert len(result) > 0 + with engine_using_test_dataset.connect() as conn: + result = conn.execute(compiled).fetchall() + assert len(result) > 0 @pytest.mark.parametrize( @@ -494,31 +519,46 @@ def test_joins(session, table, table_one_row): def test_querying_wildcard_tables(engine): table = Table( - "bigquery-public-data.noaa_gsod.gsod*", MetaData(bind=engine), autoload=True + "bigquery-public-data.noaa_gsod.gsod*", MetaData(), autoload_with=engine ) - rows = table.select().limit(1).execute().first() - assert len(rows) > 0 + with engine.connect() as conn: + rows = conn.execute(table.select().limit(1)).first() + assert len(rows) > 0 def test_dml(engine, session, table_dml): - # test insert - engine.execute(table_dml.insert(ONE_ROW_CONTENTS_DML)) - result = table_dml.select(use_labels=True).execute().fetchall() - assert len(result) == 1 - - # test update - session.query(table_dml).filter(table_dml.c.string == "test").update( - {"string": "updated_row"}, synchronize_session=False - ) - updated_result = table_dml.select(use_labels=True).execute().fetchone() - assert updated_result[table_dml.c.string] == "updated_row" + """ + Test DML operations on a table with no data. This table is created + in the `bigquery_empty_table` fixture. - # test delete - session.query(table_dml).filter(table_dml.c.string == "updated_row").delete( - synchronize_session=False - ) - result = table_dml.select(use_labels=True).execute().fetchall() - assert len(result) == 0 + Modern versions of sqlalchemy does not really require setting the + label style. This has been maintained to retain this test. + """ + # test insert + with engine.connect() as conn: + conn.execute(table_dml.insert().values(ONE_ROW_CONTENTS_DML)) + result = conn.execute( + table_dml.select().set_label_style(sqlalchemy.LABEL_STYLE_DEFAULT) + ).fetchall() + assert len(result) == 1 + + # test update + session.query(table_dml).filter(table_dml.c.string == "test").update( + {"string": "updated_row"}, synchronize_session=False + ) + updated_result = conn.execute( + table_dml.select().set_label_style(sqlalchemy.LABEL_STYLE_DEFAULT) + ).fetchone() + assert updated_result._mapping[table_dml.c.string] == "updated_row" + + # test delete + session.query(table_dml).filter(table_dml.c.string == "updated_row").delete( + synchronize_session=False + ) + result = conn.execute( + table_dml.select().set_label_style(sqlalchemy.LABEL_STYLE_DEFAULT) + ).fetchall() + assert len(result) == 0 def test_create_table(engine, bigquery_dataset): @@ -539,6 +579,14 @@ def test_create_table(engine, bigquery_dataset): Column("binary_c", sqlalchemy.BINARY), bigquery_description="test table description", bigquery_friendly_name="test table name", + bigquery_expiration_timestamp=datetime.datetime(2183, 3, 26, 8, 30, 0), + bigquery_time_partitioning=TimePartitioning( + field="timestamp_c", + expiration_ms=1000 * 60 * 60 * 24 * 30, # 30 days + ), + bigquery_require_partition_filter=True, + bigquery_default_rounding_mode="ROUND_HALF_EVEN", + bigquery_clustering_fields=["integer_c", "decimal_c"], ) meta.create_all(engine) meta.drop_all(engine) @@ -594,17 +642,7 @@ def test_view_names(inspector, inspector_using_test_dataset, bigquery_dataset): def test_get_indexes(inspector, inspector_using_test_dataset, bigquery_dataset): for _ in [f"{bigquery_dataset}.sample", f"{bigquery_dataset}.sample_one_row"]: indexes = inspector.get_indexes(f"{bigquery_dataset}.sample") - assert len(indexes) == 2 - assert indexes[0] == { - "name": "partition", - "column_names": ["timestamp"], - "unique": False, - } - assert indexes[1] == { - "name": "clustering", - "column_names": ["integer", "string"], - "unique": False, - } + assert len(indexes) == 0 def test_get_columns(inspector, inspector_using_test_dataset, bigquery_dataset): @@ -679,16 +717,34 @@ def test_invalid_table_reference( def test_has_table(engine, engine_using_test_dataset, bigquery_dataset): - assert engine.has_table("sample", bigquery_dataset) is True - assert engine.has_table(f"{bigquery_dataset}.sample") is True - assert engine.has_table(f"{bigquery_dataset}.nonexistent_table") is False - assert engine.has_table("nonexistent_table", "nonexistent_dataset") is False + assert sqlalchemy.inspect(engine).has_table("sample", bigquery_dataset) is True + assert sqlalchemy.inspect(engine).has_table(f"{bigquery_dataset}.sample") is True + assert ( + sqlalchemy.inspect(engine).has_table(f"{bigquery_dataset}.nonexistent_table") + is False + ) + assert ( + sqlalchemy.inspect(engine).has_table("nonexistent_table", "nonexistent_dataset") + is False + ) - assert engine_using_test_dataset.has_table("sample") is True - assert engine_using_test_dataset.has_table("sample", bigquery_dataset) is True - assert engine_using_test_dataset.has_table(f"{bigquery_dataset}.sample") is True + assert sqlalchemy.inspect(engine_using_test_dataset).has_table("sample") is True + assert ( + sqlalchemy.inspect(engine_using_test_dataset).has_table( + "sample", bigquery_dataset + ) + is True + ) + assert ( + sqlalchemy.inspect(engine_using_test_dataset).has_table( + f"{bigquery_dataset}.sample" + ) + is True + ) - assert engine_using_test_dataset.has_table("sample_alt") is False + assert ( + sqlalchemy.inspect(engine_using_test_dataset).has_table("sample_alt") is False + ) def test_distinct_188(engine, bigquery_dataset): @@ -735,7 +791,7 @@ def test_huge_in(): try: assert list( conn.execute( - sqlalchemy.select([sqlalchemy.literal(-1).in_(list(range(99999)))]) + sqlalchemy.select(sqlalchemy.literal(-1).in_(list(range(99999)))) ) ) == [(False,)] except Exception: @@ -765,7 +821,7 @@ def test_unnest(engine, bigquery_dataset): conn.execute( table.insert(), [dict(objects=["a", "b", "c"]), dict(objects=["x", "y"])] ) - query = select([func.unnest(table.c.objects).alias("foo_objects").column]) + query = select(func.unnest(table.c.objects).alias("foo_objects").column) compiled = str(query.compile(engine)) assert " ".join(compiled.strip().split()) == ( f"SELECT `foo_objects`" @@ -800,10 +856,8 @@ def test_unnest_with_cte(engine, bigquery_dataset): ) selectable = select(table.c).select_from(table).cte("cte") query = select( - [ - selectable.c.foo, - func.unnest(selectable.c.bars).column_valued("unnest_bars"), - ] + selectable.c.foo, + func.unnest(selectable.c.bars).column_valued("unnest_bars"), ).select_from(selectable) compiled = str(query.compile(engine)) assert " ".join(compiled.strip().split()) == ( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index f808b380..c75113a9 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -25,21 +25,19 @@ import pytest import sqlalchemy +from sqlalchemy_bigquery.base import BigQueryDDLCompiler, BigQueryDialect + from . import fauxdbi sqlalchemy_version = packaging.version.parse(sqlalchemy.__version__) -sqlalchemy_1_3_or_higher = pytest.mark.skipif( - sqlalchemy_version < packaging.version.parse("1.3"), - reason="requires sqlalchemy 1.3 or higher", +sqlalchemy_before_2_0 = pytest.mark.skipif( + sqlalchemy_version >= packaging.version.parse("2.0"), + reason="requires sqlalchemy 1.3 or lower", ) -sqlalchemy_1_4_or_higher = pytest.mark.skipif( - sqlalchemy_version < packaging.version.parse("1.4"), +sqlalchemy_2_0_or_higher = pytest.mark.skipif( + sqlalchemy_version < packaging.version.parse("2.0"), reason="requires sqlalchemy 1.4 or higher", ) -sqlalchemy_before_1_4 = pytest.mark.skipif( - sqlalchemy_version >= packaging.version.parse("1.4"), - reason="requires sqlalchemy 1.3 or lower", -) @pytest.fixture() @@ -91,6 +89,11 @@ def metadata(): return sqlalchemy.MetaData() +@pytest.fixture() +def ddl_compiler(): + return BigQueryDDLCompiler(BigQueryDialect(), None) + + def setup_table(connection, name, *columns, initial_data=(), **kw): metadata = sqlalchemy.MetaData() table = sqlalchemy.Table(name, metadata, *columns, **kw) diff --git a/tests/unit/test__struct.py b/tests/unit/test__struct.py index 77577066..6e7c7a3d 100644 --- a/tests/unit/test__struct.py +++ b/tests/unit/test__struct.py @@ -84,7 +84,7 @@ def _col(): ) def test_struct_traversal_project(faux_conn, expr, sql): sql = f"SELECT {sql} AS `anon_1` \nFROM `t`" - assert str(sqlalchemy.select([expr]).compile(faux_conn.engine)) == sql + assert str(sqlalchemy.select(expr).compile(faux_conn.engine)) == sql @pytest.mark.parametrize( @@ -117,7 +117,7 @@ def test_struct_traversal_project(faux_conn, expr, sql): ) def test_struct_traversal_filter(faux_conn, expr, sql, param=1): want = f"SELECT `t`.`person` \nFROM `t`, `t` \nWHERE {sql}" - got = str(sqlalchemy.select([_col()]).where(expr).compile(faux_conn.engine)) + got = str(sqlalchemy.select(_col()).where(expr).compile(faux_conn.engine)) assert got == want diff --git a/tests/unit/test_catalog_functions.py b/tests/unit/test_catalog_functions.py index 78614c9f..7eab7b7b 100644 --- a/tests/unit/test_catalog_functions.py +++ b/tests/unit/test_catalog_functions.py @@ -126,18 +126,7 @@ def test_get_indexes(faux_conn): client.tables.foo.time_partitioning = TimePartitioning(field="tm") client.tables.foo.clustering_fields = ["user_email", "store_code"] - assert faux_conn.dialect.get_indexes(faux_conn, "foo") == [ - dict( - name="partition", - column_names=["tm"], - unique=False, - ), - dict( - name="clustering", - column_names=["user_email", "store_code"], - unique=False, - ), - ] + assert faux_conn.dialect.get_indexes(faux_conn, "foo") == [] def test_no_table_pk_constraint(faux_conn): diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index db02e593..cc9116e3 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -21,7 +21,28 @@ import sqlalchemy.exc from .conftest import setup_table -from .conftest import sqlalchemy_1_4_or_higher +from .conftest import ( + sqlalchemy_2_0_or_higher, + sqlalchemy_before_2_0, +) +from sqlalchemy.sql.functions import rollup, cube, grouping_sets + + +@pytest.fixture +def table(faux_conn, metadata): + # Fixture to create a sample table for testing + + table = setup_table( + faux_conn, + "table1", + metadata, + sqlalchemy.Column("foo", sqlalchemy.Integer), + sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), + ) + + yield table + + table.drop(faux_conn) def test_constraints_are_ignored(faux_conn, metadata): @@ -58,7 +79,6 @@ def test_cant_compile_unnamed_column(faux_conn, metadata): sqlalchemy.Column(sqlalchemy.Integer).compile(faux_conn) -@sqlalchemy_1_4_or_higher def test_no_alias_for_known_tables(faux_conn, metadata): # See: https://github.com/googleapis/python-bigquery-sqlalchemy/issues/353 table = setup_table( @@ -80,7 +100,6 @@ def test_no_alias_for_known_tables(faux_conn, metadata): assert found_sql == expected_sql -@sqlalchemy_1_4_or_higher def test_no_alias_for_known_tables_cte(faux_conn, metadata): # See: https://github.com/googleapis/python-bigquery-sqlalchemy/issues/368 table = setup_table( @@ -114,3 +133,258 @@ def test_no_alias_for_known_tables_cte(faux_conn, metadata): ) found_cte_sql = q.compile(faux_conn).string assert found_cte_sql == expected_cte_sql + + +def prepare_implicit_join_base_query( + faux_conn, metadata, select_from_table2, old_syntax +): + table1 = setup_table( + faux_conn, "table1", metadata, sqlalchemy.Column("foo", sqlalchemy.Integer) + ) + table2 = setup_table( + faux_conn, + "table2", + metadata, + sqlalchemy.Column("foos", sqlalchemy.ARRAY(sqlalchemy.Integer)), + sqlalchemy.Column("bar", sqlalchemy.Integer), + ) + F = sqlalchemy.func + + unnested_col_name = "unnested_foos" + unnested_foos = F.unnest(table2.c.foos).alias(unnested_col_name) + unnested_foo_col = sqlalchemy.Column(unnested_col_name) + + # Set up initial query + cols = [table1.c.foo, table2.c.bar] if select_from_table2 else [table1.c.foo] + q = sqlalchemy.select(cols) if old_syntax else sqlalchemy.select(*cols) + q = q.select_from(unnested_foos.join(table1, table1.c.foo == unnested_foo_col)) + return q + + +@sqlalchemy_before_2_0 +def test_no_implicit_join_asterix_for_inner_unnest_before_2_0(faux_conn, metadata): + # See: https://github.com/googleapis/python-bigquery-sqlalchemy/issues/368 + q = prepare_implicit_join_base_query(faux_conn, metadata, True, False) + expected_initial_sql = ( + "SELECT `table1`.`foo`, `table2`.`bar` \n" + "FROM `table2`, unnest(`table2`.`foos`) AS `unnested_foos` JOIN `table1` ON `table1`.`foo` = `unnested_foos`" + ) + found_initial_sql = q.compile(faux_conn).string + assert found_initial_sql == expected_initial_sql + + q = q.subquery() + q = sqlalchemy.select("*").select_from(q) + + expected_outer_sql = ( + "SELECT * \n" + "FROM (SELECT `table1`.`foo` AS `foo`, `table2`.`bar` AS `bar` \n" + "FROM `table2`, unnest(`table2`.`foos`) AS `unnested_foos` JOIN `table1` ON `table1`.`foo` = `unnested_foos`) AS `anon_1`" + ) + found_outer_sql = q.compile(faux_conn).string + assert found_outer_sql == expected_outer_sql + + +@sqlalchemy_2_0_or_higher +def test_no_implicit_join_asterix_for_inner_unnest(faux_conn, metadata): + # See: https://github.com/googleapis/python-bigquery-sqlalchemy/issues/368 + q = prepare_implicit_join_base_query(faux_conn, metadata, True, False) + expected_initial_sql = ( + "SELECT `table1`.`foo`, `table2`.`bar` \n" + "FROM unnest(`table2`.`foos`) AS `unnested_foos` JOIN `table1` ON `table1`.`foo` = `unnested_foos`, `table2`" + ) + found_initial_sql = q.compile(faux_conn).string + assert found_initial_sql == expected_initial_sql + + q = q.subquery() + q = sqlalchemy.select("*").select_from(q) + + expected_outer_sql = ( + "SELECT * \n" + "FROM (SELECT `table1`.`foo` AS `foo`, `table2`.`bar` AS `bar` \n" + "FROM unnest(`table2`.`foos`) AS `unnested_foos` JOIN `table1` ON `table1`.`foo` = `unnested_foos`, `table2`) AS `anon_1`" + ) + found_outer_sql = q.compile(faux_conn).string + assert found_outer_sql == expected_outer_sql + + +@sqlalchemy_before_2_0 +def test_no_implicit_join_for_inner_unnest_before_2_0(faux_conn, metadata): + # See: https://github.com/googleapis/python-bigquery-sqlalchemy/issues/368 + q = prepare_implicit_join_base_query(faux_conn, metadata, True, False) + expected_initial_sql = ( + "SELECT `table1`.`foo`, `table2`.`bar` \n" + "FROM `table2`, unnest(`table2`.`foos`) AS `unnested_foos` JOIN `table1` ON `table1`.`foo` = `unnested_foos`" + ) + found_initial_sql = q.compile(faux_conn).string + assert found_initial_sql == expected_initial_sql + + q = q.subquery() + q = sqlalchemy.select(q.c.foo).select_from(q) + + expected_outer_sql = ( + "SELECT `anon_1`.`foo` \n" + "FROM (SELECT `table1`.`foo` AS `foo`, `table2`.`bar` AS `bar` \n" + "FROM `table2`, unnest(`table2`.`foos`) AS `unnested_foos` JOIN `table1` ON `table1`.`foo` = `unnested_foos`) AS `anon_1`" + ) + found_outer_sql = q.compile(faux_conn).string + assert found_outer_sql == expected_outer_sql + + +@sqlalchemy_2_0_or_higher +def test_no_implicit_join_for_inner_unnest(faux_conn, metadata): + # See: https://github.com/googleapis/python-bigquery-sqlalchemy/issues/368 + q = prepare_implicit_join_base_query(faux_conn, metadata, True, False) + expected_initial_sql = ( + "SELECT `table1`.`foo`, `table2`.`bar` \n" + "FROM unnest(`table2`.`foos`) AS `unnested_foos` JOIN `table1` ON `table1`.`foo` = `unnested_foos`, `table2`" + ) + found_initial_sql = q.compile(faux_conn).string + assert found_initial_sql == expected_initial_sql + + q = q.subquery() + q = sqlalchemy.select(q.c.foo).select_from(q) + + expected_outer_sql = ( + "SELECT `anon_1`.`foo` \n" + "FROM (SELECT `table1`.`foo` AS `foo`, `table2`.`bar` AS `bar` \n" + "FROM unnest(`table2`.`foos`) AS `unnested_foos` JOIN `table1` ON `table1`.`foo` = `unnested_foos`, `table2`) AS `anon_1`" + ) + found_outer_sql = q.compile(faux_conn).string + assert found_outer_sql == expected_outer_sql + + +def test_no_implicit_join_asterix_for_inner_unnest_no_table2_column( + faux_conn, metadata +): + # See: https://github.com/googleapis/python-bigquery-sqlalchemy/issues/368 + q = prepare_implicit_join_base_query(faux_conn, metadata, False, False) + expected_initial_sql = ( + "SELECT `table1`.`foo` \n" + "FROM `table2` `table2_1`, unnest(`table2_1`.`foos`) AS `unnested_foos` JOIN `table1` ON `table1`.`foo` = `unnested_foos`" + ) + found_initial_sql = q.compile(faux_conn).string + assert found_initial_sql == expected_initial_sql + + q = q.subquery() + q = sqlalchemy.select("*").select_from(q) + + expected_outer_sql = ( + "SELECT * \n" + "FROM (SELECT `table1`.`foo` AS `foo` \n" + "FROM `table2` `table2_1`, unnest(`table2_1`.`foos`) AS `unnested_foos` JOIN `table1` ON `table1`.`foo` = `unnested_foos`) AS `anon_1`" + ) + found_outer_sql = q.compile(faux_conn).string + assert found_outer_sql == expected_outer_sql + + +def test_no_implicit_join_for_inner_unnest_no_table2_column(faux_conn, metadata): + # See: https://github.com/googleapis/python-bigquery-sqlalchemy/issues/368 + q = prepare_implicit_join_base_query(faux_conn, metadata, False, False) + expected_initial_sql = ( + "SELECT `table1`.`foo` \n" + "FROM `table2` `table2_1`, unnest(`table2_1`.`foos`) AS `unnested_foos` JOIN `table1` ON `table1`.`foo` = `unnested_foos`" + ) + found_initial_sql = q.compile(faux_conn).string + assert found_initial_sql == expected_initial_sql + + q = q.subquery() + q = sqlalchemy.select(q.c.foo).select_from(q) + + expected_outer_sql = ( + "SELECT `anon_1`.`foo` \n" + "FROM (SELECT `table1`.`foo` AS `foo` \n" + "FROM `table2` `table2_1`, unnest(`table2_1`.`foos`) AS `unnested_foos` JOIN `table1` ON `table1`.`foo` = `unnested_foos`) AS `anon_1`" + ) + found_outer_sql = q.compile(faux_conn).string + assert found_outer_sql == expected_outer_sql + + +grouping_ops = ( + "grouping_op, grouping_op_func", + [("GROUPING SETS", grouping_sets), ("ROLLUP", rollup), ("CUBE", cube)], +) + + +@pytest.mark.parametrize(*grouping_ops) +def test_grouping_ops_vs_single_column(faux_conn, table, grouping_op, grouping_op_func): + # Tests each of the grouping ops against a single column + + q = sqlalchemy.select(table.c.foo).group_by(grouping_op_func(table.c.foo)) + found_sql = q.compile(faux_conn).string + + expected_sql = ( + f"SELECT `table1`.`foo` \n" + f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`)" + ) + + assert found_sql == expected_sql + + +@pytest.mark.parametrize(*grouping_ops) +def test_grouping_ops_vs_multi_columns(faux_conn, table, grouping_op, grouping_op_func): + # Tests each of the grouping ops against multiple columns + + q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( + grouping_op_func(table.c.foo, table.c.bar) + ) + found_sql = q.compile(faux_conn).string + + expected_sql = ( + f"SELECT `table1`.`foo`, `table1`.`bar` \n" + f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`, `table1`.`bar`)" + ) + + assert found_sql == expected_sql + + +@pytest.mark.parametrize(*grouping_ops) +def test_grouping_op_with_grouping_op(faux_conn, table, grouping_op, grouping_op_func): + # Tests multiple grouping ops in a single statement + + q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( + grouping_op_func(table.c.foo, table.c.bar), grouping_op_func(table.c.foo) + ) + found_sql = q.compile(faux_conn).string + + expected_sql = ( + f"SELECT `table1`.`foo`, `table1`.`bar` \n" + f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`, `table1`.`bar`), {grouping_op}(`table1`.`foo`)" + ) + + assert found_sql == expected_sql + + +@pytest.mark.parametrize(*grouping_ops) +def test_grouping_ops_vs_group_by(faux_conn, table, grouping_op, grouping_op_func): + # Tests grouping op against regular group by statement + + q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( + table.c.foo, grouping_op_func(table.c.bar) + ) + found_sql = q.compile(faux_conn).string + + expected_sql = ( + f"SELECT `table1`.`foo`, `table1`.`bar` \n" + f"FROM `table1` GROUP BY `table1`.`foo`, {grouping_op}(`table1`.`bar`)" + ) + + assert found_sql == expected_sql + + +@pytest.mark.parametrize(*grouping_ops) +def test_complex_grouping_ops_vs_nested_grouping_ops( + faux_conn, table, grouping_op, grouping_op_func +): + # Tests grouping ops nested within grouping ops + + q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( + grouping_sets(table.c.foo, grouping_op_func(table.c.bar)) + ) + found_sql = q.compile(faux_conn).string + + expected_sql = ( + f"SELECT `table1`.`foo`, `table1`.`bar` \n" + f"FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, {grouping_op}(`table1`.`bar`))" + ) + + assert found_sql == expected_sql diff --git a/tests/unit/test_compliance.py b/tests/unit/test_compliance.py index fd1fbb83..bd90d936 100644 --- a/tests/unit/test_compliance.py +++ b/tests/unit/test_compliance.py @@ -27,7 +27,7 @@ from sqlalchemy import Column, Integer, literal_column, select, String, Table, union from sqlalchemy.testing.assertions import eq_, in_ -from .conftest import setup_table, sqlalchemy_1_3_or_higher +from .conftest import setup_table def assert_result(connection, sel, expected, params=()): @@ -52,8 +52,8 @@ def some_table(connection): def test_distinct_selectable_in_unions(faux_conn): table = some_table(faux_conn) - s1 = select([table]).where(table.c.id == 2).distinct() - s2 = select([table]).where(table.c.id == 3).distinct() + s1 = select(table).where(table.c.id == 2).distinct() + s2 = select(table).where(table.c.id == 3).distinct() u1 = union(s1, s2).limit(2) assert_result(faux_conn, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) @@ -62,7 +62,7 @@ def test_distinct_selectable_in_unions(faux_conn): def test_limit_offset_aliased_selectable_in_unions(faux_conn): table = some_table(faux_conn) s1 = ( - select([table]) + select(table) .where(table.c.id == 2) .limit(1) .order_by(table.c.id) @@ -70,7 +70,7 @@ def test_limit_offset_aliased_selectable_in_unions(faux_conn): .select() ) s2 = ( - select([table]) + select(table) .where(table.c.id == 3) .limit(1) .order_by(table.c.id) @@ -93,27 +93,24 @@ def test_percent_sign_round_trip(faux_conn, metadata): faux_conn.execute(t.insert(), dict(data="some %% other value")) eq_( faux_conn.scalar( - select([t.c.data]).where(t.c.data == literal_column("'some % value'")) + select(t.c.data).where(t.c.data == literal_column("'some % value'")) ), "some % value", ) eq_( faux_conn.scalar( - select([t.c.data]).where( - t.c.data == literal_column("'some %% other value'") - ) + select(t.c.data).where(t.c.data == literal_column("'some %% other value'")) ), "some %% other value", ) -@sqlalchemy_1_3_or_higher def test_empty_set_against_integer(faux_conn): table = some_table(faux_conn) stmt = ( - select([table.c.id]) + select(table.c.id) .where(table.c.x.in_(sqlalchemy.bindparam("q", expanding=True))) .order_by(table.c.id) ) @@ -121,22 +118,17 @@ def test_empty_set_against_integer(faux_conn): assert_result(faux_conn, stmt, [], params={"q": []}) -@sqlalchemy_1_3_or_higher def test_null_in_empty_set_is_false(faux_conn): stmt = select( - [ - sqlalchemy.case( - [ - ( - sqlalchemy.null().in_( - sqlalchemy.bindparam("foo", value=(), expanding=True) - ), - sqlalchemy.true(), - ) - ], - else_=sqlalchemy.false(), - ) - ] + sqlalchemy.case( + ( + sqlalchemy.null().in_( + sqlalchemy.bindparam("foo", value=(), expanding=True) + ), + sqlalchemy.true(), + ), + else_=sqlalchemy.false(), + ) ) in_(faux_conn.execute(stmt).fetchone()[0], (False, 0)) @@ -170,12 +162,12 @@ def test_likish(faux_conn, meth, arg, expected): ], ) expr = getattr(table.c.data, meth)(arg) - rows = {value for value, in faux_conn.execute(select([table.c.id]).where(expr))} + rows = {value for value, in faux_conn.execute(select(table.c.id).where(expr))} eq_(rows, expected) all = {i for i in range(1, 11)} expr = sqlalchemy.not_(expr) - rows = {value for value, in faux_conn.execute(select([table.c.id]).where(expr))} + rows = {value for value, in faux_conn.execute(select(table.c.id).where(expr))} eq_(rows, all - expected) @@ -196,9 +188,7 @@ def test_group_by_composed(faux_conn): ) expr = (table.c.x + table.c.y).label("lx") - stmt = ( - select([sqlalchemy.func.count(table.c.id), expr]).group_by(expr).order_by(expr) - ) + stmt = select(sqlalchemy.func.count(table.c.id), expr).group_by(expr).order_by(expr) assert_result(faux_conn, stmt, [(1, 3), (1, 5), (1, 7)]) diff --git a/tests/unit/test_geography.py b/tests/unit/test_geography.py index 6924ade0..93b7eb37 100644 --- a/tests/unit/test_geography.py +++ b/tests/unit/test_geography.py @@ -76,7 +76,7 @@ def test_geoalchemy2_core(faux_conn, last_query): from sqlalchemy.sql import select try: - conn.execute(select([lake_table])) + conn.execute(select(lake_table)) except Exception: pass # sqlite had no special functions :) last_query( @@ -89,8 +89,8 @@ def test_geoalchemy2_core(faux_conn, last_query): try: conn.execute( - select( - [lake_table.c.name], func.ST_Contains(lake_table.c.geog, "POINT(4 1)") + select(lake_table.c.name).where( + func.ST_Contains(lake_table.c.geog, "POINT(4 1)") ) ) except Exception: @@ -104,7 +104,7 @@ def test_geoalchemy2_core(faux_conn, last_query): try: conn.execute( - select([lake_table.c.name, lake_table.c.geog.ST_Area().label("area")]) + select(lake_table.c.name, lake_table.c.geog.ST_Area().label("area")) ) except Exception: pass # sqlite had no special functions :) @@ -171,7 +171,7 @@ def test_calling_st_functions_that_dont_take_geographies(faux_conn, last_query): from sqlalchemy import select, func try: - faux_conn.execute(select([func.ST_GeogFromText("point(0 0)")])) + faux_conn.execute(select(func.ST_GeogFromText("point(0 0)"))) except Exception: pass # sqlite had no special functions :) diff --git a/tests/unit/test_select.py b/tests/unit/test_select.py index ee5e01cb..ad80047a 100644 --- a/tests/unit/test_select.py +++ b/tests/unit/test_select.py @@ -20,25 +20,18 @@ import datetime from decimal import Decimal -import packaging.version import pytest import sqlalchemy from sqlalchemy import not_ import sqlalchemy_bigquery -from .conftest import ( - setup_table, - sqlalchemy_version, - sqlalchemy_1_3_or_higher, - sqlalchemy_1_4_or_higher, - sqlalchemy_before_1_4, -) +from .conftest import setup_table def test_labels_not_forced(faux_conn): table = setup_table(faux_conn, "t", sqlalchemy.Column("id", sqlalchemy.Integer)) - result = faux_conn.execute(sqlalchemy.select([table.c.id])) + result = faux_conn.execute(sqlalchemy.select(table.c.id)) assert result.keys() == ["id"] # Look! Just the column name! @@ -154,14 +147,18 @@ def test_typed_parameters(faux_conn, type_, val, btype, vrep): {}, ) - assert list(map(list, faux_conn.execute(sqlalchemy.select([table])))) == [[val]] * 2 + assert list(map(list, faux_conn.execute(sqlalchemy.select(table)))) == [[val]] * 2 assert faux_conn.test_data["execute"][-1][0] == "SELECT `t`.`foo` \nFROM `t`" assert ( list( map( list, - faux_conn.execute(sqlalchemy.select([table.c.foo], use_labels=True)), + faux_conn.execute( + sqlalchemy.select(table.c.foo).set_label_style( + sqlalchemy.LABEL_STYLE_TABLENAME_PLUS_COL + ) + ), ) ) == [[val]] * 2 @@ -183,7 +180,7 @@ def test_select_struct(faux_conn, metadata): faux_conn.ex("create table t (x RECORD)") faux_conn.ex("""insert into t values ('{"y": 1}')""") - row = list(faux_conn.execute(sqlalchemy.select([table])))[0] + row = list(faux_conn.execute(sqlalchemy.select(table)))[0] # We expect the raw string, because sqlite3, unlike BigQuery # doesn't deserialize for us. assert row.x == '{"y": 1}' @@ -191,7 +188,7 @@ def test_select_struct(faux_conn, metadata): def test_select_label_starts_w_digit(faux_conn): # Make sure label names are legal identifiers - faux_conn.execute(sqlalchemy.select([sqlalchemy.literal(1).label("2foo")])) + faux_conn.execute(sqlalchemy.select(sqlalchemy.literal(1).label("2foo"))) assert ( faux_conn.test_data["execute"][-1][0] == "SELECT %(param_1:INT64)s AS `_2foo`" ) @@ -205,7 +202,7 @@ def test_force_quote(faux_conn): "t", sqlalchemy.Column(quoted_name("foo", True), sqlalchemy.Integer), ) - faux_conn.execute(sqlalchemy.select([table])) + faux_conn.execute(sqlalchemy.select(table)) assert faux_conn.test_data["execute"][-1][0] == ("SELECT `t`.`foo` \nFROM `t`") @@ -217,26 +214,12 @@ def test_disable_quote(faux_conn): "t", sqlalchemy.Column(quoted_name("foo", False), sqlalchemy.Integer), ) - faux_conn.execute(sqlalchemy.select([table])) + faux_conn.execute(sqlalchemy.select(table)) assert faux_conn.test_data["execute"][-1][0] == ("SELECT `t`.foo \nFROM `t`") -@sqlalchemy_before_1_4 -def test_select_in_lit_13(faux_conn): - [[isin]] = faux_conn.execute( - sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])]) - ) - assert isin - assert faux_conn.test_data["execute"][-1] == ( - "SELECT %(param_1:INT64)s IN " - "(%(param_2:INT64)s, %(param_3:INT64)s, %(param_4:INT64)s) AS `anon_1`", - {"param_1": 1, "param_2": 1, "param_3": 2, "param_4": 3}, - ) - - -@sqlalchemy_1_4_or_higher def test_select_in_lit(faux_conn, last_query): - faux_conn.execute(sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])])) + faux_conn.execute(sqlalchemy.select(sqlalchemy.literal(1).in_([1, 2, 3]))) last_query( "SELECT %(param_1:INT64)s IN UNNEST(%(param_2:INT64)s) AS `anon_1`", {"param_1": 1, "param_2": [1, 2, 3]}, @@ -244,83 +227,47 @@ def test_select_in_lit(faux_conn, last_query): def test_select_in_param(faux_conn, last_query): - [[isin]] = faux_conn.execute( + faux_conn.execute( sqlalchemy.select( - [sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))] + sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True)) ), dict(q=[1, 2, 3]), ) - if sqlalchemy_version >= packaging.version.parse("1.4"): - last_query( - "SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`", - {"param_1": 1, "q": [1, 2, 3]}, - ) - else: - assert isin - last_query( - "SELECT %(param_1:INT64)s IN UNNEST(" - "[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]" - ") AS `anon_1`", - {"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3}, - ) + + last_query( + "SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`", + {"param_1": 1, "q": [1, 2, 3]}, + ) def test_select_in_param1(faux_conn, last_query): - [[isin]] = faux_conn.execute( + faux_conn.execute( sqlalchemy.select( - [sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))] + sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True)) ), dict(q=[1]), ) - if sqlalchemy_version >= packaging.version.parse("1.4"): - last_query( - "SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`", - {"param_1": 1, "q": [1]}, - ) - else: - assert isin - last_query( - "SELECT %(param_1:INT64)s IN UNNEST(" "[ %(q_1:INT64)s ]" ") AS `anon_1`", - {"param_1": 1, "q_1": 1}, - ) + last_query( + "SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`", + {"param_1": 1, "q": [1]}, + ) -@sqlalchemy_1_3_or_higher def test_select_in_param_empty(faux_conn, last_query): - [[isin]] = faux_conn.execute( + faux_conn.execute( sqlalchemy.select( - [sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))] + sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True)) ), dict(q=[]), ) - if sqlalchemy_version >= packaging.version.parse("1.4"): - last_query( - "SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`", - {"param_1": 1, "q": []}, - ) - else: - assert not isin - last_query( - "SELECT %(param_1:INT64)s IN UNNEST([ ]) AS `anon_1`", {"param_1": 1} - ) - - -@sqlalchemy_before_1_4 -def test_select_notin_lit13(faux_conn): - [[isnotin]] = faux_conn.execute( - sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])]) - ) - assert isnotin - assert faux_conn.test_data["execute"][-1] == ( - "SELECT (%(param_1:INT64)s NOT IN " - "(%(param_2:INT64)s, %(param_3:INT64)s, %(param_4:INT64)s)) AS `anon_1`", - {"param_1": 0, "param_2": 1, "param_3": 2, "param_4": 3}, + last_query( + "SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`", + {"param_1": 1, "q": []}, ) -@sqlalchemy_1_4_or_higher def test_select_notin_lit(faux_conn, last_query): - faux_conn.execute(sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])])) + faux_conn.execute(sqlalchemy.select(sqlalchemy.literal(0).notin_([1, 2, 3]))) last_query( "SELECT (%(param_1:INT64)s NOT IN UNNEST(%(param_2:INT64)s)) AS `anon_1`", {"param_1": 0, "param_2": [1, 2, 3]}, @@ -328,45 +275,29 @@ def test_select_notin_lit(faux_conn, last_query): def test_select_notin_param(faux_conn, last_query): - [[isnotin]] = faux_conn.execute( + faux_conn.execute( sqlalchemy.select( - [sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))] + sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True)) ), dict(q=[1, 2, 3]), ) - if sqlalchemy_version >= packaging.version.parse("1.4"): - last_query( - "SELECT (%(param_1:INT64)s NOT IN UNNEST(%(q:INT64)s)) AS `anon_1`", - {"param_1": 1, "q": [1, 2, 3]}, - ) - else: - assert not isnotin - last_query( - "SELECT (%(param_1:INT64)s NOT IN UNNEST(" - "[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]" - ")) AS `anon_1`", - {"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3}, - ) + last_query( + "SELECT (%(param_1:INT64)s NOT IN UNNEST(%(q:INT64)s)) AS `anon_1`", + {"param_1": 1, "q": [1, 2, 3]}, + ) -@sqlalchemy_1_3_or_higher def test_select_notin_param_empty(faux_conn, last_query): - [[isnotin]] = faux_conn.execute( + faux_conn.execute( sqlalchemy.select( - [sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))] + sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True)) ), dict(q=[]), ) - if sqlalchemy_version >= packaging.version.parse("1.4"): - last_query( - "SELECT (%(param_1:INT64)s NOT IN UNNEST(%(q:INT64)s)) AS `anon_1`", - {"param_1": 1, "q": []}, - ) - else: - assert isnotin - last_query( - "SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`", {"param_1": 1} - ) + last_query( + "SELECT (%(param_1:INT64)s NOT IN UNNEST(%(q:INT64)s)) AS `anon_1`", + {"param_1": 1, "q": []}, + ) def test_literal_binds_kwarg_with_an_IN_operator_252(faux_conn): @@ -376,7 +307,7 @@ def test_literal_binds_kwarg_with_an_IN_operator_252(faux_conn): sqlalchemy.Column("val", sqlalchemy.Integer), initial_data=[dict(val=i) for i in range(3)], ) - q = sqlalchemy.select([table.c.val]).where(table.c.val.in_([2])) + q = sqlalchemy.select(table.c.val).where(table.c.val.in_([2])) def nstr(q): return " ".join(str(q).strip().split()) @@ -387,7 +318,6 @@ def nstr(q): ) -@sqlalchemy_1_4_or_higher @pytest.mark.parametrize("alias", [True, False]) def test_unnest(faux_conn, alias): from sqlalchemy import String @@ -405,7 +335,6 @@ def test_unnest(faux_conn, alias): ) -@sqlalchemy_1_4_or_higher @pytest.mark.parametrize("alias", [True, False]) def test_table_valued_alias_w_multiple_references_to_the_same_table(faux_conn, alias): from sqlalchemy import String @@ -424,7 +353,6 @@ def test_table_valued_alias_w_multiple_references_to_the_same_table(faux_conn, a ) -@sqlalchemy_1_4_or_higher @pytest.mark.parametrize("alias", [True, False]) def test_unnest_w_no_table_references(faux_conn, alias): fcall = sqlalchemy.func.unnest([1, 2, 3]) @@ -444,14 +372,10 @@ def test_array_indexing(faux_conn, metadata): metadata, sqlalchemy.Column("a", sqlalchemy.ARRAY(sqlalchemy.String)), ) - got = str(sqlalchemy.select([t.c.a[0]]).compile(faux_conn.engine)) + got = str(sqlalchemy.select(t.c.a[0]).compile(faux_conn.engine)) assert got == "SELECT `t`.`a`[OFFSET(%(a_1:INT64)s)] AS `anon_1` \nFROM `t`" -@pytest.mark.skipif( - packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"), - reason="regexp_match support requires version 1.4 or higher", -) def test_visit_regexp_match_op_binary(faux_conn): table = setup_table( faux_conn, @@ -468,10 +392,6 @@ def test_visit_regexp_match_op_binary(faux_conn): assert result == expected -@pytest.mark.skipif( - packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"), - reason="regexp_match support requires version 1.4 or higher", -) def test_visit_not_regexp_match_op_binary(faux_conn): table = setup_table( faux_conn, diff --git a/tests/unit/test_sqlalchemy_bigquery.py b/tests/unit/test_sqlalchemy_bigquery.py index 06ef79d2..db20e2f0 100644 --- a/tests/unit/test_sqlalchemy_bigquery.py +++ b/tests/unit/test_sqlalchemy_bigquery.py @@ -10,7 +10,6 @@ from google.cloud import bigquery from google.cloud.bigquery.dataset import DatasetListItem from google.cloud.bigquery.table import TableListItem -import packaging.version import pytest import sqlalchemy @@ -98,7 +97,7 @@ def test_get_table_names( ): mock_bigquery_client.list_datasets.return_value = datasets_list mock_bigquery_client.list_tables.side_effect = tables_lists - table_names = engine_under_test.table_names() + table_names = sqlalchemy.inspect(engine_under_test).get_table_names() mock_bigquery_client.list_datasets.assert_called_once() assert mock_bigquery_client.list_tables.call_count == len(datasets_list) assert list(sorted(table_names)) == list(sorted(expected)) @@ -227,12 +226,7 @@ def test_unnest_function(args, kw): f = sqlalchemy.func.unnest(*args, **kw) assert isinstance(f.type, sqlalchemy.String) - if packaging.version.parse(sqlalchemy.__version__) >= packaging.version.parse( - "1.4" - ): - assert isinstance( - sqlalchemy.select([f]).subquery().c.unnest.type, sqlalchemy.String - ) + assert isinstance(sqlalchemy.select(f).subquery().c.unnest.type, sqlalchemy.String) @mock.patch("sqlalchemy_bigquery._helpers.create_bigquery_client") diff --git a/tests/unit/test_table_options.py b/tests/unit/test_table_options.py new file mode 100644 index 00000000..2147fb1d --- /dev/null +++ b/tests/unit/test_table_options.py @@ -0,0 +1,474 @@ +# Copyright (c) 2021 The sqlalchemy-bigquery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import datetime +import sqlite3 +import pytest +import sqlalchemy + +from google.cloud.bigquery import ( + PartitionRange, + RangePartitioning, + TimePartitioning, + TimePartitioningType, +) + +from .conftest import setup_table + + +def test_table_expiration_timestamp_dialect_option(faux_conn): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("createdAt", sqlalchemy.DateTime), + bigquery_expiration_timestamp=datetime.datetime.fromisoformat( + "2038-01-01T00:00:00+00:00" + ), + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `createdAt` DATETIME )" + " OPTIONS(expiration_timestamp=TIMESTAMP '2038-01-01 00:00:00+00:00')" + ) + + +def test_table_default_rounding_mode_dialect_option(faux_conn): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("createdAt", sqlalchemy.DateTime), + bigquery_default_rounding_mode="ROUND_HALF_EVEN", + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `createdAt` DATETIME )" + " OPTIONS(default_rounding_mode='ROUND_HALF_EVEN')" + ) + + +def test_table_clustering_fields_dialect_option_no_such_column(faux_conn): + with pytest.raises(sqlalchemy.exc.NoSuchColumnError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("createdAt", sqlalchemy.DateTime), + bigquery_clustering_fields=["country", "unknown"], + ) + + +def test_table_clustering_fields_dialect_option(faux_conn): + # expect table creation to fail as SQLite does not support clustering + with pytest.raises(sqlite3.OperationalError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("country", sqlalchemy.Text), + sqlalchemy.Column("town", sqlalchemy.Text), + bigquery_clustering_fields=["country", "town"], + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `id` INT64, `country` STRING, `town` STRING )" + " CLUSTER BY country, town" + ) + + +def test_table_clustering_fields_dialect_option_type_error(faux_conn): + # expect TypeError when bigquery_clustering_fields is not a list + with pytest.raises(TypeError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("country", sqlalchemy.Text), + sqlalchemy.Column("town", sqlalchemy.Text), + bigquery_clustering_fields="country, town", + ) + + +def test_table_time_partitioning_dialect_option(faux_conn): + # expect table creation to fail as SQLite does not support partitioned tables + with pytest.raises(sqlite3.OperationalError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("createdAt", sqlalchemy.DateTime), + bigquery_time_partitioning=TimePartitioning(), + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `id` INT64, `createdAt` DATETIME )" + " PARTITION BY DATE_TRUNC(_PARTITIONDATE, DAY)" + ) + + +def test_table_require_partition_filter_dialect_option(faux_conn): + # expect table creation to fail as SQLite does not support partitioned tables + with pytest.raises(sqlite3.OperationalError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("createdAt", sqlalchemy.DateTime), + bigquery_time_partitioning=TimePartitioning(field="createdAt"), + bigquery_require_partition_filter=True, + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `createdAt` DATETIME )" + " PARTITION BY DATE_TRUNC(createdAt, DAY)" + " OPTIONS(require_partition_filter=true)" + ) + + +def test_table_time_partitioning_with_field_dialect_option(faux_conn): + # expect table creation to fail as SQLite does not support partitioned tables + with pytest.raises(sqlite3.OperationalError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("createdAt", sqlalchemy.DateTime), + bigquery_time_partitioning=TimePartitioning(field="createdAt"), + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `id` INT64, `createdAt` DATETIME )" + " PARTITION BY DATE_TRUNC(createdAt, DAY)" + ) + + +def test_table_time_partitioning_by_month_dialect_option(faux_conn): + # expect table creation to fail as SQLite does not support partitioned tables + with pytest.raises(sqlite3.OperationalError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("createdAt", sqlalchemy.DateTime), + bigquery_time_partitioning=TimePartitioning( + field="createdAt", + type_=TimePartitioningType.MONTH, + ), + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `id` INT64, `createdAt` DATETIME )" + " PARTITION BY DATE_TRUNC(createdAt, MONTH)" + ) + + +def test_table_time_partitioning_with_timestamp_dialect_option(faux_conn): + # expect table creation to fail as SQLite does not support partitioned tables + with pytest.raises(sqlite3.OperationalError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("createdAt", sqlalchemy.TIMESTAMP), + bigquery_time_partitioning=TimePartitioning(field="createdAt"), + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `id` INT64, `createdAt` TIMESTAMP )" + " PARTITION BY TIMESTAMP_TRUNC(createdAt, DAY)" + ) + + +def test_table_time_partitioning_dialect_option_partition_expiration_days(faux_conn): + # expect table creation to fail as SQLite does not support partitioned tables + with pytest.raises(sqlite3.OperationalError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("createdAt", sqlalchemy.DateTime), + bigquery_time_partitioning=TimePartitioning( + field="createdAt", + type_="DAY", + expiration_ms=21600000, + ), + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `createdAt` DATETIME )" + " PARTITION BY DATE_TRUNC(createdAt, DAY)" + " OPTIONS(partition_expiration_days=0.25)" + ) + + +def test_table_partitioning_dialect_option_type_error(faux_conn): + # expect TypeError when bigquery_time_partitioning is not a TimePartitioning object + with pytest.raises(TypeError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("createdAt", sqlalchemy.DateTime), + bigquery_time_partitioning="DATE(createdAt)", + ) + + +def test_table_range_partitioning_dialect_option(faux_conn): + # expect table creation to fail as SQLite does not support partitioned tables + with pytest.raises(sqlite3.OperationalError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("zipcode", sqlalchemy.INT), + bigquery_range_partitioning=RangePartitioning( + field="zipcode", + range_=PartitionRange( + start=0, + end=100000, + interval=2, + ), + ), + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `id` INT64, `zipcode` INT64 )" + " PARTITION BY RANGE_BUCKET(zipcode, GENERATE_ARRAY(0, 100000, 2))" + ) + + +def test_table_range_partitioning_dialect_option_no_field(faux_conn): + # expect TypeError when bigquery_range_partitioning field is not defined + with pytest.raises( + AttributeError, + match="bigquery_range_partitioning expects field to be defined", + ): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("zipcode", sqlalchemy.FLOAT), + bigquery_range_partitioning=RangePartitioning( + range_=PartitionRange( + start=0, + end=100000, + interval=10, + ), + ), + ) + + +def test_table_range_partitioning_dialect_option_bad_column_type(faux_conn): + # expect ValueError when bigquery_range_partitioning field is not an INTEGER + with pytest.raises( + ValueError, + match=r"bigquery_range_partitioning expects field \(i\.e\. column\) data type to be INTEGER", + ): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("zipcode", sqlalchemy.FLOAT), + bigquery_range_partitioning=RangePartitioning( + field="zipcode", + range_=PartitionRange( + start=0, + end=100000, + interval=10, + ), + ), + ) + + +def test_table_range_partitioning_dialect_option_range_missing(faux_conn): + # expect TypeError when bigquery_range_partitioning range start or end is missing + with pytest.raises( + TypeError, + match="bigquery_range_partitioning expects range_.start to be an int, provided None", + ): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("zipcode", sqlalchemy.INT), + bigquery_range_partitioning=RangePartitioning(field="zipcode"), + ) + + with pytest.raises( + TypeError, + match="bigquery_range_partitioning expects range_.end to be an int, provided None", + ): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("zipcode", sqlalchemy.INT), + bigquery_range_partitioning=RangePartitioning( + field="zipcode", + range_=PartitionRange(start=1), + ), + ) + + +def test_table_range_partitioning_dialect_option_default_interval(faux_conn): + # expect table creation to fail as SQLite does not support partitioned tables + with pytest.raises(sqlite3.OperationalError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("zipcode", sqlalchemy.INT), + bigquery_range_partitioning=RangePartitioning( + field="zipcode", + range_=PartitionRange( + start=0, + end=100000, + ), + ), + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `id` INT64, `zipcode` INT64 )" + " PARTITION BY RANGE_BUCKET(zipcode, GENERATE_ARRAY(0, 100000, 1))" + ) + + +def test_time_and_range_partitioning_mutually_exclusive(faux_conn): + # expect ValueError when both bigquery_time_partitioning and bigquery_range_partitioning are provided + with pytest.raises(ValueError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("createdAt", sqlalchemy.DateTime), + bigquery_range_partitioning=RangePartitioning(), + bigquery_time_partitioning=TimePartitioning(), + ) + + +def test_table_all_dialect_option(faux_conn): + # expect table creation to fail as SQLite does not support clustering and partitioned tables + with pytest.raises(sqlite3.OperationalError): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("country", sqlalchemy.Text), + sqlalchemy.Column("town", sqlalchemy.Text), + sqlalchemy.Column("createdAt", sqlalchemy.DateTime), + bigquery_expiration_timestamp=datetime.datetime.fromisoformat( + "2038-01-01T00:00:00+00:00" + ), + bigquery_require_partition_filter=True, + bigquery_default_rounding_mode="ROUND_HALF_EVEN", + bigquery_clustering_fields=["country", "town"], + bigquery_time_partitioning=TimePartitioning( + field="createdAt", + type_="DAY", + expiration_ms=2592000000, + ), + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `id` INT64, `country` STRING, `town` STRING, `createdAt` DATETIME )" + " PARTITION BY DATE_TRUNC(createdAt, DAY)" + " CLUSTER BY country, town" + " OPTIONS(partition_expiration_days=30.0, expiration_timestamp=TIMESTAMP '2038-01-01 00:00:00+00:00', require_partition_filter=true, default_rounding_mode='ROUND_HALF_EVEN')" + ) + + +def test_validate_friendly_name_value_type(ddl_compiler): + # expect option value to be transformed as a string expression + + assert ddl_compiler._validate_option_value_type("friendly_name", "Friendly name") + + with pytest.raises(TypeError): + ddl_compiler._validate_option_value_type("friendly_name", 1983) + + +def test_validate_expiration_timestamp_value_type(ddl_compiler): + # expect option value to be transformed as a timestamp expression + + assert ddl_compiler._validate_option_value_type( + "expiration_timestamp", + datetime.datetime.fromisoformat("2038-01-01T00:00:00+00:00"), + ) + + with pytest.raises(TypeError): + ddl_compiler._validate_option_value_type("expiration_timestamp", "2038-01-01") + + +def test_validate_require_partition_filter_type(ddl_compiler): + # expect option value to be transformed as a literal boolean + + assert ddl_compiler._validate_option_value_type("require_partition_filter", True) + assert ddl_compiler._validate_option_value_type("require_partition_filter", False) + + with pytest.raises(TypeError): + ddl_compiler._validate_option_value_type("require_partition_filter", "true") + + with pytest.raises(TypeError): + ddl_compiler._validate_option_value_type("require_partition_filter", "false") + + +def test_validate_default_rounding_mode_type(ddl_compiler): + # expect option value to be transformed as a string expression + + assert ddl_compiler._validate_option_value_type( + "default_rounding_mode", "ROUND_HALF_EVEN" + ) + + with pytest.raises(TypeError): + ddl_compiler._validate_option_value_type("default_rounding_mode", True) + + +def test_validate_unmapped_option_type(ddl_compiler): + # expect option value with no typed specified in mapping to be transformed as a string expression + + assert ddl_compiler._validate_option_value_type("unknown", "DEFAULT_IS_STRING") + + +def test_process_str_option_value(ddl_compiler): + # expect string to be transformed as a string expression + assert ddl_compiler._process_option_value("Some text") == "'Some text'" + + +def test_process_datetime_value(ddl_compiler): + # expect datetime object to be transformed as a timestamp expression + assert ( + ddl_compiler._process_option_value( + datetime.datetime.fromisoformat("2038-01-01T00:00:00+00:00") + ) + == "TIMESTAMP '2038-01-01 00:00:00+00:00'" + ) + + +def test_process_int_option_value(ddl_compiler): + # expect int to be unchanged + assert ddl_compiler._process_option_value(90) == 90 + + +def test_process_boolean_option_value(ddl_compiler): + # expect boolean to be transformed as a literal boolean expression + + assert ddl_compiler._process_option_value(True) == "true" + assert ddl_compiler._process_option_value(False) == "false" + + +def test_process_not_implementer_option_value(ddl_compiler): + # expect to raise + with pytest.raises(NotImplementedError): + ddl_compiler._process_option_value(float) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy