From 138c2aebab99659d1c970fa70e4a431fec78aae2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:24:22 +0000 Subject: [PATCH 01/66] [squash from exec-sea] bring over execution phase changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 ++- .../sql/backend/databricks_client.py | 30 - src/databricks/sql/backend/filters.py | 143 ++++ src/databricks/sql/backend/sea/backend.py | 360 ++++++++- .../sql/backend/sea/models/__init__.py | 30 + src/databricks/sql/backend/sea/models/base.py | 68 ++ .../sql/backend/sea/models/requests.py | 110 ++- .../sql/backend/sea/models/responses.py | 95 ++- src/databricks/sql/backend/thrift_backend.py | 99 ++- src/databricks/sql/backend/types.py | 64 +- src/databricks/sql/client.py | 1 - src/databricks/sql/result_set.py | 234 ++++-- src/databricks/sql/session.py | 2 +- src/databricks/sql/utils.py | 7 - tests/unit/test_client.py | 22 +- tests/unit/test_fetches.py | 13 +- tests/unit/test_fetches_bench.py | 3 +- tests/unit/test_result_set_filter.py | 246 ++++++ tests/unit/test_sea_backend.py | 755 +++++++++++++++--- tests/unit/test_sea_result_set.py | 275 +++++++ tests/unit/test_session.py | 5 + tests/unit/test_thrift_backend.py | 55 +- 22 files changed, 2375 insertions(+), 366 deletions(-) create mode 100644 src/databricks/sql/backend/filters.py create mode 100644 src/databricks/sql/backend/sea/models/base.py create mode 100644 tests/unit/test_result_set_filter.py create mode 100644 tests/unit/test_sea_result_set.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..87b62efea 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,34 +6,122 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) + +def test_sea_query_exec(): + """ + Test executing a query using the SEA backend with result compression. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with result compression enabled and disabled, + and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + sys.exit(1) + + try: + # Test with compression enabled + logger.info("Creating connection with LZ4 compression enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, # Enable cloud fetch to use compression + enable_query_result_lz4_compression=True, # Enable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"backend type: {type(connection.session.backend)}") + + # Execute a simple query with compression enabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query with compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression enabled") + + # Test with compression disabled + logger.info("Creating connection with LZ4 compression disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, # Enable cloud fetch + enable_query_result_lz4_compression=False, # Disable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query with compression disabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query without compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query without compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression disabled") + + except Exception as e: + logger.error(f"Error during SEA query execution test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + sys.exit(1) + + logger.info("SEA query execution test with compression completed successfully") + + def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -42,25 +130,33 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent + use_sea=True, + user_agent_entry="SEA-Test-Client", # add custom user agent + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback + logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") + if __name__ == "__main__": + # Test session management test_sea_session() + + # Test query execution with compression + test_sea_query_exec() diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 20b059fa7..bbca4c502 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,8 +16,6 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState -from databricks.sql.utils import ExecuteResponse -from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING @@ -88,34 +86,6 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: - """ - Executes a SQL command or query within the specified session. - - This method sends a SQL command to the server for execution and handles - the response. It can operate in both synchronous and asynchronous modes. - - Args: - operation: The SQL command or query to execute - session_id: The session identifier in which to execute the command - max_rows: Maximum number of rows to fetch in a single fetch batch - max_bytes: Maximum number of bytes to fetch in a single fetch batch - lz4_compression: Whether to use LZ4 compression for result data - cursor: The cursor object that will handle the results - use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets - parameters: List of parameters to bind to the query - async_op: Whether to execute the command asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - If async_op is False, returns a ResultSet object containing the - query results and metadata. If async_op is True, returns None and the - results must be fetched later using get_execution_result(). - - Raises: - ValueError: If the session ID is invalid - OperationalError: If there's an error executing the command - ServerOperationError: If the server encounters an error during execution - """ pass @abstractmethod diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 000000000..7f48b6179 --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,143 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + +from databricks.sql.result_set import SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Create a filtered version of the result set + filtered_response = result_set._response.copy() + + # If there's a result with rows, filter them + if ( + "result" in filtered_response + and "data_array" in filtered_response["result"] + ): + rows = filtered_response["result"]["data_array"] + filtered_rows = [row for row in rows if filter_func(row)] + filtered_response["result"]["data_array"] = filtered_rows + + # Update row count if present + if "row_count" in filtered_response["result"]: + filtered_response["result"]["row_count"] = len(filtered_rows) + + # Create a new result set with the filtered data + return SeaResultSet( + connection=result_set.connection, + sea_response=filtered_response, + sea_client=result_set.backend, + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + ) + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is typically in the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=False + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..c7a4ed1b1 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,34 @@ import logging import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +import uuid +import time +from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +66,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -274,41 +288,222 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else "NONE" + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows if max_rows > 0 else None, + byte_limit=max_bytes if max_bytes > 0 else None, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + return SeaResultSet( + connection=cursor.connection, + sea_response=response_data, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == @@ -319,9 +514,22 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -331,9 +539,30 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -345,9 +574,43 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -359,6 +622,33 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index c9310d367..b7c8bd399 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,19 +4,49 @@ This package contains data models for SEA API requests and responses. """ +from databricks.sql.backend.sea.models.base import ( + ServiceError, + StatementStatus, + ExternalLink, + ResultData, + ColumnInfo, + ResultManifest, +) + from databricks.sql.backend.sea.models.requests import ( + StatementParameter, + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) __all__ = [ + # Base models + "ServiceError", + "StatementStatus", + "ExternalLink", + "ResultData", + "ColumnInfo", + "ResultManifest", # Request models + "StatementParameter", + "ExecuteStatementRequest", + "GetStatementRequest", + "CancelStatementRequest", + "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models + "ExecuteStatementResponse", + "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py new file mode 100644 index 000000000..671f7be13 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/base.py @@ -0,0 +1,68 @@ +""" +Base models for the SEA (Statement Execution API) backend. + +These models define the common structures used in SEA API requests and responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState + + +@dataclass +class ServiceError: + """Error information returned by the SEA API.""" + + message: str + error_code: Optional[str] = None + + +@dataclass +class StatementStatus: + """Status information for a statement execution.""" + + state: CommandState + error: Optional[ServiceError] = None + sql_state: Optional[str] = None + + +@dataclass +class ExternalLink: + """External link information for result data.""" + + external_link: str + expiration: str + chunk_index: int + + +@dataclass +class ResultData: + """Result data from a statement execution.""" + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + + +@dataclass +class ColumnInfo: + """Information about a column in the result set.""" + + name: str + type_name: str + type_text: str + nullable: bool = True + precision: Optional[int] = None + scale: Optional[int] = None + ordinal_position: Optional[int] = None + + +@dataclass +class ResultManifest: + """Manifest information for a result set.""" + + schema: List[ColumnInfo] + total_row_count: int + total_byte_count: int + truncated: bool = False + chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 7966cb502..1c519d931 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,5 +1,111 @@ -from typing import Dict, Any, Optional -from dataclasses import dataclass +""" +Request models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API requests. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + + +@dataclass +class StatementParameter: + """Parameter for a SQL statement.""" + + name: str + value: Optional[str] = None + type: Optional[str] = None + + +@dataclass +class ExecuteStatementRequest: + """Request to execute a SQL statement.""" + + warehouse_id: str + statement: str + session_id: str + disposition: str = "EXTERNAL_LINKS" + format: str = "JSON_ARRAY" + wait_timeout: str = "10s" + on_wait_timeout: str = "CONTINUE" + row_limit: Optional[int] = None + byte_limit: Optional[int] = None + parameters: Optional[List[StatementParameter]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + result_compression: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "warehouse_id": self.warehouse_id, + "session_id": self.session_id, + "statement": self.statement, + "disposition": self.disposition, + "format": self.format, + "wait_timeout": self.wait_timeout, + "on_wait_timeout": self.on_wait_timeout, + } + + if self.row_limit is not None and self.row_limit > 0: + result["row_limit"] = self.row_limit + + if self.byte_limit is not None and self.byte_limit > 0: + result["byte_limit"] = self.byte_limit + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + if self.result_compression: + result["result_compression"] = self.result_compression + + if self.parameters: + result["parameters"] = [ + { + "name": param.name, + **({"value": param.value} if param.value is not None else {}), + **({"type": param.type} if param.type is not None else {}), + } + for param in self.parameters + ] + + return result + + +@dataclass +class GetStatementRequest: + """Request to get information about a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CancelStatementRequest: + """Request to cancel a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CloseStatementRequest: + """Request to close a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1bb54590f..d70459b9f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,5 +1,96 @@ -from typing import Dict, Any -from dataclasses import dataclass +""" +Response models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState +from databricks.sql.backend.sea.models.base import ( + StatementStatus, + ResultManifest, + ResultData, + ServiceError, +) + + +@dataclass +class ExecuteStatementResponse: + """Response from executing a SQL statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": + """Create an ExecuteStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) + + +@dataclass +class GetStatementResponse: + """Response from getting information about a statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": + """Create a GetStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) @dataclass diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index de388f1d4..e03d6f235 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -5,11 +5,10 @@ import time import uuid import threading -from typing import List, Optional, Union, Any, TYPE_CHECKING +from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( @@ -17,8 +16,9 @@ SessionId, CommandId, BackendType, + guid_to_hex_id, + ExecuteResponse, ) -from databricks.sql.backend.utils import guid_to_hex_id try: import pyarrow @@ -42,7 +42,7 @@ ) from databricks.sql.utils import ( - ExecuteResponse, + ResultSetQueueFactory, _bound, RequestErrorInfo, NoRetryReason, @@ -53,6 +53,7 @@ ) from databricks.sql.types import SSLOptions from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.result_set import ResultSet, ThriftResultSet logger = logging.getLogger(__name__) @@ -351,7 +352,6 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ - # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -797,23 +797,27 @@ def _results_message_to_execute_response(self, resp, operation_state): command_id = CommandId.from_thrift_handle(resp.operationHandle) - return ExecuteResponse( - arrow_queue=arrow_queue_opt, - status=CommandState.from_thrift_state(operation_state), - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, - command_id=command_id, - description=description, - arrow_schema_bytes=schema_bytes, + status = CommandState.from_thrift_state(operation_state) + if status is None: + raise ValueError(f"Invalid operation state: {operation_state}") + + return ( + ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_more_rows=has_more_rows, + results_queue=arrow_queue_opt, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + ), + schema_bytes, ) def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") @@ -863,15 +867,14 @@ def get_execution_result( ) execute_response = ExecuteResponse( - arrow_queue=queue, - status=CommandState.from_thrift_state(resp.status), - has_been_closed_server_side=False, + command_id=command_id, + status=resp.status, + description=description, has_more_rows=has_more_rows, + results_queue=queue, + has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_id=command_id, - description=description, - arrow_schema_bytes=schema_bytes, ) return ThriftResultSet( @@ -881,6 +884,7 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=schema_bytes, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -909,10 +913,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) - state = CommandState.from_thrift_state(operation_state) - if state is None: - raise ValueError(f"Unknown command state: {operation_state}") - return state + return CommandState.from_thrift_state(operation_state) @staticmethod def _check_direct_results_for_error(t_spark_direct_results): @@ -947,8 +948,6 @@ def execute_command( async_op=False, enforce_embedded_schema_correctness=False, ) -> Union["ResultSet", None]: - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -995,7 +994,9 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1004,6 +1005,7 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_catalogs( @@ -1013,8 +1015,6 @@ def get_catalogs( max_bytes: int, cursor: "Cursor", ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1027,7 +1027,9 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1036,6 +1038,7 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_schemas( @@ -1047,8 +1050,6 @@ def get_schemas( catalog_name=None, schema_name=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1063,7 +1064,9 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1072,6 +1075,7 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_tables( @@ -1085,8 +1089,6 @@ def get_tables( table_name=None, table_types=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1103,7 +1105,9 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1112,6 +1116,7 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_columns( @@ -1125,8 +1130,6 @@ def get_columns( table_name=None, column_name=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1143,7 +1146,9 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1152,6 +1157,7 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def _handle_execute_response(self, resp, cursor): @@ -1165,7 +1171,12 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - return self._results_message_to_execute_response(resp, final_operation_state) + ( + execute_response, + arrow_schema_bytes, + ) = self._results_message_to_execute_response(resp, final_operation_state) + execute_response.command_id = command_id + return execute_response, arrow_schema_bytes def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -1226,7 +1237,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + logger.debug("Cancelling command {}".format(command_id.guid)) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..5bf02e0ea 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,28 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + + Args: + state: SEA state string + + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + class BackendType(Enum): """ @@ -285,28 +308,6 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count - def __str__(self) -> str: - """ - Return a string representation of the CommandId. - - For SEA backend, returns the guid. - For Thrift backend, returns a format like "guid|secret". - - Returns: - A string representation of the command ID - """ - - if self.backend_type == BackendType.SEA: - return str(self.guid) - elif self.backend_type == BackendType.THRIFT: - secret_hex = ( - guid_to_hex_id(self.secret) - if isinstance(self.secret, bytes) - else str(self.secret) - ) - return f"{self.to_hex_guid()}|{secret_hex}" - return str(self.guid) - @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -318,7 +319,6 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ - if operation_handle is None: return None @@ -394,3 +394,19 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9f7c060a7..e145e4e58 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -24,7 +24,6 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.utils import ( - ExecuteResponse, ParamEscaper, inject_parameters, transform_paramstyle, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a0d8d3579..fc8595839 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,26 +1,23 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Any, Union, TYPE_CHECKING +from typing import List, Optional, Any, Union, Tuple, TYPE_CHECKING import logging import time import pandas -from databricks.sql.backend.types import CommandId, CommandState - try: import pyarrow except ImportError: pyarrow = None if TYPE_CHECKING: - from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection - from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue +from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -34,32 +31,31 @@ class ResultSet(ABC): def __init__( self, - connection: "Connection", - backend: "DatabricksClient", - command_id: CommandId, - op_state: Optional[CommandState], - has_been_closed_server_side: bool, + connection, + backend, arraysize: int, buffer_size_bytes: int, + command_id=None, + status=None, + has_been_closed_server_side: bool = False, + has_more_rows: bool = False, + results_queue=None, + description=None, + is_staging_operation: bool = False, ): - """ - A ResultSet manages the results of a single command. - - :param connection: The parent connection that was used to execute this command - :param backend: The specialised backend client to be invoked in the fetch phase - :param execute_response: A `ExecuteResponse` class returned by a command execution - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - amount :param arraysize: The max number of rows to fetch at a time (PEP-249) - """ - self.command_id = command_id - self.op_state = op_state - self.has_been_closed_server_side = has_been_closed_server_side + """Initialize the base ResultSet with common properties.""" self.connection = connection - self.backend = backend + self.backend = backend # Store the backend client directly self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 - self.description = None + self.description = description + self.command_id = command_id + self.status = status + self.has_been_closed_server_side = has_been_closed_server_side + self._has_more_rows = has_more_rows + self.results = results_queue + self._is_staging_operation = is_staging_operation def __iter__(self): while True: @@ -74,10 +70,9 @@ def rownumber(self): return self._next_row_index @property - @abstractmethod def is_staging_operation(self) -> bool: """Whether this result set represents a staging operation.""" - pass + return self._is_staging_operation # Define abstract methods that concrete implementations must implement @abstractmethod @@ -101,12 +96,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> "pyarrow.Table": + def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" pass @@ -119,7 +114,7 @@ def close(self) -> None: """ try: if ( - self.op_state != CommandState.CLOSED + self.status != CommandState.CLOSED and not self.has_been_closed_server_side and self.connection.open ): @@ -129,7 +124,7 @@ def close(self) -> None: logger.info("Operation was canceled by a prior request") finally: self.has_been_closed_server_side = True - self.op_state = CommandState.CLOSED + self.status = CommandState.CLOSED class ThriftResultSet(ResultSet): @@ -138,11 +133,12 @@ class ThriftResultSet(ResultSet): def __init__( self, connection: "Connection", - execute_response: ExecuteResponse, + execute_response: "ExecuteResponse", thrift_client: "ThriftDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, + arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -154,37 +150,33 @@ def __init__( buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch use_cloud_fetch: Whether to use cloud fetch for retrieving results + arrow_schema_bytes: Arrow schema bytes for the result set """ - super().__init__( - connection, - thrift_client, - execute_response.command_id, - execute_response.status, - execute_response.has_been_closed_server_side, - arraysize, - buffer_size_bytes, - ) - # Initialize ThriftResultSet-specific attributes - self.has_been_closed_server_side = execute_response.has_been_closed_server_side - self.has_more_rows = execute_response.has_more_rows - self.lz4_compressed = execute_response.lz4_compressed - self.description = execute_response.description - self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._arrow_schema_bytes = arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch - self._is_staging_operation = execute_response.is_staging_operation + self.lz4_compressed = execute_response.lz4_compressed - # Initialize results queue - if execute_response.arrow_queue: - # In this case the server has taken the fast path and returned an initial batch of - # results - self.results = execute_response.arrow_queue - else: - # In this case, there are results waiting on the server so we fetch now for simplicity + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=thrift_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + ) + + # Initialize results queue if not provided + if not self.results: self._fill_results_buffer() def _fill_results_buffer(self): - # At initialization or if the server does not have cloud fetch result links available results, has_more_rows = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, @@ -196,7 +188,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self._has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -248,7 +240,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": + def merge_columnar(self, result1, result2): """ Function to merge / combining the columnar results into a single result :param result1: @@ -280,7 +272,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -305,7 +297,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -320,7 +312,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -346,7 +338,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -389,24 +381,110 @@ def fetchmany(self, size: int) -> List[Row]: else: return self._convert_arrow_table(self.fetchmany_arrow(size)) - @property - def is_staging_operation(self) -> bool: - """Whether this result set represents a staging operation.""" - return self._is_staging_operation - @staticmethod - def _get_schema_description(table_schema_message): +class SeaResultSet(ResultSet): + """ResultSet implementation for SEA backend.""" + + def __init__( + self, + connection, + sea_client, + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + execute_response=None, + sea_response=None, + ): """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) """ + # Handle both initialization styles + if execute_response is not None: + # New style with ExecuteResponse + command_id = execute_response.command_id + status = execute_response.status + has_been_closed_server_side = execute_response.has_been_closed_server_side + has_more_rows = execute_response.has_more_rows + results_queue = execute_response.results_queue + description = execute_response.description + is_staging_operation = execute_response.is_staging_operation + self._response = getattr(execute_response, "sea_response", {}) + self.statement_id = command_id.to_sea_statement_id() if command_id else None + elif sea_response is not None: + # Legacy style with direct sea_response + self._response = sea_response + # Extract values from sea_response + command_id = CommandId.from_sea_statement_id( + sea_response.get("statement_id", "") + ) + self.statement_id = sea_response.get("statement_id", "") + + # Extract status + status_data = sea_response.get("status", {}) + status = CommandState.from_sea_state(status_data.get("state", "PENDING")) + + # Set defaults for other fields + has_been_closed_server_side = False + has_more_rows = False + results_queue = None + description = None + is_staging_operation = False + else: + raise ValueError("Either execute_response or sea_response must be provided") - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=status, + has_been_closed_server_side=has_been_closed_server_side, + has_more_rows=has_more_rows, + results_queue=results_queue, + description=description, + is_staging_operation=is_staging_operation, + ) - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 7c33d9b2d..76aec4675 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -10,7 +10,7 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId +from databricks.sql.backend.types import SessionId, BackendType logger = logging.getLogger(__name__) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 2622b1172..edb13ef6d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -349,13 +349,6 @@ def _create_empty_table(self) -> "pyarrow.Table": return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) -ExecuteResponse = namedtuple( - "ExecuteResponse", - "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " - "command_id arrow_queue arrow_schema_bytes", -) - - def _bound(min_x, max_x, x): """Bound x by [min_x, max_x] diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 1a7950870..090ec255e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -26,7 +26,7 @@ from databricks.sql.types import Row from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.backend.types import CommandId, CommandState -from databricks.sql.utils import ExecuteResponse +from databricks.sql.backend.types import ExecuteResponse from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite @@ -121,10 +121,10 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Verify initial state self.assertEqual(real_result_set.has_been_closed_server_side, closed) - expected_op_state = ( + expected_status = ( CommandState.CLOSED if closed else CommandState.SUCCEEDED ) - self.assertEqual(real_result_set.op_state, expected_op_state) + self.assertEqual(real_result_set.status, expected_status) # Mock execute_command to return our real result set cursor.backend.execute_command = Mock(return_value=real_result_set) @@ -146,8 +146,8 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # 1. has_been_closed_server_side should always be True after close() self.assertTrue(real_result_set.has_been_closed_server_side) - # 2. op_state should always be CLOSED after close() - self.assertEqual(real_result_set.op_state, CommandState.CLOSED) + # 2. status should always be CLOSED after close() + self.assertEqual(real_result_set.status, CommandState.CLOSED) # 3. Backend close_command should be called appropriately if not closed: @@ -556,7 +556,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): self.assertEqual(instance.close_session.call_count, 0) cursor.close() - @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) + @patch("%s.backend.types.ExecuteResponse" % PACKAGE_NAME) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_staging_operation_response_is_handled( @@ -678,10 +678,10 @@ def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" result_set = client.ThriftResultSet.__new__(client.ThriftResultSet) result_set.backend = Mock() - result_set.backend.CLOSED_OP_STATE = "CLOSED" + result_set.backend.CLOSED_OP_STATE = CommandState.CLOSED result_set.connection = Mock() result_set.connection.open = True - result_set.op_state = "RUNNING" + result_set.status = CommandState.RUNNING result_set.has_been_closed_server_side = False result_set.command_id = Mock() @@ -695,7 +695,7 @@ def __init__(self): try: try: if ( - result_set.op_state != result_set.backend.CLOSED_OP_STATE + result_set.status != result_set.backend.CLOSED_OP_STATE and not result_set.has_been_closed_server_side and result_set.connection.open ): @@ -705,7 +705,7 @@ def __init__(self): pass finally: result_set.has_been_closed_server_side = True - result_set.op_state = result_set.backend.CLOSED_OP_STATE + result_set.status = result_set.backend.CLOSED_OP_STATE result_set.backend.close_command.assert_called_once_with( result_set.command_id @@ -713,7 +713,7 @@ def __init__(self): assert result_set.has_been_closed_server_side is True - assert result_set.op_state == result_set.backend.CLOSED_OP_STATE + assert result_set.status == result_set.backend.CLOSED_OP_STATE finally: pass diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 030510a64..7249a59e6 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -8,7 +8,8 @@ pa = None import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.result_set import ThriftResultSet @@ -42,14 +43,13 @@ def make_dummy_result_set_from_initial_results(initial_results): rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), lz4_compressed=Mock(), - command_id=None, - arrow_queue=arrow_queue, - arrow_schema_bytes=schema.serialize().to_pybytes(), + results_queue=arrow_queue, is_staging_operation=False, ), thrift_client=None, @@ -88,6 +88,7 @@ def fetch_results( rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=False, has_more_rows=True, @@ -96,9 +97,7 @@ def fetch_results( for col_id in range(num_cols) ], lz4_compressed=Mock(), - command_id=None, - arrow_queue=None, - arrow_schema_bytes=None, + results_queue=None, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index b302c00da..7e025cf82 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -10,7 +10,8 @@ import pytest import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py new file mode 100644 index 000000000..e8eb2a757 --- /dev/null +++ b/tests/unit/test_result_set_filter.py @@ -0,0 +1,246 @@ +""" +Tests for the ResultSetFilter class. + +This module contains tests for the ResultSetFilter class, which provides +filtering capabilities for result sets returned by different backends. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestResultSetFilter: + """Test suite for the ResultSetFilter class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response_with_tables(self): + """Create a sample SEA response with table data based on the server schema.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 7, + "columns": [ + {"name": "namespace", "type_text": "STRING", "position": 0}, + {"name": "tableName", "type_text": "STRING", "position": 1}, + {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, + {"name": "information", "type_text": "STRING", "position": 3}, + {"name": "catalogName", "type_text": "STRING", "position": 4}, + {"name": "tableType", "type_text": "STRING", "position": 5}, + {"name": "remarks", "type_text": "STRING", "position": 6}, + ], + }, + "total_row_count": 4, + }, + "result": { + "data_array": [ + ["schema1", "table1", False, None, "catalog1", "TABLE", None], + ["schema1", "table2", False, None, "catalog1", "VIEW", None], + [ + "schema1", + "table3", + False, + None, + "catalog1", + "SYSTEM TABLE", + None, + ], + ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], + ] + }, + } + + @pytest.fixture + def sea_result_set_with_tables( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Create a SeaResultSet with table data.""" + # Create a deep copy of the response to avoid test interference + import copy + + sea_response_copy = copy.deepcopy(sea_response_with_tables) + + return SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy, + buffer_size_bytes=1000, + arraysize=100, + ) + + def test_filter_tables_by_type_default(self, sea_result_set_with_tables): + """Test filtering tables by type with default types.""" + # Default types are TABLE, VIEW, SYSTEM TABLE + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables + ) + + # Verify that only the default types are included + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): + """Test filtering tables by type with custom types.""" + # Filter for only TABLE and EXTERNAL + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] + ) + + # Verify that only the specified types are included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "EXTERNAL" in table_types + assert "VIEW" not in table_types + assert "SYSTEM TABLE" not in table_types + + def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): + """Test that table type filtering is case-insensitive.""" + # Filter for lowercase "table" and "view" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["table", "view"] + ) + + # Verify that the matching types are included despite case differences + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" not in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): + """Test filtering tables with an empty type list (should use defaults).""" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=[] + ) + + # Verify that default types are used + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_by_column_values(self, sea_result_set_with_tables): + """Test filtering by values in a specific column.""" + # Filter by namespace in column index 0 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] + ) + + # All rows have schema1 in namespace, so all should be included + assert len(filtered_result._response["result"]["data_array"]) == 4 + + # Filter by table name in column index 1 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, + column_index=1, + allowed_values=["table1", "table3"], + ) + + # Only rows with table1 or table3 should be included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_names = [ + row[1] for row in filtered_result._response["result"]["data_array"] + ] + assert "table1" in table_names + assert "table3" in table_names + assert "table2" not in table_names + assert "table4" not in table_names + + def test_filter_by_column_values_case_sensitive( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Test case-sensitive filtering by column values.""" + import copy + + # Create a fresh result set for the first test + sea_response_copy1 = copy.deepcopy(sea_response_with_tables) + result_set1 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy1, + buffer_size_bytes=1000, + arraysize=100, + ) + + # First test: Case-sensitive filtering with lowercase values (should find no matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set1, + column_index=5, # tableType column + allowed_values=["table", "view"], # lowercase + case_sensitive=True, + ) + + # Verify no matches with lowercase values + assert len(filtered_result._response["result"]["data_array"]) == 0 + + # Create a fresh result set for the second test + sea_response_copy2 = copy.deepcopy(sea_response_with_tables) + result_set2 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy2, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Second test: Case-sensitive filtering with correct case (should find matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set2, + column_index=5, # tableType column + allowed_values=["TABLE", "VIEW"], # correct case + case_sensitive=True, + ) + + # Verify matches with correct case + assert len(filtered_result._response["result"]["data_array"]) == 2 + + # Extract the table types from the filtered results + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + + def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): + """Test filtering with a column index that's out of bounds.""" + # Filter by column index 10 (out of bounds) + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=10, allowed_values=["value"] + ) + + # No rows should match + assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..c2078f731 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,650 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, } - assert set(allowed_configs) == expected_keys + mock_http_client._make_request.return_value = execute_response - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + # Tests for metadata operations + + def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): + """Test getting catalogs.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["lz4_compression"] is False + assert kwargs["use_cloud_fetch"] is False + assert kwargs["parameters"] == [] + assert kwargs["async_op"] is False + assert kwargs["enforce_embedded_schema_correctness"] is False + + def test_get_schemas_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_with_catalog_and_schema_pattern( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with catalog and schema pattern.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog and schema pattern + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting schemas without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + ) + + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_all_catalogs( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables from all catalogs using wildcard.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_schema_and_table_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with schema and table patterns.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog, schema, and table patterns + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting tables without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_with_all_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with all patterns specified.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with all patterns + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting columns without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_columns" in str(excinfo.value) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py new file mode 100644 index 000000000..f666fd613 --- /dev/null +++ b/tests/unit/test_sea_result_set.py @@ -0,0 +1,275 @@ +""" +Tests for the SeaResultSet class. + +This module contains tests for the SeaResultSet class, which implements +the result set functionality for the SEA (Statement Execution API) backend. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestSeaResultSet: + """Test suite for the SeaResultSet class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response(self): + """Create a sample SEA response.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + + @pytest.fixture + def execute_response(self): + """Create a sample execute response.""" + mock_response = Mock() + mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") + mock_response.status = CommandState.SUCCEEDED + mock_response.has_been_closed_server_side = False + mock_response.has_more_rows = False + mock_response.results_queue = None + mock_response.description = [ + ("test_value", "INT", None, None, None, None, None) + ] + mock_response.is_staging_operation = False + mock_response.sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "result": {"data_array": [["1"]]}, + } + return mock_response + + def test_init_with_sea_response( + self, mock_connection, mock_sea_client, sea_response + ): + """Test initializing SeaResultSet with a SEA response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == sea_response + + def test_init_with_execute_response( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with an execute response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + execute_response=execute_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == execute_response.sea_response + + def test_init_with_no_response(self, mock_connection, mock_sea_client): + """Test that initialization fails when neither response type is provided.""" + with pytest.raises(ValueError) as excinfo: + SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + assert "Either execute_response or sea_response must be provided" in str( + excinfo.value + ) + + def test_close(self, mock_connection, mock_sea_client, sea_response): + """Test closing a result set.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was called + mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_already_closed_server_side( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set that has already been closed server-side.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_connection_closed( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set when the connection is closed.""" + mock_connection.open = False + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_unimplemented_methods( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that unimplemented methods raise NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() + + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() + + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) + + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass + + def test_fill_results_buffer_not_implemented( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that _fill_results_buffer raises NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set._fill_results_buffer() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 92de8d8fd..fef070362 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,6 +2,11 @@ from unittest.mock import patch, Mock import gc +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + THandleIdentifier, +) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57a2a61e3..b8de970db 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,11 +619,18 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 + + # Create a valid operation status + op_status = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=Mock(), + operationStatus=op_status, resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -644,7 +651,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -878,11 +885,12 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - results_message_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) + self.assertEqual( - results_message_response.status, + execute_response.status, CommandState.SUCCEEDED, ) @@ -915,7 +923,9 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -943,15 +953,21 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( + execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) - self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -971,6 +987,12 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -1018,7 +1040,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) @@ -1150,7 +1172,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1184,7 +1206,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1215,7 +1237,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1255,7 +1277,7 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1299,7 +1321,7 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1645,7 +1667,9 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) # Create a mock response with a real operation handle mock_resp = Mock() @@ -2204,7 +2228,8 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", + return_value=(Mock(), Mock()), ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class From 3e3ab94e8fa3dd02e4b05b5fc35939aef57793a2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:31:37 +0000 Subject: [PATCH 02/66] remove excess test Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 +++----------------- 1 file changed, 14 insertions(+), 110 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 87b62efea..abe6bd1ab 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,122 +6,34 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) - -def test_sea_query_exec(): - """ - Test executing a query using the SEA backend with result compression. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with result compression enabled and disabled, - and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - sys.exit(1) - - try: - # Test with compression enabled - logger.info("Creating connection with LZ4 compression enabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, # Enable cloud fetch to use compression - enable_query_result_lz4_compression=True, # Enable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - logger.info(f"backend type: {type(connection.session.backend)}") - - # Execute a simple query with compression enabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query with compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression enabled") - - # Test with compression disabled - logger.info("Creating connection with LZ4 compression disabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, # Enable cloud fetch - enable_query_result_lz4_compression=False, # Disable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query with compression disabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query without compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query without compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression disabled") - - except Exception as e: - logger.error(f"Error during SEA query execution test: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - sys.exit(1) - - logger.info("SEA query execution test with compression completed successfully") - - def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -130,33 +42,25 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", # add custom user agent - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent ) + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback - logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") - if __name__ == "__main__": - # Test session management test_sea_session() - - # Test query execution with compression - test_sea_query_exec() From 4a781653375d8f06dd7d9ad745446e49a355c680 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:33:02 +0000 Subject: [PATCH 03/66] add docstring Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index bbca4c502..cd347d9ab 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -86,6 +86,33 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: + """ + Executes a SQL command or query within the specified session. + This method sends a SQL command to the server for execution and handles + the response. It can operate in both synchronous and asynchronous modes. + + Args: + operation: The SQL command or query to execute + session_id: The session identifier in which to execute the command + max_rows: Maximum number of rows to fetch in a single fetch batch + max_bytes: Maximum number of bytes to fetch in a single fetch batch + lz4_compression: Whether to use LZ4 compression for result data + cursor: The cursor object that will handle the results + use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets + parameters: List of parameters to bind to the query + async_op: Whether to execute the command asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + If async_op is False, returns a ResultSet object containing the + query results and metadata. If async_op is True, returns None and the + results must be fetched later using get_execution_result(). + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error executing the command + ServerOperationError: If the server encounters an error during execution + """ pass @abstractmethod From 0dac4aaf90dba50151dd7565adee270a794e8330 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:34:49 +0000 Subject: [PATCH 04/66] remvoe exec func in sea backend Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 360 +++------------------- 1 file changed, 35 insertions(+), 325 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index c7a4ed1b1..97d25a058 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,34 +1,23 @@ import logging import re -import uuid -import time -from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING - -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.exc import ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, - StatementParameter, - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) @@ -66,9 +55,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths @@ -288,222 +274,41 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List, + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: - """ - Execute a SQL command using the SEA backend. - - Args: - operation: SQL command to execute - session_id: Session identifier - max_rows: Maximum number of rows to fetch - max_bytes: Maximum number of bytes to fetch - lz4_compression: Whether to use LZ4 compression - cursor: Cursor executing the command - use_cloud_fetch: Whether to use cloud fetch - parameters: SQL parameters - async_op: Whether to execute asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - ResultSet: A SeaResultSet instance for the executed command - """ - if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") - - sea_session_id = session_id.to_sea_session_id() - - # Convert parameters to StatementParameter objects - sea_parameters = [] - if parameters: - for param in parameters: - sea_parameters.append( - StatementParameter( - name=param.name, - value=param.value, - type=param.type if hasattr(param, "type") else None, - ) - ) - - format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" - disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" - result_compression = "LZ4_FRAME" if lz4_compression else "NONE" - - request = ExecuteStatementRequest( - warehouse_id=self.warehouse_id, - session_id=sea_session_id, - statement=operation, - disposition=disposition, - format=format, - wait_timeout="0s" if async_op else "10s", - on_wait_timeout="CONTINUE", - row_limit=max_rows if max_rows > 0 else None, - byte_limit=max_bytes if max_bytes > 0 else None, - parameters=sea_parameters if sea_parameters else None, - result_compression=result_compression, - ) - - response_data = self.http_client._make_request( - method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ): + """Not implemented yet.""" + raise NotImplementedError( + "execute_command is not yet implemented for SEA backend" ) - response = ExecuteStatementResponse.from_dict(response_data) - statement_id = response.statement_id - if not statement_id: - raise ServerOperationError( - "Failed to execute command: No statement ID returned", - { - "operation-id": None, - "diagnostic-info": None, - }, - ) - - command_id = CommandId.from_sea_statement_id(statement_id) - - # Store the command ID in the cursor - cursor.active_command_id = command_id - - # If async operation, return None and let the client poll for results - if async_op: - return None - - # For synchronous operation, wait for the statement to complete - # Poll until the statement is done - status = response.status - state = status.state - - # Keep polling until we reach a terminal state - while state in [CommandState.PENDING, CommandState.RUNNING]: - time.sleep(0.5) # add a small delay to avoid excessive API calls - state = self.get_query_state(command_id) - - if state != CommandState.SUCCEEDED: - raise ServerOperationError( - f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", - { - "operation-id": command_id.to_sea_statement_id(), - "diagnostic-info": None, - }, - ) - - return self.get_execution_result(command_id, cursor) def cancel_command(self, command_id: CommandId) -> None: - """ - Cancel a running command. - - Args: - command_id: Command identifier to cancel - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CancelStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="POST", - path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "cancel_command is not yet implemented for SEA backend" ) def close_command(self, command_id: CommandId) -> None: - """ - Close a command and release resources. - - Args: - command_id: Command identifier to close - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CloseStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="DELETE", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "close_command is not yet implemented for SEA backend" ) def get_query_state(self, command_id: CommandId) -> CommandState: - """ - Get the state of a running query. - - Args: - command_id: Command identifier - - Returns: - CommandState: The current state of the command - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = GetStatementRequest(statement_id=sea_statement_id) - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "get_query_state is not yet implemented for SEA backend" ) - # Parse the response - response = GetStatementResponse.from_dict(response_data) - return response.status.state - def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ) -> "ResultSet": - """ - Get the result of a command execution. - - Args: - command_id: Command identifier - cursor: Cursor executing the command - - Returns: - ResultSet: A SeaResultSet instance with the execution results - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - # Create the request model - request = GetStatementRequest(statement_id=sea_statement_id) - - # Get the statement result - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) - - # Create and return a SeaResultSet - from databricks.sql.result_set import SeaResultSet - - return SeaResultSet( - connection=cursor.connection, - sea_response=response_data, - sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, + ): + """Not implemented yet.""" + raise NotImplementedError( + "get_execution_result is not yet implemented for SEA backend" ) # == Metadata Operations == @@ -514,22 +319,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - result = self.execute_command( - operation="SHOW CATALOGS", - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") def get_schemas( self, @@ -539,30 +331,9 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_schemas") - - operation = f"SHOW SCHEMAS IN `{catalog_name}`" - - if schema_name: - operation += f" LIKE '{schema_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") def get_tables( self, @@ -574,43 +345,9 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - - operation = "SHOW TABLES IN " + ( - "ALL CATALOGS" - if catalog_name in [None, "*", "%"] - else f"CATALOG `{catalog_name}`" - ) - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" LIKE '{table_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - - # Apply client-side filtering by table_types if specified - from databricks.sql.backend.filters import ResultSetFilter - - result = ResultSetFilter.filter_tables_by_type(result, table_types) - - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_tables is not yet implemented for SEA backend") def get_columns( self, @@ -622,33 +359,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_columns") - - operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" TABLE LIKE '{table_name}'" - - if column_name: - operation += f" LIKE '{column_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_columns is not yet implemented for SEA backend") From 1b794c7df6f5e414ef793a5da0f2b8ba19c9bc61 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:35:40 +0000 Subject: [PATCH 05/66] remove excess files Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 143 -------------- tests/unit/test_result_set_filter.py | 246 ----------------------- tests/unit/test_sea_result_set.py | 275 -------------------------- 3 files changed, 664 deletions(-) delete mode 100644 src/databricks/sql/backend/filters.py delete mode 100644 tests/unit/test_result_set_filter.py delete mode 100644 tests/unit/test_sea_result_set.py diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py deleted file mode 100644 index 7f48b6179..000000000 --- a/src/databricks/sql/backend/filters.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Client-side filtering utilities for Databricks SQL connector. - -This module provides filtering capabilities for result sets returned by different backends. -""" - -import logging -from typing import ( - List, - Optional, - Any, - Callable, - TYPE_CHECKING, -) - -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet - -from databricks.sql.result_set import SeaResultSet - -logger = logging.getLogger(__name__) - - -class ResultSetFilter: - """ - A general-purpose filter for result sets that can be applied to any backend. - - This class provides methods to filter result sets based on various criteria, - similar to the client-side filtering in the JDBC connector. - """ - - @staticmethod - def _filter_sea_result_set( - result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] - ) -> "SeaResultSet": - """ - Filter a SEA result set using the provided filter function. - - Args: - result_set: The SEA result set to filter - filter_func: Function that takes a row and returns True if the row should be included - - Returns: - A filtered SEA result set - """ - # Create a filtered version of the result set - filtered_response = result_set._response.copy() - - # If there's a result with rows, filter them - if ( - "result" in filtered_response - and "data_array" in filtered_response["result"] - ): - rows = filtered_response["result"]["data_array"] - filtered_rows = [row for row in rows if filter_func(row)] - filtered_response["result"]["data_array"] = filtered_rows - - # Update row count if present - if "row_count" in filtered_response["result"]: - filtered_response["result"]["row_count"] = len(filtered_rows) - - # Create a new result set with the filtered data - return SeaResultSet( - connection=result_set.connection, - sea_response=filtered_response, - sea_client=result_set.backend, - buffer_size_bytes=result_set.buffer_size_bytes, - arraysize=result_set.arraysize, - ) - - @staticmethod - def filter_by_column_values( - result_set: "ResultSet", - column_index: int, - allowed_values: List[str], - case_sensitive: bool = False, - ) -> "ResultSet": - """ - Filter a result set by values in a specific column. - - Args: - result_set: The result set to filter - column_index: The index of the column to filter on - allowed_values: List of allowed values for the column - case_sensitive: Whether to perform case-sensitive comparison - - Returns: - A filtered result set - """ - # Convert to uppercase for case-insensitive comparison if needed - if not case_sensitive: - allowed_values = [v.upper() for v in allowed_values] - - # Determine the type of result set and apply appropriate filtering - if isinstance(result_set, SeaResultSet): - return ResultSetFilter._filter_sea_result_set( - result_set, - lambda row: ( - len(row) > column_index - and isinstance(row[column_index], str) - and ( - row[column_index].upper() - if not case_sensitive - else row[column_index] - ) - in allowed_values - ), - ) - - # For other result set types, return the original (should be handled by specific implementations) - logger.warning( - f"Filtering not implemented for result set type: {type(result_set).__name__}" - ) - return result_set - - @staticmethod - def filter_tables_by_type( - result_set: "ResultSet", table_types: Optional[List[str]] = None - ) -> "ResultSet": - """ - Filter a result set of tables by the specified table types. - - This is a client-side filter that processes the result set after it has been - retrieved from the server. It filters out tables whose type does not match - any of the types in the table_types list. - - Args: - result_set: The original result set containing tables - table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) - - Returns: - A filtered result set containing only tables of the specified types - """ - # Default table types if none specified - DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] - valid_types = ( - table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES - ) - - # Table type is typically in the 6th column (index 5) - return ResultSetFilter.filter_by_column_values( - result_set, 5, valid_types, case_sensitive=False - ) diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py deleted file mode 100644 index e8eb2a757..000000000 --- a/tests/unit/test_result_set_filter.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -Tests for the ResultSetFilter class. - -This module contains tests for the ResultSetFilter class, which provides -filtering capabilities for result sets returned by different backends. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.backend.filters import ResultSetFilter -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestResultSetFilter: - """Test suite for the ResultSetFilter class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response_with_tables(self): - """Create a sample SEA response with table data based on the server schema.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 7, - "columns": [ - {"name": "namespace", "type_text": "STRING", "position": 0}, - {"name": "tableName", "type_text": "STRING", "position": 1}, - {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, - {"name": "information", "type_text": "STRING", "position": 3}, - {"name": "catalogName", "type_text": "STRING", "position": 4}, - {"name": "tableType", "type_text": "STRING", "position": 5}, - {"name": "remarks", "type_text": "STRING", "position": 6}, - ], - }, - "total_row_count": 4, - }, - "result": { - "data_array": [ - ["schema1", "table1", False, None, "catalog1", "TABLE", None], - ["schema1", "table2", False, None, "catalog1", "VIEW", None], - [ - "schema1", - "table3", - False, - None, - "catalog1", - "SYSTEM TABLE", - None, - ], - ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], - ] - }, - } - - @pytest.fixture - def sea_result_set_with_tables( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Create a SeaResultSet with table data.""" - # Create a deep copy of the response to avoid test interference - import copy - - sea_response_copy = copy.deepcopy(sea_response_with_tables) - - return SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy, - buffer_size_bytes=1000, - arraysize=100, - ) - - def test_filter_tables_by_type_default(self, sea_result_set_with_tables): - """Test filtering tables by type with default types.""" - # Default types are TABLE, VIEW, SYSTEM TABLE - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables - ) - - # Verify that only the default types are included - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): - """Test filtering tables by type with custom types.""" - # Filter for only TABLE and EXTERNAL - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] - ) - - # Verify that only the specified types are included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "EXTERNAL" in table_types - assert "VIEW" not in table_types - assert "SYSTEM TABLE" not in table_types - - def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): - """Test that table type filtering is case-insensitive.""" - # Filter for lowercase "table" and "view" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["table", "view"] - ) - - # Verify that the matching types are included despite case differences - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" not in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): - """Test filtering tables with an empty type list (should use defaults).""" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=[] - ) - - # Verify that default types are used - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_by_column_values(self, sea_result_set_with_tables): - """Test filtering by values in a specific column.""" - # Filter by namespace in column index 0 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] - ) - - # All rows have schema1 in namespace, so all should be included - assert len(filtered_result._response["result"]["data_array"]) == 4 - - # Filter by table name in column index 1 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, - column_index=1, - allowed_values=["table1", "table3"], - ) - - # Only rows with table1 or table3 should be included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_names = [ - row[1] for row in filtered_result._response["result"]["data_array"] - ] - assert "table1" in table_names - assert "table3" in table_names - assert "table2" not in table_names - assert "table4" not in table_names - - def test_filter_by_column_values_case_sensitive( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Test case-sensitive filtering by column values.""" - import copy - - # Create a fresh result set for the first test - sea_response_copy1 = copy.deepcopy(sea_response_with_tables) - result_set1 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy1, - buffer_size_bytes=1000, - arraysize=100, - ) - - # First test: Case-sensitive filtering with lowercase values (should find no matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set1, - column_index=5, # tableType column - allowed_values=["table", "view"], # lowercase - case_sensitive=True, - ) - - # Verify no matches with lowercase values - assert len(filtered_result._response["result"]["data_array"]) == 0 - - # Create a fresh result set for the second test - sea_response_copy2 = copy.deepcopy(sea_response_with_tables) - result_set2 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy2, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Second test: Case-sensitive filtering with correct case (should find matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set2, - column_index=5, # tableType column - allowed_values=["TABLE", "VIEW"], # correct case - case_sensitive=True, - ) - - # Verify matches with correct case - assert len(filtered_result._response["result"]["data_array"]) == 2 - - # Extract the table types from the filtered results - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - - def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): - """Test filtering with a column index that's out of bounds.""" - # Filter by column index 10 (out of bounds) - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=10, allowed_values=["value"] - ) - - # No rows should match - assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py deleted file mode 100644 index f666fd613..000000000 --- a/tests/unit/test_sea_result_set.py +++ /dev/null @@ -1,275 +0,0 @@ -""" -Tests for the SeaResultSet class. - -This module contains tests for the SeaResultSet class, which implements -the result set functionality for the SEA (Statement Execution API) backend. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestSeaResultSet: - """Test suite for the SeaResultSet class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response(self): - """Create a sample SEA response.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - - @pytest.fixture - def execute_response(self): - """Create a sample execute response.""" - mock_response = Mock() - mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") - mock_response.status = CommandState.SUCCEEDED - mock_response.has_been_closed_server_side = False - mock_response.has_more_rows = False - mock_response.results_queue = None - mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) - ] - mock_response.is_staging_operation = False - mock_response.sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "result": {"data_array": [["1"]]}, - } - return mock_response - - def test_init_with_sea_response( - self, mock_connection, mock_sea_client, sea_response - ): - """Test initializing SeaResultSet with a SEA response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == sea_response - - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): - """Test initializing SeaResultSet with an execute response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - execute_response=execute_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == execute_response.sea_response - - def test_init_with_no_response(self, mock_connection, mock_sea_client): - """Test that initialization fails when neither response type is provided.""" - with pytest.raises(ValueError) as excinfo: - SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - assert "Either execute_response or sea_response must be provided" in str( - excinfo.value - ) - - def test_close(self, mock_connection, mock_sea_client, sea_response): - """Test closing a result set.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, sea_response - ): - """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, sea_response - ): - """Test closing a result set when the connection is closed.""" - mock_connection.open = False - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, sea_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) - - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test using the result set in a for loop - for row in result_set: - pass - - def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, sea_response - ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set._fill_results_buffer() From da5a6fe7511e927c511d61adb222b8a6a0da14d3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:39:11 +0000 Subject: [PATCH 06/66] remove excess models Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/models/__init__.py | 30 ----- src/databricks/sql/backend/sea/models/base.py | 68 ----------- .../sql/backend/sea/models/requests.py | 110 +----------------- .../sql/backend/sea/models/responses.py | 95 +-------------- 4 files changed, 4 insertions(+), 299 deletions(-) delete mode 100644 src/databricks/sql/backend/sea/models/base.py diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index b7c8bd399..c9310d367 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,49 +4,19 @@ This package contains data models for SEA API requests and responses. """ -from databricks.sql.backend.sea.models.base import ( - ServiceError, - StatementStatus, - ExternalLink, - ResultData, - ColumnInfo, - ResultManifest, -) - from databricks.sql.backend.sea.models.requests import ( - StatementParameter, - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) __all__ = [ - # Base models - "ServiceError", - "StatementStatus", - "ExternalLink", - "ResultData", - "ColumnInfo", - "ResultManifest", # Request models - "StatementParameter", - "ExecuteStatementRequest", - "GetStatementRequest", - "CancelStatementRequest", - "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models - "ExecuteStatementResponse", - "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py deleted file mode 100644 index 671f7be13..000000000 --- a/src/databricks/sql/backend/sea/models/base.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Base models for the SEA (Statement Execution API) backend. - -These models define the common structures used in SEA API requests and responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState - - -@dataclass -class ServiceError: - """Error information returned by the SEA API.""" - - message: str - error_code: Optional[str] = None - - -@dataclass -class StatementStatus: - """Status information for a statement execution.""" - - state: CommandState - error: Optional[ServiceError] = None - sql_state: Optional[str] = None - - -@dataclass -class ExternalLink: - """External link information for result data.""" - - external_link: str - expiration: str - chunk_index: int - - -@dataclass -class ResultData: - """Result data from a statement execution.""" - - data: Optional[List[List[Any]]] = None - external_links: Optional[List[ExternalLink]] = None - - -@dataclass -class ColumnInfo: - """Information about a column in the result set.""" - - name: str - type_name: str - type_text: str - nullable: bool = True - precision: Optional[int] = None - scale: Optional[int] = None - ordinal_position: Optional[int] = None - - -@dataclass -class ResultManifest: - """Manifest information for a result set.""" - - schema: List[ColumnInfo] - total_row_count: int - total_byte_count: int - truncated: bool = False - chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 1c519d931..7966cb502 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,111 +1,5 @@ -""" -Request models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API requests. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - - -@dataclass -class StatementParameter: - """Parameter for a SQL statement.""" - - name: str - value: Optional[str] = None - type: Optional[str] = None - - -@dataclass -class ExecuteStatementRequest: - """Request to execute a SQL statement.""" - - warehouse_id: str - statement: str - session_id: str - disposition: str = "EXTERNAL_LINKS" - format: str = "JSON_ARRAY" - wait_timeout: str = "10s" - on_wait_timeout: str = "CONTINUE" - row_limit: Optional[int] = None - byte_limit: Optional[int] = None - parameters: Optional[List[StatementParameter]] = None - catalog: Optional[str] = None - schema: Optional[str] = None - result_compression: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - result: Dict[str, Any] = { - "warehouse_id": self.warehouse_id, - "session_id": self.session_id, - "statement": self.statement, - "disposition": self.disposition, - "format": self.format, - "wait_timeout": self.wait_timeout, - "on_wait_timeout": self.on_wait_timeout, - } - - if self.row_limit is not None and self.row_limit > 0: - result["row_limit"] = self.row_limit - - if self.byte_limit is not None and self.byte_limit > 0: - result["byte_limit"] = self.byte_limit - - if self.catalog: - result["catalog"] = self.catalog - - if self.schema: - result["schema"] = self.schema - - if self.result_compression: - result["result_compression"] = self.result_compression - - if self.parameters: - result["parameters"] = [ - { - "name": param.name, - **({"value": param.value} if param.value is not None else {}), - **({"type": param.type} if param.type is not None else {}), - } - for param in self.parameters - ] - - return result - - -@dataclass -class GetStatementRequest: - """Request to get information about a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CancelStatementRequest: - """Request to cancel a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CloseStatementRequest: - """Request to close a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} +from typing import Dict, Any, Optional +from dataclasses import dataclass @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d70459b9f..1bb54590f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,96 +1,5 @@ -""" -Response models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState -from databricks.sql.backend.sea.models.base import ( - StatementStatus, - ResultManifest, - ResultData, - ServiceError, -) - - -@dataclass -class ExecuteStatementResponse: - """Response from executing a SQL statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": - """Create an ExecuteStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) - - -@dataclass -class GetStatementResponse: - """Response from getting information about a statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": - """Create a GetStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) +from typing import Dict, Any +from dataclasses import dataclass @dataclass From 686ade4fbf8e43a053b61f27220066852682167e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:40:50 +0000 Subject: [PATCH 07/66] remove excess sea backend tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 755 ++++----------------------------- 1 file changed, 94 insertions(+), 661 deletions(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index c2078f731..bc2688a68 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,20 +1,11 @@ -""" -Tests for the SEA (Statement Execution API) backend implementation. - -This module contains tests for the SeaDatabricksClient class, which implements -the Databricks SQL connector's SEA backend functionality. -""" - -import json import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import patch, MagicMock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error class TestSeaBackend: @@ -50,23 +41,6 @@ def sea_client(self, mock_http_client): return client - @pytest.fixture - def sea_session_id(self): - """Create a SEA session ID.""" - return SessionId.from_sea_session_id("test-session-123") - - @pytest.fixture - def sea_command_id(self): - """Create a SEA command ID.""" - return CommandId.from_sea_statement_id("test-statement-123") - - @pytest.fixture - def mock_cursor(self): - """Create a mock cursor.""" - cursor = Mock() - cursor.active_command_id = None - return cursor - def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -201,650 +175,109 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - # Tests for command execution and management - - def test_execute_command_sync( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command synchronously.""" - # Set up mock responses - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "schema": [ - { - "name": "col1", - "type_name": "STRING", - "type_text": "string", - "nullable": True, - } - ], - "total_row_count": 1, - "total_byte_count": 100, - }, - "result": {"data": [["value1"]]}, - } - mock_http_client._make_request.return_value = execute_response - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "warehouse_id" in kwargs["data"] - assert "session_id" in kwargs["data"] - assert "statement" in kwargs["data"] - assert kwargs["data"]["statement"] == "SELECT 1" - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-123" - - def test_execute_command_async( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command asynchronously.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-456", - "status": {"state": "PENDING"}, - } - mock_http_client._make_request.return_value = execute_response - - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=True, # Async mode - enforce_embedded_schema_correctness=False, + def test_session_configuration_helpers(self): + """Test the session configuration helper methods.""" + # Test getting default value for a supported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "ANSI_MODE" ) + assert default_value == "true" - # Verify the result is None for async operation - assert result is None - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "wait_timeout" in kwargs["data"] - assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout - - # Verify the command ID was stored in the cursor - assert hasattr(mock_cursor, "active_command_id") - assert isinstance(mock_cursor.active_command_id, CommandId) - assert mock_cursor.active_command_id.guid == "test-statement-456" - - def test_execute_command_with_polling( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that requires polling.""" - # Set up mock responses for initial request and polling - initial_response = { - "statement_id": "test-statement-789", - "status": {"state": "RUNNING"}, - } - poll_response = { - "statement_id": "test-statement-789", - "status": {"state": "SUCCEEDED"}, - "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, - "result": {"data": []}, - } - - # Configure mock to return different responses on subsequent calls - mock_http_client._make_request.side_effect = [initial_response, poll_response] - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method - result = sea_client.execute_command( - operation="SELECT * FROM large_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP requests (initial and poll) - assert mock_http_client._make_request.call_count == 2 - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-789" - - def test_execute_command_with_parameters( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command with parameters.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, + # Test getting default value for an unsupported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert default_value is None + + # Test getting the list of allowed configurations + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", } - mock_http_client._make_request.return_value = execute_response + assert set(allowed_configs) == expected_keys - # Create parameter mock - param = MagicMock() - param.name = "param1" - param.value = "value1" - param.type = "STRING" + def test_unimplemented_methods(self, sea_client): + """Test that unimplemented methods raise NotImplementedError.""" + # Create dummy parameters for testing + session_id = SessionId.from_sea_session_id("test-session") + command_id = MagicMock() + cursor = MagicMock() - # Mock the get_execution_result method - with patch.object(sea_client, "get_execution_result") as mock_get_result: - # Call the method with parameters + # Test execute_command + with pytest.raises(NotImplementedError) as excinfo: sea_client.execute_command( - operation="SELECT * FROM table WHERE col = :param1", - session_id=sea_session_id, + operation="SELECT 1", + session_id=session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=mock_cursor, + cursor=cursor, use_cloud_fetch=False, - parameters=[param], + parameters=[], async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the HTTP request contains parameters - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert "parameters" in kwargs["data"] - assert len(kwargs["data"]["parameters"]) == 1 - assert kwargs["data"]["parameters"][0]["name"] == "param1" - assert kwargs["data"]["parameters"][0]["value"] == "value1" - assert kwargs["data"]["parameters"][0]["type"] == "STRING" - - def test_execute_command_failure( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that fails.""" - # Set up mock response for a failed execution - error_response = { - "statement_id": "test-statement-123", - "status": { - "state": "FAILED", - "error": { - "message": "Syntax error in SQL", - "error_code": "SYNTAX_ERROR", - }, - }, - } - - # Configure the mock to return the error response for the initial request - # and then raise an exception when trying to poll (to simulate immediate failure) - mock_http_client._make_request.side_effect = [ - error_response, # Initial response - Error( - "Statement execution did not succeed: Syntax error in SQL" - ), # Will be raised during polling - ] - - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method and expect an error - with pytest.raises(Error) as excinfo: - sea_client.execute_command( - operation="SELECT * FROM nonexistent_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Statement execution did not succeed" in str(excinfo.value) - - def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): - """Test canceling a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.cancel_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_close_command(self, sea_client, mock_http_client, sea_command_id): - """Test closing a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.close_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "DELETE" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): - """Test getting the state of a query.""" - # Set up mock response - mock_http_client._make_request.return_value = { - "statement_id": "test-statement-123", - "status": {"state": "RUNNING"}, - } - - # Call the method - state = sea_client.get_query_state(sea_command_id) - - # Verify the result - assert state == CommandState.RUNNING - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_execution_result( - self, sea_client, mock_http_client, mock_cursor, sea_command_id - ): - """Test getting the result of a command execution.""" - # Set up mock response - sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - mock_http_client._make_request.return_value = sea_response - - # Create a real result set to verify the implementation - result = sea_client.get_execution_result(sea_command_id, mock_cursor) - print(result) - - # Verify basic properties of the result - assert result.statement_id == "test-statement-123" - assert result.status == CommandState.SUCCEEDED - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + assert "execute_command is not yet implemented" in str(excinfo.value) + + # Test cancel_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.cancel_command(command_id) + assert "cancel_command is not yet implemented" in str(excinfo.value) + + # Test close_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.close_command(command_id) + assert "close_command is not yet implemented" in str(excinfo.value) + + # Test get_query_state + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_query_state(command_id) + assert "get_query_state is not yet implemented" in str(excinfo.value) + + # Test get_execution_result + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_execution_result(command_id, cursor) + assert "get_execution_result is not yet implemented" in str(excinfo.value) + + # Test metadata operations + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(session_id, 100, 1000, cursor) + assert "get_catalogs is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(session_id, 100, 1000, cursor) + assert "get_schemas is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(session_id, 100, 1000, cursor) + assert "get_tables is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(session_id, 100, 1000, cursor) + assert "get_columns is not yet implemented" in str(excinfo.value) + + def test_max_download_threads_property(self, sea_client): + """Test the max_download_threads property.""" + assert sea_client.max_download_threads == 10 + + # Create a client with a custom value + custom_client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=20, ) - # Tests for metadata operations - - def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): - """Test getting catalogs.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method - result = sea_client.get_catalogs( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["lz4_compression"] is False - assert kwargs["use_cloud_fetch"] is False - assert kwargs["parameters"] == [] - assert kwargs["async_op"] is False - assert kwargs["enforce_embedded_schema_correctness"] is False - - def test_get_schemas_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_with_catalog_and_schema_pattern( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with catalog and schema pattern.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog and schema pattern - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting schemas without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - ) - - assert "Catalog name is required for get_schemas" in str(excinfo.value) - - def test_get_tables_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_all_catalogs( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables from all catalogs using wildcard.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with wildcard catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="*", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_schema_and_table_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with schema and table patterns.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog, schema, and table patterns - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting tables without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_tables" in str(excinfo.value) - - def test_get_columns_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_with_all_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with all patterns specified.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with all patterns - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - column_name="test_column", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting columns without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_columns" in str(excinfo.value) + # Verify the custom value is returned + assert custom_client.max_download_threads == 20 From 31e6c8305154e9c6384b422be35ac17b6f851e0c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:54:05 +0000 Subject: [PATCH 08/66] cleanup Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 8 +- src/databricks/sql/backend/types.py | 38 ++++---- src/databricks/sql/result_set.py | 91 ++++++++------------ 3 files changed, 65 insertions(+), 72 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e03d6f235..21a6befbe 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -913,7 +913,10 @@ def get_query_state(self, command_id: CommandId) -> CommandState: poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) - return CommandState.from_thrift_state(operation_state) + state = CommandState.from_thrift_state(operation_state) + if state is None: + raise ValueError(f"Invalid operation state: {operation_state}") + return state @staticmethod def _check_direct_results_for_error(t_spark_direct_results): @@ -1175,7 +1178,6 @@ def _handle_execute_response(self, resp, cursor): execute_response, arrow_schema_bytes, ) = self._results_message_to_execute_response(resp, final_operation_state) - execute_response.command_id = command_id return execute_response, arrow_schema_bytes def _handle_execute_response_async(self, resp, cursor): @@ -1237,7 +1239,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(command_id.guid)) + logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 5bf02e0ea..3107083fb 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -285,9 +285,6 @@ def __init__( backend_type: BackendType, guid: Any, secret: Optional[Any] = None, - operation_type: Optional[int] = None, - has_result_set: bool = False, - modified_row_count: Optional[int] = None, ): """ Initialize a CommandId. @@ -296,17 +293,34 @@ def __init__( backend_type: The type of backend (THRIFT or SEA) guid: The primary identifier for the command secret: The secret part of the identifier (only used for Thrift) - operation_type: The operation type (only used for Thrift) - has_result_set: Whether the command has a result set - modified_row_count: The number of rows modified by the command """ self.backend_type = backend_type self.guid = guid self.secret = secret - self.operation_type = operation_type - self.has_result_set = has_result_set - self.modified_row_count = modified_row_count + + + def __str__(self) -> str: + """ + Return a string representation of the CommandId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the command ID + """ + + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.to_hex_guid()}|{secret_hex}" + return str(self.guid) @classmethod def from_thrift_handle(cls, operation_handle): @@ -329,9 +343,6 @@ def from_thrift_handle(cls, operation_handle): BackendType.THRIFT, guid_bytes, secret_bytes, - operation_handle.operationType, - operation_handle.hasResultSet, - operation_handle.modifiedRowCount, ) @classmethod @@ -364,9 +375,6 @@ def to_thrift_handle(self): handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) return ttypes.TOperationHandle( operationId=handle_identifier, - operationType=self.operation_type, - hasResultSet=self.has_result_set, - modifiedRowCount=self.modified_row_count, ) def to_sea_statement_id(self): diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index fc8595839..12ee129cf 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -5,6 +5,8 @@ import time import pandas +from databricks.sql.backend.sea.backend import SeaDatabricksClient + try: import pyarrow except ImportError: @@ -13,6 +15,7 @@ if TYPE_CHECKING: from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection +from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError @@ -31,21 +34,37 @@ class ResultSet(ABC): def __init__( self, - connection, - backend, + connection: "Connection", + backend: "DatabricksClient", arraysize: int, buffer_size_bytes: int, - command_id=None, - status=None, + command_id: CommandId, + status: CommandState, has_been_closed_server_side: bool = False, has_more_rows: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, ): - """Initialize the base ResultSet with common properties.""" + """ + A ResultSet manages the results of a single command. + + Args: + connection: The parent connection + backend: The backend client + arraysize: The max number of rows to fetch at a time (PEP-249) + buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + command_id: The command ID + status: The command status + has_been_closed_server_side: Whether the command has been closed on the server + has_more_rows: Whether the command has more rows + results_queue: The results queue + description: column description of the results + is_staging_operation: Whether the command is a staging operation + """ + self.connection = connection - self.backend = backend # Store the backend client directly + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -240,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2): + def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result :param result1: @@ -387,12 +406,11 @@ class SeaResultSet(ResultSet): def __init__( self, - connection, - sea_client, + connection: "Connection", + execute_response: "ExecuteResponse", + sea_client: "SeaDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, - execute_response=None, - sea_response=None, ): """ Initialize a SeaResultSet with the response from a SEA query execution. @@ -402,56 +420,21 @@ def __init__( sea_client: The SeaDatabricksClient instance for direct access buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) + execute_response: Response from the execute command """ - # Handle both initialization styles - if execute_response is not None: - # New style with ExecuteResponse - command_id = execute_response.command_id - status = execute_response.status - has_been_closed_server_side = execute_response.has_been_closed_server_side - has_more_rows = execute_response.has_more_rows - results_queue = execute_response.results_queue - description = execute_response.description - is_staging_operation = execute_response.is_staging_operation - self._response = getattr(execute_response, "sea_response", {}) - self.statement_id = command_id.to_sea_statement_id() if command_id else None - elif sea_response is not None: - # Legacy style with direct sea_response - self._response = sea_response - # Extract values from sea_response - command_id = CommandId.from_sea_statement_id( - sea_response.get("statement_id", "") - ) - self.statement_id = sea_response.get("statement_id", "") - - # Extract status - status_data = sea_response.get("status", {}) - status = CommandState.from_sea_state(status_data.get("state", "PENDING")) - - # Set defaults for other fields - has_been_closed_server_side = False - has_more_rows = False - results_queue = None - description = None - is_staging_operation = False - else: - raise ValueError("Either execute_response or sea_response must be provided") - # Call parent constructor with common attributes super().__init__( connection=connection, backend=sea_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, - command_id=command_id, - status=status, - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - results_queue=results_queue, - description=description, - is_staging_operation=is_staging_operation, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, ) def _fill_results_buffer(self): From 69ea23811e03705998baba569bcda259a0646de5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:56:09 +0000 Subject: [PATCH 09/66] re-introduce get_schema_desc Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 1 - src/databricks/sql/result_set.py | 21 +++++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3107083fb..7a276c102 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -299,7 +299,6 @@ def __init__( self.guid = guid self.secret = secret - def __str__(self) -> str: """ Return a string representation of the CommandId. diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 12ee129cf..1fee995e5 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -59,12 +59,12 @@ def __init__( has_been_closed_server_side: Whether the command has been closed on the server has_more_rows: Whether the command has more rows results_queue: The results queue - description: column description of the results + description: column description of the results is_staging_operation: Whether the command is a staging operation """ self.connection = connection - self.backend = backend + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -400,6 +400,23 @@ def fetchmany(self, size: int) -> List[Row]: else: return self._convert_arrow_table(self.fetchmany_arrow(size)) + @staticmethod + def _get_schema_description(table_schema_message): + """ + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + """ + + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ + + return [ + (column.name, map_col_type(column.datatype), None, None, None, None, None) + for column in table_schema_message.columns + ] + class SeaResultSet(ResultSet): """ResultSet implementation for SEA backend.""" From 66d75171991f9fcc98d541729a3127aea0d37a81 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:57:53 +0000 Subject: [PATCH 10/66] remove SeaResultSet Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 72 -------------------------------- 1 file changed, 72 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 1fee995e5..eaabcc186 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -416,75 +416,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for SEA backend.""" - - def __init__( - self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - execute_response: Response from the execute command - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") From 71feef96b3a41889a5cd9313fc81910cebd7a084 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:01:22 +0000 Subject: [PATCH 11/66] clean imports and attributes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/databricks_client.py | 1 + src/databricks/sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/result_set.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index cd347d9ab..8fda71e1e 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -88,6 +88,7 @@ def execute_command( ) -> Union["ResultSet", None]: """ Executes a SQL command or query within the specified session. + This method sends a SQL command to the server for execution and handles the response. It can operate in both synchronous and asynchronous modes. diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index f0b931ee4..fe292919c 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, Union, List, Tuple +from typing import Callable, Dict, Any, Optional, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index eaabcc186..a33fc977d 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -72,7 +72,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows self.results = results_queue self._is_staging_operation = is_staging_operation From ae9862f90e7cf0a4949d6b1c7e04fdbae222c2d8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:05:53 +0000 Subject: [PATCH 12/66] pass CommandId to ExecResp Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 7 ++++++- src/databricks/sql/result_set.py | 10 +++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 21a6befbe..316cf24a0 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,6 +352,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -866,9 +867,13 @@ def get_execution_result( ssl_options=self._ssl_options, ) + status = CommandState.from_thrift_state(resp.status) + if status is None: + raise ValueError(f"Invalid operation state: {resp.status}") + execute_response = ExecuteResponse( command_id=command_id, - status=resp.status, + status=status, description=description, has_more_rows=has_more_rows, results_queue=queue, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a33fc977d..a0cb73732 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) From d8aa69e40438c33014e0d5afaec6a4175e64bea8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:08:04 +0000 Subject: [PATCH 13/66] remove changes in types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 57 +++++++++-------------------- 1 file changed, 17 insertions(+), 40 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 7a276c102..9cd21b5e6 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,6 +1,5 @@ -from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict, Optional, Any import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -81,28 +80,6 @@ def from_thrift_state( else: return None - @classmethod - def from_sea_state(cls, state: str) -> Optional["CommandState"]: - """ - Map SEA state string to CommandState enum. - - Args: - state: SEA state string - - Returns: - CommandState: The corresponding CommandState enum value - """ - state_mapping = { - "PENDING": cls.PENDING, - "RUNNING": cls.RUNNING, - "SUCCEEDED": cls.SUCCEEDED, - "FAILED": cls.FAILED, - "CLOSED": cls.CLOSED, - "CANCELED": cls.CANCELLED, - } - - return state_mapping.get(state, None) - class BackendType(Enum): """ @@ -285,6 +262,9 @@ def __init__( backend_type: BackendType, guid: Any, secret: Optional[Any] = None, + operation_type: Optional[int] = None, + has_result_set: bool = False, + modified_row_count: Optional[int] = None, ): """ Initialize a CommandId. @@ -293,11 +273,17 @@ def __init__( backend_type: The type of backend (THRIFT or SEA) guid: The primary identifier for the command secret: The secret part of the identifier (only used for Thrift) + operation_type: The operation type (only used for Thrift) + has_result_set: Whether the command has a result set + modified_row_count: The number of rows modified by the command """ self.backend_type = backend_type self.guid = guid self.secret = secret + self.operation_type = operation_type + self.has_result_set = has_result_set + self.modified_row_count = modified_row_count def __str__(self) -> str: """ @@ -332,6 +318,7 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ + if operation_handle is None: return None @@ -342,6 +329,9 @@ def from_thrift_handle(cls, operation_handle): BackendType.THRIFT, guid_bytes, secret_bytes, + operation_handle.operationType, + operation_handle.hasResultSet, + operation_handle.modifiedRowCount, ) @classmethod @@ -374,6 +364,9 @@ def to_thrift_handle(self): handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) return ttypes.TOperationHandle( operationId=handle_identifier, + operationType=self.operation_type, + hasResultSet=self.has_result_set, + modifiedRowCount=self.modified_row_count, ) def to_sea_statement_id(self): @@ -401,19 +394,3 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) - - -@dataclass -class ExecuteResponse: - """Response from executing a SQL command.""" - - command_id: CommandId - status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None - has_more_rows: bool = False - results_queue: Optional[Any] = None - has_been_closed_server_side: bool = False - lz4_compressed: bool = True - is_staging_operation: bool = False From db139bc1179bb7cab6ec6f283cdfa0646b04b01b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:09:35 +0000 Subject: [PATCH 14/66] add back essential types (ExecResponse, from_sea_state) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 39 ++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..958eaa289 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,27 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + Args: + state: SEA state string + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + + class BackendType(Enum): """ @@ -394,3 +416,18 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False \ No newline at end of file From b977b1210a5d39543b8a3734128ba820e597337f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:11:23 +0000 Subject: [PATCH 15/66] fix fetch types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 4 ++-- src/databricks/sql/result_set.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 958eaa289..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -102,7 +102,6 @@ def from_sea_state(cls, state: str) -> Optional["CommandState"]: return state_mapping.get(state, None) - class BackendType(Enum): """ Enum representing the type of backend @@ -417,6 +416,7 @@ def to_hex_guid(self) -> str: else: return str(self.guid) + @dataclass class ExecuteResponse: """Response from executing a SQL command.""" @@ -430,4 +430,4 @@ class ExecuteResponse: results_queue: Optional[Any] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True - is_staging_operation: bool = False \ No newline at end of file + is_staging_operation: bool = False diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a0cb73732..e177d495f 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> Any: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> Any: + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all remaining rows as an Arrow table.""" pass From da615c0db8ba2037c106b533331cf1ca1c9f49f9 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:12:45 +0000 Subject: [PATCH 16/66] excess imports Signed-off-by: varun-edachali-dbx --- tests/unit/test_session.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fef070362..92de8d8fd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,11 +2,6 @@ from unittest.mock import patch, Mock import gc -from databricks.sql.thrift_api.TCLIService.ttypes import ( - TOpenSessionResp, - TSessionHandle, - THandleIdentifier, -) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From 0da04a6f1086998927a28759fc67da4e2c8c71c6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:15:59 +0000 Subject: [PATCH 17/66] reduce diff by maintaining logs Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 316cf24a0..821559ad3 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -800,7 +800,7 @@ def _results_message_to_execute_response(self, resp, operation_state): status = CommandState.from_thrift_state(operation_state) if status is None: - raise ValueError(f"Invalid operation state: {operation_state}") + raise ValueError(f"Unknown command state: {operation_state}") return ( ExecuteResponse( From ea9d456ee9ca47434618a079698fa166b6c8a308 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 08:47:54 +0000 Subject: [PATCH 18/66] fix int test types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 4 +--- tests/e2e/common/retry_test_mixins.py | 2 +- tests/e2e/test_driver.py | 6 +++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 821559ad3..f41f4b6d8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -867,9 +867,7 @@ def get_execution_result( ssl_options=self._ssl_options, ) - status = CommandState.from_thrift_state(resp.status) - if status is None: - raise ValueError(f"Invalid operation state: {resp.status}") + status = self.get_query_state(command_id) execute_response = ExecuteResponse( command_id=command_id, diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index b5d01a45d..dd509c062 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -326,7 +326,7 @@ def test_retry_abort_close_operation_on_404(self, caplog): with self.connection(extra_params={**self._retry_policy}) as conn: with conn.cursor() as curs: with patch( - "databricks.sql.utils.ExecuteResponse.has_been_closed_server_side", + "databricks.sql.backend.types.ExecuteResponse.has_been_closed_server_side", new_callable=PropertyMock, return_value=False, ): diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 22897644f..8cfed7c28 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -933,12 +933,12 @@ def test_result_set_close(self): result_set = cursor.active_result_set assert result_set is not None - initial_op_state = result_set.op_state + initial_op_state = result_set.status result_set.close() - assert result_set.op_state == CommandState.CLOSED - assert result_set.op_state != initial_op_state + assert result_set.status == CommandState.CLOSED + assert result_set.status != initial_op_state # Closing the result set again should be a no-op and not raise exceptions result_set.close() From 8985c624bcdbb7e0abfa73b7a1a2dbad15b4e1ec Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 08:55:24 +0000 Subject: [PATCH 19/66] [squashed from exec-sea] init execution func Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 ++- .../sql/backend/databricks_client.py | 28 - src/databricks/sql/backend/filters.py | 143 ++++ src/databricks/sql/backend/sea/backend.py | 360 ++++++++- .../sql/backend/sea/models/__init__.py | 30 + src/databricks/sql/backend/sea/models/base.py | 68 ++ .../sql/backend/sea/models/requests.py | 110 ++- .../sql/backend/sea/models/responses.py | 95 ++- src/databricks/sql/backend/thrift_backend.py | 3 +- src/databricks/sql/backend/types.py | 25 +- src/databricks/sql/result_set.py | 118 ++- tests/unit/test_result_set_filter.py | 246 ++++++ tests/unit/test_sea_backend.py | 755 +++++++++++++++--- tests/unit/test_sea_result_set.py | 275 +++++++ tests/unit/test_session.py | 5 + 15 files changed, 2166 insertions(+), 219 deletions(-) create mode 100644 src/databricks/sql/backend/filters.py create mode 100644 src/databricks/sql/backend/sea/models/base.py create mode 100644 tests/unit/test_result_set_filter.py create mode 100644 tests/unit/test_sea_result_set.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..87b62efea 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,34 +6,122 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) + +def test_sea_query_exec(): + """ + Test executing a query using the SEA backend with result compression. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with result compression enabled and disabled, + and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + sys.exit(1) + + try: + # Test with compression enabled + logger.info("Creating connection with LZ4 compression enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, # Enable cloud fetch to use compression + enable_query_result_lz4_compression=True, # Enable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"backend type: {type(connection.session.backend)}") + + # Execute a simple query with compression enabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query with compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression enabled") + + # Test with compression disabled + logger.info("Creating connection with LZ4 compression disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, # Enable cloud fetch + enable_query_result_lz4_compression=False, # Disable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query with compression disabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query without compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query without compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression disabled") + + except Exception as e: + logger.error(f"Error during SEA query execution test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + sys.exit(1) + + logger.info("SEA query execution test with compression completed successfully") + + def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -42,25 +130,33 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent + use_sea=True, + user_agent_entry="SEA-Test-Client", # add custom user agent + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback + logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") + if __name__ == "__main__": + # Test session management test_sea_session() + + # Test query execution with compression + test_sea_query_exec() diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 8fda71e1e..bbca4c502 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -86,34 +86,6 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: - """ - Executes a SQL command or query within the specified session. - - This method sends a SQL command to the server for execution and handles - the response. It can operate in both synchronous and asynchronous modes. - - Args: - operation: The SQL command or query to execute - session_id: The session identifier in which to execute the command - max_rows: Maximum number of rows to fetch in a single fetch batch - max_bytes: Maximum number of bytes to fetch in a single fetch batch - lz4_compression: Whether to use LZ4 compression for result data - cursor: The cursor object that will handle the results - use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets - parameters: List of parameters to bind to the query - async_op: Whether to execute the command asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - If async_op is False, returns a ResultSet object containing the - query results and metadata. If async_op is True, returns None and the - results must be fetched later using get_execution_result(). - - Raises: - ValueError: If the session ID is invalid - OperationalError: If there's an error executing the command - ServerOperationError: If the server encounters an error during execution - """ pass @abstractmethod diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 000000000..7f48b6179 --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,143 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + +from databricks.sql.result_set import SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Create a filtered version of the result set + filtered_response = result_set._response.copy() + + # If there's a result with rows, filter them + if ( + "result" in filtered_response + and "data_array" in filtered_response["result"] + ): + rows = filtered_response["result"]["data_array"] + filtered_rows = [row for row in rows if filter_func(row)] + filtered_response["result"]["data_array"] = filtered_rows + + # Update row count if present + if "row_count" in filtered_response["result"]: + filtered_response["result"]["row_count"] = len(filtered_rows) + + # Create a new result set with the filtered data + return SeaResultSet( + connection=result_set.connection, + sea_response=filtered_response, + sea_client=result_set.backend, + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + ) + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is typically in the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=False + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..c7a4ed1b1 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,34 @@ import logging import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +import uuid +import time +from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +66,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -274,41 +288,222 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else "NONE" + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows if max_rows > 0 else None, + byte_limit=max_bytes if max_bytes > 0 else None, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + return SeaResultSet( + connection=cursor.connection, + sea_response=response_data, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == @@ -319,9 +514,22 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -331,9 +539,30 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -345,9 +574,43 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -359,6 +622,33 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index c9310d367..b7c8bd399 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,19 +4,49 @@ This package contains data models for SEA API requests and responses. """ +from databricks.sql.backend.sea.models.base import ( + ServiceError, + StatementStatus, + ExternalLink, + ResultData, + ColumnInfo, + ResultManifest, +) + from databricks.sql.backend.sea.models.requests import ( + StatementParameter, + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) __all__ = [ + # Base models + "ServiceError", + "StatementStatus", + "ExternalLink", + "ResultData", + "ColumnInfo", + "ResultManifest", # Request models + "StatementParameter", + "ExecuteStatementRequest", + "GetStatementRequest", + "CancelStatementRequest", + "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models + "ExecuteStatementResponse", + "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py new file mode 100644 index 000000000..671f7be13 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/base.py @@ -0,0 +1,68 @@ +""" +Base models for the SEA (Statement Execution API) backend. + +These models define the common structures used in SEA API requests and responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState + + +@dataclass +class ServiceError: + """Error information returned by the SEA API.""" + + message: str + error_code: Optional[str] = None + + +@dataclass +class StatementStatus: + """Status information for a statement execution.""" + + state: CommandState + error: Optional[ServiceError] = None + sql_state: Optional[str] = None + + +@dataclass +class ExternalLink: + """External link information for result data.""" + + external_link: str + expiration: str + chunk_index: int + + +@dataclass +class ResultData: + """Result data from a statement execution.""" + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + + +@dataclass +class ColumnInfo: + """Information about a column in the result set.""" + + name: str + type_name: str + type_text: str + nullable: bool = True + precision: Optional[int] = None + scale: Optional[int] = None + ordinal_position: Optional[int] = None + + +@dataclass +class ResultManifest: + """Manifest information for a result set.""" + + schema: List[ColumnInfo] + total_row_count: int + total_byte_count: int + truncated: bool = False + chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 7966cb502..1c519d931 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,5 +1,111 @@ -from typing import Dict, Any, Optional -from dataclasses import dataclass +""" +Request models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API requests. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + + +@dataclass +class StatementParameter: + """Parameter for a SQL statement.""" + + name: str + value: Optional[str] = None + type: Optional[str] = None + + +@dataclass +class ExecuteStatementRequest: + """Request to execute a SQL statement.""" + + warehouse_id: str + statement: str + session_id: str + disposition: str = "EXTERNAL_LINKS" + format: str = "JSON_ARRAY" + wait_timeout: str = "10s" + on_wait_timeout: str = "CONTINUE" + row_limit: Optional[int] = None + byte_limit: Optional[int] = None + parameters: Optional[List[StatementParameter]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + result_compression: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "warehouse_id": self.warehouse_id, + "session_id": self.session_id, + "statement": self.statement, + "disposition": self.disposition, + "format": self.format, + "wait_timeout": self.wait_timeout, + "on_wait_timeout": self.on_wait_timeout, + } + + if self.row_limit is not None and self.row_limit > 0: + result["row_limit"] = self.row_limit + + if self.byte_limit is not None and self.byte_limit > 0: + result["byte_limit"] = self.byte_limit + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + if self.result_compression: + result["result_compression"] = self.result_compression + + if self.parameters: + result["parameters"] = [ + { + "name": param.name, + **({"value": param.value} if param.value is not None else {}), + **({"type": param.type} if param.type is not None else {}), + } + for param in self.parameters + ] + + return result + + +@dataclass +class GetStatementRequest: + """Request to get information about a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CancelStatementRequest: + """Request to cancel a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CloseStatementRequest: + """Request to close a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1bb54590f..d70459b9f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,5 +1,96 @@ -from typing import Dict, Any -from dataclasses import dataclass +""" +Response models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState +from databricks.sql.backend.sea.models.base import ( + StatementStatus, + ResultManifest, + ResultData, + ServiceError, +) + + +@dataclass +class ExecuteStatementResponse: + """Response from executing a SQL statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": + """Create an ExecuteStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) + + +@dataclass +class GetStatementResponse: + """Response from getting information about a statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": + """Create a GetStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) @dataclass diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f41f4b6d8..810c2e7a1 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,7 +352,6 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ - # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -1242,7 +1241,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + logger.debug("Cancelling command {}".format(command_id.guid)) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3d24a09a1..5bf02e0ea 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -85,8 +85,10 @@ def from_thrift_state( def from_sea_state(cls, state: str) -> Optional["CommandState"]: """ Map SEA state string to CommandState enum. + Args: state: SEA state string + Returns: CommandState: The corresponding CommandState enum value """ @@ -306,28 +308,6 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count - def __str__(self) -> str: - """ - Return a string representation of the CommandId. - - For SEA backend, returns the guid. - For Thrift backend, returns a format like "guid|secret". - - Returns: - A string representation of the command ID - """ - - if self.backend_type == BackendType.SEA: - return str(self.guid) - elif self.backend_type == BackendType.THRIFT: - secret_hex = ( - guid_to_hex_id(self.secret) - if isinstance(self.secret, bytes) - else str(self.secret) - ) - return f"{self.to_hex_guid()}|{secret_hex}" - return str(self.guid) - @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -339,7 +319,6 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ - if operation_handle is None: return None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index e177d495f..a4beda629 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -64,7 +64,7 @@ def __init__( """ self.connection = connection - self.backend = backend + self.backend = backend # Store the backend client directly self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> "pyarrow.Table": + def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" pass @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self._has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -259,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": + def merge_columnar(self, result1, result2): """ Function to merge / combining the columnar results into a single result :param result1: @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -403,16 +403,96 @@ def fetchmany(self, size: int) -> List[Row]: @staticmethod def _get_schema_description(table_schema_message): """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) """ + # Handle both initialization styles + if execute_response is not None: + # New style with ExecuteResponse + command_id = execute_response.command_id + status = execute_response.status + has_been_closed_server_side = execute_response.has_been_closed_server_side + has_more_rows = execute_response.has_more_rows + results_queue = execute_response.results_queue + description = execute_response.description + is_staging_operation = execute_response.is_staging_operation + self._response = getattr(execute_response, "sea_response", {}) + self.statement_id = command_id.to_sea_statement_id() if command_id else None + elif sea_response is not None: + # Legacy style with direct sea_response + self._response = sea_response + # Extract values from sea_response + command_id = CommandId.from_sea_statement_id( + sea_response.get("statement_id", "") + ) + self.statement_id = sea_response.get("statement_id", "") + + # Extract status + status_data = sea_response.get("status", {}) + status = CommandState.from_sea_state(status_data.get("state", "PENDING")) + + # Set defaults for other fields + has_been_closed_server_side = False + has_more_rows = False + results_queue = None + description = None + is_staging_operation = False + else: + raise ValueError("Either execute_response or sea_response must be provided") - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=status, + has_been_closed_server_side=has_been_closed_server_side, + has_more_rows=has_more_rows, + results_queue=results_queue, + description=description, + is_staging_operation=is_staging_operation, + ) - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py new file mode 100644 index 000000000..e8eb2a757 --- /dev/null +++ b/tests/unit/test_result_set_filter.py @@ -0,0 +1,246 @@ +""" +Tests for the ResultSetFilter class. + +This module contains tests for the ResultSetFilter class, which provides +filtering capabilities for result sets returned by different backends. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestResultSetFilter: + """Test suite for the ResultSetFilter class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response_with_tables(self): + """Create a sample SEA response with table data based on the server schema.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 7, + "columns": [ + {"name": "namespace", "type_text": "STRING", "position": 0}, + {"name": "tableName", "type_text": "STRING", "position": 1}, + {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, + {"name": "information", "type_text": "STRING", "position": 3}, + {"name": "catalogName", "type_text": "STRING", "position": 4}, + {"name": "tableType", "type_text": "STRING", "position": 5}, + {"name": "remarks", "type_text": "STRING", "position": 6}, + ], + }, + "total_row_count": 4, + }, + "result": { + "data_array": [ + ["schema1", "table1", False, None, "catalog1", "TABLE", None], + ["schema1", "table2", False, None, "catalog1", "VIEW", None], + [ + "schema1", + "table3", + False, + None, + "catalog1", + "SYSTEM TABLE", + None, + ], + ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], + ] + }, + } + + @pytest.fixture + def sea_result_set_with_tables( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Create a SeaResultSet with table data.""" + # Create a deep copy of the response to avoid test interference + import copy + + sea_response_copy = copy.deepcopy(sea_response_with_tables) + + return SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy, + buffer_size_bytes=1000, + arraysize=100, + ) + + def test_filter_tables_by_type_default(self, sea_result_set_with_tables): + """Test filtering tables by type with default types.""" + # Default types are TABLE, VIEW, SYSTEM TABLE + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables + ) + + # Verify that only the default types are included + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): + """Test filtering tables by type with custom types.""" + # Filter for only TABLE and EXTERNAL + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] + ) + + # Verify that only the specified types are included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "EXTERNAL" in table_types + assert "VIEW" not in table_types + assert "SYSTEM TABLE" not in table_types + + def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): + """Test that table type filtering is case-insensitive.""" + # Filter for lowercase "table" and "view" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["table", "view"] + ) + + # Verify that the matching types are included despite case differences + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" not in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): + """Test filtering tables with an empty type list (should use defaults).""" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=[] + ) + + # Verify that default types are used + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_by_column_values(self, sea_result_set_with_tables): + """Test filtering by values in a specific column.""" + # Filter by namespace in column index 0 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] + ) + + # All rows have schema1 in namespace, so all should be included + assert len(filtered_result._response["result"]["data_array"]) == 4 + + # Filter by table name in column index 1 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, + column_index=1, + allowed_values=["table1", "table3"], + ) + + # Only rows with table1 or table3 should be included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_names = [ + row[1] for row in filtered_result._response["result"]["data_array"] + ] + assert "table1" in table_names + assert "table3" in table_names + assert "table2" not in table_names + assert "table4" not in table_names + + def test_filter_by_column_values_case_sensitive( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Test case-sensitive filtering by column values.""" + import copy + + # Create a fresh result set for the first test + sea_response_copy1 = copy.deepcopy(sea_response_with_tables) + result_set1 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy1, + buffer_size_bytes=1000, + arraysize=100, + ) + + # First test: Case-sensitive filtering with lowercase values (should find no matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set1, + column_index=5, # tableType column + allowed_values=["table", "view"], # lowercase + case_sensitive=True, + ) + + # Verify no matches with lowercase values + assert len(filtered_result._response["result"]["data_array"]) == 0 + + # Create a fresh result set for the second test + sea_response_copy2 = copy.deepcopy(sea_response_with_tables) + result_set2 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy2, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Second test: Case-sensitive filtering with correct case (should find matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set2, + column_index=5, # tableType column + allowed_values=["TABLE", "VIEW"], # correct case + case_sensitive=True, + ) + + # Verify matches with correct case + assert len(filtered_result._response["result"]["data_array"]) == 2 + + # Extract the table types from the filtered results + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + + def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): + """Test filtering with a column index that's out of bounds.""" + # Filter by column index 10 (out of bounds) + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=10, allowed_values=["value"] + ) + + # No rows should match + assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..c2078f731 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,650 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, } - assert set(allowed_configs) == expected_keys + mock_http_client._make_request.return_value = execute_response - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + # Tests for metadata operations + + def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): + """Test getting catalogs.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["lz4_compression"] is False + assert kwargs["use_cloud_fetch"] is False + assert kwargs["parameters"] == [] + assert kwargs["async_op"] is False + assert kwargs["enforce_embedded_schema_correctness"] is False + + def test_get_schemas_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_with_catalog_and_schema_pattern( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with catalog and schema pattern.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog and schema pattern + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting schemas without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + ) + + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_all_catalogs( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables from all catalogs using wildcard.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_schema_and_table_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with schema and table patterns.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog, schema, and table patterns + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting tables without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_with_all_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with all patterns specified.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with all patterns + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting columns without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_columns" in str(excinfo.value) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py new file mode 100644 index 000000000..f666fd613 --- /dev/null +++ b/tests/unit/test_sea_result_set.py @@ -0,0 +1,275 @@ +""" +Tests for the SeaResultSet class. + +This module contains tests for the SeaResultSet class, which implements +the result set functionality for the SEA (Statement Execution API) backend. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestSeaResultSet: + """Test suite for the SeaResultSet class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response(self): + """Create a sample SEA response.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + + @pytest.fixture + def execute_response(self): + """Create a sample execute response.""" + mock_response = Mock() + mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") + mock_response.status = CommandState.SUCCEEDED + mock_response.has_been_closed_server_side = False + mock_response.has_more_rows = False + mock_response.results_queue = None + mock_response.description = [ + ("test_value", "INT", None, None, None, None, None) + ] + mock_response.is_staging_operation = False + mock_response.sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "result": {"data_array": [["1"]]}, + } + return mock_response + + def test_init_with_sea_response( + self, mock_connection, mock_sea_client, sea_response + ): + """Test initializing SeaResultSet with a SEA response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == sea_response + + def test_init_with_execute_response( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with an execute response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + execute_response=execute_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == execute_response.sea_response + + def test_init_with_no_response(self, mock_connection, mock_sea_client): + """Test that initialization fails when neither response type is provided.""" + with pytest.raises(ValueError) as excinfo: + SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + assert "Either execute_response or sea_response must be provided" in str( + excinfo.value + ) + + def test_close(self, mock_connection, mock_sea_client, sea_response): + """Test closing a result set.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was called + mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_already_closed_server_side( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set that has already been closed server-side.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_connection_closed( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set when the connection is closed.""" + mock_connection.open = False + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_unimplemented_methods( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that unimplemented methods raise NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() + + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() + + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) + + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass + + def test_fill_results_buffer_not_implemented( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that _fill_results_buffer raises NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set._fill_results_buffer() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 92de8d8fd..fef070362 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,6 +2,11 @@ from unittest.mock import patch, Mock import gc +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + THandleIdentifier, +) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From d9bcdbef396433e01b298fca9a27b1bce2b1414b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:03:13 +0000 Subject: [PATCH 20/66] remove irrelevant changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 +----- .../sql/backend/databricks_client.py | 30 ++ src/databricks/sql/backend/sea/backend.py | 360 ++---------------- .../sql/backend/sea/models/__init__.py | 30 -- src/databricks/sql/backend/sea/models/base.py | 68 ---- .../sql/backend/sea/models/requests.py | 110 +----- .../sql/backend/sea/models/responses.py | 95 +---- src/databricks/sql/backend/types.py | 64 ++-- 8 files changed, 107 insertions(+), 774 deletions(-) delete mode 100644 src/databricks/sql/backend/sea/models/base.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 87b62efea..abe6bd1ab 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,122 +6,34 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) - -def test_sea_query_exec(): - """ - Test executing a query using the SEA backend with result compression. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with result compression enabled and disabled, - and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - sys.exit(1) - - try: - # Test with compression enabled - logger.info("Creating connection with LZ4 compression enabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, # Enable cloud fetch to use compression - enable_query_result_lz4_compression=True, # Enable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - logger.info(f"backend type: {type(connection.session.backend)}") - - # Execute a simple query with compression enabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query with compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression enabled") - - # Test with compression disabled - logger.info("Creating connection with LZ4 compression disabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, # Enable cloud fetch - enable_query_result_lz4_compression=False, # Disable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query with compression disabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query without compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query without compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression disabled") - - except Exception as e: - logger.error(f"Error during SEA query execution test: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - sys.exit(1) - - logger.info("SEA query execution test with compression completed successfully") - - def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -130,33 +42,25 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", # add custom user agent - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent ) + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback - logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") - if __name__ == "__main__": - # Test session management test_sea_session() - - # Test query execution with compression - test_sea_query_exec() diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index bbca4c502..20b059fa7 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,6 +16,8 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState +from databricks.sql.utils import ExecuteResponse +from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING @@ -86,6 +88,34 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: + """ + Executes a SQL command or query within the specified session. + + This method sends a SQL command to the server for execution and handles + the response. It can operate in both synchronous and asynchronous modes. + + Args: + operation: The SQL command or query to execute + session_id: The session identifier in which to execute the command + max_rows: Maximum number of rows to fetch in a single fetch batch + max_bytes: Maximum number of bytes to fetch in a single fetch batch + lz4_compression: Whether to use LZ4 compression for result data + cursor: The cursor object that will handle the results + use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets + parameters: List of parameters to bind to the query + async_op: Whether to execute the command asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + If async_op is False, returns a ResultSet object containing the + query results and metadata. If async_op is True, returns None and the + results must be fetched later using get_execution_result(). + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error executing the command + ServerOperationError: If the server encounters an error during execution + """ pass @abstractmethod diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index c7a4ed1b1..97d25a058 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,34 +1,23 @@ import logging import re -import uuid -import time -from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING - -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.exc import ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, - StatementParameter, - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) @@ -66,9 +55,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths @@ -288,222 +274,41 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List, + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: - """ - Execute a SQL command using the SEA backend. - - Args: - operation: SQL command to execute - session_id: Session identifier - max_rows: Maximum number of rows to fetch - max_bytes: Maximum number of bytes to fetch - lz4_compression: Whether to use LZ4 compression - cursor: Cursor executing the command - use_cloud_fetch: Whether to use cloud fetch - parameters: SQL parameters - async_op: Whether to execute asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - ResultSet: A SeaResultSet instance for the executed command - """ - if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") - - sea_session_id = session_id.to_sea_session_id() - - # Convert parameters to StatementParameter objects - sea_parameters = [] - if parameters: - for param in parameters: - sea_parameters.append( - StatementParameter( - name=param.name, - value=param.value, - type=param.type if hasattr(param, "type") else None, - ) - ) - - format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" - disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" - result_compression = "LZ4_FRAME" if lz4_compression else "NONE" - - request = ExecuteStatementRequest( - warehouse_id=self.warehouse_id, - session_id=sea_session_id, - statement=operation, - disposition=disposition, - format=format, - wait_timeout="0s" if async_op else "10s", - on_wait_timeout="CONTINUE", - row_limit=max_rows if max_rows > 0 else None, - byte_limit=max_bytes if max_bytes > 0 else None, - parameters=sea_parameters if sea_parameters else None, - result_compression=result_compression, - ) - - response_data = self.http_client._make_request( - method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ): + """Not implemented yet.""" + raise NotImplementedError( + "execute_command is not yet implemented for SEA backend" ) - response = ExecuteStatementResponse.from_dict(response_data) - statement_id = response.statement_id - if not statement_id: - raise ServerOperationError( - "Failed to execute command: No statement ID returned", - { - "operation-id": None, - "diagnostic-info": None, - }, - ) - - command_id = CommandId.from_sea_statement_id(statement_id) - - # Store the command ID in the cursor - cursor.active_command_id = command_id - - # If async operation, return None and let the client poll for results - if async_op: - return None - - # For synchronous operation, wait for the statement to complete - # Poll until the statement is done - status = response.status - state = status.state - - # Keep polling until we reach a terminal state - while state in [CommandState.PENDING, CommandState.RUNNING]: - time.sleep(0.5) # add a small delay to avoid excessive API calls - state = self.get_query_state(command_id) - - if state != CommandState.SUCCEEDED: - raise ServerOperationError( - f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", - { - "operation-id": command_id.to_sea_statement_id(), - "diagnostic-info": None, - }, - ) - - return self.get_execution_result(command_id, cursor) def cancel_command(self, command_id: CommandId) -> None: - """ - Cancel a running command. - - Args: - command_id: Command identifier to cancel - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CancelStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="POST", - path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "cancel_command is not yet implemented for SEA backend" ) def close_command(self, command_id: CommandId) -> None: - """ - Close a command and release resources. - - Args: - command_id: Command identifier to close - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CloseStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="DELETE", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "close_command is not yet implemented for SEA backend" ) def get_query_state(self, command_id: CommandId) -> CommandState: - """ - Get the state of a running query. - - Args: - command_id: Command identifier - - Returns: - CommandState: The current state of the command - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = GetStatementRequest(statement_id=sea_statement_id) - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "get_query_state is not yet implemented for SEA backend" ) - # Parse the response - response = GetStatementResponse.from_dict(response_data) - return response.status.state - def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ) -> "ResultSet": - """ - Get the result of a command execution. - - Args: - command_id: Command identifier - cursor: Cursor executing the command - - Returns: - ResultSet: A SeaResultSet instance with the execution results - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - # Create the request model - request = GetStatementRequest(statement_id=sea_statement_id) - - # Get the statement result - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) - - # Create and return a SeaResultSet - from databricks.sql.result_set import SeaResultSet - - return SeaResultSet( - connection=cursor.connection, - sea_response=response_data, - sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, + ): + """Not implemented yet.""" + raise NotImplementedError( + "get_execution_result is not yet implemented for SEA backend" ) # == Metadata Operations == @@ -514,22 +319,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - result = self.execute_command( - operation="SHOW CATALOGS", - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") def get_schemas( self, @@ -539,30 +331,9 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_schemas") - - operation = f"SHOW SCHEMAS IN `{catalog_name}`" - - if schema_name: - operation += f" LIKE '{schema_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") def get_tables( self, @@ -574,43 +345,9 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - - operation = "SHOW TABLES IN " + ( - "ALL CATALOGS" - if catalog_name in [None, "*", "%"] - else f"CATALOG `{catalog_name}`" - ) - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" LIKE '{table_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - - # Apply client-side filtering by table_types if specified - from databricks.sql.backend.filters import ResultSetFilter - - result = ResultSetFilter.filter_tables_by_type(result, table_types) - - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_tables is not yet implemented for SEA backend") def get_columns( self, @@ -622,33 +359,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_columns") - - operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" TABLE LIKE '{table_name}'" - - if column_name: - operation += f" LIKE '{column_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_columns is not yet implemented for SEA backend") diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index b7c8bd399..c9310d367 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,49 +4,19 @@ This package contains data models for SEA API requests and responses. """ -from databricks.sql.backend.sea.models.base import ( - ServiceError, - StatementStatus, - ExternalLink, - ResultData, - ColumnInfo, - ResultManifest, -) - from databricks.sql.backend.sea.models.requests import ( - StatementParameter, - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) __all__ = [ - # Base models - "ServiceError", - "StatementStatus", - "ExternalLink", - "ResultData", - "ColumnInfo", - "ResultManifest", # Request models - "StatementParameter", - "ExecuteStatementRequest", - "GetStatementRequest", - "CancelStatementRequest", - "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models - "ExecuteStatementResponse", - "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py deleted file mode 100644 index 671f7be13..000000000 --- a/src/databricks/sql/backend/sea/models/base.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Base models for the SEA (Statement Execution API) backend. - -These models define the common structures used in SEA API requests and responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState - - -@dataclass -class ServiceError: - """Error information returned by the SEA API.""" - - message: str - error_code: Optional[str] = None - - -@dataclass -class StatementStatus: - """Status information for a statement execution.""" - - state: CommandState - error: Optional[ServiceError] = None - sql_state: Optional[str] = None - - -@dataclass -class ExternalLink: - """External link information for result data.""" - - external_link: str - expiration: str - chunk_index: int - - -@dataclass -class ResultData: - """Result data from a statement execution.""" - - data: Optional[List[List[Any]]] = None - external_links: Optional[List[ExternalLink]] = None - - -@dataclass -class ColumnInfo: - """Information about a column in the result set.""" - - name: str - type_name: str - type_text: str - nullable: bool = True - precision: Optional[int] = None - scale: Optional[int] = None - ordinal_position: Optional[int] = None - - -@dataclass -class ResultManifest: - """Manifest information for a result set.""" - - schema: List[ColumnInfo] - total_row_count: int - total_byte_count: int - truncated: bool = False - chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 1c519d931..7966cb502 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,111 +1,5 @@ -""" -Request models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API requests. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - - -@dataclass -class StatementParameter: - """Parameter for a SQL statement.""" - - name: str - value: Optional[str] = None - type: Optional[str] = None - - -@dataclass -class ExecuteStatementRequest: - """Request to execute a SQL statement.""" - - warehouse_id: str - statement: str - session_id: str - disposition: str = "EXTERNAL_LINKS" - format: str = "JSON_ARRAY" - wait_timeout: str = "10s" - on_wait_timeout: str = "CONTINUE" - row_limit: Optional[int] = None - byte_limit: Optional[int] = None - parameters: Optional[List[StatementParameter]] = None - catalog: Optional[str] = None - schema: Optional[str] = None - result_compression: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - result: Dict[str, Any] = { - "warehouse_id": self.warehouse_id, - "session_id": self.session_id, - "statement": self.statement, - "disposition": self.disposition, - "format": self.format, - "wait_timeout": self.wait_timeout, - "on_wait_timeout": self.on_wait_timeout, - } - - if self.row_limit is not None and self.row_limit > 0: - result["row_limit"] = self.row_limit - - if self.byte_limit is not None and self.byte_limit > 0: - result["byte_limit"] = self.byte_limit - - if self.catalog: - result["catalog"] = self.catalog - - if self.schema: - result["schema"] = self.schema - - if self.result_compression: - result["result_compression"] = self.result_compression - - if self.parameters: - result["parameters"] = [ - { - "name": param.name, - **({"value": param.value} if param.value is not None else {}), - **({"type": param.type} if param.type is not None else {}), - } - for param in self.parameters - ] - - return result - - -@dataclass -class GetStatementRequest: - """Request to get information about a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CancelStatementRequest: - """Request to cancel a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CloseStatementRequest: - """Request to close a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} +from typing import Dict, Any, Optional +from dataclasses import dataclass @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d70459b9f..1bb54590f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,96 +1,5 @@ -""" -Response models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState -from databricks.sql.backend.sea.models.base import ( - StatementStatus, - ResultManifest, - ResultData, - ServiceError, -) - - -@dataclass -class ExecuteStatementResponse: - """Response from executing a SQL statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": - """Create an ExecuteStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) - - -@dataclass -class GetStatementResponse: - """Response from getting information about a statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": - """Create a GetStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) +from typing import Dict, Any +from dataclasses import dataclass @dataclass diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 5bf02e0ea..9cd21b5e6 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,6 +1,5 @@ -from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict, Optional, Any import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -81,28 +80,6 @@ def from_thrift_state( else: return None - @classmethod - def from_sea_state(cls, state: str) -> Optional["CommandState"]: - """ - Map SEA state string to CommandState enum. - - Args: - state: SEA state string - - Returns: - CommandState: The corresponding CommandState enum value - """ - state_mapping = { - "PENDING": cls.PENDING, - "RUNNING": cls.RUNNING, - "SUCCEEDED": cls.SUCCEEDED, - "FAILED": cls.FAILED, - "CLOSED": cls.CLOSED, - "CANCELED": cls.CANCELLED, - } - - return state_mapping.get(state, None) - class BackendType(Enum): """ @@ -308,6 +285,28 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count + def __str__(self) -> str: + """ + Return a string representation of the CommandId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the command ID + """ + + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.to_hex_guid()}|{secret_hex}" + return str(self.guid) + @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -319,6 +318,7 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ + if operation_handle is None: return None @@ -394,19 +394,3 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) - - -@dataclass -class ExecuteResponse: - """Response from executing a SQL command.""" - - command_id: CommandId - status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None - has_more_rows: bool = False - results_queue: Optional[Any] = None - has_been_closed_server_side: bool = False - lz4_compressed: bool = True - is_staging_operation: bool = False From ee9fa1c972bad75557ac0671d5eef96c0a0cff21 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:03:59 +0000 Subject: [PATCH 21/66] remove ResultSetFilter functionality Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 143 --------------- tests/unit/test_result_set_filter.py | 246 -------------------------- 2 files changed, 389 deletions(-) delete mode 100644 src/databricks/sql/backend/filters.py delete mode 100644 tests/unit/test_result_set_filter.py diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py deleted file mode 100644 index 7f48b6179..000000000 --- a/src/databricks/sql/backend/filters.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Client-side filtering utilities for Databricks SQL connector. - -This module provides filtering capabilities for result sets returned by different backends. -""" - -import logging -from typing import ( - List, - Optional, - Any, - Callable, - TYPE_CHECKING, -) - -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet - -from databricks.sql.result_set import SeaResultSet - -logger = logging.getLogger(__name__) - - -class ResultSetFilter: - """ - A general-purpose filter for result sets that can be applied to any backend. - - This class provides methods to filter result sets based on various criteria, - similar to the client-side filtering in the JDBC connector. - """ - - @staticmethod - def _filter_sea_result_set( - result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] - ) -> "SeaResultSet": - """ - Filter a SEA result set using the provided filter function. - - Args: - result_set: The SEA result set to filter - filter_func: Function that takes a row and returns True if the row should be included - - Returns: - A filtered SEA result set - """ - # Create a filtered version of the result set - filtered_response = result_set._response.copy() - - # If there's a result with rows, filter them - if ( - "result" in filtered_response - and "data_array" in filtered_response["result"] - ): - rows = filtered_response["result"]["data_array"] - filtered_rows = [row for row in rows if filter_func(row)] - filtered_response["result"]["data_array"] = filtered_rows - - # Update row count if present - if "row_count" in filtered_response["result"]: - filtered_response["result"]["row_count"] = len(filtered_rows) - - # Create a new result set with the filtered data - return SeaResultSet( - connection=result_set.connection, - sea_response=filtered_response, - sea_client=result_set.backend, - buffer_size_bytes=result_set.buffer_size_bytes, - arraysize=result_set.arraysize, - ) - - @staticmethod - def filter_by_column_values( - result_set: "ResultSet", - column_index: int, - allowed_values: List[str], - case_sensitive: bool = False, - ) -> "ResultSet": - """ - Filter a result set by values in a specific column. - - Args: - result_set: The result set to filter - column_index: The index of the column to filter on - allowed_values: List of allowed values for the column - case_sensitive: Whether to perform case-sensitive comparison - - Returns: - A filtered result set - """ - # Convert to uppercase for case-insensitive comparison if needed - if not case_sensitive: - allowed_values = [v.upper() for v in allowed_values] - - # Determine the type of result set and apply appropriate filtering - if isinstance(result_set, SeaResultSet): - return ResultSetFilter._filter_sea_result_set( - result_set, - lambda row: ( - len(row) > column_index - and isinstance(row[column_index], str) - and ( - row[column_index].upper() - if not case_sensitive - else row[column_index] - ) - in allowed_values - ), - ) - - # For other result set types, return the original (should be handled by specific implementations) - logger.warning( - f"Filtering not implemented for result set type: {type(result_set).__name__}" - ) - return result_set - - @staticmethod - def filter_tables_by_type( - result_set: "ResultSet", table_types: Optional[List[str]] = None - ) -> "ResultSet": - """ - Filter a result set of tables by the specified table types. - - This is a client-side filter that processes the result set after it has been - retrieved from the server. It filters out tables whose type does not match - any of the types in the table_types list. - - Args: - result_set: The original result set containing tables - table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) - - Returns: - A filtered result set containing only tables of the specified types - """ - # Default table types if none specified - DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] - valid_types = ( - table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES - ) - - # Table type is typically in the 6th column (index 5) - return ResultSetFilter.filter_by_column_values( - result_set, 5, valid_types, case_sensitive=False - ) diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py deleted file mode 100644 index e8eb2a757..000000000 --- a/tests/unit/test_result_set_filter.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -Tests for the ResultSetFilter class. - -This module contains tests for the ResultSetFilter class, which provides -filtering capabilities for result sets returned by different backends. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.backend.filters import ResultSetFilter -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestResultSetFilter: - """Test suite for the ResultSetFilter class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response_with_tables(self): - """Create a sample SEA response with table data based on the server schema.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 7, - "columns": [ - {"name": "namespace", "type_text": "STRING", "position": 0}, - {"name": "tableName", "type_text": "STRING", "position": 1}, - {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, - {"name": "information", "type_text": "STRING", "position": 3}, - {"name": "catalogName", "type_text": "STRING", "position": 4}, - {"name": "tableType", "type_text": "STRING", "position": 5}, - {"name": "remarks", "type_text": "STRING", "position": 6}, - ], - }, - "total_row_count": 4, - }, - "result": { - "data_array": [ - ["schema1", "table1", False, None, "catalog1", "TABLE", None], - ["schema1", "table2", False, None, "catalog1", "VIEW", None], - [ - "schema1", - "table3", - False, - None, - "catalog1", - "SYSTEM TABLE", - None, - ], - ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], - ] - }, - } - - @pytest.fixture - def sea_result_set_with_tables( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Create a SeaResultSet with table data.""" - # Create a deep copy of the response to avoid test interference - import copy - - sea_response_copy = copy.deepcopy(sea_response_with_tables) - - return SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy, - buffer_size_bytes=1000, - arraysize=100, - ) - - def test_filter_tables_by_type_default(self, sea_result_set_with_tables): - """Test filtering tables by type with default types.""" - # Default types are TABLE, VIEW, SYSTEM TABLE - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables - ) - - # Verify that only the default types are included - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): - """Test filtering tables by type with custom types.""" - # Filter for only TABLE and EXTERNAL - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] - ) - - # Verify that only the specified types are included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "EXTERNAL" in table_types - assert "VIEW" not in table_types - assert "SYSTEM TABLE" not in table_types - - def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): - """Test that table type filtering is case-insensitive.""" - # Filter for lowercase "table" and "view" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["table", "view"] - ) - - # Verify that the matching types are included despite case differences - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" not in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): - """Test filtering tables with an empty type list (should use defaults).""" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=[] - ) - - # Verify that default types are used - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_by_column_values(self, sea_result_set_with_tables): - """Test filtering by values in a specific column.""" - # Filter by namespace in column index 0 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] - ) - - # All rows have schema1 in namespace, so all should be included - assert len(filtered_result._response["result"]["data_array"]) == 4 - - # Filter by table name in column index 1 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, - column_index=1, - allowed_values=["table1", "table3"], - ) - - # Only rows with table1 or table3 should be included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_names = [ - row[1] for row in filtered_result._response["result"]["data_array"] - ] - assert "table1" in table_names - assert "table3" in table_names - assert "table2" not in table_names - assert "table4" not in table_names - - def test_filter_by_column_values_case_sensitive( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Test case-sensitive filtering by column values.""" - import copy - - # Create a fresh result set for the first test - sea_response_copy1 = copy.deepcopy(sea_response_with_tables) - result_set1 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy1, - buffer_size_bytes=1000, - arraysize=100, - ) - - # First test: Case-sensitive filtering with lowercase values (should find no matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set1, - column_index=5, # tableType column - allowed_values=["table", "view"], # lowercase - case_sensitive=True, - ) - - # Verify no matches with lowercase values - assert len(filtered_result._response["result"]["data_array"]) == 0 - - # Create a fresh result set for the second test - sea_response_copy2 = copy.deepcopy(sea_response_with_tables) - result_set2 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy2, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Second test: Case-sensitive filtering with correct case (should find matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set2, - column_index=5, # tableType column - allowed_values=["TABLE", "VIEW"], # correct case - case_sensitive=True, - ) - - # Verify matches with correct case - assert len(filtered_result._response["result"]["data_array"]) == 2 - - # Extract the table types from the filtered results - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - - def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): - """Test filtering with a column index that's out of bounds.""" - # Filter by column index 10 (out of bounds) - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=10, allowed_values=["value"] - ) - - # No rows should match - assert len(filtered_result._response["result"]["data_array"]) == 0 From 24c6152e9c2c003aa3074057c3d7d6e98d8d1916 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:06:23 +0000 Subject: [PATCH 22/66] remove more irrelevant changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 39 +- tests/unit/test_sea_backend.py | 755 ++++------------------------ 2 files changed, 132 insertions(+), 662 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,26 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + Args: + state: SEA state string + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + class BackendType(Enum): """ @@ -394,3 +415,19 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index c2078f731..bc2688a68 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,20 +1,11 @@ -""" -Tests for the SEA (Statement Execution API) backend implementation. - -This module contains tests for the SeaDatabricksClient class, which implements -the Databricks SQL connector's SEA backend functionality. -""" - -import json import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import patch, MagicMock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error class TestSeaBackend: @@ -50,23 +41,6 @@ def sea_client(self, mock_http_client): return client - @pytest.fixture - def sea_session_id(self): - """Create a SEA session ID.""" - return SessionId.from_sea_session_id("test-session-123") - - @pytest.fixture - def sea_command_id(self): - """Create a SEA command ID.""" - return CommandId.from_sea_statement_id("test-statement-123") - - @pytest.fixture - def mock_cursor(self): - """Create a mock cursor.""" - cursor = Mock() - cursor.active_command_id = None - return cursor - def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -201,650 +175,109 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - # Tests for command execution and management - - def test_execute_command_sync( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command synchronously.""" - # Set up mock responses - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "schema": [ - { - "name": "col1", - "type_name": "STRING", - "type_text": "string", - "nullable": True, - } - ], - "total_row_count": 1, - "total_byte_count": 100, - }, - "result": {"data": [["value1"]]}, - } - mock_http_client._make_request.return_value = execute_response - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "warehouse_id" in kwargs["data"] - assert "session_id" in kwargs["data"] - assert "statement" in kwargs["data"] - assert kwargs["data"]["statement"] == "SELECT 1" - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-123" - - def test_execute_command_async( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command asynchronously.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-456", - "status": {"state": "PENDING"}, - } - mock_http_client._make_request.return_value = execute_response - - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=True, # Async mode - enforce_embedded_schema_correctness=False, + def test_session_configuration_helpers(self): + """Test the session configuration helper methods.""" + # Test getting default value for a supported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "ANSI_MODE" ) + assert default_value == "true" - # Verify the result is None for async operation - assert result is None - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "wait_timeout" in kwargs["data"] - assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout - - # Verify the command ID was stored in the cursor - assert hasattr(mock_cursor, "active_command_id") - assert isinstance(mock_cursor.active_command_id, CommandId) - assert mock_cursor.active_command_id.guid == "test-statement-456" - - def test_execute_command_with_polling( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that requires polling.""" - # Set up mock responses for initial request and polling - initial_response = { - "statement_id": "test-statement-789", - "status": {"state": "RUNNING"}, - } - poll_response = { - "statement_id": "test-statement-789", - "status": {"state": "SUCCEEDED"}, - "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, - "result": {"data": []}, - } - - # Configure mock to return different responses on subsequent calls - mock_http_client._make_request.side_effect = [initial_response, poll_response] - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method - result = sea_client.execute_command( - operation="SELECT * FROM large_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP requests (initial and poll) - assert mock_http_client._make_request.call_count == 2 - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-789" - - def test_execute_command_with_parameters( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command with parameters.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, + # Test getting default value for an unsupported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert default_value is None + + # Test getting the list of allowed configurations + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", } - mock_http_client._make_request.return_value = execute_response + assert set(allowed_configs) == expected_keys - # Create parameter mock - param = MagicMock() - param.name = "param1" - param.value = "value1" - param.type = "STRING" + def test_unimplemented_methods(self, sea_client): + """Test that unimplemented methods raise NotImplementedError.""" + # Create dummy parameters for testing + session_id = SessionId.from_sea_session_id("test-session") + command_id = MagicMock() + cursor = MagicMock() - # Mock the get_execution_result method - with patch.object(sea_client, "get_execution_result") as mock_get_result: - # Call the method with parameters + # Test execute_command + with pytest.raises(NotImplementedError) as excinfo: sea_client.execute_command( - operation="SELECT * FROM table WHERE col = :param1", - session_id=sea_session_id, + operation="SELECT 1", + session_id=session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=mock_cursor, + cursor=cursor, use_cloud_fetch=False, - parameters=[param], + parameters=[], async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the HTTP request contains parameters - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert "parameters" in kwargs["data"] - assert len(kwargs["data"]["parameters"]) == 1 - assert kwargs["data"]["parameters"][0]["name"] == "param1" - assert kwargs["data"]["parameters"][0]["value"] == "value1" - assert kwargs["data"]["parameters"][0]["type"] == "STRING" - - def test_execute_command_failure( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that fails.""" - # Set up mock response for a failed execution - error_response = { - "statement_id": "test-statement-123", - "status": { - "state": "FAILED", - "error": { - "message": "Syntax error in SQL", - "error_code": "SYNTAX_ERROR", - }, - }, - } - - # Configure the mock to return the error response for the initial request - # and then raise an exception when trying to poll (to simulate immediate failure) - mock_http_client._make_request.side_effect = [ - error_response, # Initial response - Error( - "Statement execution did not succeed: Syntax error in SQL" - ), # Will be raised during polling - ] - - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method and expect an error - with pytest.raises(Error) as excinfo: - sea_client.execute_command( - operation="SELECT * FROM nonexistent_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Statement execution did not succeed" in str(excinfo.value) - - def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): - """Test canceling a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.cancel_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_close_command(self, sea_client, mock_http_client, sea_command_id): - """Test closing a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.close_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "DELETE" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): - """Test getting the state of a query.""" - # Set up mock response - mock_http_client._make_request.return_value = { - "statement_id": "test-statement-123", - "status": {"state": "RUNNING"}, - } - - # Call the method - state = sea_client.get_query_state(sea_command_id) - - # Verify the result - assert state == CommandState.RUNNING - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_execution_result( - self, sea_client, mock_http_client, mock_cursor, sea_command_id - ): - """Test getting the result of a command execution.""" - # Set up mock response - sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - mock_http_client._make_request.return_value = sea_response - - # Create a real result set to verify the implementation - result = sea_client.get_execution_result(sea_command_id, mock_cursor) - print(result) - - # Verify basic properties of the result - assert result.statement_id == "test-statement-123" - assert result.status == CommandState.SUCCEEDED - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + assert "execute_command is not yet implemented" in str(excinfo.value) + + # Test cancel_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.cancel_command(command_id) + assert "cancel_command is not yet implemented" in str(excinfo.value) + + # Test close_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.close_command(command_id) + assert "close_command is not yet implemented" in str(excinfo.value) + + # Test get_query_state + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_query_state(command_id) + assert "get_query_state is not yet implemented" in str(excinfo.value) + + # Test get_execution_result + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_execution_result(command_id, cursor) + assert "get_execution_result is not yet implemented" in str(excinfo.value) + + # Test metadata operations + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(session_id, 100, 1000, cursor) + assert "get_catalogs is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(session_id, 100, 1000, cursor) + assert "get_schemas is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(session_id, 100, 1000, cursor) + assert "get_tables is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(session_id, 100, 1000, cursor) + assert "get_columns is not yet implemented" in str(excinfo.value) + + def test_max_download_threads_property(self, sea_client): + """Test the max_download_threads property.""" + assert sea_client.max_download_threads == 10 + + # Create a client with a custom value + custom_client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=20, ) - # Tests for metadata operations - - def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): - """Test getting catalogs.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method - result = sea_client.get_catalogs( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["lz4_compression"] is False - assert kwargs["use_cloud_fetch"] is False - assert kwargs["parameters"] == [] - assert kwargs["async_op"] is False - assert kwargs["enforce_embedded_schema_correctness"] is False - - def test_get_schemas_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_with_catalog_and_schema_pattern( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with catalog and schema pattern.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog and schema pattern - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting schemas without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - ) - - assert "Catalog name is required for get_schemas" in str(excinfo.value) - - def test_get_tables_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_all_catalogs( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables from all catalogs using wildcard.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with wildcard catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="*", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_schema_and_table_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with schema and table patterns.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog, schema, and table patterns - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting tables without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_tables" in str(excinfo.value) - - def test_get_columns_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_with_all_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with all patterns specified.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with all patterns - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - column_name="test_column", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting columns without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_columns" in str(excinfo.value) + # Verify the custom value is returned + assert custom_client.max_download_threads == 20 From 67fd1012f9496724aa05183f82d9c92f0c40f1ed Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:10:48 +0000 Subject: [PATCH 23/66] remove more irrelevant changes Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 2 - src/databricks/sql/backend/thrift_backend.py | 3 +- src/databricks/sql/result_set.py | 91 +++++++++---------- 3 files changed, 44 insertions(+), 52 deletions(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 20b059fa7..8fda71e1e 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,8 +16,6 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState -from databricks.sql.utils import ExecuteResponse -from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 810c2e7a1..f41f4b6d8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,6 +352,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -1241,7 +1242,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(command_id.guid)) + logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a4beda629..dd61408db 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -64,7 +64,7 @@ def __init__( """ self.connection = connection - self.backend = backend # Store the backend client directly + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> Any: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> Any: + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all remaining rows as an Arrow table.""" pass @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -402,6 +402,33 @@ def fetchmany(self, size: int) -> List[Row]: @staticmethod def _get_schema_description(table_schema_message): + """ + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + """ + + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ + + return [ + (column.name, map_col_type(column.datatype), None, None, None, None, None) + for column in table_schema_message.columns + ] + + +class SeaResultSet(ResultSet): + """ResultSet implementation for the SEA backend.""" + + def __init__( + self, + connection: "Connection", + execute_response: "ExecuteResponse", + sea_client: "SeaDatabricksClient", + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + ): """ Initialize a SeaResultSet with the response from a SEA query execution. @@ -413,53 +440,19 @@ def _get_schema_description(table_schema_message): execute_response: Response from the execute command (new style) sea_response: Direct SEA response (legacy style) """ - # Handle both initialization styles - if execute_response is not None: - # New style with ExecuteResponse - command_id = execute_response.command_id - status = execute_response.status - has_been_closed_server_side = execute_response.has_been_closed_server_side - has_more_rows = execute_response.has_more_rows - results_queue = execute_response.results_queue - description = execute_response.description - is_staging_operation = execute_response.is_staging_operation - self._response = getattr(execute_response, "sea_response", {}) - self.statement_id = command_id.to_sea_statement_id() if command_id else None - elif sea_response is not None: - # Legacy style with direct sea_response - self._response = sea_response - # Extract values from sea_response - command_id = CommandId.from_sea_statement_id( - sea_response.get("statement_id", "") - ) - self.statement_id = sea_response.get("statement_id", "") - - # Extract status - status_data = sea_response.get("status", {}) - status = CommandState.from_sea_state(status_data.get("state", "PENDING")) - - # Set defaults for other fields - has_been_closed_server_side = False - has_more_rows = False - results_queue = None - description = None - is_staging_operation = False - else: - raise ValueError("Either execute_response or sea_response must be provided") - # Call parent constructor with common attributes super().__init__( connection=connection, backend=sea_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, - command_id=command_id, - status=status, - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - results_queue=results_queue, - description=description, - is_staging_operation=is_staging_operation, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, ) def _fill_results_buffer(self): From 271fcafbb04e7c5e08423b7536dac57f9595c5b6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:12:13 +0000 Subject: [PATCH 24/66] even more irrelevant changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 2 +- tests/unit/test_session.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index dd61408db..7ea4976c1 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -259,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2): + def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result :param result1: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fef070362..92de8d8fd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,11 +2,6 @@ from unittest.mock import patch, Mock import gc -from databricks.sql.thrift_api.TCLIService.ttypes import ( - TOpenSessionResp, - TSessionHandle, - THandleIdentifier, -) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From bf26ea3e4dae441d0e82d1f55c3da36ee2282568 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:19:46 +0000 Subject: [PATCH 25/66] remove sea response as init option Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_result_set.py | 103 ++++-------------------------- 1 file changed, 14 insertions(+), 89 deletions(-) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index f666fd613..02421a915 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -27,38 +27,6 @@ def mock_sea_client(self): """Create a mock SEA client.""" return Mock() - @pytest.fixture - def sea_response(self): - """Create a sample SEA response.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - @pytest.fixture def execute_response(self): """Create a sample execute response.""" @@ -72,78 +40,35 @@ def execute_response(self): ("test_value", "INT", None, None, None, None, None) ] mock_response.is_staging_operation = False - mock_response.sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "result": {"data_array": [["1"]]}, - } return mock_response - def test_init_with_sea_response( - self, mock_connection, mock_sea_client, sea_response - ): - """Test initializing SeaResultSet with a SEA response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == sea_response - def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): """Test initializing SeaResultSet with an execute response.""" result_set = SeaResultSet( connection=mock_connection, - sea_client=mock_sea_client, execute_response=execute_response, + sea_client=mock_sea_client, buffer_size_bytes=1000, arraysize=100, ) # Verify basic properties - assert result_set.statement_id == "test-statement-123" + assert result_set.command_id == execute_response.command_id assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA assert result_set.connection == mock_connection assert result_set.backend == mock_sea_client assert result_set.buffer_size_bytes == 1000 assert result_set.arraysize == 100 - assert result_set._response == execute_response.sea_response + assert result_set.description == execute_response.description - def test_init_with_no_response(self, mock_connection, mock_sea_client): - """Test that initialization fails when neither response type is provided.""" - with pytest.raises(ValueError) as excinfo: - SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - assert "Either execute_response or sea_response must be provided" in str( - excinfo.value - ) - - def test_close(self, mock_connection, mock_sea_client, sea_response): + def test_close(self, mock_connection, mock_sea_client, execute_response): """Test closing a result set.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -157,13 +82,13 @@ def test_close(self, mock_connection, mock_sea_client, sea_response): assert result_set.status == CommandState.CLOSED def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test closing a result set that has already been closed server-side.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -178,14 +103,14 @@ def test_close_when_already_closed_server_side( assert result_set.status == CommandState.CLOSED def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test closing a result set when the connection is closed.""" mock_connection.open = False result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -199,13 +124,13 @@ def test_close_when_connection_closed( assert result_set.status == CommandState.CLOSED def test_unimplemented_methods( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test that unimplemented methods raise NotImplementedError.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -258,13 +183,13 @@ def test_unimplemented_methods( pass def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test that _fill_results_buffer raises NotImplementedError.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -272,4 +197,4 @@ def test_fill_results_buffer_not_implemented( with pytest.raises( NotImplementedError, match="fetchone is not implemented for SEA backend" ): - result_set._fill_results_buffer() + result_set._fill_results_buffer() \ No newline at end of file From d97463b45fd6c8e7457988441edc012e51d78368 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 13:21:34 +0000 Subject: [PATCH 26/66] move guid_to_hex_id import to utils Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f41f4b6d8..f90d2897e 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -10,15 +10,13 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor -from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( CommandState, SessionId, CommandId, - BackendType, - guid_to_hex_id, ExecuteResponse, ) +from databricks.sql.backend.utils.guid_utils import guid_to_hex_id try: import pyarrow From 139e2466ef9c35a2673e4af6066549004cf16533 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 13:22:25 +0000 Subject: [PATCH 27/66] reduce diff in guid utils import Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f90d2897e..4b3e827f2 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -16,7 +16,8 @@ CommandId, ExecuteResponse, ) -from databricks.sql.backend.utils.guid_utils import guid_to_hex_id +from databricks.sql.backend.utils import guid_to_hex_id + try: import pyarrow From e3ee4e4acfd7178db6a78dadce21bc6e7a52b77f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 15:24:33 +0000 Subject: [PATCH 28/66] move arrow_schema_bytes back into ExecuteResult Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 58 ++++------ src/databricks/sql/backend/types.py | 1 + src/databricks/sql/result_set.py | 4 +- tests/unit/test_thrift_backend.py | 106 ++++++++++++++++--- 4 files changed, 116 insertions(+), 53 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 4b3e827f2..d99cf2624 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -801,18 +801,16 @@ def _results_message_to_execute_response(self, resp, operation_state): if status is None: raise ValueError(f"Unknown command state: {operation_state}") - return ( - ExecuteResponse( - command_id=command_id, - status=status, - description=description, - has_more_rows=has_more_rows, - results_queue=arrow_queue_opt, - has_been_closed_server_side=has_been_closed_server_side, - lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, - ), - schema_bytes, + return ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_more_rows=has_more_rows, + results_queue=arrow_queue_opt, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + arrow_schema_bytes=schema_bytes, ) def get_execution_result( @@ -877,6 +875,7 @@ def get_execution_result( has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, + arrow_schema_bytes=schema_bytes, ) return ThriftResultSet( @@ -886,7 +885,6 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=schema_bytes, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -999,9 +997,7 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, arrow_schema_bytes = self._handle_execute_response( - resp, cursor - ) + execute_response = self._handle_execute_response(resp, cursor) return ThriftResultSet( connection=cursor.connection, @@ -1010,7 +1006,6 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, ) def get_catalogs( @@ -1032,9 +1027,7 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( - resp, cursor - ) + execute_response = self._handle_execute_response(resp, cursor) return ThriftResultSet( connection=cursor.connection, @@ -1043,7 +1036,6 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, ) def get_schemas( @@ -1069,9 +1061,7 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( - resp, cursor - ) + execute_response = self._handle_execute_response(resp, cursor) return ThriftResultSet( connection=cursor.connection, @@ -1080,7 +1070,6 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, ) def get_tables( @@ -1110,9 +1099,7 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( - resp, cursor - ) + execute_response = self._handle_execute_response(resp, cursor) return ThriftResultSet( connection=cursor.connection, @@ -1121,7 +1108,6 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, ) def get_columns( @@ -1151,9 +1137,7 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( - resp, cursor - ) + execute_response = self._handle_execute_response(resp, cursor) return ThriftResultSet( connection=cursor.connection, @@ -1162,7 +1146,6 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, ) def _handle_execute_response(self, resp, cursor): @@ -1176,11 +1159,10 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - ( - execute_response, - arrow_schema_bytes, - ) = self._results_message_to_execute_response(resp, final_operation_state) - return execute_response, arrow_schema_bytes + execute_response = self._results_message_to_execute_response( + resp, final_operation_state + ) + return execute_response def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3d24a09a1..fed1bc6cd 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -431,3 +431,4 @@ class ExecuteResponse: has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False + arrow_schema_bytes: Optional[bytes] = None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index e177d495f..23e0fa490 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -157,7 +157,6 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, - arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -169,10 +168,9 @@ def __init__( buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch use_cloud_fetch: Whether to use cloud fetch for retrieving results - arrow_schema_bytes: Arrow schema bytes for the result set """ # Initialize ThriftResultSet-specific attributes - self._arrow_schema_bytes = arrow_schema_bytes + self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch self.lz4_compressed = execute_response.lz4_compressed diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index b8de970db..dc2b9c038 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -19,7 +19,13 @@ from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.result_set import ResultSet, ThriftResultSet -from databricks.sql.backend.types import CommandId, CommandState, SessionId, BackendType +from databricks.sql.backend.types import ( + CommandId, + CommandState, + SessionId, + BackendType, + ExecuteResponse, +) def retry_policy_factory(): @@ -651,7 +657,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( + execute_response = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -885,7 +891,7 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( + execute_response = thrift_backend._handle_execute_response( execute_resp, Mock() ) @@ -963,11 +969,11 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): t_get_result_set_metadata_resp ) thrift_backend = self._make_fake_thrift_backend() - execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( + execute_response = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) - self.assertEqual(arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -1040,7 +1046,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( + execute_response = thrift_backend._handle_execute_response( execute_resp, Mock() ) @@ -1172,7 +1178,20 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock( + return_value=Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_more_rows=Mock(), + results_queue=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + ) + ) cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1206,7 +1225,20 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock( + return_value=Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_more_rows=Mock(), + results_queue=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + ) + ) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1237,7 +1269,20 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock( + return_value=Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_more_rows=Mock(), + results_queue=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + ) + ) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1277,7 +1322,20 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock( + return_value=Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_more_rows=Mock(), + results_queue=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + ) + ) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1321,7 +1379,20 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock( + return_value=Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_more_rows=Mock(), + results_queue=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + ) + ) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -2229,7 +2300,18 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", - return_value=(Mock(), Mock()), + return_value=Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_more_rows=Mock(), + results_queue=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + ), ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class From f448a8f18170c3acd157810b6960605362fcfbd3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 15:59:50 +0000 Subject: [PATCH 29/66] maintain log Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index d99cf2624..6f05b45a5 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -915,7 +915,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) state = CommandState.from_thrift_state(operation_state) if state is None: - raise ValueError(f"Invalid operation state: {operation_state}") + raise ValueError(f"Unknown command state: {operation_state}") return state @staticmethod From 82ca1eefc150da88e637d25f26198fc696400dbe Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 16:01:48 +0000 Subject: [PATCH 30/66] remove un-necessary assignment Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 6f05b45a5..0ff68651e 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -1159,10 +1159,7 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - execute_response = self._results_message_to_execute_response( - resp, final_operation_state - ) - return execute_response + return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) From e96a0785d188171aa79121b15c722a9dfd09cccd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 16:06:03 +0000 Subject: [PATCH 31/66] remove un-necessary tuple response Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index dc2b9c038..733ea17a5 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -929,12 +929,9 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) - thrift_backend._results_message_to_execute_response.assert_called_with( execute_resp, ttypes.TOperationState.FINISHED_STATE, @@ -1738,9 +1735,7 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() # Create a mock response with a real operation handle mock_resp = Mock() From 27158b1fe5998e3ccaebf2c3a0cc5b462e1f656c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 16:10:27 +0000 Subject: [PATCH 32/66] remove un-ncessary verbose mocking Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 75 +++---------------------------- 1 file changed, 5 insertions(+), 70 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 733ea17a5..c9cb05305 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1175,20 +1175,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock( - return_value=Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_more_rows=Mock(), - results_queue=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - ) - ) + thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1222,20 +1209,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock( - return_value=Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_more_rows=Mock(), - results_queue=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - ) - ) + thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1266,20 +1240,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock( - return_value=Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_more_rows=Mock(), - results_queue=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - ) - ) + thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1319,20 +1280,7 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock( - return_value=Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_more_rows=Mock(), - results_queue=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - ) - ) + thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1376,20 +1324,7 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock( - return_value=Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_more_rows=Mock(), - results_queue=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - ) - ) + thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() result = thrift_backend.get_columns( From d3200c49d87ef32184b48877d115353d51b82dd4 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 05:31:55 +0000 Subject: [PATCH 33/66] move Queue construction to ResultSert Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 81 ++++++++++++-------- src/databricks/sql/backend/types.py | 6 +- src/databricks/sql/result_set.py | 24 +++++- tests/unit/test_client.py | 9 ++- tests/unit/test_fetches.py | 40 ++++++---- tests/unit/test_thrift_backend.py | 55 ++++++++++--- 6 files changed, 148 insertions(+), 67 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 0ff68651e..2e3e61ca0 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -3,7 +3,6 @@ import logging import math import time -import uuid import threading from typing import List, Union, Any, TYPE_CHECKING @@ -728,7 +727,7 @@ def _col_to_description(col): else: precision, scale = None, None - return col.columnName, cleaned_type, None, None, precision, scale, None + return [col.columnName, cleaned_type, None, None, precision, scale, None] @staticmethod def _hive_schema_to_description(t_table_schema): @@ -778,23 +777,6 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = None lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - if direct_results and direct_results.resultSet: - assert direct_results.resultSet.results.startRowOffset == 0 - assert direct_results.resultSetMetadata - - arrow_queue_opt = ResultSetQueueFactory.build_queue( - row_set_type=t_result_set_metadata_resp.resultFormat, - t_row_set=direct_results.resultSet.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) - else: - arrow_queue_opt = None - command_id = CommandId.from_thrift_handle(resp.operationHandle) status = CommandState.from_thrift_state(operation_state) @@ -806,11 +788,11 @@ def _results_message_to_execute_response(self, resp, operation_state): status=status, description=description, has_more_rows=has_more_rows, - results_queue=arrow_queue_opt, has_been_closed_server_side=has_been_closed_server_side, lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, + is_staging_operation=t_result_set_metadata_resp.isStagingOperation, arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) def get_execution_result( @@ -837,9 +819,6 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata - lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -854,15 +833,9 @@ def get_execution_result( else: schema_bytes = None - queue = ResultSetQueueFactory.build_queue( - row_set_type=resp.resultSetMetadata.resultFormat, - t_row_set=resp.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + has_more_rows = resp.hasMoreRows status = self.get_query_state(command_id) @@ -871,11 +844,11 @@ def get_execution_result( status=status, description=description, has_more_rows=has_more_rows, - results_queue=queue, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) return ThriftResultSet( @@ -885,6 +858,9 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=resp.results, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -999,6 +975,10 @@ def execute_command( else: execute_response = self._handle_execute_response(resp, cursor) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1006,6 +986,9 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, ) def get_catalogs( @@ -1029,6 +1012,10 @@ def get_catalogs( execute_response = self._handle_execute_response(resp, cursor) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1036,6 +1023,9 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, ) def get_schemas( @@ -1063,6 +1053,10 @@ def get_schemas( execute_response = self._handle_execute_response(resp, cursor) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1070,6 +1064,9 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, ) def get_tables( @@ -1101,6 +1098,10 @@ def get_tables( execute_response = self._handle_execute_response(resp, cursor) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1108,6 +1109,9 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, ) def get_columns( @@ -1139,6 +1143,10 @@ def get_columns( execute_response = self._handle_execute_response(resp, cursor) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1146,6 +1154,9 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, ) def _handle_execute_response(self, resp, cursor): @@ -1203,6 +1214,8 @@ def fetch_results( ) ) + from databricks.sql.utils import ResultSetQueueFactory + queue = ResultSetQueueFactory.build_queue( row_set_type=resp.resultSetMetadata.resultFormat, t_row_set=resp.results, diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index fed1bc6cd..ba2975d7c 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -423,12 +423,10 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None + description: Optional[List[List[Any]]] = None has_more_rows: bool = False - results_queue: Optional[Any] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False arrow_schema_bytes: Optional[bytes] = None + result_format: Optional[Any] = None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 23e0fa490..ab3fb68f2 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -157,6 +157,9 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, + t_row_set=None, + max_download_threads: int = 10, + ssl_options=None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -168,12 +171,31 @@ def __init__( buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch use_cloud_fetch: Whether to use cloud fetch for retrieving results + t_row_set: The TRowSet containing result data (if available) + max_download_threads: Maximum number of download threads for cloud fetch + ssl_options: SSL options for cloud fetch """ # Initialize ThriftResultSet-specific attributes self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch self.lz4_compressed = execute_response.lz4_compressed + # Build the results queue if t_row_set is provided + results_queue = None + if t_row_set and execute_response.result_format is not None: + from databricks.sql.utils import ResultSetQueueFactory + + # Create the results queue using the provided format + results_queue = ResultSetQueueFactory.build_queue( + row_set_type=execute_response.result_format, + t_row_set=t_row_set, + arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", + max_download_threads=max_download_threads, + lz4_compressed=execute_response.lz4_compressed, + description=execute_response.description, + ssl_options=ssl_options, + ) + # Call parent constructor with common attributes super().__init__( connection=connection, @@ -184,7 +206,7 @@ def __init__( status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, + results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, ) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 090ec255e..63bc92fdc 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -104,6 +104,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None + mock_backend.fetch_results.return_value = (Mock(), False) # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend @@ -184,6 +185,7 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( connection=mock_connection, @@ -210,6 +212,8 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) + mock_thrift_backend.fetch_results.return_value = (Mock(), False) + result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -254,7 +258,10 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - result_set = ThriftResultSet(Mock(), Mock(), Mock()) + mock_backend = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) + + result_set = ThriftResultSet(Mock(), Mock(), mock_backend) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 7249a59e6..18be51da8 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -40,6 +40,17 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) + + # Create a mock backend that will return the queue when _fill_results_buffer is called + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) + mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + + num_cols = len(initial_results[0]) if initial_results else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( @@ -47,18 +58,16 @@ def make_dummy_result_set_from_initial_results(initial_results): status=None, has_been_closed_server_side=True, has_more_rows=False, - description=Mock(), - lz4_compressed=Mock(), - results_queue=arrow_queue, + description=description, + lz4_compressed=True, is_staging_operation=False, ), - thrift_client=None, + thrift_client=mock_thrift_backend, + t_row_set=None, ) - num_cols = len(initial_results[0]) if initial_results else 0 - rs.description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] + + # Replace the results queue with our arrow_queue + rs.results = arrow_queue return rs @staticmethod @@ -85,6 +94,11 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( @@ -92,12 +106,8 @@ def fetch_results( status=None, has_been_closed_server_side=False, has_more_rows=True, - description=[ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ], - lz4_compressed=Mock(), - results_queue=None, + description=description, + lz4_compressed=True, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index c9cb05305..7165c6259 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -511,10 +511,10 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self): self.assertEqual( description, [ - ("column 1", "int", None, None, None, None, None), - ("column 2", "boolean", None, None, None, None, None), - ("column 2", "map", None, None, None, None, None), - ("", "struct", None, None, None, None, None), + ["column 1", "int", None, None, None, None, None], + ["column 2", "boolean", None, None, None, None, None], + ["column 2", "map", None, None, None, None, None], + ["", "struct", None, None, None, None, None], ], ) @@ -549,7 +549,7 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): self.assertEqual( description, [ - ("column 1", "decimal", None, None, 10, 100, None), + ["column 1", "decimal", None, None, 10, 100, None], ], ) @@ -1161,8 +1161,11 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) def test_execute_statement_calls_client_and_handle_execute_response( - self, tcli_service_class + self, mock_build_queue, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1178,6 +1181,8 @@ def test_execute_statement_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() + thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) + result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) @@ -1195,8 +1200,11 @@ def test_execute_statement_calls_client_and_handle_execute_response( ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) def test_get_catalogs_calls_client_and_handle_execute_response( - self, tcli_service_class + self, mock_build_queue, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1212,6 +1220,8 @@ def test_get_catalogs_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() + thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) + result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet self.assertIsInstance(result, ResultSet) @@ -1226,8 +1236,11 @@ def test_get_catalogs_calls_client_and_handle_execute_response( ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) def test_get_schemas_calls_client_and_handle_execute_response( - self, tcli_service_class + self, mock_build_queue, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1243,6 +1256,8 @@ def test_get_schemas_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() + thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) + result = thrift_backend.get_schemas( Mock(), 100, @@ -1266,8 +1281,11 @@ def test_get_schemas_calls_client_and_handle_execute_response( ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) def test_get_tables_calls_client_and_handle_execute_response( - self, tcli_service_class + self, mock_build_queue, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1283,6 +1301,8 @@ def test_get_tables_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() + thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) + result = thrift_backend.get_tables( Mock(), 100, @@ -1310,8 +1330,11 @@ def test_get_tables_calls_client_and_handle_execute_response( ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) def test_get_columns_calls_client_and_handle_execute_response( - self, tcli_service_class + self, mock_build_queue, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1327,6 +1350,8 @@ def test_get_columns_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() + thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) + result = thrift_backend.get_columns( Mock(), 100, @@ -2228,6 +2253,9 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) @patch( "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", return_value=Mock( @@ -2236,15 +2264,15 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): status=Mock(), description=Mock(), has_more_rows=Mock(), - results_queue=Mock(), has_been_closed_server_side=Mock(), lz4_compressed=Mock(), is_staging_operation=Mock(), arrow_schema_bytes=Mock(), + result_format=Mock(), ), ) def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, tcli_service_class + self, mock_handle_execute_response, mock_build_queue, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value # Iterate through each possible combination of native types (True, False and unset) @@ -2268,6 +2296,9 @@ def test_execute_command_sets_complex_type_fields_correctly( ssl_options=SSLOptions(), **complex_arg_types, ) + + thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) + thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[ 0 From 8a014f01df6137685a3acd58f10852d73fba3c2f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 06:10:58 +0000 Subject: [PATCH 34/66] move description to List[Tuple] Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 +- src/databricks/sql/backend/types.py | 2 +- src/databricks/sql/utils.py | 6 +++--- tests/unit/test_thrift_backend.py | 10 +++++----- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 2e3e61ca0..3792d4935 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -727,7 +727,7 @@ def _col_to_description(col): else: precision, scale = None, None - return [col.columnName, cleaned_type, None, None, precision, scale, None] + return (col.columnName, cleaned_type, None, None, precision, scale, None) @staticmethod def _hive_schema_to_description(t_table_schema): diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index ba2975d7c..249816eab 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -423,7 +423,7 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[List[List[Any]]] = None + description: Optional[List[Tuple]] = None has_more_rows: bool = False has_been_closed_server_side: bool = False lz4_compressed: bool = True diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index edb13ef6d..d7b1b74b4 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import re import lz4.frame @@ -57,7 +57,7 @@ def build_queue( max_download_threads: int, ssl_options: SSLOptions, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -206,7 +206,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 7165c6259..aae11c56c 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -511,10 +511,10 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self): self.assertEqual( description, [ - ["column 1", "int", None, None, None, None, None], - ["column 2", "boolean", None, None, None, None, None], - ["column 2", "map", None, None, None, None, None], - ["", "struct", None, None, None, None, None], + ("column 1", "int", None, None, None, None, None), + ("column 2", "boolean", None, None, None, None, None), + ("column 2", "map", None, None, None, None, None), + ("", "struct", None, None, None, None, None), ], ) @@ -549,7 +549,7 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): self.assertEqual( description, [ - ["column 1", "decimal", None, None, 10, 100, None], + ("column 1", "decimal", None, None, 10, 100, None), ], ) From 39c41ab9abf54e0fc4d1fbc8c02abe02271fb866 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 06:12:10 +0000 Subject: [PATCH 35/66] frmatting (black) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index ab3fb68f2..dc72382c6 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -184,7 +184,7 @@ def __init__( results_queue = None if t_row_set and execute_response.result_format is not None: from databricks.sql.utils import ResultSetQueueFactory - + # Create the results queue using the provided format results_queue = ResultSetQueueFactory.build_queue( row_set_type=execute_response.result_format, From 2cd04dfc331b7ef8335cdca288884a951a4dc269 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 06:13:12 +0000 Subject: [PATCH 36/66] reduce diff (remove explicit tuple conversion) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 3792d4935..f2e95fb66 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -727,7 +727,7 @@ def _col_to_description(col): else: precision, scale = None, None - return (col.columnName, cleaned_type, None, None, precision, scale, None) + return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod def _hive_schema_to_description(t_table_schema): From 067a01967c4fe9b6b5e4bc83792b6457e2666c12 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 08:51:35 +0000 Subject: [PATCH 37/66] remove has_more_rows from ExecuteResponse Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 -- src/databricks/sql/backend/types.py | 1 - src/databricks/sql/result_set.py | 2 +- tests/unit/test_fetches.py | 2 -- tests/unit/test_thrift_backend.py | 14 +++++++------- 5 files changed, 8 insertions(+), 13 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f2e95fb66..46f5ef02e 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -787,7 +787,6 @@ def _results_message_to_execute_response(self, resp, operation_state): command_id=command_id, status=status, description=description, - has_more_rows=has_more_rows, has_been_closed_server_side=has_been_closed_server_side, lz4_compressed=lz4_compressed, is_staging_operation=t_result_set_metadata_resp.isStagingOperation, @@ -843,7 +842,6 @@ def get_execution_result( command_id=command_id, status=status, description=description, - has_more_rows=has_more_rows, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 249816eab..93bd7d525 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -424,7 +424,6 @@ class ExecuteResponse: command_id: CommandId status: CommandState description: Optional[List[Tuple]] = None - has_more_rows: bool = False has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index dc72382c6..fb9b417c1 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -205,7 +205,7 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, + has_more_rows=False, results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 18be51da8..ba9b50aef 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -57,7 +57,6 @@ def make_dummy_result_set_from_initial_results(initial_results): command_id=None, status=None, has_been_closed_server_side=True, - has_more_rows=False, description=description, lz4_compressed=True, is_staging_operation=False, @@ -105,7 +104,6 @@ def fetch_results( command_id=None, status=None, has_been_closed_server_side=False, - has_more_rows=True, description=description, lz4_compressed=True, is_staging_operation=False, diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index aae11c56c..bab9cb3ca 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1009,13 +1009,12 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - def test_handle_execute_response_reads_has_more_rows_in_direct_results( + def test_handle_execute_response_creates_execute_response( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( - [True, False], self.execute_response_types - ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + """Test that _handle_execute_response creates an ExecuteResponse object correctly.""" + for resp_type in self.execute_response_types: + with self.subTest(resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1027,7 +1026,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=True, results=results_mock, ), closeOperation=Mock(), @@ -1047,7 +1046,8 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( execute_resp, Mock() ) - self.assertEqual(has_more_rows, execute_response.has_more_rows) + self.assertIsNotNone(execute_response) + self.assertIsInstance(execute_response, ExecuteResponse) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() From 48c83e095afe26438b2da71a6bdd6be9e03d1d7d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 09:02:02 +0000 Subject: [PATCH 38/66] remove un-necessary has_more_rows aclc Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 46f5ef02e..7cdd583d5 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -757,11 +757,7 @@ def _results_message_to_execute_response(self, resp, operation_state): ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation - has_more_rows = ( - (not direct_results) - or (not direct_results.resultSet) - or direct_results.resultSet.hasMoreRows - ) + description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) From 281a9e9675f5b573c87053f47c07517e2a4db2ca Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 10:33:27 +0000 Subject: [PATCH 39/66] default has_more_rows to True Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index fb9b417c1..cb6c5e1c3 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -205,7 +205,7 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=False, + has_more_rows=True, results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, From 192901d2f51bf4764276c60bdd75a005e0562de0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 11:40:42 +0000 Subject: [PATCH 40/66] return has_more_rows from ExecResponse conversion during GetRespMetadata Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 28 ++- src/databricks/sql/result_set.py | 4 +- tests/unit/test_thrift_backend.py | 244 +++++++++---------- 3 files changed, 137 insertions(+), 139 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 7cdd583d5..ffbd2885e 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -758,6 +758,12 @@ def _results_message_to_execute_response(self, resp, operation_state): direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation + has_more_rows = ( + (not direct_results) + or (not direct_results.resultSet) + or direct_results.resultSet.hasMoreRows + ) + description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -779,7 +785,7 @@ def _results_message_to_execute_response(self, resp, operation_state): if status is None: raise ValueError(f"Unknown command state: {operation_state}") - return ExecuteResponse( + execute_response = ExecuteResponse( command_id=command_id, status=status, description=description, @@ -790,6 +796,8 @@ def _results_message_to_execute_response(self, resp, operation_state): result_format=t_result_set_metadata_resp.resultFormat, ) + return execute_response, has_more_rows + def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": @@ -855,6 +863,7 @@ def get_execution_result( t_row_set=resp.results, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, + has_more_rows=has_more_rows, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -967,7 +976,9 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response = self._handle_execute_response(resp, cursor) + execute_response, has_more_rows = self._handle_execute_response( + resp, cursor + ) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -983,6 +994,7 @@ def execute_command( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, + has_more_rows=has_more_rows, ) def get_catalogs( @@ -1004,7 +1016,7 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, has_more_rows = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1020,6 +1032,7 @@ def get_catalogs( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, + has_more_rows=has_more_rows, ) def get_schemas( @@ -1045,7 +1058,7 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, has_more_rows = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1061,6 +1074,7 @@ def get_schemas( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, + has_more_rows=has_more_rows, ) def get_tables( @@ -1090,7 +1104,7 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, has_more_rows = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1106,6 +1120,7 @@ def get_tables( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, + has_more_rows=has_more_rows, ) def get_columns( @@ -1135,7 +1150,7 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, has_more_rows = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1151,6 +1166,7 @@ def get_columns( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, + has_more_rows=has_more_rows, ) def _handle_execute_response(self, resp, cursor): diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index cb6c5e1c3..9857d9e0f 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -160,6 +160,7 @@ def __init__( t_row_set=None, max_download_threads: int = 10, ssl_options=None, + has_more_rows: bool = True, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -174,6 +175,7 @@ def __init__( t_row_set: The TRowSet containing result data (if available) max_download_threads: Maximum number of download threads for cloud fetch ssl_options: SSL options for cloud fetch + has_more_rows: Whether there are more rows to fetch """ # Initialize ThriftResultSet-specific attributes self._arrow_schema_bytes = execute_response.arrow_schema_bytes @@ -205,7 +207,7 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=True, + has_more_rows=has_more_rows, results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index bab9cb3ca..4f5e14cab 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -82,14 +82,7 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -100,8 +93,22 @@ def _make_type_desc(self, type): ] ) - def _make_fake_thrift_backend(self): - thrift_backend = ThriftDatabricksClient( + def _create_mock_execute_response(self): + """Create a properly mocked ExecuteResponse object with all required attributes.""" + mock_execute_response = Mock() + mock_execute_response.command_id = Mock() + mock_execute_response.status = Mock() + mock_execute_response.description = Mock() + mock_execute_response.has_been_closed_server_side = Mock() + mock_execute_response.lz4_compressed = Mock() + mock_execute_response.is_staging_operation = Mock() + mock_execute_response.arrow_schema_bytes = Mock() + mock_execute_response.result_format = Mock() + return mock_execute_response + + def _create_fake_thrift_client(self): + """Create a fake ThriftDatabricksClient without mocking any methods.""" + return ThriftDatabricksClient( "foobar", 443, "path", @@ -109,10 +116,20 @@ def _make_fake_thrift_backend(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) + + def _make_fake_thrift_backend(self): + """Create a fake ThriftDatabricksClient with mocked methods.""" + thrift_backend = self._create_fake_thrift_client() thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() thrift_backend._create_arrow_table = MagicMock() thrift_backend._create_arrow_table.return_value = (MagicMock(), Mock()) + # Mock _results_message_to_execute_response to return a tuple + mock_execute_response = self._create_mock_execute_response() + mock_has_more_rows = True + thrift_backend._results_message_to_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) + ) return thrift_backend def test_hive_schema_to_arrow_schema_preserves_column_names(self): @@ -558,14 +575,7 @@ def test_make_request_checks_status_code(self): ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ] - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() for code in error_codes: mock_error_response = Mock() @@ -602,14 +612,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): closeOperation=None, ), ) - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -657,7 +660,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -832,14 +835,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(error_resp, Mock()) @@ -891,7 +887,7 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) @@ -921,21 +917,22 @@ def test_handle_execute_response_can_handle_with_direct_results(self): operationHandle=self.operation_handle, ) - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), + thrift_backend = self._make_fake_thrift_backend() + mock_execute_response = Mock(spec=ExecuteResponse) + mock_has_more_rows = True + thrift_backend._results_message_to_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) ) - thrift_backend._results_message_to_execute_response = Mock() - thrift_backend._handle_execute_response(execute_resp, Mock()) + result = thrift_backend._handle_execute_response(execute_resp, Mock()) thrift_backend._results_message_to_execute_response.assert_called_with( execute_resp, ttypes.TOperationState.FINISHED_STATE, ) + # Verify the result is a tuple with the expected values + self.assertIsInstance(result, tuple) + self.assertEqual(result[0], mock_execute_response) + self.assertEqual(result[1], mock_has_more_rows) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_use_arrow_schema_if_available(self, tcli_service_class): @@ -965,9 +962,12 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) - thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( - t_execute_resp, Mock() + + thrift_backend = self._create_fake_thrift_client() + + # Call the real _results_message_to_execute_response method + execute_response, _ = thrift_backend._results_message_to_execute_response( + t_execute_resp, ttypes.TOperationState.FINISHED_STATE ) self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @@ -997,8 +997,14 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req - thrift_backend = self._make_fake_thrift_backend() - thrift_backend._handle_execute_response(t_execute_resp, Mock()) + + thrift_backend = self._create_fake_thrift_client() + thrift_backend._hive_schema_to_arrow_schema = Mock() + + # Call the real _results_message_to_execute_response method + thrift_backend._results_message_to_execute_response( + t_execute_resp, ttypes.TOperationState.FINISHED_STATE + ) self.assertEqual( hive_schema_mock, @@ -1040,14 +1046,16 @@ def test_handle_execute_response_creates_execute_response( tcli_service_instance.GetResultSetMetadata.return_value = ( self.metadata_resp ) - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_fake_thrift_client() - execute_response = thrift_backend._handle_execute_response( + execute_response_tuple = thrift_backend._handle_execute_response( execute_resp, Mock() ) - self.assertIsNotNone(execute_response) - self.assertIsInstance(execute_response, ExecuteResponse) + self.assertIsNotNone(execute_response_tuple) + self.assertIsInstance(execute_response_tuple, tuple) + self.assertIsInstance(execute_response_tuple[0], ExecuteResponse) + self.assertIsInstance(execute_response_tuple[1], bool) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1178,7 +1186,11 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + mock_execute_response = self._create_mock_execute_response() + mock_has_more_rows = True + thrift_backend._handle_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) + ) cursor_mock = Mock() thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) @@ -1209,15 +1221,13 @@ def test_get_catalogs_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), + thrift_backend = self._make_fake_thrift_backend() + + mock_execute_response = self._create_mock_execute_response() + mock_has_more_rows = True + thrift_backend._handle_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) ) - thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) @@ -1245,15 +1255,12 @@ def test_get_schemas_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), + thrift_backend = self._make_fake_thrift_backend() + mock_execute_response = self._create_mock_execute_response() + mock_has_more_rows = True + thrift_backend._handle_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) ) - thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) @@ -1290,15 +1297,12 @@ def test_get_tables_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), + thrift_backend = self._make_fake_thrift_backend() + mock_execute_response = self._create_mock_execute_response() + mock_has_more_rows = True + thrift_backend._handle_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) ) - thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) @@ -1339,15 +1343,12 @@ def test_get_columns_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), + thrift_backend = self._make_fake_thrift_backend() + mock_execute_response = self._create_mock_execute_response() + mock_has_more_rows = True + thrift_backend._handle_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) ) - thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) @@ -1397,14 +1398,7 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() command_id = CommandId.from_thrift_handle(self.operation_handle) thrift_backend.close_command(command_id) self.assertEqual( @@ -1415,14 +1409,7 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() session_id = SessionId.from_thrift_handle(self.session_handle) thrift_backend.close_session(session_id) self.assertEqual( @@ -1458,7 +1445,8 @@ def test_non_arrow_non_column_based_set_triggers_exception( tcli_service_instance.ExecuteStatement.return_value = execute_statement_resp tcli_service_instance.GetResultSetMetadata.return_value = metadata_resp tcli_service_instance.GetOperationStatus.return_value = operation_status_resp - thrift_backend = self._make_fake_thrift_backend() + + thrift_backend = self._create_fake_thrift_client() with self.assertRaises(OperationalError) as cm: thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock()) @@ -1468,14 +1456,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) @@ -1488,14 +1469,7 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) @@ -1695,7 +1669,11 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock() + mock_execute_response = Mock(spec=ExecuteResponse) + mock_has_more_rows = True + thrift_backend._results_message_to_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) + ) # Create a mock response with a real operation handle mock_resp = Mock() @@ -2258,17 +2236,19 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): ) @patch( "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", - return_value=Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_more_rows=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - result_format=Mock(), + return_value=( + Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + result_format=Mock(), + ), + True, # has_more_rows ), ) def test_execute_command_sets_complex_type_fields_correctly( From 55f5c45a9fe18ac76839a4b8ff4955e58af18fe6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 12:38:24 +0000 Subject: [PATCH 41/66] remove unnecessary replacement Signed-off-by: varun-edachali-dbx --- tests/unit/test_fetches.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index ba9b50aef..a649941e1 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -64,9 +64,6 @@ def make_dummy_result_set_from_initial_results(initial_results): thrift_client=mock_thrift_backend, t_row_set=None, ) - - # Replace the results queue with our arrow_queue - rs.results = arrow_queue return rs @staticmethod From edc36b5540d178f6e52bc022eeb265122d6c7d81 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 12:41:12 +0000 Subject: [PATCH 42/66] better mocked backend naming Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 4f5e14cab..8582fd7f9 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -117,7 +117,7 @@ def _create_fake_thrift_client(self): ssl_options=SSLOptions(), ) - def _make_fake_thrift_backend(self): + def _create_mocked_thrift_client(self): """Create a fake ThriftDatabricksClient with mocked methods.""" thrift_backend = self._create_fake_thrift_client() thrift_backend._hive_schema_to_arrow_schema = Mock() @@ -184,7 +184,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): ) with self.assertRaises(OperationalError) as cm: - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() thrift_backend.open_session({}, None, None) self.assertIn( @@ -207,7 +207,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): sessionHandle=self.session_handle, ) - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() thrift_backend.open_session({}, None, None) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @@ -917,7 +917,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): operationHandle=self.operation_handle, ) - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() mock_execute_response = Mock(spec=ExecuteResponse) mock_has_more_rows = True thrift_backend._results_message_to_execute_response = Mock( @@ -1100,7 +1100,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( tcli_service_instance.GetResultSetMetadata.return_value = ( self.metadata_resp ) - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() thrift_backend._handle_execute_response(execute_resp, Mock()) _, has_more_rows_resp = thrift_backend.fetch_results( @@ -1221,7 +1221,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() mock_execute_response = self._create_mock_execute_response() mock_has_more_rows = True @@ -1255,7 +1255,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() mock_execute_response = self._create_mock_execute_response() mock_has_more_rows = True thrift_backend._handle_execute_response = Mock( @@ -1297,7 +1297,7 @@ def test_get_tables_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() mock_execute_response = self._create_mock_execute_response() mock_has_more_rows = True thrift_backend._handle_execute_response = Mock( @@ -1343,7 +1343,7 @@ def test_get_columns_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() mock_execute_response = self._create_mock_execute_response() mock_has_more_rows = True thrift_backend._handle_execute_response = Mock( @@ -1655,7 +1655,7 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): def test_cancel_command_uses_active_op_handle(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() # Create a proper CommandId from the existing operation_handle command_id = CommandId.from_thrift_handle(self.operation_handle) thrift_backend.cancel_command(command_id) @@ -1666,7 +1666,7 @@ def test_cancel_command_uses_active_op_handle(self, tcli_service_class): ) def test_handle_execute_response_sets_active_op_handle(self): - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() mock_execute_response = Mock(spec=ExecuteResponse) From 81280e701d52609a5ad59deab63d2e24012d2002 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 12:47:06 +0000 Subject: [PATCH 43/66] remove has_more_rows test in ExecuteResponse Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 47 ------------------------------- 1 file changed, 47 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 8582fd7f9..2054cb65a 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -990,7 +990,6 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response op_state = ttypes.TGetOperationStatusResp( status=self.okay_status, operationState=ttypes.TOperationState.FINISHED_STATE, @@ -1011,52 +1010,6 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): thrift_backend._hive_schema_to_arrow_schema.call_args[0][0], ) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) - @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - def test_handle_execute_response_creates_execute_response( - self, tcli_service_class, build_queue - ): - """Test that _handle_execute_response creates an ExecuteResponse object correctly.""" - for resp_type in self.execute_response_types: - with self.subTest(resp_type=resp_type): - tcli_service_instance = tcli_service_class.return_value - results_mock = Mock() - results_mock.startRowOffset = 0 - direct_results_message = ttypes.TSparkDirectResults( - operationStatus=ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ), - resultSetMetadata=self.metadata_resp, - resultSet=ttypes.TFetchResultsResp( - status=self.okay_status, - hasMoreRows=True, - results=results_mock, - ), - closeOperation=Mock(), - ) - execute_resp = resp_type( - status=self.okay_status, - directResults=direct_results_message, - operationHandle=self.operation_handle, - ) - - tcli_service_instance.GetResultSetMetadata.return_value = ( - self.metadata_resp - ) - thrift_backend = self._create_fake_thrift_client() - - execute_response_tuple = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) - - self.assertIsNotNone(execute_response_tuple) - self.assertIsInstance(execute_response_tuple, tuple) - self.assertIsInstance(execute_response_tuple[0], ExecuteResponse) - self.assertIsInstance(execute_response_tuple[1], bool) - @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) From c1d3be2fadc4d1aab3f63136ddcff6e2a4a1931a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 13:11:36 +0000 Subject: [PATCH 44/66] introduce replacement of original has_more_rows read test Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 78 +++++++++++++------------------ 1 file changed, 32 insertions(+), 46 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 2054cb65a..3bdf1434d 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -82,7 +82,7 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -106,7 +106,7 @@ def _create_mock_execute_response(self): mock_execute_response.result_format = Mock() return mock_execute_response - def _create_fake_thrift_client(self): + def _create_thrift_client(self): """Create a fake ThriftDatabricksClient without mocking any methods.""" return ThriftDatabricksClient( "foobar", @@ -119,7 +119,7 @@ def _create_fake_thrift_client(self): def _create_mocked_thrift_client(self): """Create a fake ThriftDatabricksClient with mocked methods.""" - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() thrift_backend._create_arrow_table = MagicMock() @@ -575,7 +575,7 @@ def test_make_request_checks_status_code(self): ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ] - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() for code in error_codes: mock_error_response = Mock() @@ -612,7 +612,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): closeOperation=None, ), ) - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -835,7 +835,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(error_resp, Mock()) @@ -963,7 +963,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): t_get_result_set_metadata_resp ) - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() # Call the real _results_message_to_execute_response method execute_response, _ = thrift_backend._results_message_to_execute_response( @@ -997,7 +997,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() thrift_backend._hive_schema_to_arrow_schema = Mock() # Call the real _results_message_to_execute_response method @@ -1014,7 +1014,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - def test_handle_execute_response_reads_has_more_rows_in_result_response( + def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): for has_more_rows, resp_type in itertools.product( @@ -1022,48 +1022,34 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( ): with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value - results_mock = MagicMock() + results_mock = Mock() results_mock.startRowOffset = 0 - - execute_resp = resp_type( - status=self.okay_status, - directResults=None, - operationHandle=self.operation_handle, - ) - - fetch_results_resp = ttypes.TFetchResultsResp( - status=self.okay_status, - hasMoreRows=has_more_rows, - results=results_mock, - resultSetMetadata=ttypes.TGetResultSetMetadataResp( - resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET + direct_results_message = ttypes.TSparkDirectResults( + operationStatus=ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, ), + resultSetMetadata=self.metadata_resp, + resultSet=ttypes.TFetchResultsResp( + status=self.okay_status, + hasMoreRows=has_more_rows, + results=results_mock, + ), + closeOperation=Mock(), ) - - operation_status_resp = ttypes.TGetOperationStatusResp( + execute_resp = resp_type( status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - errorMessage="some information about the error", + directResults=direct_results_message, + operationHandle=self.operation_handle, ) - tcli_service_instance.FetchResults.return_value = fetch_results_resp - tcli_service_instance.GetOperationStatus.return_value = ( - operation_status_resp - ) tcli_service_instance.GetResultSetMetadata.return_value = ( self.metadata_resp ) - thrift_backend = self._create_mocked_thrift_client() + thrift_backend = self._create_thrift_client() - thrift_backend._handle_execute_response(execute_resp, Mock()) - _, has_more_rows_resp = thrift_backend.fetch_results( - command_id=Mock(), - max_rows=1, - max_bytes=1, - expected_row_start_offset=0, - lz4_compressed=False, - arrow_schema_bytes=Mock(), - description=Mock(), + _, has_more_rows_resp = thrift_backend._handle_execute_response( + execute_resp, Mock() ) self.assertEqual(has_more_rows, has_more_rows_resp) @@ -1351,7 +1337,7 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() command_id = CommandId.from_thrift_handle(self.operation_handle) thrift_backend.close_command(command_id) self.assertEqual( @@ -1362,7 +1348,7 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() session_id = SessionId.from_thrift_handle(self.session_handle) thrift_backend.close_session(session_id) self.assertEqual( @@ -1399,7 +1385,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( tcli_service_instance.GetResultSetMetadata.return_value = metadata_resp tcli_service_instance.GetOperationStatus.return_value = operation_status_resp - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() with self.assertRaises(OperationalError) as cm: thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock()) @@ -1409,7 +1395,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) @@ -1422,7 +1408,7 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) From 5ee41367701696a2cd4f791a2633b374a36ced0c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 13:14:18 +0000 Subject: [PATCH 45/66] call correct method in test_use_arrow_schema Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 3bdf1434d..fc56feea6 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -966,8 +966,8 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): thrift_backend = self._create_thrift_client() # Call the real _results_message_to_execute_response method - execute_response, _ = thrift_backend._results_message_to_execute_response( - t_execute_resp, ttypes.TOperationState.FINISHED_STATE + execute_response, _ = thrift_backend._handle_execute_response( + t_execute_resp, Mock() ) self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) From b881ab0823f31d709c5d76aa00d9d051506eb835 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 13:15:41 +0000 Subject: [PATCH 46/66] call correct method in test_fall_back_to_hive_schema Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index fc56feea6..cbde1a29b 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -965,7 +965,6 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): thrift_backend = self._create_thrift_client() - # Call the real _results_message_to_execute_response method execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) @@ -1000,10 +999,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): thrift_backend = self._create_thrift_client() thrift_backend._hive_schema_to_arrow_schema = Mock() - # Call the real _results_message_to_execute_response method - thrift_backend._results_message_to_execute_response( - t_execute_resp, ttypes.TOperationState.FINISHED_STATE - ) + thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, From 53bf715a28e59043e7f692ee67b3ef5be36740a0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 13:17:54 +0000 Subject: [PATCH 47/66] re-introduce result response read test Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 58 +++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index cbde1a29b..b7922d729 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1050,6 +1050,64 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( self.assertEqual(has_more_rows, has_more_rows_resp) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + def test_handle_execute_response_reads_has_more_rows_in_result_response( + self, tcli_service_class, build_queue + ): + for has_more_rows, resp_type in itertools.product( + [True, False], self.execute_response_types + ): + with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + tcli_service_instance = tcli_service_class.return_value + results_mock = MagicMock() + results_mock.startRowOffset = 0 + + execute_resp = resp_type( + status=self.okay_status, + directResults=None, + operationHandle=self.operation_handle, + ) + + fetch_results_resp = ttypes.TFetchResultsResp( + status=self.okay_status, + hasMoreRows=has_more_rows, + results=results_mock, + resultSetMetadata=ttypes.TGetResultSetMetadataResp( + resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET + ), + ) + + operation_status_resp = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + errorMessage="some information about the error", + ) + + tcli_service_instance.FetchResults.return_value = fetch_results_resp + tcli_service_instance.GetOperationStatus.return_value = ( + operation_status_resp + ) + tcli_service_instance.GetResultSetMetadata.return_value = ( + self.metadata_resp + ) + thrift_backend = self._create_thrift_client() + + thrift_backend._handle_execute_response(execute_resp, Mock()) + _, has_more_rows_resp = thrift_backend.fetch_results( + command_id=Mock(), + max_rows=1, + max_bytes=1, + expected_row_start_offset=0, + lz4_compressed=False, + arrow_schema_bytes=Mock(), + description=Mock(), + ) + + self.assertEqual(has_more_rows, has_more_rows_resp) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): # make some semi-real arrow batches and check the number of rows is correct in the queue From 45a32be5915927bce570710e0375488580041bf8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 13:20:54 +0000 Subject: [PATCH 48/66] simplify test Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index b7922d729..c54fabf40 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -184,7 +184,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): ) with self.assertRaises(OperationalError) as cm: - thrift_backend = self._create_mocked_thrift_client() + thrift_backend = self._create_thrift_client() thrift_backend.open_session({}, None, None) self.assertIn( @@ -207,7 +207,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): sessionHandle=self.session_handle, ) - thrift_backend = self._create_mocked_thrift_client() + thrift_backend = self._create_thrift_client() thrift_backend.open_session({}, None, None) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @@ -918,21 +918,12 @@ def test_handle_execute_response_can_handle_with_direct_results(self): ) thrift_backend = self._create_mocked_thrift_client() - mock_execute_response = Mock(spec=ExecuteResponse) - mock_has_more_rows = True - thrift_backend._results_message_to_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) - ) - result = thrift_backend._handle_execute_response(execute_resp, Mock()) + thrift_backend._handle_execute_response(execute_resp, Mock()) thrift_backend._results_message_to_execute_response.assert_called_with( execute_resp, ttypes.TOperationState.FINISHED_STATE, ) - # Verify the result is a tuple with the expected values - self.assertIsInstance(result, tuple) - self.assertEqual(result[0], mock_execute_response) - self.assertEqual(result[1], mock_has_more_rows) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_use_arrow_schema_if_available(self, tcli_service_class): From e3fe29979743c14099e9d7f88daf2b3f750121a8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 13:35:16 +0000 Subject: [PATCH 49/66] remove excess fetch_results mocks Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 2 -- tests/unit/test_thrift_backend.py | 12 ------------ 2 files changed, 14 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 63bc92fdc..1f0c34025 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -213,7 +213,6 @@ def test_closing_result_set_hard_closes_commands(self): type(mock_connection).session = PropertyMock(return_value=mock_session) mock_thrift_backend.fetch_results.return_value = (Mock(), False) - result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -479,7 +478,6 @@ def make_fake_row_slice(n_rows): mock_aq = Mock() mock_aq.next_n_rows.side_effect = make_fake_row_slice mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq - mock_thrift_backend.fetch_results.return_value = (mock_aq, True) cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.execute("foo") diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index c54fabf40..7a59c6256 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1177,8 +1177,6 @@ def test_execute_statement_calls_client_and_handle_execute_response( ) cursor_mock = Mock() - thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) - result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) @@ -1214,8 +1212,6 @@ def test_get_catalogs_calls_client_and_handle_execute_response( ) cursor_mock = Mock() - thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) - result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet self.assertIsInstance(result, ResultSet) @@ -1247,8 +1243,6 @@ def test_get_schemas_calls_client_and_handle_execute_response( ) cursor_mock = Mock() - thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) - result = thrift_backend.get_schemas( Mock(), 100, @@ -1289,8 +1283,6 @@ def test_get_tables_calls_client_and_handle_execute_response( ) cursor_mock = Mock() - thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) - result = thrift_backend.get_tables( Mock(), 100, @@ -1335,8 +1327,6 @@ def test_get_columns_calls_client_and_handle_execute_response( ) cursor_mock = Mock() - thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) - result = thrift_backend.get_columns( Mock(), 100, @@ -2261,8 +2251,6 @@ def test_execute_command_sets_complex_type_fields_correctly( **complex_arg_types, ) - thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) - thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[ 0 From e8038d3ac07ebc368f30f6c9102e578691891c75 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 15:25:19 +0000 Subject: [PATCH 50/66] more minimal changes to thrift_backend tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 347 ++++++++++++++++-------------- 1 file changed, 183 insertions(+), 164 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 7a59c6256..5d9da0e13 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -19,13 +19,7 @@ from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.result_set import ResultSet, ThriftResultSet -from databricks.sql.backend.types import ( - CommandId, - CommandState, - SessionId, - BackendType, - ExecuteResponse, -) +from databricks.sql.backend.types import CommandId, CommandState, SessionId, BackendType def retry_policy_factory(): @@ -82,7 +76,14 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -93,22 +94,8 @@ def _make_type_desc(self, type): ] ) - def _create_mock_execute_response(self): - """Create a properly mocked ExecuteResponse object with all required attributes.""" - mock_execute_response = Mock() - mock_execute_response.command_id = Mock() - mock_execute_response.status = Mock() - mock_execute_response.description = Mock() - mock_execute_response.has_been_closed_server_side = Mock() - mock_execute_response.lz4_compressed = Mock() - mock_execute_response.is_staging_operation = Mock() - mock_execute_response.arrow_schema_bytes = Mock() - mock_execute_response.result_format = Mock() - return mock_execute_response - - def _create_thrift_client(self): - """Create a fake ThriftDatabricksClient without mocking any methods.""" - return ThriftDatabricksClient( + def _make_fake_thrift_backend(self): + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -116,20 +103,10 @@ def _create_thrift_client(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - - def _create_mocked_thrift_client(self): - """Create a fake ThriftDatabricksClient with mocked methods.""" - thrift_backend = self._create_thrift_client() thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() thrift_backend._create_arrow_table = MagicMock() thrift_backend._create_arrow_table.return_value = (MagicMock(), Mock()) - # Mock _results_message_to_execute_response to return a tuple - mock_execute_response = self._create_mock_execute_response() - mock_has_more_rows = True - thrift_backend._results_message_to_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) - ) return thrift_backend def test_hive_schema_to_arrow_schema_preserves_column_names(self): @@ -184,7 +161,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): ) with self.assertRaises(OperationalError) as cm: - thrift_backend = self._create_thrift_client() + thrift_backend = self._make_fake_thrift_backend() thrift_backend.open_session({}, None, None) self.assertIn( @@ -207,7 +184,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): sessionHandle=self.session_handle, ) - thrift_backend = self._create_thrift_client() + thrift_backend = self._make_fake_thrift_backend() thrift_backend.open_session({}, None, None) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @@ -575,7 +552,14 @@ def test_make_request_checks_status_code(self): ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ] - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) for code in error_codes: mock_error_response = Mock() @@ -612,7 +596,14 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): closeOperation=None, ), ) - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -628,18 +619,14 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 - - # Create a valid operation status - op_status = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=op_status, + operationStatus=ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ), resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -835,15 +822,23 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value @@ -887,10 +882,10 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) - + ( + execute_response, + arrow_schema_bytes, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual( execute_response.status, CommandState.SUCCEEDED, @@ -917,9 +912,18 @@ def test_handle_execute_response_can_handle_with_direct_results(self): operationHandle=self.operation_handle, ) - thrift_backend = self._create_mocked_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) + thrift_backend._results_message_to_execute_response.assert_called_with( execute_resp, ttypes.TOperationState.FINISHED_STATE, @@ -944,18 +948,16 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) - - thrift_backend = self._create_thrift_client() - + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) + thrift_backend = self._make_fake_thrift_backend() execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) @@ -980,17 +982,17 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req - - thrift_backend = self._create_thrift_client() - thrift_backend._hive_schema_to_arrow_schema = Mock() - - thrift_backend._handle_execute_response(t_execute_resp, Mock()) + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) + thrift_backend = self._make_fake_thrift_backend() + execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( + t_execute_resp, Mock() + ) self.assertEqual( hive_schema_mock, @@ -1033,13 +1035,14 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( tcli_service_instance.GetResultSetMetadata.return_value = ( self.metadata_resp ) - thrift_backend = self._create_thrift_client() + thrift_backend = self._make_fake_thrift_backend() - _, has_more_rows_resp = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) + ( + execute_response, + has_more_rows_result, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) - self.assertEqual(has_more_rows, has_more_rows_resp) + self.assertEqual(has_more_rows, has_more_rows_result) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1084,7 +1087,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( tcli_service_instance.GetResultSetMetadata.return_value = ( self.metadata_resp ) - thrift_backend = self._create_thrift_client() + thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(execute_resp, Mock()) _, has_more_rows_resp = thrift_backend.fetch_results( @@ -1152,12 +1155,10 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) def test_execute_statement_calls_client_and_handle_execute_response( - self, mock_build_queue, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1170,18 +1171,15 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - mock_execute_response = self._create_mock_execute_response() - mock_has_more_rows = True - thrift_backend._handle_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) - ) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] @@ -1193,28 +1191,29 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) def test_get_catalogs_calls_client_and_handle_execute_response( - self, mock_build_queue, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = self._create_mocked_thrift_client() - - mock_execute_response = self._create_mock_execute_response() - mock_has_more_rows = True - thrift_backend._handle_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] @@ -1225,22 +1224,24 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) def test_get_schemas_calls_client_and_handle_execute_response( - self, mock_build_queue, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = self._create_mocked_thrift_client() - mock_execute_response = self._create_mock_execute_response() - mock_has_more_rows = True - thrift_backend._handle_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1252,7 +1253,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( schema_name="schema_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] @@ -1265,22 +1266,24 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) def test_get_tables_calls_client_and_handle_execute_response( - self, mock_build_queue, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = self._create_mocked_thrift_client() - mock_execute_response = self._create_mock_execute_response() - mock_has_more_rows = True - thrift_backend._handle_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1291,10 +1294,10 @@ def test_get_tables_calls_client_and_handle_execute_response( catalog_name="catalog_pattern", schema_name="schema_pattern", table_name="table_pattern", - table_types=["type1", "type2"], + table_types=["VIEW", "TABLE"], ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] @@ -1303,28 +1306,30 @@ def test_get_tables_calls_client_and_handle_execute_response( self.assertEqual(req.catalogName, "catalog_pattern") self.assertEqual(req.schemaName, "schema_pattern") self.assertEqual(req.tableName, "table_pattern") - self.assertEqual(req.tableTypes, ["type1", "type2"]) + self.assertEqual(req.tableTypes, ["VIEW", "TABLE"]) # Check response handling thrift_backend._handle_execute_response.assert_called_with( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) def test_get_columns_calls_client_and_handle_execute_response( - self, mock_build_queue, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = self._create_mocked_thrift_client() - mock_execute_response = self._create_mock_execute_response() - mock_has_more_rows = True - thrift_backend._handle_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1338,7 +1343,7 @@ def test_get_columns_calls_client_and_handle_execute_response( column_name="column_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] @@ -1372,7 +1377,14 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) command_id = CommandId.from_thrift_handle(self.operation_handle) thrift_backend.close_command(command_id) self.assertEqual( @@ -1383,7 +1395,14 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) session_id = SessionId.from_thrift_handle(self.session_handle) thrift_backend.close_session(session_id) self.assertEqual( @@ -1419,8 +1438,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( tcli_service_instance.ExecuteStatement.return_value = execute_statement_resp tcli_service_instance.GetResultSetMetadata.return_value = metadata_resp tcli_service_instance.GetOperationStatus.return_value = operation_status_resp - - thrift_backend = self._create_thrift_client() + thrift_backend = self._make_fake_thrift_backend() with self.assertRaises(OperationalError) as cm: thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock()) @@ -1430,7 +1448,14 @@ def test_non_arrow_non_column_based_set_triggers_exception( def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) @@ -1443,7 +1468,14 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) @@ -1629,7 +1661,7 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): def test_cancel_command_uses_active_op_handle(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = self._create_mocked_thrift_client() + thrift_backend = self._make_fake_thrift_backend() # Create a proper CommandId from the existing operation_handle command_id = CommandId.from_thrift_handle(self.operation_handle) thrift_backend.cancel_command(command_id) @@ -1640,14 +1672,10 @@ def test_cancel_command_uses_active_op_handle(self, tcli_service_class): ) def test_handle_execute_response_sets_active_op_handle(self): - thrift_backend = self._create_mocked_thrift_client() + thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - mock_execute_response = Mock(spec=ExecuteResponse) - mock_has_more_rows = True - thrift_backend._results_message_to_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) - ) + thrift_backend._results_message_to_execute_response = Mock() # Create a mock response with a real operation handle mock_resp = Mock() @@ -2204,31 +2232,23 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) - @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", - return_value=( - Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - result_format=Mock(), - ), - True, # has_more_rows - ), + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, mock_build_queue, tcli_service_class + self, mock_handle_execute_response, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value + # Set up the mock to return a tuple with two values + mock_execute_response = Mock() + mock_arrow_schema = Mock() + mock_handle_execute_response.return_value = ( + mock_execute_response, + mock_arrow_schema, + ) + # Iterate through each possible combination of native types (True, False and unset) for complex, timestamp, decimals in itertools.product( [True, False, None], [True, False, None], [True, False, None] @@ -2250,7 +2270,6 @@ def test_execute_command_sets_complex_type_fields_correctly( ssl_options=SSLOptions(), **complex_arg_types, ) - thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[ 0 From 2f6ec19b29dc0bffced7e96ec2ef596880aa7193 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 15:33:48 +0000 Subject: [PATCH 51/66] move back to old table types Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 5d9da0e13..61b96e523 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1294,7 +1294,7 @@ def test_get_tables_calls_client_and_handle_execute_response( catalog_name="catalog_pattern", schema_name="schema_pattern", table_name="table_pattern", - table_types=["VIEW", "TABLE"], + table_types=["type1", "type2"], ) # Verify the result is a ResultSet self.assertEqual(result, mock_result_set.return_value) @@ -1306,7 +1306,7 @@ def test_get_tables_calls_client_and_handle_execute_response( self.assertEqual(req.catalogName, "catalog_pattern") self.assertEqual(req.schemaName, "schema_pattern") self.assertEqual(req.tableName, "table_pattern") - self.assertEqual(req.tableTypes, ["VIEW", "TABLE"]) + self.assertEqual(req.tableTypes, ["type1", "type2"]) # Check response handling thrift_backend._handle_execute_response.assert_called_with( response, cursor_mock From 73bc28267f83656b7d7f82cab77721cf93ef013f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 15:35:14 +0000 Subject: [PATCH 52/66] remove outdated arrow_schema_bytes return Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 61b96e523..a05e8cb87 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -884,7 +884,7 @@ def test_handle_execute_response_can_handle_without_direct_results( ) ( execute_response, - arrow_schema_bytes, + _, ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual( execute_response.status, @@ -990,9 +990,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) ) thrift_backend = self._make_fake_thrift_backend() - execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( - t_execute_resp, Mock() - ) + _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, From 4e07f1ee60a163e5fd623b28ad703ffde1bf0ce2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 05:02:24 +0000 Subject: [PATCH 53/66] align SeaResultSet with new structure Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 58 +++++++++++++++++++++++-------- tests/unit/test_sea_result_set.py | 2 +- 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 7ea4976c1..d6f6be3bd 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -19,7 +19,7 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.utils import ColumnTable, ColumnQueue, JsonQueue from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -41,10 +41,11 @@ def __init__( command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, - has_more_rows: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, + lz4_compressed: bool = False, + arrow_schema_bytes: bytes = b"", ): """ A ResultSet manages the results of a single command. @@ -72,9 +73,10 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self.has_more_rows = has_more_rows self.results = results_queue self._is_staging_operation = is_staging_operation + self.lz4_compressed = lz4_compressed + self._arrow_schema_bytes = arrow_schema_bytes def __iter__(self): while True: @@ -157,7 +159,10 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, - arrow_schema_bytes: Optional[bytes] = None, + t_row_set=None, + max_download_threads: int = 10, + ssl_options=None, + has_more_rows: bool = True, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -169,12 +174,30 @@ def __init__( buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch use_cloud_fetch: Whether to use cloud fetch for retrieving results - arrow_schema_bytes: Arrow schema bytes for the result set + t_row_set: The TRowSet containing result data (if available) + max_download_threads: Maximum number of download threads for cloud fetch + ssl_options: SSL options for cloud fetch + has_more_rows: Whether there are more rows to fetch """ # Initialize ThriftResultSet-specific attributes - self._arrow_schema_bytes = arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch - self.lz4_compressed = execute_response.lz4_compressed + self.has_more_rows = has_more_rows + + # Build the results queue if t_row_set is provided + results_queue = None + if t_row_set and execute_response.result_format is not None: + from databricks.sql.utils import ThriftResultSetQueueFactory + + # Create the results queue using the provided format + results_queue = ThriftResultSetQueueFactory.build_queue( + row_set_type=execute_response.result_format, + t_row_set=t_row_set, + arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", + max_download_threads=max_download_threads, + lz4_compressed=execute_response.lz4_compressed, + description=execute_response.description, + ssl_options=ssl_options, + ) # Call parent constructor with common attributes super().__init__( @@ -185,10 +208,11 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, + results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, + lz4_compressed=execute_response.lz4_compressed, + arrow_schema_bytes=execute_response.arrow_schema_bytes, ) # Initialize results queue if not provided @@ -419,7 +443,7 @@ def map_col_type(type_): class SeaResultSet(ResultSet): - """ResultSet implementation for the SEA backend.""" + """ResultSet implementation for SEA backend.""" def __init__( self, @@ -428,17 +452,20 @@ def __init__( sea_client: "SeaDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, + result_data=None, + manifest=None, ): """ Initialize a SeaResultSet with the response from a SEA query execution. Args: connection: The parent connection + execute_response: Response from the execute command sea_client: The SeaDatabricksClient instance for direct access buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) + result_data: Result data from SEA response (optional) + manifest: Manifest from SEA response (optional) """ super().__init__( @@ -449,15 +476,15 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, + lz4_compressed=execute_response.lz4_compressed, + arrow_schema_bytes=execute_response.arrow_schema_bytes, ) def _fill_results_buffer(self): """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") def fetchone(self) -> Optional[Row]: """ @@ -480,6 +507,7 @@ def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result, returning them as a list of rows. """ + raise NotImplementedError("fetchall is not implemented for SEA backend") def fetchmany_arrow(self, size: int) -> Any: diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 02421a915..b691872af 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -197,4 +197,4 @@ def test_fill_results_buffer_not_implemented( with pytest.raises( NotImplementedError, match="fetchone is not implemented for SEA backend" ): - result_set._fill_results_buffer() \ No newline at end of file + result_set._fill_results_buffer() From 65e7c6be97f94e6db0031c1501ebcb7f0c43fc9c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 05:05:25 +0000 Subject: [PATCH 54/66] correct sea res set tests Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 6 ++++-- tests/unit/test_sea_result_set.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index d6f6be3bd..3ff0cc378 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -19,7 +19,7 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ColumnTable, ColumnQueue, JsonQueue +from databricks.sql.utils import ColumnTable, ColumnQueue from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -484,7 +484,9 @@ def __init__( def _fill_results_buffer(self): """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") + raise NotImplementedError( + "_fill_results_buffer is not implemented for SEA backend" + ) def fetchone(self) -> Optional[Row]: """ diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index b691872af..d5d8a3667 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -195,6 +195,7 @@ def test_fill_results_buffer_not_implemented( ) with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" + NotImplementedError, + match="_fill_results_buffer is not implemented for SEA backend", ): result_set._fill_results_buffer() From 7c483f26579d88b9af646014ec8e7e42fc895ce1 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 06:37:20 +0000 Subject: [PATCH 55/66] remove duplicate import Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index ffbd2885e..2c6e968cb 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -1224,8 +1224,6 @@ def fetch_results( ) ) - from databricks.sql.utils import ResultSetQueueFactory - queue = ResultSetQueueFactory.build_queue( row_set_type=resp.resultSetMetadata.resultFormat, t_row_set=resp.results, From 8cbeb08e6ea1a78ef4be00964325d2768a225648 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 11:08:21 +0000 Subject: [PATCH 56/66] rephrase model docstrings to explicitly denote that they are representations and not used over the wire Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/models/requests.py | 4 ++-- src/databricks/sql/backend/sea/models/responses.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 7966cb502..3175132bd 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -4,7 +4,7 @@ @dataclass class CreateSessionRequest: - """Request to create a new session.""" + """Representation of a request to create a new session.""" warehouse_id: str session_confs: Optional[Dict[str, str]] = None @@ -29,7 +29,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class DeleteSessionRequest: - """Request to delete a session.""" + """Representation of a request to delete a session.""" warehouse_id: str session_id: str diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1bb54590f..4eeb9eef7 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -4,7 +4,7 @@ @dataclass class CreateSessionResponse: - """Response from creating a new session.""" + """Representation of the response from creating a new session.""" session_id: str From 36b9cfb24c7caac10230334cfbf7ea6afd5511b3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 12:08:01 +0000 Subject: [PATCH 57/66] has_more_rows -> is_direct_results Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 36 ++++++++++++-------- src/databricks/sql/result_set.py | 24 ++++++------- tests/unit/test_client.py | 2 +- tests/unit/test_fetches_bench.py | 2 +- tests/unit/test_thrift_backend.py | 16 ++++----- 5 files changed, 44 insertions(+), 36 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 2c6e968cb..e824de1c2 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -758,7 +758,7 @@ def _results_message_to_execute_response(self, resp, operation_state): direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation - has_more_rows = ( + is_direct_results = ( (not direct_results) or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows @@ -796,7 +796,7 @@ def _results_message_to_execute_response(self, resp, operation_state): result_format=t_result_set_metadata_resp.resultFormat, ) - return execute_response, has_more_rows + return execute_response, is_direct_results def get_execution_result( self, command_id: CommandId, cursor: "Cursor" @@ -838,7 +838,7 @@ def get_execution_result( lz4_compressed = t_result_set_metadata_resp.lz4Compressed is_staging_operation = t_result_set_metadata_resp.isStagingOperation - has_more_rows = resp.hasMoreRows + is_direct_results = resp.hasMoreRows status = self.get_query_state(command_id) @@ -863,7 +863,7 @@ def get_execution_result( t_row_set=resp.results, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, - has_more_rows=has_more_rows, + is_direct_results=is_direct_results, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -976,7 +976,7 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, has_more_rows = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) @@ -994,7 +994,7 @@ def execute_command( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, - has_more_rows=has_more_rows, + is_direct_results=is_direct_results, ) def get_catalogs( @@ -1016,7 +1016,9 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, has_more_rows = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1032,7 +1034,7 @@ def get_catalogs( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, - has_more_rows=has_more_rows, + is_direct_results=is_direct_results, ) def get_schemas( @@ -1058,7 +1060,9 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, has_more_rows = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1074,7 +1078,7 @@ def get_schemas( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, - has_more_rows=has_more_rows, + is_direct_results=is_direct_results, ) def get_tables( @@ -1104,7 +1108,9 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, has_more_rows = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1120,7 +1126,7 @@ def get_tables( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, - has_more_rows=has_more_rows, + is_direct_results=is_direct_results, ) def get_columns( @@ -1150,7 +1156,9 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, has_more_rows = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1166,7 +1174,7 @@ def get_columns( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, - has_more_rows=has_more_rows, + is_direct_results=is_direct_results, ) def _handle_execute_response(self, resp, cursor): diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 9857d9e0f..3b3a13ba3 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -41,7 +41,7 @@ def __init__( command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, - has_more_rows: bool = False, + is_direct_results: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, @@ -57,7 +57,7 @@ def __init__( command_id: The command ID status: The command status has_been_closed_server_side: Whether the command has been closed on the server - has_more_rows: Whether the command has more rows + is_direct_results: Whether the command has more rows results_queue: The results queue description: column description of the results is_staging_operation: Whether the command is a staging operation @@ -72,7 +72,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self.has_more_rows = has_more_rows + self.is_direct_results = is_direct_results self.results = results_queue self._is_staging_operation = is_staging_operation @@ -160,7 +160,7 @@ def __init__( t_row_set=None, max_download_threads: int = 10, ssl_options=None, - has_more_rows: bool = True, + is_direct_results: bool = True, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -175,7 +175,7 @@ def __init__( t_row_set: The TRowSet containing result data (if available) max_download_threads: Maximum number of download threads for cloud fetch ssl_options: SSL options for cloud fetch - has_more_rows: Whether there are more rows to fetch + is_direct_results: Whether there are more rows to fetch """ # Initialize ThriftResultSet-specific attributes self._arrow_schema_bytes = execute_response.arrow_schema_bytes @@ -207,7 +207,7 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=has_more_rows, + is_direct_results=is_direct_results, results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, @@ -218,7 +218,7 @@ def __init__( self._fill_results_buffer() def _fill_results_buffer(self): - results, has_more_rows = self.backend.fetch_results( + results, is_direct_results = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -229,7 +229,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self.is_direct_results = is_direct_results def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -313,7 +313,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self.is_direct_results ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -338,7 +338,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self.is_direct_results ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -353,7 +353,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self.is_direct_results: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -379,7 +379,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self.is_direct_results: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 1f0c34025..2054d01d1 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -48,7 +48,7 @@ def new(cls): is_staging_operation=False, command_id=None, has_been_closed_server_side=True, - has_more_rows=True, + is_direct_results=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index 7e025cf82..e4a9e5cdd 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -36,7 +36,7 @@ def make_dummy_result_set_from_initial_results(arrow_table): execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, - has_more_rows=False, + is_direct_results=False, description=Mock(), command_id=None, arrow_queue=arrow_queue, diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index a05e8cb87..57b5e9b58 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1004,10 +1004,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1019,7 +1019,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, ), closeOperation=Mock(), @@ -1040,7 +1040,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( has_more_rows_result, ) = thrift_backend._handle_execute_response(execute_resp, Mock()) - self.assertEqual(has_more_rows, has_more_rows_result) + self.assertEqual(is_direct_results, has_more_rows_result) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1049,10 +1049,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = MagicMock() results_mock.startRowOffset = 0 @@ -1065,7 +1065,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( fetch_results_resp = ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET @@ -1098,7 +1098,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( description=Mock(), ) - self.assertEqual(has_more_rows, has_more_rows_resp) + self.assertEqual(is_direct_results, has_more_rows_resp) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): From c04d583cc0e6211fbe4853ddefdcfcd6acfadcf2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 12:12:12 +0000 Subject: [PATCH 58/66] switch docstring format to align with Connection class Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 46 ++++++++++++++++---------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 3b3a13ba3..cf6940bb2 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -49,18 +49,18 @@ def __init__( """ A ResultSet manages the results of a single command. - Args: - connection: The parent connection - backend: The backend client - arraysize: The max number of rows to fetch at a time (PEP-249) - buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - command_id: The command ID - status: The command status - has_been_closed_server_side: Whether the command has been closed on the server - is_direct_results: Whether the command has more rows - results_queue: The results queue - description: column description of the results - is_staging_operation: Whether the command is a staging operation + Parameters: + :param connection: The parent connection + :param backend: The backend client + :param arraysize: The max number of rows to fetch at a time (PEP-249) + :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + :param command_id: The command ID + :param status: The command status + :param has_been_closed_server_side: Whether the command has been closed on the server + :param is_direct_results: Whether the command has more rows + :param results_queue: The results queue + :param description: column description of the results + :param is_staging_operation: Whether the command is a staging operation """ self.connection = connection @@ -165,17 +165,17 @@ def __init__( """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. - Args: - connection: The parent connection - execute_response: Response from the execute command - thrift_client: The ThriftDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - use_cloud_fetch: Whether to use cloud fetch for retrieving results - t_row_set: The TRowSet containing result data (if available) - max_download_threads: Maximum number of download threads for cloud fetch - ssl_options: SSL options for cloud fetch - is_direct_results: Whether there are more rows to fetch + Parameters: + :param connection: The parent connection + :param execute_response: Response from the execute command + :param thrift_client: The ThriftDatabricksClient instance for direct access + :param buffer_size_bytes: Buffer size for fetching results + :param arraysize: Default number of rows to fetch + :param use_cloud_fetch: Whether to use cloud fetch for retrieving results + :param t_row_set: The TRowSet containing result data (if available) + :param max_download_threads: Maximum number of download threads for cloud fetch + :param ssl_options: SSL options for cloud fetch + :param is_direct_results: Whether there are more rows to fetch """ # Initialize ThriftResultSet-specific attributes self._arrow_schema_bytes = execute_response.arrow_schema_bytes From ed7079e378b64851008f48f26ce288c2d735aaf7 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 12:48:37 +0000 Subject: [PATCH 59/66] has_more_rows -> is_direct_results Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 2 +- tests/unit/test_sea_result_set.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 5f0403f5b..6519ecefc 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -184,7 +184,7 @@ def __init__( # Initialize ThriftResultSet-specific attributes self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch - self.has_more_rows = has_more_rows + self.is_direct_results = is_direct_results # Build the results queue if t_row_set is provided results_queue = None diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index d5d8a3667..c596dbc14 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -34,7 +34,7 @@ def execute_response(self): mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") mock_response.status = CommandState.SUCCEEDED mock_response.has_been_closed_server_side = False - mock_response.has_more_rows = False + mock_response.is_direct_results = False mock_response.results_queue = None mock_response.description = [ ("test_value", "INT", None, None, None, None, None) From 0384b659228291964b35fb9490900a9ff6c1bf33 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 12:52:07 +0000 Subject: [PATCH 60/66] fix type errors with arrow_schema_bytes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 6519ecefc..2237e1f47 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -46,7 +46,7 @@ def __init__( description=None, is_staging_operation: bool = False, lz4_compressed: bool = False, - arrow_schema_bytes: bytes = b"", + arrow_schema_bytes: Optional[bytes] = b"", ): """ A ResultSet manages the results of a single command. @@ -182,17 +182,16 @@ def __init__( :param is_direct_results: Whether there are more rows to fetch """ # Initialize ThriftResultSet-specific attributes - self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch self.is_direct_results = is_direct_results # Build the results queue if t_row_set is provided results_queue = None if t_row_set and execute_response.result_format is not None: - from databricks.sql.utils import ThriftResultSetQueueFactory + from databricks.sql.utils import ResultSetQueueFactory # Create the results queue using the provided format - results_queue = ThriftResultSetQueueFactory.build_queue( + results_queue = ResultSetQueueFactory.build_queue( row_set_type=execute_response.result_format, t_row_set=t_row_set, arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", From 218e54728334c4f0abddc6d088a4616b16f75bd6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 12:57:58 +0000 Subject: [PATCH 61/66] spaces after multi line pydocs Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 2237e1f47..badf651f8 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -181,6 +181,7 @@ def __init__( :param ssl_options: SSL options for cloud fetch :param is_direct_results: Whether there are more rows to fetch """ + # Initialize ThriftResultSet-specific attributes self._use_cloud_fetch = use_cloud_fetch self.is_direct_results = is_direct_results @@ -325,6 +326,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": An empty sequence is returned when no more rows are available. """ + if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) results = self.results.next_n_rows(size) @@ -349,6 +351,7 @@ def fetchmany_columnar(self, size: int): Fetch the next set of rows of a query result, returning a Columnar Table. An empty sequence is returned when no more rows are available. """ + if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) @@ -413,6 +416,7 @@ def fetchone(self) -> Optional[Row]: Fetch the next row of a query result set, returning a single sequence, or None when no more data is available. """ + if isinstance(self.results, ColumnQueue): res = self._convert_columnar_table(self.fetchmany_columnar(1)) else: @@ -427,6 +431,7 @@ def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result, returning them as a list of rows. """ + if isinstance(self.results, ColumnQueue): return self._convert_columnar_table(self.fetchall_columnar()) else: @@ -438,6 +443,7 @@ def fetchmany(self, size: int) -> List[Row]: An empty sequence is returned when no more rows are available. """ + if isinstance(self.results, ColumnQueue): return self._convert_columnar_table(self.fetchmany_columnar(size)) else: From a6788f82ccc21ef36b8cc6beea24501608e5d779 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:00:59 +0000 Subject: [PATCH 62/66] remove duplicate queue init (merge artifact) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index badf651f8..f685f6cdf 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -186,22 +186,6 @@ def __init__( self._use_cloud_fetch = use_cloud_fetch self.is_direct_results = is_direct_results - # Build the results queue if t_row_set is provided - results_queue = None - if t_row_set and execute_response.result_format is not None: - from databricks.sql.utils import ResultSetQueueFactory - - # Create the results queue using the provided format - results_queue = ResultSetQueueFactory.build_queue( - row_set_type=execute_response.result_format, - t_row_set=t_row_set, - arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", - max_download_threads=max_download_threads, - lz4_compressed=execute_response.lz4_compressed, - description=execute_response.description, - ssl_options=ssl_options, - ) - # Build the results queue if t_row_set is provided results_queue = None if t_row_set and execute_response.result_format is not None: From 93468e689c673c2fbfe55c8e1b1b8512e7ce8aa4 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:02:12 +0000 Subject: [PATCH 63/66] reduce diff (remove newlines) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index f685f6cdf..a71ddf7cd 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -310,7 +310,6 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": An empty sequence is returned when no more rows are available. """ - if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) results = self.results.next_n_rows(size) @@ -335,7 +334,6 @@ def fetchmany_columnar(self, size: int): Fetch the next set of rows of a query result, returning a Columnar Table. An empty sequence is returned when no more rows are available. """ - if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) @@ -400,7 +398,6 @@ def fetchone(self) -> Optional[Row]: Fetch the next row of a query result set, returning a single sequence, or None when no more data is available. """ - if isinstance(self.results, ColumnQueue): res = self._convert_columnar_table(self.fetchmany_columnar(1)) else: @@ -415,7 +412,6 @@ def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result, returning them as a list of rows. """ - if isinstance(self.results, ColumnQueue): return self._convert_columnar_table(self.fetchall_columnar()) else: @@ -427,7 +423,6 @@ def fetchmany(self, size: int) -> List[Row]: An empty sequence is returned when no more rows are available. """ - if isinstance(self.results, ColumnQueue): return self._convert_columnar_table(self.fetchmany_columnar(size)) else: From a70a6cee277db44d6951604e890f91cae9f92f32 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:14:51 +0000 Subject: [PATCH 64/66] remove un-necessary changes covered by #588 anyway Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 89 +------------------------------- 1 file changed, 2 insertions(+), 87 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a71ddf7cd..cf6940bb2 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -45,8 +45,6 @@ def __init__( results_queue=None, description=None, is_staging_operation: bool = False, - lz4_compressed: bool = False, - arrow_schema_bytes: Optional[bytes] = b"", ): """ A ResultSet manages the results of a single command. @@ -77,8 +75,6 @@ def __init__( self.is_direct_results = is_direct_results self.results = results_queue self._is_staging_operation = is_staging_operation - self.lz4_compressed = lz4_compressed - self._arrow_schema_bytes = arrow_schema_bytes def __iter__(self): while True: @@ -181,10 +177,10 @@ def __init__( :param ssl_options: SSL options for cloud fetch :param is_direct_results: Whether there are more rows to fetch """ - # Initialize ThriftResultSet-specific attributes + self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch - self.is_direct_results = is_direct_results + self.lz4_compressed = execute_response.lz4_compressed # Build the results queue if t_row_set is provided results_queue = None @@ -215,8 +211,6 @@ def __init__( results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, - lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, ) # Initialize results queue if not provided @@ -444,82 +438,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for SEA backend.""" - - def __init__( - self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - result_data=None, - manifest=None, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - execute_response: Response from the execute command - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - result_data: Result data from SEA response (optional) - manifest: Manifest from SEA response (optional) - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError( - "_fill_results_buffer is not implemented for SEA backend" - ) - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") From de181d81af244cbbe2fd4c9a1adc41f784282120 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:15:54 +0000 Subject: [PATCH 65/66] Revert "remove un-necessary changes" This reverts commit a70a6cee277db44d6951604e890f91cae9f92f32. Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 89 +++++++++++++++++++++++++++++++- 1 file changed, 87 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index cf6940bb2..a71ddf7cd 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -45,6 +45,8 @@ def __init__( results_queue=None, description=None, is_staging_operation: bool = False, + lz4_compressed: bool = False, + arrow_schema_bytes: Optional[bytes] = b"", ): """ A ResultSet manages the results of a single command. @@ -75,6 +77,8 @@ def __init__( self.is_direct_results = is_direct_results self.results = results_queue self._is_staging_operation = is_staging_operation + self.lz4_compressed = lz4_compressed + self._arrow_schema_bytes = arrow_schema_bytes def __iter__(self): while True: @@ -177,10 +181,10 @@ def __init__( :param ssl_options: SSL options for cloud fetch :param is_direct_results: Whether there are more rows to fetch """ + # Initialize ThriftResultSet-specific attributes - self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch - self.lz4_compressed = execute_response.lz4_compressed + self.is_direct_results = is_direct_results # Build the results queue if t_row_set is provided results_queue = None @@ -211,6 +215,8 @@ def __init__( results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, + lz4_compressed=execute_response.lz4_compressed, + arrow_schema_bytes=execute_response.arrow_schema_bytes, ) # Initialize results queue if not provided @@ -438,3 +444,82 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] + + +class SeaResultSet(ResultSet): + """ResultSet implementation for SEA backend.""" + + def __init__( + self, + connection: "Connection", + execute_response: "ExecuteResponse", + sea_client: "SeaDatabricksClient", + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + result_data=None, + manifest=None, + ): + """ + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + execute_response: Response from the execute command + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + result_data: Result data from SEA response (optional) + manifest: Manifest from SEA response (optional) + """ + + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + lz4_compressed=execute_response.lz4_compressed, + arrow_schema_bytes=execute_response.arrow_schema_bytes, + ) + + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError( + "_fill_results_buffer is not implemented for SEA backend" + ) + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") From 8c653458ec69f46e6b4efba4bc52862c80d737ee Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 11:34:00 +0000 Subject: [PATCH 66/66] b"" -> None Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a71ddf7cd..38b8a3c2f 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -46,7 +46,7 @@ def __init__( description=None, is_staging_operation: bool = False, lz4_compressed: bool = False, - arrow_schema_bytes: Optional[bytes] = b"", + arrow_schema_bytes: Optional[bytes] = None, ): """ A ResultSet manages the results of a single command. pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy