From b2cf39b80c0cb53787c496d67d1b6ffe314a8bc0 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Tue, 22 Jul 2025 13:18:48 +0530 Subject: [PATCH 1/3] fixed --- src/databricks/sql/result_set.py | 54 ++++++++++++-------------------- src/databricks/sql/utils.py | 22 +++++++++++++ tests/unit/test_util.py | 41 +++++++++++++++++++++++- 3 files changed, 82 insertions(+), 35 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 074877d32..9ed0188bf 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -20,7 +20,12 @@ from databricks.sql.types import Row from databricks.sql.exc import RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue +from databricks.sql.utils import ( + ExecuteResponse, + ColumnTable, + ColumnQueue, + concat_table_chunks, +) logger = logging.getLogger(__name__) @@ -251,23 +256,6 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": - """ - Function to merge / combining the columnar results into a single result - :param result1: - :param result2: - :return: - """ - - if result1.column_names != result2.column_names: - raise ValueError("The columns in the results don't match") - - merged_result = [ - result1.column_table[i] + result2.column_table[i] - for i in range(result1.num_columns) - ] - return ColumnTable(merged_result, result1.column_names) - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ Fetch the next set of rows of a query result, returning a PyArrow table. @@ -292,7 +280,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - return pyarrow.concat_tables(partial_result_chunks, use_threads=True) + return concat_table_chunks(partial_result_chunks) def fetchmany_columnar(self, size: int): """ @@ -305,7 +293,7 @@ def fetchmany_columnar(self, size: int): results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows - + partial_result_chunks = [results] while ( n_remaining_rows > 0 and not self.has_been_closed_server_side @@ -313,11 +301,11 @@ def fetchmany_columnar(self, size: int): ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - results = self.merge_columnar(results, partial_results) + partial_result_chunks.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - return results + return concat_table_chunks(partial_result_chunks) def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" @@ -327,36 +315,34 @@ def fetchall_arrow(self) -> "pyarrow.Table": while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - if isinstance(results, ColumnTable) and isinstance( - partial_results, ColumnTable - ): - results = self.merge_columnar(results, partial_results) - else: - partial_result_chunks.append(partial_results) + partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows + result_table = concat_table_chunks(partial_result_chunks) # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table # Valid only for metadata commands result set - if isinstance(results, ColumnTable) and pyarrow: + if isinstance(result_table, ColumnTable) and pyarrow: data = { name: col - for name, col in zip(results.column_names, results.column_table) + for name, col in zip( + result_table.column_names, result_table.column_table + ) } return pyarrow.Table.from_pydict(data) - return pyarrow.concat_tables(partial_result_chunks, use_threads=True) + return result_table def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows - + partial_result_chunks = [results] while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - results = self.merge_columnar(results, partial_results) + partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows - return results + return concat_table_chunks(partial_result_chunks) def fetchone(self) -> Optional[Row]: """ diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index a3e3e1dd0..d62f2394f 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -785,3 +785,25 @@ def _create_python_tuple(t_col_value_wrapper): result[i] = None return tuple(result) + + +def concat_table_chunks( + table_chunks: List[Union["pyarrow.Table", ColumnTable]] +) -> Union["pyarrow.Table", ColumnTable]: + if len(table_chunks) == 0: + return table_chunks + + if isinstance(table_chunks[0], ColumnTable): + ## Check if all have the same column names + if not all( + table.column_names == table_chunks[0].column_names for table in table_chunks + ): + raise ValueError("The columns in the results don't match") + + result_table = table_chunks[0].column_table + for i in range(1, len(table_chunks)): + for j in range(table_chunks[i].num_columns): + result_table[j].extend(table_chunks[i].column_table[j]) + return ColumnTable(result_table, table_chunks[0].column_names) + else: + return pyarrow.concat_tables(table_chunks, use_threads=True) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index a47ab786f..713342b2e 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -1,8 +1,17 @@ import decimal import datetime from datetime import timezone, timedelta +import pytest +from databricks.sql.utils import ( + convert_to_assigned_datatypes_in_column_table, + ColumnTable, + concat_table_chunks, +) -from databricks.sql.utils import convert_to_assigned_datatypes_in_column_table +try: + import pyarrow +except ImportError: + pyarrow = None class TestUtils: @@ -122,3 +131,33 @@ def test_convert_to_assigned_datatypes_in_column_table(self): for index, entry in enumerate(converted_column_table): assert entry[0] == expected_convertion[index][0] assert isinstance(entry[0], expected_convertion[index][1]) + + def test_concat_table_chunks_column_table(self): + column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"]) + column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col2"]) + + result_table = concat_table_chunks([column_table1, column_table2]) + + assert result_table.column_table == [[1, 2, 3, 4], [5, 6, 7, 8]] + assert result_table.column_names == ["col1", "col2"] + + @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") + def test_concat_table_chunks_arrow_table(self): + arrow_table1 = pyarrow.Table.from_pydict({"col1": [1, 2], "col2": [5, 6]}) + arrow_table2 = pyarrow.Table.from_pydict({"col1": [3, 4], "col2": [7, 8]}) + + result_table = concat_table_chunks([arrow_table1, arrow_table2]) + assert result_table.column_names == ["col1", "col2"] + assert result_table.column("col1").to_pylist() == [1, 2, 3, 4] + assert result_table.column("col2").to_pylist() == [5, 6, 7, 8] + + def test_concat_table_chunks_empty(self): + result_table = concat_table_chunks([]) + assert result_table == [] + + def test_concat_table_chunks__incorrect_column_names_error(self): + column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"]) + column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col3"]) + + with pytest.raises(ValueError): + concat_table_chunks([column_table1, column_table2]) From 3232768fb68672d5474eb4fe7bb2ab1a21862a20 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Tue, 22 Jul 2025 15:21:13 +0530 Subject: [PATCH 2/3] Minor fix --- src/databricks/sql/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index d62f2394f..0c2dd54e5 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -800,8 +800,8 @@ def concat_table_chunks( ): raise ValueError("The columns in the results don't match") - result_table = table_chunks[0].column_table - for i in range(1, len(table_chunks)): + result_table = [[] for _ in range(table_chunks[0].num_columns)] + for i in range(0, len(table_chunks)): for j in range(table_chunks[i].num_columns): result_table[j].extend(table_chunks[i].column_table[j]) return ColumnTable(result_table, table_chunks[0].column_names) From 6eba353c5251198522e3db9d2eefc94ea0e6f2cd Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Tue, 22 Jul 2025 15:26:19 +0530 Subject: [PATCH 3/3] more types --- src/databricks/sql/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 0c2dd54e5..9a70ed38d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -800,7 +800,7 @@ def concat_table_chunks( ): raise ValueError("The columns in the results don't match") - result_table = [[] for _ in range(table_chunks[0].num_columns)] + result_table: List[List[Any]] = [[] for _ in range(table_chunks[0].num_columns)] for i in range(0, len(table_chunks)): for j in range(table_chunks[i].num_columns): result_table[j].extend(table_chunks[i].column_table[j]) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy