diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 130f0c5b..db430a52 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -364,14 +364,14 @@ def __init__( # Initialize table and position self.table = self._create_next_table() - def _create_next_table(self) -> Union["pyarrow.Table", None]: + def _create_next_table(self) -> "pyarrow.Table": """Create next table by retrieving the logical next downloaded file.""" if self.link_fetcher is None: - return None + return self._create_empty_table() chunk_link = self.link_fetcher.get_chunk_link(self._current_chunk_index) if chunk_link is None: - return None + return self._create_empty_table() row_offset = chunk_link.row_offset # NOTE: link has already been submitted to download manager at this point diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 32b698be..e187771f 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -1,6 +1,7 @@ import logging from concurrent.futures import ThreadPoolExecutor, Future +import threading from typing import List, Union, Tuple, Optional from databricks.sql.cloudfetch.downloader import ( @@ -8,6 +9,7 @@ DownloadableResultSettings, DownloadedFile, ) +from databricks.sql.exc import Error from databricks.sql.types import SSLOptions from databricks.sql.telemetry.models.event import StatementType from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink @@ -39,8 +41,10 @@ def __init__( self._pending_links.append((i, link)) self.chunk_id += len(links) - self._download_tasks: List[Future[DownloadedFile]] = [] self._max_download_threads: int = max_download_threads + + self._download_condition = threading.Condition() + self._download_tasks: List[Future[DownloadedFile]] = [] self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads) self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed) @@ -48,17 +52,13 @@ def __init__( self.session_id_hex = session_id_hex self.statement_id = statement_id - def get_next_downloaded_file( - self, next_row_offset: int - ) -> Union[DownloadedFile, None]: + def get_next_downloaded_file(self, next_row_offset: int) -> DownloadedFile: """ Get next file that starts at given offset. This function gets the next downloaded file in which its rows start at the specified next_row_offset in relation to the full result. File downloads are scheduled if not already, and once the correct download handler is located, the function waits for the download status and returns the resulting file. - If there are no more downloads, a download was not successful, or the correct file could not be located, - this function shuts down the thread pool and returns None. Args: next_row_offset (int): The offset of the starting row of the next file we want data from. @@ -67,10 +67,11 @@ def get_next_downloaded_file( # Make sure the download queue is always full self._schedule_downloads() - # No more files to download from this batch of links - if len(self._download_tasks) == 0: - self._shutdown_manager() - return None + while len(self._download_tasks) == 0: + if self._thread_pool._shutdown: + raise Error("download manager shut down before file was ready") + with self._download_condition: + self._download_condition.wait() task = self._download_tasks.pop(0) # Future's `result()` method will wait for the call to complete, and return @@ -113,6 +114,9 @@ def _schedule_downloads(self): task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) + with self._download_condition: + self._download_condition.notify_all() + def add_link(self, link: TSparkArrowResultLink): """ Add more links to the download manager. @@ -132,8 +136,12 @@ def add_link(self, link: TSparkArrowResultLink): self._pending_links.append((self.chunk_id, link)) self.chunk_id += 1 + self._schedule_downloads() + def _shutdown_manager(self): # Clear download handlers and shutdown the thread pool self._pending_links = [] self._download_tasks = [] self._thread_pool.shutdown(wait=False) + with self._download_condition: + self._download_condition.notify_all() diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 4617f7de..dea187ce 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -249,7 +249,7 @@ def __init__( self.chunk_id = chunk_id # Table state - self.table = None + self.table = self._create_empty_table() self.table_row_index = 0 # Initialize download manager @@ -273,24 +273,20 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": pyarrow.Table """ - if not self.table: - logger.debug("CloudFetchQueue: no more rows available") - # Return empty pyarrow table to cause retry of fetch - return self._create_empty_table() logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows)) results = self.table.slice(0, 0) partial_result_chunks = [results] - while num_rows > 0 and self.table: + while num_rows > 0 and self.table.num_rows > 0: + # Replace current table with the next table if we are at the end of the current table + if self.table_row_index == self.table.num_rows: + self.table = self._create_next_table() + self.table_row_index = 0 + # Get remaining of num_rows or the rest of the current table, whichever is smaller length = min(num_rows, self.table.num_rows - self.table_row_index) table_slice = self.table.slice(self.table_row_index, length) partial_result_chunks.append(table_slice) self.table_row_index += table_slice.num_rows - - # Replace current table with the next table if we are at the end of the current table - if self.table_row_index == self.table.num_rows: - self.table = self._create_next_table() - self.table_row_index = 0 num_rows -= table_slice.num_rows logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) @@ -304,12 +300,9 @@ def remaining_rows(self) -> "pyarrow.Table": pyarrow.Table """ - if not self.table: - # Return empty pyarrow table to cause retry of fetch - return self._create_empty_table() results = self.table.slice(0, 0) partial_result_chunks = [results] - while self.table: + while self.table.num_rows > 0: table_slice = self.table.slice( self.table_row_index, self.table.num_rows - self.table_row_index ) @@ -319,17 +312,11 @@ def remaining_rows(self) -> "pyarrow.Table": self.table_row_index = 0 return pyarrow.concat_tables(partial_result_chunks, use_threads=True) - def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]: + def _create_table_at_offset(self, offset: int) -> "pyarrow.Table": """Create next table at the given row offset""" # Create next table by retrieving the logical next downloaded file, or return None to signal end of queue downloaded_file = self.download_manager.get_next_downloaded_file(offset) - if not downloaded_file: - logger.debug( - "CloudFetchQueue: Cannot find downloaded file for row {}".format(offset) - ) - # None signals no more Arrow tables can be built from the remaining handlers if any remain - return None arrow_table = create_arrow_table_from_arrow_file( downloaded_file.file_bytes, self.description ) @@ -345,7 +332,7 @@ def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]: return arrow_table @abstractmethod - def _create_next_table(self) -> Union["pyarrow.Table", None]: + def _create_next_table(self) -> "pyarrow.Table": """Create next table by retrieving the logical next downloaded file.""" pass @@ -364,7 +351,7 @@ class ThriftCloudFetchQueue(CloudFetchQueue): def __init__( self, - schema_bytes, + schema_bytes: Optional[bytes], max_download_threads: int, ssl_options: SSLOptions, session_id_hex: Optional[str], @@ -398,6 +385,8 @@ def __init__( chunk_id=chunk_id, ) + self.num_links_downloaded = 0 + self.start_row_index = start_row_offset self.result_links = result_links or [] self.session_id_hex = session_id_hex @@ -421,20 +410,23 @@ def __init__( # Initialize table and position self.table = self._create_next_table() - def _create_next_table(self) -> Union["pyarrow.Table", None]: + def _create_next_table(self) -> "pyarrow.Table": + if self.num_links_downloaded >= len(self.result_links): + return self._create_empty_table() + logger.debug( "ThriftCloudFetchQueue: Trying to get downloaded file for row {}".format( self.start_row_index ) ) arrow_table = self._create_table_at_offset(self.start_row_index) - if arrow_table: - self.start_row_index += arrow_table.num_rows - logger.debug( - "ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format( - arrow_table.num_rows, self.start_row_index - ) + self.num_links_downloaded += 1 + self.start_row_index += arrow_table.num_rows + logger.debug( + "ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format( + arrow_table.num_rows, self.start_row_index ) + ) return arrow_table @@ -740,7 +732,6 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table": def convert_to_assigned_datatypes_in_column_table(column_table, description): - converted_column_table = [] for i, col in enumerate(column_table): if description[i][1] == "decimal": diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index faa8e2f9..31450e7f 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -36,27 +36,30 @@ def create_result_links(self, num_files: int, start_row_offset: int = 0): return result_links @staticmethod - def make_arrow_table(): - batch = [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] - n_cols = len(batch[0]) if batch else 0 + def make_arrow_table(num_rows: int = 4, num_cols: int = 4): + batch = [[i for i in range(num_cols)] for _ in range(num_rows)] + n_cols = len(batch[0]) if batch else num_cols schema = pyarrow.schema({"col%s" % i: pyarrow.uint32() for i in range(n_cols)}) cols = [[batch[row][col] for row in range(len(batch))] for col in range(n_cols)] return pyarrow.Table.from_pydict(dict(zip(schema.names, cols)), schema=schema) @staticmethod - def get_schema_bytes(): + def get_schema_bytes_and_description(): schema = pyarrow.schema({"col%s" % i: pyarrow.uint32() for i in range(4)}) + description = [ + ("col%s" % i, "int", None, None, None, None, None) for i in range(4) + ] sink = pyarrow.BufferOutputStream() writer = pyarrow.ipc.RecordBatchStreamWriter(sink, schema) writer.close() - return sink.getvalue().to_pybytes() + return sink.getvalue().to_pybytes(), description @patch( "databricks.sql.utils.ThriftCloudFetchQueue._create_next_table", return_value=[None, None], ) def test_initializer_adds_links(self, mock_create_next_table): - schema_bytes = MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() result_links = self.create_result_links(10) queue = utils.ThriftCloudFetchQueue( schema_bytes, @@ -66,14 +69,18 @@ def test_initializer_adds_links(self, mock_create_next_table): session_id_hex=Mock(), statement_id=Mock(), chunk_id=0, + description=description, ) - assert len(queue.download_manager._pending_links) == 10 - assert len(queue.download_manager._download_tasks) == 0 + assert ( + len(queue.download_manager._pending_links) + + len(queue.download_manager._download_tasks) + == 10 + ) mock_create_next_table.assert_called() def test_initializer_no_links_to_add(self): - schema_bytes = MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() result_links = [] queue = utils.ThriftCloudFetchQueue( schema_bytes, @@ -83,29 +90,11 @@ def test_initializer_no_links_to_add(self): session_id_hex=Mock(), statement_id=Mock(), chunk_id=0, + description=description, ) assert len(queue.download_manager._pending_links) == 0 assert len(queue.download_manager._download_tasks) == 0 - assert queue.table is None - - @patch( - "databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", - return_value=None, - ) - def test_create_next_table_no_download(self, mock_get_next_downloaded_file): - queue = utils.ThriftCloudFetchQueue( - MagicMock(), - result_links=[], - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) - - assert queue._create_next_table() is None - mock_get_next_downloaded_file.assert_called_with(0) @patch("databricks.sql.utils.create_arrow_table_from_arrow_file") @patch( @@ -116,10 +105,23 @@ def test_initializer_create_next_table_success( self, mock_get_next_downloaded_file, mock_create_arrow_table ): mock_create_arrow_table.return_value = self.make_arrow_table() - schema_bytes, description = MagicMock(), MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ), + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=4, + rowCount=4, + bytesNum=10, + ), + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -144,10 +146,17 @@ def test_initializer_create_next_table_success( @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_next_n_rows_0_rows(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() - schema_bytes, description = MagicMock(), MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -166,10 +175,17 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_next_n_rows_partial_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() - schema_bytes, description = MagicMock(), MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -189,10 +205,17 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_next_n_rows_more_than_one_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() - schema_bytes, description = MagicMock(), MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -216,11 +239,21 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): - mock_create_next_table.side_effect = [self.make_arrow_table(), None] - schema_bytes, description = MagicMock(), MagicMock() + mock_create_next_table.side_effect = [ + self.make_arrow_table(), + self.make_arrow_table(0), + ] + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -241,11 +274,19 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): return_value=None, ) def test_next_n_rows_empty_table(self, mock_create_next_table): - schema_bytes = self.get_schema_bytes() - description = MagicMock() + mock_create_next_table.side_effect = [self.make_arrow_table(0)] + + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -253,7 +294,8 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): statement_id=Mock(), chunk_id=0, ) - assert queue.table is None + + assert queue.table == self.make_arrow_table(0) result = queue.next_n_rows(100) mock_create_next_table.assert_called() @@ -261,11 +303,21 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): - mock_create_next_table.side_effect = [self.make_arrow_table(), None, 0] - schema_bytes, description = MagicMock(), MagicMock() + mock_create_next_table.side_effect = [ + self.make_arrow_table(), + self.make_arrow_table(0), + ] + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -283,11 +335,21 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): - mock_create_next_table.side_effect = [self.make_arrow_table(), None] - schema_bytes, description = MagicMock(), MagicMock() + mock_create_next_table.side_effect = [ + self.make_arrow_table(), + self.make_arrow_table(0), + ] + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -305,11 +367,21 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): - mock_create_next_table.side_effect = [self.make_arrow_table(), None] - schema_bytes, description = MagicMock(), MagicMock() + mock_create_next_table.side_effect = [ + self.make_arrow_table(), + self.make_arrow_table(0), + ] + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -332,12 +404,19 @@ def test_remaining_rows_multiple_tables_fully_returned( mock_create_next_table.side_effect = [ self.make_arrow_table(), self.make_arrow_table(), - None, + self.make_arrow_table(0), ] - schema_bytes, description = MagicMock(), MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -364,11 +443,19 @@ def test_remaining_rows_multiple_tables_fully_returned( return_value=None, ) def test_remaining_rows_empty_table(self, mock_create_next_table): - schema_bytes = self.get_schema_bytes() - description = MagicMock() + mock_create_next_table.side_effect = [self.make_arrow_table(0)] + + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -376,7 +463,7 @@ def test_remaining_rows_empty_table(self, mock_create_next_table): statement_id=Mock(), chunk_id=0, ) - assert queue.table is None + assert queue.table == self.make_arrow_table(0) result = queue.remaining_rows() assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index c514980e..dbcea900 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -86,11 +86,15 @@ def test_run_get_response_not_ok(self, mock_time): settings.use_proxy = False result_link = Mock(expiryTime=1001) - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=404, _content=b"1234"), - ): + # Create a mock response with 404 status + mock_response = create_response(status_code=404, _content=b"Not Found") + mock_response.raise_for_status = Mock( + side_effect=requests.exceptions.HTTPError("404") + ) + + with patch.object(http_client, "execute") as mock_execute: + mock_execute.return_value.__enter__.return_value = mock_response + d = downloader.ResultSetDownloadHandler( settings, result_link, diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index cbeae098..f0dcf529 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -27,6 +27,11 @@ import threading import time +try: + import pyarrow as pa +except ImportError: + pa = None + class TestJsonQueue: """Test suite for the JsonQueue class.""" @@ -199,6 +204,7 @@ def test_build_queue_json_array(self, json_manifest, sample_data): assert isinstance(queue, JsonQueue) assert queue.data_array == sample_data + @pytest.mark.skipif(pa is None, reason="pyarrow not installed") def test_build_queue_arrow_stream( self, arrow_manifest, ssl_options, mock_sea_client, description ): @@ -328,6 +334,7 @@ def test_convert_to_thrift_link_no_headers(self, sample_external_link_no_headers @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") @patch("databricks.sql.backend.sea.queue.logger") + @pytest.mark.skipif(pa is None, reason="pyarrow not installed") def test_init_with_valid_initial_link( self, mock_logger, @@ -357,6 +364,7 @@ def test_init_with_valid_initial_link( @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") @patch("databricks.sql.backend.sea.queue.logger") + @pytest.mark.skipif(pa is None, reason="pyarrow not installed") def test_init_no_initial_links( self, mock_logger, @@ -377,7 +385,7 @@ def test_init_no_initial_links( lz4_compressed=False, description=description, ) - assert queue.table is None + assert queue.table == pa.Table.from_pydict({}) @patch("databricks.sql.backend.sea.queue.logger") def test_create_next_table_success(self, mock_logger): @@ -481,6 +489,7 @@ def test_hybrid_disposition_with_attachment( @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") @patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None) + @pytest.mark.skipif(pa is None, reason="pyarrow not installed") def test_hybrid_disposition_with_external_links( self, mock_create_table, pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy