From 048af73f330958908380d5d7aa60f1aba0275961 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 30 May 2025 11:33:04 +0530 Subject: [PATCH 1/5] Enhance Arrow to Pandas conversion with type overrides and additional kwargs * Introduced _arrow_pandas_type_override and _arrow_to_pandas_kwargs in Connection class for customizable dtype mapping and DataFrame construction parameters. * Updated ResultSet to utilize these new options during conversion from Arrow tables to Pandas DataFrames. * Added unit tests to validate the new functionality, including scenarios for type overrides and additional kwargs handling. --- src/databricks/sql/client.py | 21 +++-- tests/unit/test_arrow_conversion.py | 128 ++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+), 6 deletions(-) create mode 100644 tests/unit/test_arrow_conversion.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0c9a08a85..0b17ae7d7 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -213,6 +213,11 @@ def read(self) -> Optional[OAuthToken]: # (True by default) # use_cloud_fetch # Enable use of cloud fetch to extract large query results in parallel via cloud storage + # _arrow_pandas_type_override + # Override the default pandas dtype mapping for Arrow types. + # This is a dictionary of Arrow types to pandas dtypes. + # _arrow_to_pandas_kwargs + # Additional or modified arguments to pass to pandas.DataFrame constructor. logger.debug( "Connection.__init__(server_hostname=%s, http_path=%s)", @@ -1346,7 +1351,7 @@ def _convert_arrow_table(self, table): # Need to use nullable types, as otherwise type can change when there are missing values. # See https://arrow.apache.org/docs/python/pandas.html#nullable-types # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { + DEFAULT_DTYPE_MAPPING: Dict[pyarrow.DataType, pandas.api.extensions.ExtensionDtype] = { pyarrow.int8(): pandas.Int8Dtype(), pyarrow.int16(): pandas.Int16Dtype(), pyarrow.int32(): pandas.Int32Dtype(), @@ -1360,14 +1365,18 @@ def _convert_arrow_table(self, table): pyarrow.float64(): pandas.Float64Dtype(), pyarrow.string(): pandas.StringDtype(), } + dtype_mapping = {**DEFAULT_DTYPE_MAPPING, **self.connection._arrow_pandas_type_override} + + to_pandas_kwargs: dict[str, Any] = { + "types_mapper": dtype_mapping.get, + "date_as_object": True, + "timestamp_as_object": True, + } + to_pandas_kwargs.update(self.connection._arrow_to_pandas_kwargs) # Need to rename columns, as the to_pandas function cannot handle duplicate column names table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) + df = table_renamed.to_pandas(**to_pandas_kwargs) res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] diff --git a/tests/unit/test_arrow_conversion.py b/tests/unit/test_arrow_conversion.py new file mode 100644 index 000000000..b673d0dfa --- /dev/null +++ b/tests/unit/test_arrow_conversion.py @@ -0,0 +1,128 @@ +import pytest +import pyarrow +import pandas +import datetime +from unittest.mock import MagicMock, patch + +from databricks.sql.client import ResultSet, Connection, ExecuteResponse +from databricks.sql.types import Row +from databricks.sql.utils import ArrowQueue + + +@pytest.fixture +def mock_connection(): + conn = MagicMock(spec=Connection) + conn.disable_pandas = False + conn._arrow_pandas_type_override = {} + conn._arrow_to_pandas_kwargs = {} + if not hasattr(conn, '_arrow_to_pandas_kwargs'): + conn._arrow_to_pandas_kwargs = {} + return conn + +@pytest.fixture +def mock_thrift_backend(sample_arrow_table): + tb = MagicMock() + empty_arrays = [pyarrow.array([], type=field.type) for field in sample_arrow_table.schema] + empty_table = pyarrow.Table.from_arrays(empty_arrays, schema=sample_arrow_table.schema) + tb.fetch_results.return_value = (ArrowQueue(empty_table, 0) , False) + return tb + +@pytest.fixture +def mock_raw_execute_response(): + er = MagicMock(spec=ExecuteResponse) + er.description = [("col_int", "int", None, None, None, None, None), + ("col_str", "string", None, None, None, None, None)] + er.arrow_schema_bytes = None + er.arrow_queue = None + er.has_more_rows = False + er.lz4_compressed = False + er.command_handle = MagicMock() + er.status = MagicMock() + er.has_been_closed_server_side = False + er.is_staging_operation = False + return er + +@pytest.fixture +def sample_arrow_table(): + data = [ + pyarrow.array([1, 2, 3], type=pyarrow.int32()), + pyarrow.array(["a", "b", "c"], type=pyarrow.string()) + ] + schema = pyarrow.schema([ + ('col_int', pyarrow.int32()), + ('col_str', pyarrow.string()) + ]) + return pyarrow.Table.from_arrays(data, schema=schema) + + +def test_convert_arrow_table_default(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table): + mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_one = rs.fetchone() + assert isinstance(result_one, Row) + assert result_one.col_int == 1 + assert result_one.col_str == "a" + mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_all = rs.fetchall() + assert len(result_all) == 3 + assert isinstance(result_all[0], Row) + assert result_all[0].col_int == 1 + assert result_all[1].col_str == "b" + + +def test_convert_arrow_table_disable_pandas(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table): + mock_connection.disable_pandas = True + mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result = rs.fetchall() + assert len(result) == 3 + assert isinstance(result[0], Row) + assert result[0].col_int == 1 + assert result[0].col_str == "a" + assert isinstance(sample_arrow_table.column(0)[0].as_py(), int) + assert isinstance(sample_arrow_table.column(1)[0].as_py(), str) + + +def test_convert_arrow_table_type_override(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table): + mock_connection._arrow_pandas_type_override = {pyarrow.int32(): pandas.Float64Dtype()} + mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result = rs.fetchall() + assert len(result) == 3 + assert isinstance(result[0].col_int, float) + assert result[0].col_int == 1.0 + assert result[0].col_str == "a" + + +def test_convert_arrow_table_to_pandas_kwargs(mock_connection, mock_thrift_backend, mock_raw_execute_response): + dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + ts_array = pyarrow.array([dt_obj], type=pyarrow.timestamp('us', tz='UTC')) + ts_schema = pyarrow.schema([('col_ts', pyarrow.timestamp('us', tz='UTC'))]) + ts_table = pyarrow.Table.from_arrays([ts_array], schema=ts_schema) + + mock_raw_execute_response.description = [("col_ts", "timestamp", None, None, None, None, None)] + mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) + + # Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True} + rs_ts_true = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_true = rs_ts_true.fetchall() + assert len(result_true) == 1 + assert isinstance(result_true[0].col_ts, datetime.datetime) + + # Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. + mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} + rs_ts_false = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_false = rs_ts_false.fetchall() + assert len(result_false) == 1 + assert isinstance(result_false[0].col_ts, pandas.Timestamp) + + # Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. + mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) + mock_connection._arrow_to_pandas_kwargs = {} + rs_ts_true = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_true = rs_ts_true.fetchall() + assert len(result_true) == 1 + assert isinstance(result_true[0].col_ts, datetime.datetime) From 0b1b05b9fc7d88036f180178a99883a6dbeda921 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 30 May 2025 11:54:14 +0530 Subject: [PATCH 2/5] fmt --- src/databricks/sql/client.py | 9 ++- tests/unit/test_arrow_conversion.py | 86 ++++++++++++++++++++--------- 2 files changed, 67 insertions(+), 28 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0b17ae7d7..79338f387 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1351,7 +1351,9 @@ def _convert_arrow_table(self, table): # Need to use nullable types, as otherwise type can change when there are missing values. # See https://arrow.apache.org/docs/python/pandas.html#nullable-types # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - DEFAULT_DTYPE_MAPPING: Dict[pyarrow.DataType, pandas.api.extensions.ExtensionDtype] = { + DEFAULT_DTYPE_MAPPING: Dict[ + pyarrow.DataType, pandas.api.extensions.ExtensionDtype + ] = { pyarrow.int8(): pandas.Int8Dtype(), pyarrow.int16(): pandas.Int16Dtype(), pyarrow.int32(): pandas.Int32Dtype(), @@ -1365,7 +1367,10 @@ def _convert_arrow_table(self, table): pyarrow.float64(): pandas.Float64Dtype(), pyarrow.string(): pandas.StringDtype(), } - dtype_mapping = {**DEFAULT_DTYPE_MAPPING, **self.connection._arrow_pandas_type_override} + dtype_mapping = { + **DEFAULT_DTYPE_MAPPING, + **self.connection._arrow_pandas_type_override, + } to_pandas_kwargs: dict[str, Any] = { "types_mapper": dtype_mapping.get, diff --git a/tests/unit/test_arrow_conversion.py b/tests/unit/test_arrow_conversion.py index b673d0dfa..30fd4f04e 100644 --- a/tests/unit/test_arrow_conversion.py +++ b/tests/unit/test_arrow_conversion.py @@ -15,23 +15,31 @@ def mock_connection(): conn.disable_pandas = False conn._arrow_pandas_type_override = {} conn._arrow_to_pandas_kwargs = {} - if not hasattr(conn, '_arrow_to_pandas_kwargs'): + if not hasattr(conn, "_arrow_to_pandas_kwargs"): conn._arrow_to_pandas_kwargs = {} return conn + @pytest.fixture def mock_thrift_backend(sample_arrow_table): tb = MagicMock() - empty_arrays = [pyarrow.array([], type=field.type) for field in sample_arrow_table.schema] - empty_table = pyarrow.Table.from_arrays(empty_arrays, schema=sample_arrow_table.schema) - tb.fetch_results.return_value = (ArrowQueue(empty_table, 0) , False) + empty_arrays = [ + pyarrow.array([], type=field.type) for field in sample_arrow_table.schema + ] + empty_table = pyarrow.Table.from_arrays( + empty_arrays, schema=sample_arrow_table.schema + ) + tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False) return tb + @pytest.fixture def mock_raw_execute_response(): er = MagicMock(spec=ExecuteResponse) - er.description = [("col_int", "int", None, None, None, None, None), - ("col_str", "string", None, None, None, None, None)] + er.description = [ + ("col_int", "int", None, None, None, None, None), + ("col_str", "string", None, None, None, None, None), + ] er.arrow_schema_bytes = None er.arrow_queue = None er.has_more_rows = False @@ -42,27 +50,33 @@ def mock_raw_execute_response(): er.is_staging_operation = False return er + @pytest.fixture def sample_arrow_table(): data = [ pyarrow.array([1, 2, 3], type=pyarrow.int32()), - pyarrow.array(["a", "b", "c"], type=pyarrow.string()) + pyarrow.array(["a", "b", "c"], type=pyarrow.string()), ] - schema = pyarrow.schema([ - ('col_int', pyarrow.int32()), - ('col_str', pyarrow.string()) - ]) + schema = pyarrow.schema( + [("col_int", pyarrow.int32()), ("col_str", pyarrow.string())] + ) return pyarrow.Table.from_arrays(data, schema=schema) -def test_convert_arrow_table_default(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table): - mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) +def test_convert_arrow_table_default( + mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table +): + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) result_one = rs.fetchone() assert isinstance(result_one, Row) assert result_one.col_int == 1 assert result_one.col_str == "a" - mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) result_all = rs.fetchall() assert len(result_all) == 3 @@ -71,9 +85,13 @@ def test_convert_arrow_table_default(mock_connection, mock_thrift_backend, mock_ assert result_all[1].col_str == "b" -def test_convert_arrow_table_disable_pandas(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table): +def test_convert_arrow_table_disable_pandas( + mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table +): mock_connection.disable_pandas = True - mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) result = rs.fetchall() assert len(result) == 3 @@ -84,9 +102,15 @@ def test_convert_arrow_table_disable_pandas(mock_connection, mock_thrift_backend assert isinstance(sample_arrow_table.column(1)[0].as_py(), str) -def test_convert_arrow_table_type_override(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table): - mock_connection._arrow_pandas_type_override = {pyarrow.int32(): pandas.Float64Dtype()} - mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) +def test_convert_arrow_table_type_override( + mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table +): + mock_connection._arrow_pandas_type_override = { + pyarrow.int32(): pandas.Float64Dtype() + } + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) result = rs.fetchall() assert len(result) == 3 @@ -95,18 +119,24 @@ def test_convert_arrow_table_type_override(mock_connection, mock_thrift_backend, assert result[0].col_str == "a" -def test_convert_arrow_table_to_pandas_kwargs(mock_connection, mock_thrift_backend, mock_raw_execute_response): +def test_convert_arrow_table_to_pandas_kwargs( + mock_connection, mock_thrift_backend, mock_raw_execute_response +): dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) - ts_array = pyarrow.array([dt_obj], type=pyarrow.timestamp('us', tz='UTC')) - ts_schema = pyarrow.schema([('col_ts', pyarrow.timestamp('us', tz='UTC'))]) + ts_array = pyarrow.array([dt_obj], type=pyarrow.timestamp("us", tz="UTC")) + ts_schema = pyarrow.schema([("col_ts", pyarrow.timestamp("us", tz="UTC"))]) ts_table = pyarrow.Table.from_arrays([ts_array], schema=ts_schema) - mock_raw_execute_response.description = [("col_ts", "timestamp", None, None, None, None, None)] + mock_raw_execute_response.description = [ + ("col_ts", "timestamp", None, None, None, None, None) + ] mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) # Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True} - rs_ts_true = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + rs_ts_true = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) result_true = rs_ts_true.fetchall() assert len(result_true) == 1 assert isinstance(result_true[0].col_ts, datetime.datetime) @@ -114,7 +144,9 @@ def test_convert_arrow_table_to_pandas_kwargs(mock_connection, mock_thrift_backe # Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} - rs_ts_false = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + rs_ts_false = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) result_false = rs_ts_false.fetchall() assert len(result_false) == 1 assert isinstance(result_false[0].col_ts, pandas.Timestamp) @@ -122,7 +154,9 @@ def test_convert_arrow_table_to_pandas_kwargs(mock_connection, mock_thrift_backe # Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) mock_connection._arrow_to_pandas_kwargs = {} - rs_ts_true = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + rs_ts_true = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) result_true = rs_ts_true.fetchall() assert len(result_true) == 1 assert isinstance(result_true[0].col_ts, datetime.datetime) From 647ed391be8377afab7fb2ad48d336f656f16185 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 30 May 2025 12:29:47 +0530 Subject: [PATCH 3/5] fix unit tests --- src/databricks/sql/client.py | 19 +- tests/unit/test_arrow_conversion.py | 328 +++++++++++++++------------- 2 files changed, 192 insertions(+), 155 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 79338f387..da1177f45 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1367,9 +1367,17 @@ def _convert_arrow_table(self, table): pyarrow.float64(): pandas.Float64Dtype(), pyarrow.string(): pandas.StringDtype(), } + + arrow_pandas_type_override = self.connection._arrow_pandas_type_override + if not isinstance(arrow_pandas_type_override, dict): + logger.debug( + "_arrow_pandas_type_override on connection was not a dict, using default type mapping" + ) + arrow_pandas_type_override = {} + dtype_mapping = { **DEFAULT_DTYPE_MAPPING, - **self.connection._arrow_pandas_type_override, + **arrow_pandas_type_override, } to_pandas_kwargs: dict[str, Any] = { @@ -1377,7 +1385,14 @@ def _convert_arrow_table(self, table): "date_as_object": True, "timestamp_as_object": True, } - to_pandas_kwargs.update(self.connection._arrow_to_pandas_kwargs) + + arrow_to_pandas_kwargs = self.connection._arrow_to_pandas_kwargs + if isinstance(arrow_to_pandas_kwargs, dict): + to_pandas_kwargs.update(arrow_to_pandas_kwargs) + else: + logger.debug( + "_arrow_to_pandas_kwargs on connection was not a dict, using default arguments" + ) # Need to rename columns, as the to_pandas function cannot handle duplicate column names table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) diff --git a/tests/unit/test_arrow_conversion.py b/tests/unit/test_arrow_conversion.py index 30fd4f04e..78d43635a 100644 --- a/tests/unit/test_arrow_conversion.py +++ b/tests/unit/test_arrow_conversion.py @@ -1,162 +1,184 @@ import pytest -import pyarrow + +try: + import pyarrow as pa +except ImportError: + pa = None import pandas import datetime -from unittest.mock import MagicMock, patch +import unittest +from unittest.mock import MagicMock from databricks.sql.client import ResultSet, Connection, ExecuteResponse from databricks.sql.types import Row from databricks.sql.utils import ArrowQueue - -@pytest.fixture -def mock_connection(): - conn = MagicMock(spec=Connection) - conn.disable_pandas = False - conn._arrow_pandas_type_override = {} - conn._arrow_to_pandas_kwargs = {} - if not hasattr(conn, "_arrow_to_pandas_kwargs"): +@pytest.mark.skipif(pa is None, reason="PyArrow is not installed") +class ArrowConversionTests(unittest.TestCase): + @staticmethod + def mock_connection_static(): + conn = MagicMock(spec=Connection) + conn.disable_pandas = False + conn._arrow_pandas_type_override = {} conn._arrow_to_pandas_kwargs = {} - return conn - - -@pytest.fixture -def mock_thrift_backend(sample_arrow_table): - tb = MagicMock() - empty_arrays = [ - pyarrow.array([], type=field.type) for field in sample_arrow_table.schema - ] - empty_table = pyarrow.Table.from_arrays( - empty_arrays, schema=sample_arrow_table.schema - ) - tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False) - return tb - - -@pytest.fixture -def mock_raw_execute_response(): - er = MagicMock(spec=ExecuteResponse) - er.description = [ - ("col_int", "int", None, None, None, None, None), - ("col_str", "string", None, None, None, None, None), - ] - er.arrow_schema_bytes = None - er.arrow_queue = None - er.has_more_rows = False - er.lz4_compressed = False - er.command_handle = MagicMock() - er.status = MagicMock() - er.has_been_closed_server_side = False - er.is_staging_operation = False - return er - - -@pytest.fixture -def sample_arrow_table(): - data = [ - pyarrow.array([1, 2, 3], type=pyarrow.int32()), - pyarrow.array(["a", "b", "c"], type=pyarrow.string()), - ] - schema = pyarrow.schema( - [("col_int", pyarrow.int32()), ("col_str", pyarrow.string())] - ) - return pyarrow.Table.from_arrays(data, schema=schema) - - -def test_convert_arrow_table_default( - mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table -): - mock_raw_execute_response.arrow_queue = ArrowQueue( - sample_arrow_table, sample_arrow_table.num_rows - ) - rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) - result_one = rs.fetchone() - assert isinstance(result_one, Row) - assert result_one.col_int == 1 - assert result_one.col_str == "a" - mock_raw_execute_response.arrow_queue = ArrowQueue( - sample_arrow_table, sample_arrow_table.num_rows - ) - rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) - result_all = rs.fetchall() - assert len(result_all) == 3 - assert isinstance(result_all[0], Row) - assert result_all[0].col_int == 1 - assert result_all[1].col_str == "b" - - -def test_convert_arrow_table_disable_pandas( - mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table -): - mock_connection.disable_pandas = True - mock_raw_execute_response.arrow_queue = ArrowQueue( - sample_arrow_table, sample_arrow_table.num_rows - ) - rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) - result = rs.fetchall() - assert len(result) == 3 - assert isinstance(result[0], Row) - assert result[0].col_int == 1 - assert result[0].col_str == "a" - assert isinstance(sample_arrow_table.column(0)[0].as_py(), int) - assert isinstance(sample_arrow_table.column(1)[0].as_py(), str) - - -def test_convert_arrow_table_type_override( - mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table -): - mock_connection._arrow_pandas_type_override = { - pyarrow.int32(): pandas.Float64Dtype() - } - mock_raw_execute_response.arrow_queue = ArrowQueue( - sample_arrow_table, sample_arrow_table.num_rows - ) - rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) - result = rs.fetchall() - assert len(result) == 3 - assert isinstance(result[0].col_int, float) - assert result[0].col_int == 1.0 - assert result[0].col_str == "a" - - -def test_convert_arrow_table_to_pandas_kwargs( - mock_connection, mock_thrift_backend, mock_raw_execute_response -): - dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) - ts_array = pyarrow.array([dt_obj], type=pyarrow.timestamp("us", tz="UTC")) - ts_schema = pyarrow.schema([("col_ts", pyarrow.timestamp("us", tz="UTC"))]) - ts_table = pyarrow.Table.from_arrays([ts_array], schema=ts_schema) - - mock_raw_execute_response.description = [ - ("col_ts", "timestamp", None, None, None, None, None) - ] - mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) - - # Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. - mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True} - rs_ts_true = ResultSet( - mock_connection, mock_raw_execute_response, mock_thrift_backend - ) - result_true = rs_ts_true.fetchall() - assert len(result_true) == 1 - assert isinstance(result_true[0].col_ts, datetime.datetime) - - # Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. - mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) - mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} - rs_ts_false = ResultSet( - mock_connection, mock_raw_execute_response, mock_thrift_backend - ) - result_false = rs_ts_false.fetchall() - assert len(result_false) == 1 - assert isinstance(result_false[0].col_ts, pandas.Timestamp) - - # Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. - mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) - mock_connection._arrow_to_pandas_kwargs = {} - rs_ts_true = ResultSet( - mock_connection, mock_raw_execute_response, mock_thrift_backend - ) - result_true = rs_ts_true.fetchall() - assert len(result_true) == 1 - assert isinstance(result_true[0].col_ts, datetime.datetime) + return conn + + @staticmethod + def sample_arrow_table_static(): + data = [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["a", "b", "c"], type=pa.string()), + ] + schema = pa.schema([("col_int", pa.int32()), ("col_str", pa.string())]) + return pa.Table.from_arrays(data, schema=schema) + + @staticmethod + def mock_thrift_backend_static(): + sample_table = ArrowConversionTests.sample_arrow_table_static() + tb = MagicMock() + empty_arrays = [pa.array([], type=field.type) for field in sample_table.schema] + empty_table = pa.Table.from_arrays(empty_arrays, schema=sample_table.schema) + tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False) + return tb + + @staticmethod + def mock_raw_execute_response_static(): + er = MagicMock(spec=ExecuteResponse) + er.description = [ + ("col_int", "int", None, None, None, None, None), + ("col_str", "string", None, None, None, None, None), + ] + er.arrow_schema_bytes = None + er.arrow_queue = None + er.has_more_rows = False + er.lz4_compressed = False + er.command_handle = MagicMock() + er.status = MagicMock() + er.has_been_closed_server_side = False + er.is_staging_operation = False + return er + + def test_convert_arrow_table_default(self): + mock_connection = ArrowConversionTests.mock_connection_static() + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_one = rs.fetchone() + self.assertIsInstance(result_one, Row) + self.assertEqual(result_one.col_int, 1) + self.assertEqual(result_one.col_str, "a") + + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_all = rs.fetchall() + self.assertEqual(len(result_all), 3) + self.assertIsInstance(result_all[0], Row) + self.assertEqual(result_all[0].col_int, 1) + self.assertEqual(result_all[1].col_str, "b") + + def test_convert_arrow_table_disable_pandas(self): + mock_connection = ArrowConversionTests.mock_connection_static() + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + mock_connection.disable_pandas = True + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result = rs.fetchall() + self.assertEqual(len(result), 3) + self.assertIsInstance(result[0], Row) + self.assertEqual(result[0].col_int, 1) + self.assertEqual(result[0].col_str, "a") + self.assertIsInstance(sample_arrow_table.column(0)[0].as_py(), int) + self.assertIsInstance(sample_arrow_table.column(1)[0].as_py(), str) + + def test_convert_arrow_table_type_override(self): + mock_connection = ArrowConversionTests.mock_connection_static() + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + mock_connection._arrow_pandas_type_override = { + pa.int32(): pandas.Float64Dtype() + } + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result = rs.fetchall() + self.assertEqual(len(result), 3) + self.assertIsInstance(result[0].col_int, float) + self.assertEqual(result[0].col_int, 1.0) + self.assertEqual(result[0].col_str, "a") + + def test_convert_arrow_table_to_pandas_kwargs(self): + mock_connection = ArrowConversionTests.mock_connection_static() + mock_thrift_backend = ( + ArrowConversionTests.mock_thrift_backend_static() + ) # Does not use sample_arrow_table + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + ts_array = pa.array([dt_obj], type=pa.timestamp("us", tz="UTC")) + ts_schema = pa.schema([("col_ts", pa.timestamp("us", tz="UTC"))]) + ts_table = pa.Table.from_arrays([ts_array], schema=ts_schema) + + mock_raw_execute_response.description = [ + ("col_ts", "timestamp", None, None, None, None, None) + ] + mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) + + # Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True} + rs_ts_true = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) + result_true = rs_ts_true.fetchall() + self.assertEqual(len(result_true), 1) + self.assertIsInstance(result_true[0].col_ts, datetime.datetime) + + # Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. + mock_raw_execute_response.arrow_queue = ArrowQueue( + ts_table, ts_table.num_rows + ) # Reset queue + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} + rs_ts_false = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) + result_false = rs_ts_false.fetchall() + self.assertEqual(len(result_false), 1) + self.assertIsInstance(result_false[0].col_ts, pandas.Timestamp) + + # Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. + mock_raw_execute_response.arrow_queue = ArrowQueue( + ts_table, ts_table.num_rows + ) # Reset queue + mock_connection._arrow_to_pandas_kwargs = {} + rs_ts_default = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) + result_default = rs_ts_default.fetchall() + self.assertEqual(len(result_default), 1) + self.assertIsInstance(result_default[0].col_ts, datetime.datetime) + + +if __name__ == "__main__": + unittest.main() From 31b44d4d53bdf154a0cd63630509621fcd753264 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Mon, 16 Jun 2025 06:38:48 +0000 Subject: [PATCH 4/5] Add _arrow_pandas_type_override and _arrow_to_pandas_kwargs to Connection class --- src/databricks/sql/client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index da1177f45..d47c3f24c 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -234,6 +234,10 @@ def read(self) -> Optional[OAuthToken]: self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) + self._arrow_pandas_type_override = kwargs.get( + "_arrow_pandas_type_override", {} + ) + self._arrow_to_pandas_kwargs = kwargs.get("_arrow_to_pandas_kwargs", {}) auth_provider = get_python_sql_connector_auth_provider( server_hostname, **kwargs From 2f32c6c08d9b40266a931b75b7868d526642de7e Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Mon, 16 Jun 2025 06:47:33 +0000 Subject: [PATCH 5/5] fmt --- src/databricks/sql/client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index d47c3f24c..112a60dc9 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -234,9 +234,7 @@ def read(self) -> Optional[OAuthToken]: self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) - self._arrow_pandas_type_override = kwargs.get( - "_arrow_pandas_type_override", {} - ) + self._arrow_pandas_type_override = kwargs.get("_arrow_pandas_type_override", {}) self._arrow_to_pandas_kwargs = kwargs.get("_arrow_to_pandas_kwargs", {}) auth_provider = get_python_sql_connector_auth_provider( pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy