Skip to content

Concat tables to be backward compatible #647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 15 additions & 33 deletions src/databricks/sql/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from databricks.sql.utils import (
ColumnTable,
ColumnQueue,
concat_table_chunks,
)
from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse
from databricks.sql.telemetry.models.event import StatementType
Expand Down Expand Up @@ -296,23 +297,6 @@ def _convert_columnar_table(self, table):

return result

def merge_columnar(self, result1, result2) -> "ColumnTable":
"""
Function to merge / combining the columnar results into a single result
:param result1:
:param result2:
:return:
"""

if result1.column_names != result2.column_names:
raise ValueError("The columns in the results don't match")

merged_result = [
result1.column_table[i] + result2.column_table[i]
for i in range(result1.num_columns)
]
return ColumnTable(merged_result, result1.column_names)

def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
"""
Fetch the next set of rows of a query result, returning a PyArrow table.
Expand All @@ -337,7 +321,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
n_remaining_rows -= partial_results.num_rows
self._next_row_index += partial_results.num_rows

return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
return concat_table_chunks(partial_result_chunks)

def fetchmany_columnar(self, size: int):
"""
Expand All @@ -350,19 +334,19 @@ def fetchmany_columnar(self, size: int):
results = self.results.next_n_rows(size)
n_remaining_rows = size - results.num_rows
self._next_row_index += results.num_rows

partial_result_chunks = [results]
while (
n_remaining_rows > 0
and not self.has_been_closed_server_side
and self.has_more_rows
):
self._fill_results_buffer()
partial_results = self.results.next_n_rows(n_remaining_rows)
results = self.merge_columnar(results, partial_results)
partial_result_chunks.append(partial_results)
n_remaining_rows -= partial_results.num_rows
self._next_row_index += partial_results.num_rows

return results
return concat_table_chunks(partial_result_chunks)

def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
Expand All @@ -372,36 +356,34 @@ def fetchall_arrow(self) -> "pyarrow.Table":
while not self.has_been_closed_server_side and self.has_more_rows:
self._fill_results_buffer()
partial_results = self.results.remaining_rows()
if isinstance(results, ColumnTable) and isinstance(
partial_results, ColumnTable
):
results = self.merge_columnar(results, partial_results)
else:
partial_result_chunks.append(partial_results)
partial_result_chunks.append(partial_results)
self._next_row_index += partial_results.num_rows

result_table = concat_table_chunks(partial_result_chunks)
# If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table
# Valid only for metadata commands result set
if isinstance(results, ColumnTable) and pyarrow:
if isinstance(result_table, ColumnTable) and pyarrow:
data = {
name: col
for name, col in zip(results.column_names, results.column_table)
for name, col in zip(
result_table.column_names, result_table.column_table
)
}
return pyarrow.Table.from_pydict(data)
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
return result_table

def fetchall_columnar(self):
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
results = self.results.remaining_rows()
self._next_row_index += results.num_rows

partial_result_chunks = [results]
while not self.has_been_closed_server_side and self.has_more_rows:
self._fill_results_buffer()
partial_results = self.results.remaining_rows()
results = self.merge_columnar(results, partial_results)
partial_result_chunks.append(partial_results)
self._next_row_index += partial_results.num_rows

return results
return concat_table_chunks(partial_result_chunks)

def fetchone(self) -> Optional[Row]:
"""
Expand Down
22 changes: 22 additions & 0 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,3 +853,25 @@ def _create_python_tuple(t_col_value_wrapper):
result[i] = None

return tuple(result)


def concat_table_chunks(
table_chunks: List[Union["pyarrow.Table", ColumnTable]]
) -> Union["pyarrow.Table", ColumnTable]:
if len(table_chunks) == 0:
return table_chunks

if isinstance(table_chunks[0], ColumnTable):
## Check if all have the same column names
if not all(
table.column_names == table_chunks[0].column_names for table in table_chunks
):
raise ValueError("The columns in the results don't match")

result_table: List[List[Any]] = [[] for _ in range(table_chunks[0].num_columns)]
for i in range(0, len(table_chunks)):
for j in range(table_chunks[i].num_columns):
result_table[j].extend(table_chunks[i].column_table[j])
return ColumnTable(result_table, table_chunks[0].column_names)
else:
return pyarrow.concat_tables(table_chunks, use_threads=True)
41 changes: 40 additions & 1 deletion tests/unit/test_util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import decimal
import datetime
from datetime import timezone, timedelta
import pytest
from databricks.sql.utils import (
convert_to_assigned_datatypes_in_column_table,
ColumnTable,
concat_table_chunks,
)

from databricks.sql.utils import convert_to_assigned_datatypes_in_column_table
try:
import pyarrow
except ImportError:
pyarrow = None


class TestUtils:
Expand Down Expand Up @@ -122,3 +131,33 @@ def test_convert_to_assigned_datatypes_in_column_table(self):
for index, entry in enumerate(converted_column_table):
assert entry[0] == expected_convertion[index][0]
assert isinstance(entry[0], expected_convertion[index][1])

def test_concat_table_chunks_column_table(self):
column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"])
column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col2"])

result_table = concat_table_chunks([column_table1, column_table2])

assert result_table.column_table == [[1, 2, 3, 4], [5, 6, 7, 8]]
assert result_table.column_names == ["col1", "col2"]

@pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed")
def test_concat_table_chunks_arrow_table(self):
arrow_table1 = pyarrow.Table.from_pydict({"col1": [1, 2], "col2": [5, 6]})
arrow_table2 = pyarrow.Table.from_pydict({"col1": [3, 4], "col2": [7, 8]})

result_table = concat_table_chunks([arrow_table1, arrow_table2])
assert result_table.column_names == ["col1", "col2"]
assert result_table.column("col1").to_pylist() == [1, 2, 3, 4]
assert result_table.column("col2").to_pylist() == [5, 6, 7, 8]

def test_concat_table_chunks_empty(self):
result_table = concat_table_chunks([])
assert result_table == []

def test_concat_table_chunks__incorrect_column_names_error(self):
column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"])
column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col3"])

with pytest.raises(ValueError):
concat_table_chunks([column_table1, column_table2])
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy