diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 3d3587cae..9feb6e924 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -20,6 +20,7 @@ from databricks.sql.utils import ( ColumnTable, ColumnQueue, + concat_table_chunks, ) from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse from databricks.sql.telemetry.models.event import StatementType @@ -296,23 +297,6 @@ def _convert_columnar_table(self, table): return result - def merge_columnar(self, result1, result2) -> "ColumnTable": - """ - Function to merge / combining the columnar results into a single result - :param result1: - :param result2: - :return: - """ - - if result1.column_names != result2.column_names: - raise ValueError("The columns in the results don't match") - - merged_result = [ - result1.column_table[i] + result2.column_table[i] - for i in range(result1.num_columns) - ] - return ColumnTable(merged_result, result1.column_names) - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ Fetch the next set of rows of a query result, returning a PyArrow table. @@ -337,7 +321,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - return pyarrow.concat_tables(partial_result_chunks, use_threads=True) + return concat_table_chunks(partial_result_chunks) def fetchmany_columnar(self, size: int): """ @@ -350,7 +334,7 @@ def fetchmany_columnar(self, size: int): results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows - + partial_result_chunks = [results] while ( n_remaining_rows > 0 and not self.has_been_closed_server_side @@ -358,11 +342,11 @@ def fetchmany_columnar(self, size: int): ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - results = self.merge_columnar(results, partial_results) + partial_result_chunks.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - return results + return concat_table_chunks(partial_result_chunks) def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" @@ -372,36 +356,34 @@ def fetchall_arrow(self) -> "pyarrow.Table": while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - if isinstance(results, ColumnTable) and isinstance( - partial_results, ColumnTable - ): - results = self.merge_columnar(results, partial_results) - else: - partial_result_chunks.append(partial_results) + partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows + result_table = concat_table_chunks(partial_result_chunks) # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table # Valid only for metadata commands result set - if isinstance(results, ColumnTable) and pyarrow: + if isinstance(result_table, ColumnTable) and pyarrow: data = { name: col - for name, col in zip(results.column_names, results.column_table) + for name, col in zip( + result_table.column_names, result_table.column_table + ) } return pyarrow.Table.from_pydict(data) - return pyarrow.concat_tables(partial_result_chunks, use_threads=True) + return result_table def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows - + partial_result_chunks = [results] while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - results = self.merge_columnar(results, partial_results) + partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows - return results + return concat_table_chunks(partial_result_chunks) def fetchone(self) -> Optional[Row]: """ diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 4617f7de6..c1d89ca5c 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -853,3 +853,25 @@ def _create_python_tuple(t_col_value_wrapper): result[i] = None return tuple(result) + + +def concat_table_chunks( + table_chunks: List[Union["pyarrow.Table", ColumnTable]] +) -> Union["pyarrow.Table", ColumnTable]: + if len(table_chunks) == 0: + return table_chunks + + if isinstance(table_chunks[0], ColumnTable): + ## Check if all have the same column names + if not all( + table.column_names == table_chunks[0].column_names for table in table_chunks + ): + raise ValueError("The columns in the results don't match") + + result_table: List[List[Any]] = [[] for _ in range(table_chunks[0].num_columns)] + for i in range(0, len(table_chunks)): + for j in range(table_chunks[i].num_columns): + result_table[j].extend(table_chunks[i].column_table[j]) + return ColumnTable(result_table, table_chunks[0].column_names) + else: + return pyarrow.concat_tables(table_chunks, use_threads=True) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index a47ab786f..713342b2e 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -1,8 +1,17 @@ import decimal import datetime from datetime import timezone, timedelta +import pytest +from databricks.sql.utils import ( + convert_to_assigned_datatypes_in_column_table, + ColumnTable, + concat_table_chunks, +) -from databricks.sql.utils import convert_to_assigned_datatypes_in_column_table +try: + import pyarrow +except ImportError: + pyarrow = None class TestUtils: @@ -122,3 +131,33 @@ def test_convert_to_assigned_datatypes_in_column_table(self): for index, entry in enumerate(converted_column_table): assert entry[0] == expected_convertion[index][0] assert isinstance(entry[0], expected_convertion[index][1]) + + def test_concat_table_chunks_column_table(self): + column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"]) + column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col2"]) + + result_table = concat_table_chunks([column_table1, column_table2]) + + assert result_table.column_table == [[1, 2, 3, 4], [5, 6, 7, 8]] + assert result_table.column_names == ["col1", "col2"] + + @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") + def test_concat_table_chunks_arrow_table(self): + arrow_table1 = pyarrow.Table.from_pydict({"col1": [1, 2], "col2": [5, 6]}) + arrow_table2 = pyarrow.Table.from_pydict({"col1": [3, 4], "col2": [7, 8]}) + + result_table = concat_table_chunks([arrow_table1, arrow_table2]) + assert result_table.column_names == ["col1", "col2"] + assert result_table.column("col1").to_pylist() == [1, 2, 3, 4] + assert result_table.column("col2").to_pylist() == [5, 6, 7, 8] + + def test_concat_table_chunks_empty(self): + result_table = concat_table_chunks([]) + assert result_table == [] + + def test_concat_table_chunks__incorrect_column_names_error(self): + column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"]) + column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col3"]) + + with pytest.raises(ValueError): + concat_table_chunks([column_table1, column_table2]) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy