diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..33b5af334 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -1,66 +1,111 @@ +""" +Main script to run all SEA connector tests. + +This script imports and runs all the individual test modules and displays +a summary of test results with visual indicators. +""" import os import sys import logging -from databricks.sql.client import Connection +import importlib.util +from typing import Dict, Callable, List, Tuple -logging.basicConfig(level=logging.DEBUG) +# Configure logging +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -def test_sea_session(): - """ - Test opening and closing a SEA session using the connector. +# Define test modules and their main test functions +TEST_MODULES = [ + "test_sea_session", + "test_sea_sync_query", + "test_sea_async_query", + "test_sea_metadata", +] + +def load_test_function(module_name: str) -> Callable: + """Load a test function from a module.""" + module_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "tests", + f"{module_name}.py" + ) + + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Get the main test function (assuming it starts with "test_") + for name in dir(module): + if name.startswith("test_") and callable(getattr(module, name)): + # For sync and async query modules, we want the main function that runs both tests + if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": + return getattr(module, name) - This function connects to a Databricks SQL endpoint using the SEA backend, - opens a session, and then closes it. + # Fallback to the first test function found + for name in dir(module): + if name.startswith("test_") and callable(getattr(module, name)): + return getattr(module, name) - Required environment variables: - - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - - DATABRICKS_TOKEN: Personal access token for authentication - """ + raise ValueError(f"No test function found in module {module_name}") - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") +def run_tests() -> List[Tuple[str, bool]]: + """Run all tests and return results.""" + results = [] - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") - sys.exit(1) + for module_name in TEST_MODULES: + try: + test_func = load_test_function(module_name) + logger.info(f"\n{'=' * 50}") + logger.info(f"Running test: {module_name}") + logger.info(f"{'-' * 50}") + + success = test_func() + results.append((module_name, success)) + + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"Test {module_name}: {status}") + + except Exception as e: + logger.error(f"Error loading or running test {module_name}: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + results.append((module_name, False)) - logger.info(f"Connecting to {server_hostname}") - logger.info(f"HTTP Path: {http_path}") - if catalog: - logger.info(f"Using catalog: {catalog}") + return results + +def print_summary(results: List[Tuple[str, bool]]) -> None: + """Print a summary of test results.""" + logger.info(f"\n{'=' * 50}") + logger.info("TEST SUMMARY") + logger.info(f"{'-' * 50}") - try: - logger.info("Creating connection with SEA backend...") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent - ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") - logger.info(f"backend type: {type(connection.session.backend)}") - - # Close the connection - logger.info("Closing the SEA session...") - connection.close() - logger.info("Successfully closed SEA session") - - except Exception as e: - logger.error(f"Error testing SEA session: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - sys.exit(1) + passed = sum(1 for _, success in results if success) + total = len(results) - logger.info("SEA session test completed successfully") + for module_name, success in results: + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"{status} - {module_name}") + + logger.info(f"{'-' * 50}") + logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") + logger.info(f"{'=' * 50}") if __name__ == "__main__": - test_sea_session() + # Check if required environment variables are set + required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"] + missing_vars = [var for var in required_vars if not os.environ.get(var)] + + if missing_vars: + logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") + logger.error("Please set these variables before running the tests.") + sys.exit(1) + + # Run all tests + results = run_tests() + + # Print summary + print_summary(results) + + # Exit with appropriate status code + all_passed = all(success for _, success in results) + sys.exit(0 if all_passed else 1) \ No newline at end of file diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py new file mode 100644 index 000000000..5e1a8a58b --- /dev/null +++ b/examples/experimental/tests/__init__.py @@ -0,0 +1 @@ +# This file makes the tests directory a Python package \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py new file mode 100644 index 000000000..a4f3702f9 --- /dev/null +++ b/examples/experimental/tests/test_sea_async_query.py @@ -0,0 +1,165 @@ +""" +Test for SEA asynchronous query execution functionality. +""" +import os +import sys +import logging +import time +from databricks.sql.client import Connection +from databricks.sql.backend.types import CommandState + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_async_query_with_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info("Creating connection for asynchronous query execution with cloud fetch enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info("Executing asynchronous query with cloud fetch: SELECT 1 as test_value") + cursor.execute_async("SELECT 1 as test_value") + logger.info("Asynchronous query submitted successfully with cloud fetch enabled") + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info("Successfully retrieved asynchronous query results with cloud fetch enabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_without_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info("Creating connection for asynchronous query execution with cloud fetch disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info("Executing asynchronous query without cloud fetch: SELECT 1 as test_value") + cursor.execute_async("SELECT 1 as test_value") + logger.info("Asynchronous query submitted successfully with cloud fetch disabled") + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info("Successfully retrieved asynchronous query results with cloud fetch disabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_exec(): + """ + Run both asynchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() + logger.info(f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") + + without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() + logger.info(f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_async_query_exec() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py new file mode 100644 index 000000000..ba760b61a --- /dev/null +++ b/examples/experimental/tests/test_sea_metadata.py @@ -0,0 +1,91 @@ +""" +Test for SEA metadata functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_metadata(): + """ + Test metadata operations using the SEA backend. + + This function connects to a Databricks SQL endpoint using the SEA backend, + and executes metadata operations like catalogs(), schemas(), tables(), and columns(). + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + if not catalog: + logger.error("DATABRICKS_CATALOG environment variable is required for metadata tests.") + return False + + try: + # Create connection + logger.info("Creating connection for metadata operations") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Test catalogs + cursor = connection.cursor() + logger.info("Fetching catalogs...") + cursor.catalogs() + logger.info("Successfully fetched catalogs") + + # Test schemas + logger.info(f"Fetching schemas for catalog '{catalog}'...") + cursor.schemas(catalog_name=catalog) + logger.info("Successfully fetched schemas") + + # Test tables + logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") + cursor.tables(catalog_name=catalog, schema_name="default") + logger.info("Successfully fetched tables") + + # Test columns for a specific table + # Using a common table that should exist in most environments + logger.info(f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'...") + cursor.columns(catalog_name=catalog, schema_name="default", table_name="information_schema") + logger.info("Successfully fetched columns") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA metadata test: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_metadata() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py new file mode 100644 index 000000000..c0f6817da --- /dev/null +++ b/examples/experimental/tests/test_sea_session.py @@ -0,0 +1,70 @@ +""" +Test for SEA session management functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"Backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_session() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py new file mode 100644 index 000000000..4879e587a --- /dev/null +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -0,0 +1,143 @@ +""" +Test for SEA synchronous query execution functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_sync_query_with_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info("Creating connection for synchronous query execution with cloud fetch enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info("Executing synchronous query with cloud fetch: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch enabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_without_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info("Creating connection for synchronous query execution with cloud fetch disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info("Executing synchronous query without cloud fetch: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch disabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_exec(): + """ + Run both synchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() + logger.info(f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") + + without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() + logger.info(f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_sync_query_exec() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 20b059fa7..bbca4c502 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,8 +16,6 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState -from databricks.sql.utils import ExecuteResponse -from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING @@ -88,34 +86,6 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: - """ - Executes a SQL command or query within the specified session. - - This method sends a SQL command to the server for execution and handles - the response. It can operate in both synchronous and asynchronous modes. - - Args: - operation: The SQL command or query to execute - session_id: The session identifier in which to execute the command - max_rows: Maximum number of rows to fetch in a single fetch batch - max_bytes: Maximum number of bytes to fetch in a single fetch batch - lz4_compression: Whether to use LZ4 compression for result data - cursor: The cursor object that will handle the results - use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets - parameters: List of parameters to bind to the query - async_op: Whether to execute the command asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - If async_op is False, returns a ResultSet object containing the - query results and metadata. If async_op is True, returns None and the - results must be fetched later using get_execution_result(). - - Raises: - ValueError: If the session ID is invalid - OperationalError: If there's an error executing the command - ServerOperationError: If the server encounters an error during execution - """ pass @abstractmethod diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 000000000..7f48b6179 --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,143 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + +from databricks.sql.result_set import SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Create a filtered version of the result set + filtered_response = result_set._response.copy() + + # If there's a result with rows, filter them + if ( + "result" in filtered_response + and "data_array" in filtered_response["result"] + ): + rows = filtered_response["result"]["data_array"] + filtered_rows = [row for row in rows if filter_func(row)] + filtered_response["result"]["data_array"] = filtered_rows + + # Update row count if present + if "row_count" in filtered_response["result"]: + filtered_response["result"]["row_count"] = len(filtered_rows) + + # Create a new result set with the filtered data + return SeaResultSet( + connection=result_set.connection, + sea_response=filtered_response, + sea_client=result_set.backend, + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + ) + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is typically in the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=False + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..10100e86e 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,34 @@ import logging import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +import uuid +import time +from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +66,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -274,41 +288,221 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else None + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows if max_rows > 0 else None, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + return SeaResultSet( + connection=cursor.connection, + sea_response=response_data, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == @@ -319,9 +513,22 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -331,9 +538,30 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -345,9 +573,43 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -359,6 +621,33 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index c9310d367..b7c8bd399 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,19 +4,49 @@ This package contains data models for SEA API requests and responses. """ +from databricks.sql.backend.sea.models.base import ( + ServiceError, + StatementStatus, + ExternalLink, + ResultData, + ColumnInfo, + ResultManifest, +) + from databricks.sql.backend.sea.models.requests import ( + StatementParameter, + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) __all__ = [ + # Base models + "ServiceError", + "StatementStatus", + "ExternalLink", + "ResultData", + "ColumnInfo", + "ResultManifest", # Request models + "StatementParameter", + "ExecuteStatementRequest", + "GetStatementRequest", + "CancelStatementRequest", + "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models + "ExecuteStatementResponse", + "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py new file mode 100644 index 000000000..671f7be13 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/base.py @@ -0,0 +1,68 @@ +""" +Base models for the SEA (Statement Execution API) backend. + +These models define the common structures used in SEA API requests and responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState + + +@dataclass +class ServiceError: + """Error information returned by the SEA API.""" + + message: str + error_code: Optional[str] = None + + +@dataclass +class StatementStatus: + """Status information for a statement execution.""" + + state: CommandState + error: Optional[ServiceError] = None + sql_state: Optional[str] = None + + +@dataclass +class ExternalLink: + """External link information for result data.""" + + external_link: str + expiration: str + chunk_index: int + + +@dataclass +class ResultData: + """Result data from a statement execution.""" + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + + +@dataclass +class ColumnInfo: + """Information about a column in the result set.""" + + name: str + type_name: str + type_text: str + nullable: bool = True + precision: Optional[int] = None + scale: Optional[int] = None + ordinal_position: Optional[int] = None + + +@dataclass +class ResultManifest: + """Manifest information for a result set.""" + + schema: List[ColumnInfo] + total_row_count: int + total_byte_count: int + truncated: bool = False + chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 7966cb502..e26b32e0a 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,5 +1,107 @@ -from typing import Dict, Any, Optional -from dataclasses import dataclass +""" +Request models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API requests. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + + +@dataclass +class StatementParameter: + """Parameter for a SQL statement.""" + + name: str + value: Optional[str] = None + type: Optional[str] = None + + +@dataclass +class ExecuteStatementRequest: + """Request to execute a SQL statement.""" + + warehouse_id: str + statement: str + session_id: str + disposition: str = "EXTERNAL_LINKS" + format: str = "JSON_ARRAY" + wait_timeout: str = "10s" + on_wait_timeout: str = "CONTINUE" + row_limit: Optional[int] = None + parameters: Optional[List[StatementParameter]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + result_compression: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "warehouse_id": self.warehouse_id, + "session_id": self.session_id, + "statement": self.statement, + "disposition": self.disposition, + "format": self.format, + "wait_timeout": self.wait_timeout, + "on_wait_timeout": self.on_wait_timeout, + } + + if self.row_limit is not None and self.row_limit > 0: + result["row_limit"] = self.row_limit + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + if self.result_compression: + result["result_compression"] = self.result_compression + + if self.parameters: + result["parameters"] = [ + { + "name": param.name, + **({"value": param.value} if param.value is not None else {}), + **({"type": param.type} if param.type is not None else {}), + } + for param in self.parameters + ] + + return result + + +@dataclass +class GetStatementRequest: + """Request to get information about a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CancelStatementRequest: + """Request to cancel a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CloseStatementRequest: + """Request to close a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1bb54590f..d70459b9f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,5 +1,96 @@ -from typing import Dict, Any -from dataclasses import dataclass +""" +Response models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState +from databricks.sql.backend.sea.models.base import ( + StatementStatus, + ResultManifest, + ResultData, + ServiceError, +) + + +@dataclass +class ExecuteStatementResponse: + """Response from executing a SQL statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": + """Create an ExecuteStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) + + +@dataclass +class GetStatementResponse: + """Response from getting information about a statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": + """Create a GetStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) @dataclass diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index de388f1d4..e03d6f235 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -5,11 +5,10 @@ import time import uuid import threading -from typing import List, Optional, Union, Any, TYPE_CHECKING +from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( @@ -17,8 +16,9 @@ SessionId, CommandId, BackendType, + guid_to_hex_id, + ExecuteResponse, ) -from databricks.sql.backend.utils import guid_to_hex_id try: import pyarrow @@ -42,7 +42,7 @@ ) from databricks.sql.utils import ( - ExecuteResponse, + ResultSetQueueFactory, _bound, RequestErrorInfo, NoRetryReason, @@ -53,6 +53,7 @@ ) from databricks.sql.types import SSLOptions from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.result_set import ResultSet, ThriftResultSet logger = logging.getLogger(__name__) @@ -351,7 +352,6 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ - # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -797,23 +797,27 @@ def _results_message_to_execute_response(self, resp, operation_state): command_id = CommandId.from_thrift_handle(resp.operationHandle) - return ExecuteResponse( - arrow_queue=arrow_queue_opt, - status=CommandState.from_thrift_state(operation_state), - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, - command_id=command_id, - description=description, - arrow_schema_bytes=schema_bytes, + status = CommandState.from_thrift_state(operation_state) + if status is None: + raise ValueError(f"Invalid operation state: {operation_state}") + + return ( + ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_more_rows=has_more_rows, + results_queue=arrow_queue_opt, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + ), + schema_bytes, ) def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") @@ -863,15 +867,14 @@ def get_execution_result( ) execute_response = ExecuteResponse( - arrow_queue=queue, - status=CommandState.from_thrift_state(resp.status), - has_been_closed_server_side=False, + command_id=command_id, + status=resp.status, + description=description, has_more_rows=has_more_rows, + results_queue=queue, + has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_id=command_id, - description=description, - arrow_schema_bytes=schema_bytes, ) return ThriftResultSet( @@ -881,6 +884,7 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=schema_bytes, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -909,10 +913,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) - state = CommandState.from_thrift_state(operation_state) - if state is None: - raise ValueError(f"Unknown command state: {operation_state}") - return state + return CommandState.from_thrift_state(operation_state) @staticmethod def _check_direct_results_for_error(t_spark_direct_results): @@ -947,8 +948,6 @@ def execute_command( async_op=False, enforce_embedded_schema_correctness=False, ) -> Union["ResultSet", None]: - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -995,7 +994,9 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1004,6 +1005,7 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_catalogs( @@ -1013,8 +1015,6 @@ def get_catalogs( max_bytes: int, cursor: "Cursor", ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1027,7 +1027,9 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1036,6 +1038,7 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_schemas( @@ -1047,8 +1050,6 @@ def get_schemas( catalog_name=None, schema_name=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1063,7 +1064,9 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1072,6 +1075,7 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_tables( @@ -1085,8 +1089,6 @@ def get_tables( table_name=None, table_types=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1103,7 +1105,9 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1112,6 +1116,7 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_columns( @@ -1125,8 +1130,6 @@ def get_columns( table_name=None, column_name=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1143,7 +1146,9 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1152,6 +1157,7 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def _handle_execute_response(self, resp, cursor): @@ -1165,7 +1171,12 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - return self._results_message_to_execute_response(resp, final_operation_state) + ( + execute_response, + arrow_schema_bytes, + ) = self._results_message_to_execute_response(resp, final_operation_state) + execute_response.command_id = command_id + return execute_response, arrow_schema_bytes def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -1226,7 +1237,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + logger.debug("Cancelling command {}".format(command_id.guid)) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..5bf02e0ea 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,28 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + + Args: + state: SEA state string + + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + class BackendType(Enum): """ @@ -285,28 +308,6 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count - def __str__(self) -> str: - """ - Return a string representation of the CommandId. - - For SEA backend, returns the guid. - For Thrift backend, returns a format like "guid|secret". - - Returns: - A string representation of the command ID - """ - - if self.backend_type == BackendType.SEA: - return str(self.guid) - elif self.backend_type == BackendType.THRIFT: - secret_hex = ( - guid_to_hex_id(self.secret) - if isinstance(self.secret, bytes) - else str(self.secret) - ) - return f"{self.to_hex_guid()}|{secret_hex}" - return str(self.guid) - @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -318,7 +319,6 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ - if operation_handle is None: return None @@ -394,3 +394,19 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9f7c060a7..e145e4e58 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -24,7 +24,6 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.utils import ( - ExecuteResponse, ParamEscaper, inject_parameters, transform_paramstyle, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a0d8d3579..fc8595839 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,26 +1,23 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Any, Union, TYPE_CHECKING +from typing import List, Optional, Any, Union, Tuple, TYPE_CHECKING import logging import time import pandas -from databricks.sql.backend.types import CommandId, CommandState - try: import pyarrow except ImportError: pyarrow = None if TYPE_CHECKING: - from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection - from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue +from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -34,32 +31,31 @@ class ResultSet(ABC): def __init__( self, - connection: "Connection", - backend: "DatabricksClient", - command_id: CommandId, - op_state: Optional[CommandState], - has_been_closed_server_side: bool, + connection, + backend, arraysize: int, buffer_size_bytes: int, + command_id=None, + status=None, + has_been_closed_server_side: bool = False, + has_more_rows: bool = False, + results_queue=None, + description=None, + is_staging_operation: bool = False, ): - """ - A ResultSet manages the results of a single command. - - :param connection: The parent connection that was used to execute this command - :param backend: The specialised backend client to be invoked in the fetch phase - :param execute_response: A `ExecuteResponse` class returned by a command execution - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - amount :param arraysize: The max number of rows to fetch at a time (PEP-249) - """ - self.command_id = command_id - self.op_state = op_state - self.has_been_closed_server_side = has_been_closed_server_side + """Initialize the base ResultSet with common properties.""" self.connection = connection - self.backend = backend + self.backend = backend # Store the backend client directly self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 - self.description = None + self.description = description + self.command_id = command_id + self.status = status + self.has_been_closed_server_side = has_been_closed_server_side + self._has_more_rows = has_more_rows + self.results = results_queue + self._is_staging_operation = is_staging_operation def __iter__(self): while True: @@ -74,10 +70,9 @@ def rownumber(self): return self._next_row_index @property - @abstractmethod def is_staging_operation(self) -> bool: """Whether this result set represents a staging operation.""" - pass + return self._is_staging_operation # Define abstract methods that concrete implementations must implement @abstractmethod @@ -101,12 +96,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> "pyarrow.Table": + def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" pass @@ -119,7 +114,7 @@ def close(self) -> None: """ try: if ( - self.op_state != CommandState.CLOSED + self.status != CommandState.CLOSED and not self.has_been_closed_server_side and self.connection.open ): @@ -129,7 +124,7 @@ def close(self) -> None: logger.info("Operation was canceled by a prior request") finally: self.has_been_closed_server_side = True - self.op_state = CommandState.CLOSED + self.status = CommandState.CLOSED class ThriftResultSet(ResultSet): @@ -138,11 +133,12 @@ class ThriftResultSet(ResultSet): def __init__( self, connection: "Connection", - execute_response: ExecuteResponse, + execute_response: "ExecuteResponse", thrift_client: "ThriftDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, + arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -154,37 +150,33 @@ def __init__( buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch use_cloud_fetch: Whether to use cloud fetch for retrieving results + arrow_schema_bytes: Arrow schema bytes for the result set """ - super().__init__( - connection, - thrift_client, - execute_response.command_id, - execute_response.status, - execute_response.has_been_closed_server_side, - arraysize, - buffer_size_bytes, - ) - # Initialize ThriftResultSet-specific attributes - self.has_been_closed_server_side = execute_response.has_been_closed_server_side - self.has_more_rows = execute_response.has_more_rows - self.lz4_compressed = execute_response.lz4_compressed - self.description = execute_response.description - self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._arrow_schema_bytes = arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch - self._is_staging_operation = execute_response.is_staging_operation + self.lz4_compressed = execute_response.lz4_compressed - # Initialize results queue - if execute_response.arrow_queue: - # In this case the server has taken the fast path and returned an initial batch of - # results - self.results = execute_response.arrow_queue - else: - # In this case, there are results waiting on the server so we fetch now for simplicity + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=thrift_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + ) + + # Initialize results queue if not provided + if not self.results: self._fill_results_buffer() def _fill_results_buffer(self): - # At initialization or if the server does not have cloud fetch result links available results, has_more_rows = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, @@ -196,7 +188,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self._has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -248,7 +240,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": + def merge_columnar(self, result1, result2): """ Function to merge / combining the columnar results into a single result :param result1: @@ -280,7 +272,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -305,7 +297,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -320,7 +312,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -346,7 +338,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -389,24 +381,110 @@ def fetchmany(self, size: int) -> List[Row]: else: return self._convert_arrow_table(self.fetchmany_arrow(size)) - @property - def is_staging_operation(self) -> bool: - """Whether this result set represents a staging operation.""" - return self._is_staging_operation - @staticmethod - def _get_schema_description(table_schema_message): +class SeaResultSet(ResultSet): + """ResultSet implementation for SEA backend.""" + + def __init__( + self, + connection, + sea_client, + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + execute_response=None, + sea_response=None, + ): """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) """ + # Handle both initialization styles + if execute_response is not None: + # New style with ExecuteResponse + command_id = execute_response.command_id + status = execute_response.status + has_been_closed_server_side = execute_response.has_been_closed_server_side + has_more_rows = execute_response.has_more_rows + results_queue = execute_response.results_queue + description = execute_response.description + is_staging_operation = execute_response.is_staging_operation + self._response = getattr(execute_response, "sea_response", {}) + self.statement_id = command_id.to_sea_statement_id() if command_id else None + elif sea_response is not None: + # Legacy style with direct sea_response + self._response = sea_response + # Extract values from sea_response + command_id = CommandId.from_sea_statement_id( + sea_response.get("statement_id", "") + ) + self.statement_id = sea_response.get("statement_id", "") + + # Extract status + status_data = sea_response.get("status", {}) + status = CommandState.from_sea_state(status_data.get("state", "PENDING")) + + # Set defaults for other fields + has_been_closed_server_side = False + has_more_rows = False + results_queue = None + description = None + is_staging_operation = False + else: + raise ValueError("Either execute_response or sea_response must be provided") - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=status, + has_been_closed_server_side=has_been_closed_server_side, + has_more_rows=has_more_rows, + results_queue=results_queue, + description=description, + is_staging_operation=is_staging_operation, + ) - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 7c33d9b2d..76aec4675 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -10,7 +10,7 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId +from databricks.sql.backend.types import SessionId, BackendType logger = logging.getLogger(__name__) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 2622b1172..edb13ef6d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -349,13 +349,6 @@ def _create_empty_table(self) -> "pyarrow.Table": return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) -ExecuteResponse = namedtuple( - "ExecuteResponse", - "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " - "command_id arrow_queue arrow_schema_bytes", -) - - def _bound(min_x, max_x, x): """Bound x by [min_x, max_x] diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 1a7950870..090ec255e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -26,7 +26,7 @@ from databricks.sql.types import Row from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.backend.types import CommandId, CommandState -from databricks.sql.utils import ExecuteResponse +from databricks.sql.backend.types import ExecuteResponse from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite @@ -121,10 +121,10 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Verify initial state self.assertEqual(real_result_set.has_been_closed_server_side, closed) - expected_op_state = ( + expected_status = ( CommandState.CLOSED if closed else CommandState.SUCCEEDED ) - self.assertEqual(real_result_set.op_state, expected_op_state) + self.assertEqual(real_result_set.status, expected_status) # Mock execute_command to return our real result set cursor.backend.execute_command = Mock(return_value=real_result_set) @@ -146,8 +146,8 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # 1. has_been_closed_server_side should always be True after close() self.assertTrue(real_result_set.has_been_closed_server_side) - # 2. op_state should always be CLOSED after close() - self.assertEqual(real_result_set.op_state, CommandState.CLOSED) + # 2. status should always be CLOSED after close() + self.assertEqual(real_result_set.status, CommandState.CLOSED) # 3. Backend close_command should be called appropriately if not closed: @@ -556,7 +556,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): self.assertEqual(instance.close_session.call_count, 0) cursor.close() - @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) + @patch("%s.backend.types.ExecuteResponse" % PACKAGE_NAME) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_staging_operation_response_is_handled( @@ -678,10 +678,10 @@ def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" result_set = client.ThriftResultSet.__new__(client.ThriftResultSet) result_set.backend = Mock() - result_set.backend.CLOSED_OP_STATE = "CLOSED" + result_set.backend.CLOSED_OP_STATE = CommandState.CLOSED result_set.connection = Mock() result_set.connection.open = True - result_set.op_state = "RUNNING" + result_set.status = CommandState.RUNNING result_set.has_been_closed_server_side = False result_set.command_id = Mock() @@ -695,7 +695,7 @@ def __init__(self): try: try: if ( - result_set.op_state != result_set.backend.CLOSED_OP_STATE + result_set.status != result_set.backend.CLOSED_OP_STATE and not result_set.has_been_closed_server_side and result_set.connection.open ): @@ -705,7 +705,7 @@ def __init__(self): pass finally: result_set.has_been_closed_server_side = True - result_set.op_state = result_set.backend.CLOSED_OP_STATE + result_set.status = result_set.backend.CLOSED_OP_STATE result_set.backend.close_command.assert_called_once_with( result_set.command_id @@ -713,7 +713,7 @@ def __init__(self): assert result_set.has_been_closed_server_side is True - assert result_set.op_state == result_set.backend.CLOSED_OP_STATE + assert result_set.status == result_set.backend.CLOSED_OP_STATE finally: pass diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 030510a64..7249a59e6 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -8,7 +8,8 @@ pa = None import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.result_set import ThriftResultSet @@ -42,14 +43,13 @@ def make_dummy_result_set_from_initial_results(initial_results): rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), lz4_compressed=Mock(), - command_id=None, - arrow_queue=arrow_queue, - arrow_schema_bytes=schema.serialize().to_pybytes(), + results_queue=arrow_queue, is_staging_operation=False, ), thrift_client=None, @@ -88,6 +88,7 @@ def fetch_results( rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=False, has_more_rows=True, @@ -96,9 +97,7 @@ def fetch_results( for col_id in range(num_cols) ], lz4_compressed=Mock(), - command_id=None, - arrow_queue=None, - arrow_schema_bytes=None, + results_queue=None, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index b302c00da..7e025cf82 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -10,7 +10,8 @@ import pytest import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py new file mode 100644 index 000000000..e8eb2a757 --- /dev/null +++ b/tests/unit/test_result_set_filter.py @@ -0,0 +1,246 @@ +""" +Tests for the ResultSetFilter class. + +This module contains tests for the ResultSetFilter class, which provides +filtering capabilities for result sets returned by different backends. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestResultSetFilter: + """Test suite for the ResultSetFilter class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response_with_tables(self): + """Create a sample SEA response with table data based on the server schema.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 7, + "columns": [ + {"name": "namespace", "type_text": "STRING", "position": 0}, + {"name": "tableName", "type_text": "STRING", "position": 1}, + {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, + {"name": "information", "type_text": "STRING", "position": 3}, + {"name": "catalogName", "type_text": "STRING", "position": 4}, + {"name": "tableType", "type_text": "STRING", "position": 5}, + {"name": "remarks", "type_text": "STRING", "position": 6}, + ], + }, + "total_row_count": 4, + }, + "result": { + "data_array": [ + ["schema1", "table1", False, None, "catalog1", "TABLE", None], + ["schema1", "table2", False, None, "catalog1", "VIEW", None], + [ + "schema1", + "table3", + False, + None, + "catalog1", + "SYSTEM TABLE", + None, + ], + ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], + ] + }, + } + + @pytest.fixture + def sea_result_set_with_tables( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Create a SeaResultSet with table data.""" + # Create a deep copy of the response to avoid test interference + import copy + + sea_response_copy = copy.deepcopy(sea_response_with_tables) + + return SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy, + buffer_size_bytes=1000, + arraysize=100, + ) + + def test_filter_tables_by_type_default(self, sea_result_set_with_tables): + """Test filtering tables by type with default types.""" + # Default types are TABLE, VIEW, SYSTEM TABLE + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables + ) + + # Verify that only the default types are included + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): + """Test filtering tables by type with custom types.""" + # Filter for only TABLE and EXTERNAL + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] + ) + + # Verify that only the specified types are included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "EXTERNAL" in table_types + assert "VIEW" not in table_types + assert "SYSTEM TABLE" not in table_types + + def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): + """Test that table type filtering is case-insensitive.""" + # Filter for lowercase "table" and "view" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["table", "view"] + ) + + # Verify that the matching types are included despite case differences + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" not in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): + """Test filtering tables with an empty type list (should use defaults).""" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=[] + ) + + # Verify that default types are used + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_by_column_values(self, sea_result_set_with_tables): + """Test filtering by values in a specific column.""" + # Filter by namespace in column index 0 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] + ) + + # All rows have schema1 in namespace, so all should be included + assert len(filtered_result._response["result"]["data_array"]) == 4 + + # Filter by table name in column index 1 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, + column_index=1, + allowed_values=["table1", "table3"], + ) + + # Only rows with table1 or table3 should be included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_names = [ + row[1] for row in filtered_result._response["result"]["data_array"] + ] + assert "table1" in table_names + assert "table3" in table_names + assert "table2" not in table_names + assert "table4" not in table_names + + def test_filter_by_column_values_case_sensitive( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Test case-sensitive filtering by column values.""" + import copy + + # Create a fresh result set for the first test + sea_response_copy1 = copy.deepcopy(sea_response_with_tables) + result_set1 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy1, + buffer_size_bytes=1000, + arraysize=100, + ) + + # First test: Case-sensitive filtering with lowercase values (should find no matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set1, + column_index=5, # tableType column + allowed_values=["table", "view"], # lowercase + case_sensitive=True, + ) + + # Verify no matches with lowercase values + assert len(filtered_result._response["result"]["data_array"]) == 0 + + # Create a fresh result set for the second test + sea_response_copy2 = copy.deepcopy(sea_response_with_tables) + result_set2 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy2, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Second test: Case-sensitive filtering with correct case (should find matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set2, + column_index=5, # tableType column + allowed_values=["TABLE", "VIEW"], # correct case + case_sensitive=True, + ) + + # Verify matches with correct case + assert len(filtered_result._response["result"]["data_array"]) == 2 + + # Extract the table types from the filtered results + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + + def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): + """Test filtering with a column index that's out of bounds.""" + # Filter by column index 10 (out of bounds) + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=10, allowed_values=["value"] + ) + + # No rows should match + assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..c2078f731 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,650 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, } - assert set(allowed_configs) == expected_keys + mock_http_client._make_request.return_value = execute_response - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + # Tests for metadata operations + + def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): + """Test getting catalogs.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["lz4_compression"] is False + assert kwargs["use_cloud_fetch"] is False + assert kwargs["parameters"] == [] + assert kwargs["async_op"] is False + assert kwargs["enforce_embedded_schema_correctness"] is False + + def test_get_schemas_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_with_catalog_and_schema_pattern( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with catalog and schema pattern.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog and schema pattern + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting schemas without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + ) + + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_all_catalogs( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables from all catalogs using wildcard.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_schema_and_table_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with schema and table patterns.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog, schema, and table patterns + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting tables without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_with_all_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with all patterns specified.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with all patterns + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting columns without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_columns" in str(excinfo.value) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py new file mode 100644 index 000000000..f666fd613 --- /dev/null +++ b/tests/unit/test_sea_result_set.py @@ -0,0 +1,275 @@ +""" +Tests for the SeaResultSet class. + +This module contains tests for the SeaResultSet class, which implements +the result set functionality for the SEA (Statement Execution API) backend. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestSeaResultSet: + """Test suite for the SeaResultSet class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response(self): + """Create a sample SEA response.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + + @pytest.fixture + def execute_response(self): + """Create a sample execute response.""" + mock_response = Mock() + mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") + mock_response.status = CommandState.SUCCEEDED + mock_response.has_been_closed_server_side = False + mock_response.has_more_rows = False + mock_response.results_queue = None + mock_response.description = [ + ("test_value", "INT", None, None, None, None, None) + ] + mock_response.is_staging_operation = False + mock_response.sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "result": {"data_array": [["1"]]}, + } + return mock_response + + def test_init_with_sea_response( + self, mock_connection, mock_sea_client, sea_response + ): + """Test initializing SeaResultSet with a SEA response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == sea_response + + def test_init_with_execute_response( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with an execute response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + execute_response=execute_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == execute_response.sea_response + + def test_init_with_no_response(self, mock_connection, mock_sea_client): + """Test that initialization fails when neither response type is provided.""" + with pytest.raises(ValueError) as excinfo: + SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + assert "Either execute_response or sea_response must be provided" in str( + excinfo.value + ) + + def test_close(self, mock_connection, mock_sea_client, sea_response): + """Test closing a result set.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was called + mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_already_closed_server_side( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set that has already been closed server-side.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_connection_closed( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set when the connection is closed.""" + mock_connection.open = False + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_unimplemented_methods( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that unimplemented methods raise NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() + + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() + + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) + + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass + + def test_fill_results_buffer_not_implemented( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that _fill_results_buffer raises NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set._fill_results_buffer() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 92de8d8fd..fef070362 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,6 +2,11 @@ from unittest.mock import patch, Mock import gc +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + THandleIdentifier, +) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57a2a61e3..b8de970db 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,11 +619,18 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 + + # Create a valid operation status + op_status = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=Mock(), + operationStatus=op_status, resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -644,7 +651,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -878,11 +885,12 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - results_message_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) + self.assertEqual( - results_message_response.status, + execute_response.status, CommandState.SUCCEEDED, ) @@ -915,7 +923,9 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -943,15 +953,21 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( + execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) - self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -971,6 +987,12 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -1018,7 +1040,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) @@ -1150,7 +1172,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1184,7 +1206,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1215,7 +1237,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1255,7 +1277,7 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1299,7 +1321,7 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1645,7 +1667,9 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) # Create a mock response with a real operation handle mock_resp = Mock() @@ -2204,7 +2228,8 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", + return_value=(Mock(), Mock()), ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy