diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0c9a08a85..112a60dc9 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -213,6 +213,11 @@ def read(self) -> Optional[OAuthToken]: # (True by default) # use_cloud_fetch # Enable use of cloud fetch to extract large query results in parallel via cloud storage + # _arrow_pandas_type_override + # Override the default pandas dtype mapping for Arrow types. + # This is a dictionary of Arrow types to pandas dtypes. + # _arrow_to_pandas_kwargs + # Additional or modified arguments to pass to pandas.DataFrame constructor. logger.debug( "Connection.__init__(server_hostname=%s, http_path=%s)", @@ -229,6 +234,8 @@ def read(self) -> Optional[OAuthToken]: self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) + self._arrow_pandas_type_override = kwargs.get("_arrow_pandas_type_override", {}) + self._arrow_to_pandas_kwargs = kwargs.get("_arrow_to_pandas_kwargs", {}) auth_provider = get_python_sql_connector_auth_provider( server_hostname, **kwargs @@ -1346,7 +1353,9 @@ def _convert_arrow_table(self, table): # Need to use nullable types, as otherwise type can change when there are missing values. # See https://arrow.apache.org/docs/python/pandas.html#nullable-types # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { + DEFAULT_DTYPE_MAPPING: Dict[ + pyarrow.DataType, pandas.api.extensions.ExtensionDtype + ] = { pyarrow.int8(): pandas.Int8Dtype(), pyarrow.int16(): pandas.Int16Dtype(), pyarrow.int32(): pandas.Int32Dtype(), @@ -1361,13 +1370,35 @@ def _convert_arrow_table(self, table): pyarrow.string(): pandas.StringDtype(), } + arrow_pandas_type_override = self.connection._arrow_pandas_type_override + if not isinstance(arrow_pandas_type_override, dict): + logger.debug( + "_arrow_pandas_type_override on connection was not a dict, using default type mapping" + ) + arrow_pandas_type_override = {} + + dtype_mapping = { + **DEFAULT_DTYPE_MAPPING, + **arrow_pandas_type_override, + } + + to_pandas_kwargs: dict[str, Any] = { + "types_mapper": dtype_mapping.get, + "date_as_object": True, + "timestamp_as_object": True, + } + + arrow_to_pandas_kwargs = self.connection._arrow_to_pandas_kwargs + if isinstance(arrow_to_pandas_kwargs, dict): + to_pandas_kwargs.update(arrow_to_pandas_kwargs) + else: + logger.debug( + "_arrow_to_pandas_kwargs on connection was not a dict, using default arguments" + ) + # Need to rename columns, as the to_pandas function cannot handle duplicate column names table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) + df = table_renamed.to_pandas(**to_pandas_kwargs) res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] diff --git a/tests/unit/test_arrow_conversion.py b/tests/unit/test_arrow_conversion.py new file mode 100644 index 000000000..78d43635a --- /dev/null +++ b/tests/unit/test_arrow_conversion.py @@ -0,0 +1,184 @@ +import pytest + +try: + import pyarrow as pa +except ImportError: + pa = None +import pandas +import datetime +import unittest +from unittest.mock import MagicMock + +from databricks.sql.client import ResultSet, Connection, ExecuteResponse +from databricks.sql.types import Row +from databricks.sql.utils import ArrowQueue + +@pytest.mark.skipif(pa is None, reason="PyArrow is not installed") +class ArrowConversionTests(unittest.TestCase): + @staticmethod + def mock_connection_static(): + conn = MagicMock(spec=Connection) + conn.disable_pandas = False + conn._arrow_pandas_type_override = {} + conn._arrow_to_pandas_kwargs = {} + return conn + + @staticmethod + def sample_arrow_table_static(): + data = [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["a", "b", "c"], type=pa.string()), + ] + schema = pa.schema([("col_int", pa.int32()), ("col_str", pa.string())]) + return pa.Table.from_arrays(data, schema=schema) + + @staticmethod + def mock_thrift_backend_static(): + sample_table = ArrowConversionTests.sample_arrow_table_static() + tb = MagicMock() + empty_arrays = [pa.array([], type=field.type) for field in sample_table.schema] + empty_table = pa.Table.from_arrays(empty_arrays, schema=sample_table.schema) + tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False) + return tb + + @staticmethod + def mock_raw_execute_response_static(): + er = MagicMock(spec=ExecuteResponse) + er.description = [ + ("col_int", "int", None, None, None, None, None), + ("col_str", "string", None, None, None, None, None), + ] + er.arrow_schema_bytes = None + er.arrow_queue = None + er.has_more_rows = False + er.lz4_compressed = False + er.command_handle = MagicMock() + er.status = MagicMock() + er.has_been_closed_server_side = False + er.is_staging_operation = False + return er + + def test_convert_arrow_table_default(self): + mock_connection = ArrowConversionTests.mock_connection_static() + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_one = rs.fetchone() + self.assertIsInstance(result_one, Row) + self.assertEqual(result_one.col_int, 1) + self.assertEqual(result_one.col_str, "a") + + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_all = rs.fetchall() + self.assertEqual(len(result_all), 3) + self.assertIsInstance(result_all[0], Row) + self.assertEqual(result_all[0].col_int, 1) + self.assertEqual(result_all[1].col_str, "b") + + def test_convert_arrow_table_disable_pandas(self): + mock_connection = ArrowConversionTests.mock_connection_static() + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + mock_connection.disable_pandas = True + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result = rs.fetchall() + self.assertEqual(len(result), 3) + self.assertIsInstance(result[0], Row) + self.assertEqual(result[0].col_int, 1) + self.assertEqual(result[0].col_str, "a") + self.assertIsInstance(sample_arrow_table.column(0)[0].as_py(), int) + self.assertIsInstance(sample_arrow_table.column(1)[0].as_py(), str) + + def test_convert_arrow_table_type_override(self): + mock_connection = ArrowConversionTests.mock_connection_static() + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + mock_connection._arrow_pandas_type_override = { + pa.int32(): pandas.Float64Dtype() + } + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result = rs.fetchall() + self.assertEqual(len(result), 3) + self.assertIsInstance(result[0].col_int, float) + self.assertEqual(result[0].col_int, 1.0) + self.assertEqual(result[0].col_str, "a") + + def test_convert_arrow_table_to_pandas_kwargs(self): + mock_connection = ArrowConversionTests.mock_connection_static() + mock_thrift_backend = ( + ArrowConversionTests.mock_thrift_backend_static() + ) # Does not use sample_arrow_table + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + ts_array = pa.array([dt_obj], type=pa.timestamp("us", tz="UTC")) + ts_schema = pa.schema([("col_ts", pa.timestamp("us", tz="UTC"))]) + ts_table = pa.Table.from_arrays([ts_array], schema=ts_schema) + + mock_raw_execute_response.description = [ + ("col_ts", "timestamp", None, None, None, None, None) + ] + mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) + + # Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True} + rs_ts_true = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) + result_true = rs_ts_true.fetchall() + self.assertEqual(len(result_true), 1) + self.assertIsInstance(result_true[0].col_ts, datetime.datetime) + + # Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. + mock_raw_execute_response.arrow_queue = ArrowQueue( + ts_table, ts_table.num_rows + ) # Reset queue + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} + rs_ts_false = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) + result_false = rs_ts_false.fetchall() + self.assertEqual(len(result_false), 1) + self.assertIsInstance(result_false[0].col_ts, pandas.Timestamp) + + # Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. + mock_raw_execute_response.arrow_queue = ArrowQueue( + ts_table, ts_table.num_rows + ) # Reset queue + mock_connection._arrow_to_pandas_kwargs = {} + rs_ts_default = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) + result_default = rs_ts_default.fetchall() + self.assertEqual(len(result_default), 1) + self.assertIsInstance(result_default[0].col_ts, datetime.datetime) + + +if __name__ == "__main__": + unittest.main()
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: