diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index edff10159..20b059fa7 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -15,10 +15,16 @@ from databricks.sql.client import Cursor from databricks.sql.thrift_api.TCLIService import ttypes -from databricks.sql.backend.types import SessionId, CommandId +from databricks.sql.backend.types import SessionId, CommandId, CommandState from databricks.sql.utils import ExecuteResponse from databricks.sql.types import SSLOptions +# Forward reference for type hints +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + class DatabricksClient(ABC): # == Connection and Session Management == @@ -81,7 +87,7 @@ def execute_command( parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Optional[ExecuteResponse]: + ) -> Union["ResultSet", None]: """ Executes a SQL command or query within the specified session. @@ -101,7 +107,7 @@ def execute_command( enforce_embedded_schema_correctness: Whether to enforce schema correctness Returns: - If async_op is False, returns an ExecuteResponse object containing the + If async_op is False, returns a ResultSet object containing the query results and metadata. If async_op is True, returns None and the results must be fetched later using get_execution_result(). @@ -130,7 +136,7 @@ def cancel_command(self, command_id: CommandId) -> None: pass @abstractmethod - def close_command(self, command_id: CommandId) -> ttypes.TStatus: + def close_command(self, command_id: CommandId) -> None: """ Closes a command and releases associated resources. @@ -140,9 +146,6 @@ def close_command(self, command_id: CommandId) -> ttypes.TStatus: Args: command_id: The command identifier to close - Returns: - ttypes.TStatus: The status of the close operation - Raises: ValueError: If the command ID is invalid OperationalError: If there's an error closing the command @@ -150,7 +153,7 @@ def close_command(self, command_id: CommandId) -> ttypes.TStatus: pass @abstractmethod - def get_query_state(self, command_id: CommandId) -> ttypes.TOperationState: + def get_query_state(self, command_id: CommandId) -> CommandState: """ Gets the current state of a query or command. @@ -160,7 +163,7 @@ def get_query_state(self, command_id: CommandId) -> ttypes.TOperationState: command_id: The command identifier to check Returns: - ttypes.TOperationState: The current state of the command + CommandState: The current state of the command Raises: ValueError: If the command ID is invalid @@ -175,7 +178,7 @@ def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves the results of a previously executed command. @@ -187,7 +190,7 @@ def get_execution_result( cursor: The cursor object that will handle the results Returns: - ExecuteResponse: An object containing the query results and metadata + ResultSet: An object containing the query results and metadata Raises: ValueError: If the command ID is invalid @@ -203,7 +206,7 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves a list of available catalogs. @@ -217,7 +220,7 @@ def get_catalogs( cursor: The cursor object that will handle the results Returns: - ExecuteResponse: An object containing the catalog metadata + ResultSet: An object containing the catalog metadata Raises: ValueError: If the session ID is invalid @@ -234,7 +237,7 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves a list of schemas, optionally filtered by catalog and schema name patterns. @@ -250,7 +253,7 @@ def get_schemas( schema_name: Optional schema name pattern to filter by Returns: - ExecuteResponse: An object containing the schema metadata + ResultSet: An object containing the schema metadata Raises: ValueError: If the session ID is invalid @@ -269,7 +272,7 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves a list of tables, optionally filtered by catalog, schema, table name, and table types. @@ -287,7 +290,7 @@ def get_tables( table_types: Optional list of table types to filter by (e.g., ['TABLE', 'VIEW']) Returns: - ExecuteResponse: An object containing the table metadata + ResultSet: An object containing the table metadata Raises: ValueError: If the session ID is invalid @@ -306,7 +309,7 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> ExecuteResponse: + ) -> "ResultSet": """ Retrieves a list of columns, optionally filtered by catalog, schema, table, and column name patterns. @@ -324,7 +327,7 @@ def get_columns( column_name: Optional column name pattern to filter by Returns: - ExecuteResponse: An object containing the column metadata + ResultSet: An object containing the column metadata Raises: ValueError: If the session ID is invalid diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index c09397c2f..de388f1d4 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -9,9 +9,11 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( + CommandState, SessionId, CommandId, BackendType, @@ -84,8 +86,8 @@ class ThriftDatabricksClient(DatabricksClient): - CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE - ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE + CLOSED_OP_STATE = CommandState.CLOSED + ERROR_OP_STATE = CommandState.FAILED _retry_delay_min: float _retry_delay_max: float @@ -349,6 +351,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -796,7 +799,7 @@ def _results_message_to_execute_response(self, resp, operation_state): return ExecuteResponse( arrow_queue=arrow_queue_opt, - status=operation_state, + status=CommandState.from_thrift_state(operation_state), has_been_closed_server_side=has_been_closed_server_side, has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, @@ -808,7 +811,9 @@ def _results_message_to_execute_response(self, resp, operation_state): def get_execution_result( self, command_id: CommandId, cursor: "Cursor" - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") @@ -857,9 +862,9 @@ def get_execution_result( ssl_options=self._ssl_options, ) - return ExecuteResponse( + execute_response = ExecuteResponse( arrow_queue=queue, - status=resp.status, + status=CommandState.from_thrift_state(resp.status), has_been_closed_server_side=False, has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, @@ -869,6 +874,15 @@ def get_execution_result( arrow_schema_bytes=schema_bytes, ) + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) + def _wait_until_command_done(self, op_handle, initial_operation_status_resp): if initial_operation_status_resp: self._check_command_not_in_error_or_closed_state( @@ -887,7 +901,7 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) return operation_state - def get_query_state(self, command_id: CommandId) -> "TOperationState": + def get_query_state(self, command_id: CommandId) -> CommandState: thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") @@ -895,7 +909,10 @@ def get_query_state(self, command_id: CommandId) -> "TOperationState": poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) - return operation_state + state = CommandState.from_thrift_state(operation_state) + if state is None: + raise ValueError(f"Unknown command state: {operation_state}") + return state @staticmethod def _check_direct_results_for_error(t_spark_direct_results): @@ -929,7 +946,9 @@ def execute_command( parameters=[], async_op=False, enforce_embedded_schema_correctness=False, - ) -> Optional[ExecuteResponse]: + ) -> Union["ResultSet", None]: + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -976,7 +995,16 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - return self._handle_execute_response(resp, cursor) + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=use_cloud_fetch, + ) def get_catalogs( self, @@ -984,7 +1012,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -996,7 +1026,17 @@ def get_catalogs( ), ) resp = self.make_request(self._client.GetCatalogs, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_schemas( self, @@ -1006,7 +1046,9 @@ def get_schemas( cursor: "Cursor", catalog_name=None, schema_name=None, - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1020,7 +1062,17 @@ def get_schemas( schemaName=schema_name, ) resp = self.make_request(self._client.GetSchemas, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_tables( self, @@ -1032,7 +1084,9 @@ def get_tables( schema_name=None, table_name=None, table_types=None, - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1048,7 +1102,17 @@ def get_tables( tableTypes=table_types, ) resp = self.make_request(self._client.GetTables, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_columns( self, @@ -1060,7 +1124,9 @@ def get_columns( schema_name=None, table_name=None, column_name=None, - ) -> ExecuteResponse: + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1076,7 +1142,17 @@ def get_columns( columnName=column_name, ) resp = self.make_request(self._client.GetColumns, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def _handle_execute_response(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -1154,12 +1230,11 @@ def cancel_command(self, command_id: CommandId) -> None: req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) - def close_command(self, command_id: CommandId): + def close_command(self, command_id: CommandId) -> None: thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") logger.debug("ThriftBackend.close_command(command_id=%s)", command_id) req = ttypes.TCloseOperationReq(operationHandle=thrift_handle) - resp = self.make_request(self._client.CloseOperation, req) - return resp.status + self.make_request(self._client.CloseOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 740be0199..9cd21b5e6 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,12 +1,86 @@ from enum import Enum -from typing import Dict, Optional, Any, Union +from typing import Dict, Optional, Any import logging from databricks.sql.backend.utils import guid_to_hex_id +from databricks.sql.thrift_api.TCLIService import ttypes logger = logging.getLogger(__name__) +class CommandState(Enum): + """ + Enum representing the execution state of a command in Databricks SQL. + + This enum maps Thrift operation states to normalized command states, + providing a consistent interface for tracking command execution status + across different backend implementations. + + Attributes: + PENDING: Command is queued or initialized but not yet running + RUNNING: Command is currently executing + SUCCEEDED: Command completed successfully + FAILED: Command failed due to error, timeout, or unknown state + CLOSED: Command has been closed + CANCELLED: Command was cancelled before completion + """ + + PENDING = "PENDING" + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + CLOSED = "CLOSED" + CANCELLED = "CANCELLED" + + @classmethod + def from_thrift_state( + cls, state: ttypes.TOperationState + ) -> Optional["CommandState"]: + """ + Convert a Thrift TOperationState to a normalized CommandState. + + Args: + state: A TOperationState from the Thrift API representing the current + state of an operation + + Returns: + CommandState: The corresponding normalized command state + + Raises: + ValueError: If the provided state is not a recognized TOperationState + + State Mappings: + - INITIALIZED_STATE, PENDING_STATE -> PENDING + - RUNNING_STATE -> RUNNING + - FINISHED_STATE -> SUCCEEDED + - ERROR_STATE, TIMEDOUT_STATE, UKNOWN_STATE -> FAILED + - CLOSED_STATE -> CLOSED + - CANCELED_STATE -> CANCELLED + """ + + if state in ( + ttypes.TOperationState.INITIALIZED_STATE, + ttypes.TOperationState.PENDING_STATE, + ): + return cls.PENDING + elif state == ttypes.TOperationState.RUNNING_STATE: + return cls.RUNNING + elif state == ttypes.TOperationState.FINISHED_STATE: + return cls.SUCCEEDED + elif state in ( + ttypes.TOperationState.ERROR_STATE, + ttypes.TOperationState.TIMEDOUT_STATE, + ttypes.TOperationState.UKNOWN_STATE, + ): + return cls.FAILED + elif state == ttypes.TOperationState.CLOSED_STATE: + return cls.CLOSED + elif state == ttypes.TOperationState.CANCELED_STATE: + return cls.CANCELLED + else: + return None + + class BackendType(Enum): """ Enum representing the type of backend @@ -40,6 +114,7 @@ def __init__( secret: The secret part of the identifier (only used for Thrift) properties: Additional information about the session """ + self.backend_type = backend_type self.guid = guid self.secret = secret @@ -55,6 +130,7 @@ def __str__(self) -> str: Returns: A string representation of the session ID """ + if self.backend_type == BackendType.SEA: return str(self.guid) elif self.backend_type == BackendType.THRIFT: @@ -79,6 +155,7 @@ def from_thrift_handle( Returns: A SessionId instance """ + if session_handle is None: return None @@ -105,6 +182,7 @@ def from_sea_session_id( Returns: A SessionId instance """ + return cls(BackendType.SEA, session_id, properties=properties) def to_thrift_handle(self): @@ -114,6 +192,7 @@ def to_thrift_handle(self): Returns: A TSessionHandle object or None if this is not a Thrift session ID """ + if self.backend_type != BackendType.THRIFT: return None @@ -132,6 +211,7 @@ def to_sea_session_id(self): Returns: The session ID string or None if this is not a SEA session ID """ + if self.backend_type != BackendType.SEA: return None @@ -141,6 +221,7 @@ def get_guid(self) -> Any: """ Get the ID of the session. """ + return self.guid def get_hex_guid(self) -> str: @@ -150,6 +231,7 @@ def get_hex_guid(self) -> str: Returns: A hexadecimal string representation """ + if isinstance(self.guid, bytes): return guid_to_hex_id(self.guid) else: @@ -163,6 +245,7 @@ def get_protocol_version(self): The server protocol version or None if it does not exist It is not expected to exist for SEA sessions. """ + return self.properties.get("serverProtocolVersion") @@ -194,6 +277,7 @@ def __init__( has_result_set: Whether the command has a result set modified_row_count: The number of rows modified by the command """ + self.backend_type = backend_type self.guid = guid self.secret = secret @@ -211,6 +295,7 @@ def __str__(self) -> str: Returns: A string representation of the command ID """ + if self.backend_type == BackendType.SEA: return str(self.guid) elif self.backend_type == BackendType.THRIFT: @@ -233,6 +318,7 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ + if operation_handle is None: return None @@ -259,6 +345,7 @@ def from_sea_statement_id(cls, statement_id: str): Returns: A CommandId instance """ + return cls(BackendType.SEA, statement_id) def to_thrift_handle(self): @@ -268,6 +355,7 @@ def to_thrift_handle(self): Returns: A TOperationHandle object or None if this is not a Thrift command ID """ + if self.backend_type != BackendType.THRIFT: return None @@ -288,6 +376,7 @@ def to_sea_statement_id(self): Returns: The statement ID string or None if this is not a SEA statement ID """ + if self.backend_type != BackendType.SEA: return None @@ -300,6 +389,7 @@ def to_hex_guid(self) -> str: Returns: A hexadecimal string representation """ + if isinstance(self.guid, bytes): return guid_to_hex_id(self.guid) else: diff --git a/src/databricks/sql/backend/utils/guid_utils.py b/src/databricks/sql/backend/utils/guid_utils.py index 28975171f..2c440afd2 100644 --- a/src/databricks/sql/backend/utils/guid_utils.py +++ b/src/databricks/sql/backend/utils/guid_utils.py @@ -14,6 +14,7 @@ def guid_to_hex_id(guid: bytes) -> str: If conversion to hexadecimal fails, a string representation of the original bytes is returned """ + try: this_uuid = uuid.UUID(bytes=guid) except Exception as e: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 1c384c735..9f7c060a7 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -42,14 +42,15 @@ ParameterApproach, ) - +from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.types import Row, SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence from databricks.sql.session import Session -from databricks.sql.backend.types import CommandId, BackendType +from databricks.sql.backend.types import CommandId, BackendType, CommandState, SessionId from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, TSparkParameter, TOperationState, ) @@ -320,9 +321,17 @@ def protocol_version(self): return self.session.protocol_version @staticmethod - def get_protocol_version(openSessionResp): + def get_protocol_version(openSessionResp: TOpenSessionResp): """Delegate to Session class static method""" - return Session.get_protocol_version(openSessionResp) + properties = ( + {"serverProtocolVersion": openSessionResp.serverProtocolVersion} + if openSessionResp.serverProtocolVersion + else {} + ) + session_id = SessionId.from_thrift_handle( + openSessionResp.sessionHandle, properties + ) + return Session.get_protocol_version(session_id) @property def open(self) -> bool: @@ -388,6 +397,7 @@ def __init__( Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately visible by other cursors or connections. """ + self.connection = connection self.rowcount = -1 # Return -1 as this is not supported self.buffer_size_bytes = result_buffer_size_bytes @@ -746,6 +756,7 @@ def execute( :returns self """ + logger.debug( "Cursor.execute(operation=%s, parameters=%s)", operation, parameters ) @@ -771,7 +782,7 @@ def execute( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.execute_command( + self.active_result_set = self.backend.execute_command( operation=prepared_operation, session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, @@ -783,18 +794,8 @@ def execute( async_op=False, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, ) - assert execute_response is not None # async_op = False above - - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) - if execute_response.is_staging_operation: + if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( staging_allowed_local_path=self.connection.staging_allowed_local_path ) @@ -815,6 +816,7 @@ def execute_async( :param parameters: :return: """ + param_approach = self._determine_parameter_approach(parameters) if param_approach == ParameterApproach.NONE: prepared_params = NO_NATIVE_PARAMS @@ -851,7 +853,7 @@ def execute_async( return self - def get_query_state(self) -> "TOperationState": + def get_query_state(self) -> CommandState: """ Get the state of the async executing query or basically poll the status of the query @@ -869,11 +871,7 @@ def is_query_pending(self): :return: """ operation_state = self.get_query_state() - - return not operation_state or operation_state in [ - ttypes.TOperationState.RUNNING_STATE, - ttypes.TOperationState.PENDING_STATE, - ] + return operation_state in [CommandState.PENDING, CommandState.RUNNING] def get_async_execution_result(self): """ @@ -889,19 +887,12 @@ def get_async_execution_result(self): time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL) operation_state = self.get_query_state() - if operation_state == ttypes.TOperationState.FINISHED_STATE: - execute_response = self.backend.get_execution_result( + if operation_state == CommandState.SUCCEEDED: + self.active_result_set = self.backend.get_execution_result( self.active_command_id, self ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - ) - if execute_response.is_staging_operation: + if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( staging_allowed_local_path=self.connection.staging_allowed_local_path ) @@ -935,20 +926,12 @@ def catalogs(self) -> "Cursor": """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.get_catalogs( + self.active_result_set = self.backend.get_catalogs( session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) return self def schemas( @@ -962,7 +945,7 @@ def schemas( """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.get_schemas( + self.active_result_set = self.backend.get_schemas( session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -970,14 +953,6 @@ def schemas( catalog_name=catalog_name, schema_name=schema_name, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) return self def tables( @@ -996,7 +971,7 @@ def tables( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.get_tables( + self.active_result_set = self.backend.get_tables( session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -1006,14 +981,6 @@ def tables( table_name=table_name, table_types=table_types, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) return self def columns( @@ -1032,7 +999,7 @@ def columns( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.backend.get_columns( + self.active_result_set = self.backend.get_columns( session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -1042,14 +1009,6 @@ def columns( table_name=table_name, column_name=column_name, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) return self def fetchall(self) -> List[Row]: @@ -1205,312 +1164,3 @@ def setinputsizes(self, sizes): def setoutputsize(self, size, column=None): """Does nothing by default""" pass - - -class ResultSet: - def __init__( - self, - connection: Connection, - execute_response: ExecuteResponse, - backend: DatabricksClient, - result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, - arraysize: int = 10000, - use_cloud_fetch: bool = True, - ): - """ - A ResultSet manages the results of a single command. - - :param connection: The parent connection that was used to execute this command - :param execute_response: A `ExecuteResponse` class returned by a command execution - :param backend: The DatabricksClient instance to use for fetching results - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch amount - :param arraysize: The max number of rows to fetch at a time (PEP-249) - :param use_cloud_fetch: Whether to use cloud fetch for retrieving results - """ - self.connection = connection - self.command_id = execute_response.command_id - self.op_state = execute_response.status - self.has_been_closed_server_side = execute_response.has_been_closed_server_side - self.has_more_rows = execute_response.has_more_rows - self.buffer_size_bytes = result_buffer_size_bytes - self.lz4_compressed = execute_response.lz4_compressed - self.arraysize = arraysize - self.backend = backend - self.description = execute_response.description - self._arrow_schema_bytes = execute_response.arrow_schema_bytes - self._next_row_index = 0 - self._use_cloud_fetch = use_cloud_fetch - - if execute_response.arrow_queue: - # In this case the server has taken the fast path and returned an initial batch of - # results - self.results = execute_response.arrow_queue - else: - # In this case, there are results waiting on the server so we fetch now for simplicity - self._fill_results_buffer() - - def __iter__(self): - while True: - row = self.fetchone() - if row: - yield row - else: - break - - def _fill_results_buffer(self): - if not isinstance(self.backend, ThriftDatabricksClient): - # currently, we are assuming only the Thrift backend exists - raise NotImplementedError( - "Fetching further result batches is currently only implemented for the Thrift backend." - ) - - # Now we know self.backend is ThriftDatabricksClient, so it has fetch_results - thrift_backend_instance = self.backend # type: ThriftDatabricksClient - results, has_more_rows = thrift_backend_instance.fetch_results( - command_id=self.command_id, - max_rows=self.arraysize, - max_bytes=self.buffer_size_bytes, - expected_row_start_offset=self._next_row_index, - lz4_compressed=self.lz4_compressed, - arrow_schema_bytes=self._arrow_schema_bytes, - description=self.description, - use_cloud_fetch=self._use_cloud_fetch, - ) - self.results = results - self.has_more_rows = has_more_rows - - def _convert_columnar_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - result = [] - for row_index in range(table.num_rows): - curr_row = [] - for col_index in range(table.num_columns): - curr_row.append(table.get_item(col_index, row_index)) - result.append(ResultRow(*curr_row)) - - return result - - def _convert_arrow_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - - if self.connection.disable_pandas is True: - return [ - ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) - ] - - # Need to use nullable types, as otherwise type can change when there are missing values. - # See https://arrow.apache.org/docs/python/pandas.html#nullable-types - # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { - pyarrow.int8(): pandas.Int8Dtype(), - pyarrow.int16(): pandas.Int16Dtype(), - pyarrow.int32(): pandas.Int32Dtype(), - pyarrow.int64(): pandas.Int64Dtype(), - pyarrow.uint8(): pandas.UInt8Dtype(), - pyarrow.uint16(): pandas.UInt16Dtype(), - pyarrow.uint32(): pandas.UInt32Dtype(), - pyarrow.uint64(): pandas.UInt64Dtype(), - pyarrow.bool_(): pandas.BooleanDtype(), - pyarrow.float32(): pandas.Float32Dtype(), - pyarrow.float64(): pandas.Float64Dtype(), - pyarrow.string(): pandas.StringDtype(), - } - - # Need to rename columns, as the to_pandas function cannot handle duplicate column names - table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) - - res = df.to_numpy(na_value=None, dtype="object") - return [ResultRow(*v) for v in res] - - @property - def rownumber(self): - return self._next_row_index - - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": - """ - Fetch the next set of rows of a query result, returning a PyArrow table. - - An empty sequence is returned when no more rows are available. - """ - if size < 0: - raise ValueError("size argument for fetchmany is %s but must be >= 0", size) - results = self.results.next_n_rows(size) - n_remaining_rows = size - results.num_rows - self._next_row_index += results.num_rows - - while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows - ): - self._fill_results_buffer() - partial_results = self.results.next_n_rows(n_remaining_rows) - results = pyarrow.concat_tables([results, partial_results]) - n_remaining_rows -= partial_results.num_rows - self._next_row_index += partial_results.num_rows - - return results - - def merge_columnar(self, result1, result2): - """ - Function to merge / combining the columnar results into a single result - :param result1: - :param result2: - :return: - """ - - if result1.column_names != result2.column_names: - raise ValueError("The columns in the results don't match") - - merged_result = [ - result1.column_table[i] + result2.column_table[i] - for i in range(result1.num_columns) - ] - return ColumnTable(merged_result, result1.column_names) - - def fetchmany_columnar(self, size: int): - """ - Fetch the next set of rows of a query result, returning a Columnar Table. - An empty sequence is returned when no more rows are available. - """ - if size < 0: - raise ValueError("size argument for fetchmany is %s but must be >= 0", size) - - results = self.results.next_n_rows(size) - n_remaining_rows = size - results.num_rows - self._next_row_index += results.num_rows - - while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows - ): - self._fill_results_buffer() - partial_results = self.results.next_n_rows(n_remaining_rows) - results = self.merge_columnar(results, partial_results) - n_remaining_rows -= partial_results.num_rows - self._next_row_index += partial_results.num_rows - - return results - - def fetchall_arrow(self) -> "pyarrow.Table": - """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" - results = self.results.remaining_rows() - self._next_row_index += results.num_rows - - while not self.has_been_closed_server_side and self.has_more_rows: - self._fill_results_buffer() - partial_results = self.results.remaining_rows() - if isinstance(results, ColumnTable) and isinstance( - partial_results, ColumnTable - ): - results = self.merge_columnar(results, partial_results) - else: - results = pyarrow.concat_tables([results, partial_results]) - self._next_row_index += partial_results.num_rows - - # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table - # Valid only for metadata commands result set - if isinstance(results, ColumnTable) and pyarrow: - data = { - name: col - for name, col in zip(results.column_names, results.column_table) - } - return pyarrow.Table.from_pydict(data) - return results - - def fetchall_columnar(self): - """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" - results = self.results.remaining_rows() - self._next_row_index += results.num_rows - - while not self.has_been_closed_server_side and self.has_more_rows: - self._fill_results_buffer() - partial_results = self.results.remaining_rows() - results = self.merge_columnar(results, partial_results) - self._next_row_index += partial_results.num_rows - - return results - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - if isinstance(self.results, ColumnQueue): - res = self._convert_columnar_table(self.fetchmany_columnar(1)) - else: - res = self._convert_arrow_table(self.fetchmany_arrow(1)) - - if len(res) > 0: - return res[0] - else: - return None - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - if isinstance(self.results, ColumnQueue): - return self._convert_columnar_table(self.fetchall_columnar()) - else: - return self._convert_arrow_table(self.fetchall_arrow()) - - def fetchmany(self, size: int) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - if isinstance(self.results, ColumnQueue): - return self._convert_columnar_table(self.fetchmany_columnar(size)) - else: - return self._convert_arrow_table(self.fetchmany_arrow(size)) - - def close(self) -> None: - """ - Close the cursor. - - If the connection has not been closed, and the cursor has not already - been closed on the server for some other reason, issue a request to the server to close it. - """ - # TODO: the state is still thrift specific, define some ENUM for status that each service has to map to - # when we generalise the ResultSet - try: - if ( - self.op_state != ttypes.TOperationState.CLOSED_STATE - and not self.has_been_closed_server_side - and self.connection.open - ): - self.backend.close_command(self.command_id) - except RequestError as e: - if isinstance(e.args[1], CursorAlreadyClosedError): - logger.info("Operation was canceled by a prior request") - finally: - self.has_been_closed_server_side = True - self.op_state = ttypes.TOperationState.CLOSED_STATE - - @staticmethod - def _get_schema_description(table_schema_message): - """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 - """ - - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ - - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py new file mode 100644 index 000000000..a0d8d3579 --- /dev/null +++ b/src/databricks/sql/result_set.py @@ -0,0 +1,412 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Any, Union, TYPE_CHECKING + +import logging +import time +import pandas + +from databricks.sql.backend.types import CommandId, CommandState + +try: + import pyarrow +except ImportError: + pyarrow = None + +if TYPE_CHECKING: + from databricks.sql.backend.databricks_client import DatabricksClient + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + from databricks.sql.client import Connection + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import Row +from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError +from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue + +logger = logging.getLogger(__name__) + + +class ResultSet(ABC): + """ + Abstract base class for result sets returned by different backend implementations. + + This class defines the interface that all concrete result set implementations must follow. + """ + + def __init__( + self, + connection: "Connection", + backend: "DatabricksClient", + command_id: CommandId, + op_state: Optional[CommandState], + has_been_closed_server_side: bool, + arraysize: int, + buffer_size_bytes: int, + ): + """ + A ResultSet manages the results of a single command. + + :param connection: The parent connection that was used to execute this command + :param backend: The specialised backend client to be invoked in the fetch phase + :param execute_response: A `ExecuteResponse` class returned by a command execution + :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + amount :param arraysize: The max number of rows to fetch at a time (PEP-249) + """ + self.command_id = command_id + self.op_state = op_state + self.has_been_closed_server_side = has_been_closed_server_side + self.connection = connection + self.backend = backend + self.arraysize = arraysize + self.buffer_size_bytes = buffer_size_bytes + self._next_row_index = 0 + self.description = None + + def __iter__(self): + while True: + row = self.fetchone() + if row: + yield row + else: + break + + @property + def rownumber(self): + return self._next_row_index + + @property + @abstractmethod + def is_staging_operation(self) -> bool: + """Whether this result set represents a staging operation.""" + pass + + # Define abstract methods that concrete implementations must implement + @abstractmethod + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + pass + + @abstractmethod + def fetchone(self) -> Optional[Row]: + """Fetch the next row of a query result set.""" + pass + + @abstractmethod + def fetchmany(self, size: int) -> List[Row]: + """Fetch the next set of rows of a query result.""" + pass + + @abstractmethod + def fetchall(self) -> List[Row]: + """Fetch all remaining rows of a query result.""" + pass + + @abstractmethod + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """Fetch the next set of rows as an Arrow table.""" + pass + + @abstractmethod + def fetchall_arrow(self) -> "pyarrow.Table": + """Fetch all remaining rows as an Arrow table.""" + pass + + def close(self) -> None: + """ + Close the result set. + + If the connection has not been closed, and the result set has not already + been closed on the server for some other reason, issue a request to the server to close it. + """ + try: + if ( + self.op_state != CommandState.CLOSED + and not self.has_been_closed_server_side + and self.connection.open + ): + self.backend.close_command(self.command_id) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + finally: + self.has_been_closed_server_side = True + self.op_state = CommandState.CLOSED + + +class ThriftResultSet(ResultSet): + """ResultSet implementation for the Thrift backend.""" + + def __init__( + self, + connection: "Connection", + execute_response: ExecuteResponse, + thrift_client: "ThriftDatabricksClient", + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + use_cloud_fetch: bool = True, + ): + """ + Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. + + Args: + connection: The parent connection + execute_response: Response from the execute command + thrift_client: The ThriftDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + use_cloud_fetch: Whether to use cloud fetch for retrieving results + """ + super().__init__( + connection, + thrift_client, + execute_response.command_id, + execute_response.status, + execute_response.has_been_closed_server_side, + arraysize, + buffer_size_bytes, + ) + + # Initialize ThriftResultSet-specific attributes + self.has_been_closed_server_side = execute_response.has_been_closed_server_side + self.has_more_rows = execute_response.has_more_rows + self.lz4_compressed = execute_response.lz4_compressed + self.description = execute_response.description + self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._use_cloud_fetch = use_cloud_fetch + self._is_staging_operation = execute_response.is_staging_operation + + # Initialize results queue + if execute_response.arrow_queue: + # In this case the server has taken the fast path and returned an initial batch of + # results + self.results = execute_response.arrow_queue + else: + # In this case, there are results waiting on the server so we fetch now for simplicity + self._fill_results_buffer() + + def _fill_results_buffer(self): + # At initialization or if the server does not have cloud fetch result links available + results, has_more_rows = self.backend.fetch_results( + command_id=self.command_id, + max_rows=self.arraysize, + max_bytes=self.buffer_size_bytes, + expected_row_start_offset=self._next_row_index, + lz4_compressed=self.lz4_compressed, + arrow_schema_bytes=self._arrow_schema_bytes, + description=self.description, + use_cloud_fetch=self._use_cloud_fetch, + ) + self.results = results + self.has_more_rows = has_more_rows + + def _convert_columnar_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + result = [] + for row_index in range(table.num_rows): + curr_row = [] + for col_index in range(table.num_columns): + curr_row.append(table.get_item(col_index, row_index)) + result.append(ResultRow(*curr_row)) + + return result + + def _convert_arrow_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + + if self.connection.disable_pandas is True: + return [ + ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) + ] + + # Need to use nullable types, as otherwise type can change when there are missing values. + # See https://arrow.apache.org/docs/python/pandas.html#nullable-types + # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html + dtype_mapping = { + pyarrow.int8(): pandas.Int8Dtype(), + pyarrow.int16(): pandas.Int16Dtype(), + pyarrow.int32(): pandas.Int32Dtype(), + pyarrow.int64(): pandas.Int64Dtype(), + pyarrow.uint8(): pandas.UInt8Dtype(), + pyarrow.uint16(): pandas.UInt16Dtype(), + pyarrow.uint32(): pandas.UInt32Dtype(), + pyarrow.uint64(): pandas.UInt64Dtype(), + pyarrow.bool_(): pandas.BooleanDtype(), + pyarrow.float32(): pandas.Float32Dtype(), + pyarrow.float64(): pandas.Float64Dtype(), + pyarrow.string(): pandas.StringDtype(), + } + + # Need to rename columns, as the to_pandas function cannot handle duplicate column names + table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) + df = table_renamed.to_pandas( + types_mapper=dtype_mapping.get, + date_as_object=True, + timestamp_as_object=True, + ) + + res = df.to_numpy(na_value=None, dtype="object") + return [ResultRow(*v) for v in res] + + def merge_columnar(self, result1, result2) -> "ColumnTable": + """ + Function to merge / combining the columnar results into a single result + :param result1: + :param result2: + :return: + """ + + if result1.column_names != result2.column_names: + raise ValueError("The columns in the results don't match") + + merged_result = [ + result1.column_table[i] + result2.column_table[i] + for i in range(result1.num_columns) + ] + return ColumnTable(merged_result, result1.column_names) + + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """ + Fetch the next set of rows of a query result, returning a PyArrow table. + + An empty sequence is returned when no more rows are available. + """ + if size < 0: + raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + results = self.results.next_n_rows(size) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows + + while ( + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows + ): + self._fill_results_buffer() + partial_results = self.results.next_n_rows(n_remaining_rows) + results = pyarrow.concat_tables([results, partial_results]) + n_remaining_rows -= partial_results.num_rows + self._next_row_index += partial_results.num_rows + + return results + + def fetchmany_columnar(self, size: int): + """ + Fetch the next set of rows of a query result, returning a Columnar Table. + An empty sequence is returned when no more rows are available. + """ + if size < 0: + raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + results = self.results.next_n_rows(size) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows + + while ( + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows + ): + self._fill_results_buffer() + partial_results = self.results.next_n_rows(n_remaining_rows) + results = self.merge_columnar(results, partial_results) + n_remaining_rows -= partial_results.num_rows + self._next_row_index += partial_results.num_rows + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": + """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + while not self.has_been_closed_server_side and self.has_more_rows: + self._fill_results_buffer() + partial_results = self.results.remaining_rows() + if isinstance(results, ColumnTable) and isinstance( + partial_results, ColumnTable + ): + results = self.merge_columnar(results, partial_results) + else: + results = pyarrow.concat_tables([results, partial_results]) + self._next_row_index += partial_results.num_rows + + # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table + # Valid only for metadata commands result set + if isinstance(results, ColumnTable) and pyarrow: + data = { + name: col + for name, col in zip(results.column_names, results.column_table) + } + return pyarrow.Table.from_pydict(data) + return results + + def fetchall_columnar(self): + """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + while not self.has_been_closed_server_side and self.has_more_rows: + self._fill_results_buffer() + partial_results = self.results.remaining_rows() + results = self.merge_columnar(results, partial_results) + self._next_row_index += partial_results.num_rows + + return results + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + if isinstance(self.results, ColumnQueue): + res = self._convert_columnar_table(self.fetchmany_columnar(1)) + else: + res = self._convert_arrow_table(self.fetchmany_arrow(1)) + + if len(res) > 0: + return res[0] + else: + return None + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchall_columnar()) + else: + return self._convert_arrow_table(self.fetchall_arrow()) + + def fetchmany(self, size: int) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchmany_columnar(size)) + else: + return self._convert_arrow_table(self.fetchmany_arrow(size)) + + @property + def is_staging_operation(self) -> bool: + """Whether this result set represents a staging operation.""" + return self._is_staging_operation + + @staticmethod + def _get_schema_description(table_schema_message): + """ + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + """ + + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ + + return [ + (column.name, map_col_type(column.datatype), None, None, None, None, None) + for column in table_schema_message.columns + ] diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 2ee5e53f1..6d69b5487 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -31,6 +31,7 @@ def __init__( This class handles all session-related behavior and communication with the backend. """ + self.is_open = False self.host = server_hostname self.port = kwargs.get("_port", 443) diff --git a/src/databricks/sql/types.py b/src/databricks/sql/types.py index fef22cd9f..4d9f8be5f 100644 --- a/src/databricks/sql/types.py +++ b/src/databricks/sql/types.py @@ -158,6 +158,7 @@ def asDict(self, recursive: bool = False) -> Dict[str, Any]: >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} True """ + if not hasattr(self, "__fields__"): raise TypeError("Cannot convert a Row class into dict") @@ -186,6 +187,7 @@ def __contains__(self, item: Any) -> bool: # let object acts like class def __call__(self, *args: Any) -> "Row": """create new Row object""" + if len(args) > len(self): raise ValueError( "Can not create Row with fields %s, expected %d values " @@ -228,6 +230,7 @@ def __reduce__( self, ) -> Union[str, Tuple[Any, ...]]: """Returns a tuple so Python knows how to pickle Row.""" + if hasattr(self, "__fields__"): return (_create_row, (self.__fields__, tuple(self))) else: @@ -235,6 +238,7 @@ def __reduce__( def __repr__(self) -> str: """Printable representation of Row used in Python REPL.""" + if hasattr(self, "__fields__"): return "Row(%s)" % ", ".join( "%s=%r" % (k, v) for k, v in zip(self.__fields__, tuple(self)) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index c541ad3fd..2622b1172 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -74,6 +74,7 @@ def build_queue( Returns: ResultSetQueue """ + if row_set_type == TSparkRowSetType.ARROW_BASED_SET: arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes @@ -173,12 +174,14 @@ def __init__( :param n_valid_rows: The index of the last valid row in the table :param start_row_index: The first row in the table we should start fetching from """ + self.cur_row_index = start_row_index self.arrow_table = arrow_table self.n_valid_rows = n_valid_rows def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """Get upto the next n rows of the Arrow dataframe""" + length = min(num_rows, self.n_valid_rows - self.cur_row_index) # Note that the table.slice API is not the same as Python's slice # The second argument should be length, not end index @@ -216,6 +219,7 @@ def __init__( lz4_compressed (bool): Whether the files are lz4 compressed. description (List[List[Any]]): Hive table schema description. """ + self.schema_bytes = schema_bytes self.max_download_threads = max_download_threads self.start_row_index = start_row_offset @@ -256,6 +260,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": Returns: pyarrow.Table """ + if not self.table: logger.debug("CloudFetchQueue: no more rows available") # Return empty pyarrow table to cause retry of fetch @@ -285,6 +290,7 @@ def remaining_rows(self) -> "pyarrow.Table": Returns: pyarrow.Table """ + if not self.table: # Return empty pyarrow table to cause retry of fetch return self._create_empty_table() @@ -566,6 +572,7 @@ def transform_paramstyle( Returns: str """ + output = operation if ( param_structure == ParameterStructure.POSITIONAL diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index c446b6715..22897644f 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -30,6 +30,7 @@ OperationalError, RequestError, ) +from databricks.sql.backend.types import CommandState from tests.e2e.common.predicates import ( pysql_has_version, pysql_supports_arrow, @@ -826,10 +827,7 @@ def test_close_connection_closes_cursors(self): getProgressUpdate=False, ) op_status_at_server = ars.backend._client.GetOperationStatus(status_request) - assert ( - op_status_at_server.operationState - != ttypes.TOperationState.CLOSED_STATE - ) + assert op_status_at_server.operationState != CommandState.CLOSED conn.close() @@ -939,7 +937,7 @@ def test_result_set_close(self): result_set.close() - assert result_set.op_state == result_set.backend.CLOSED_OP_STATE + assert result_set.op_state == CommandState.CLOSED assert result_set.op_state != initial_op_state # Closing the result set again should be a no-op and not raise exceptions diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index fa6fae1d9..1a7950870 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -14,7 +14,9 @@ TOperationHandle, THandleIdentifier, TOperationType, + TOperationState, ) +from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.thrift_backend import ThriftDatabricksClient import databricks.sql @@ -22,7 +24,9 @@ from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError from databricks.sql.exc import RequestError, CursorAlreadyClosedError from databricks.sql.types import Row -from databricks.sql.client import CommandId +from databricks.sql.result_set import ResultSet, ThriftResultSet +from databricks.sql.backend.types import CommandId, CommandState +from databricks.sql.utils import ExecuteResponse from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite @@ -36,12 +40,11 @@ def new(cls): ThriftBackendMock.return_value = ThriftBackendMock cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None) - MockTExecuteStatementResp = MagicMock(spec=TExecuteStatementResp()) + mock_result_set = Mock(spec=ThriftResultSet) cls.apply_property_to_mock( - MockTExecuteStatementResp, + mock_result_set, description=None, - arrow_queue=None, is_staging_operation=False, command_id=None, has_been_closed_server_side=True, @@ -50,7 +53,7 @@ def new(cls): arrow_schema_bytes=b"schema", ) - ThriftBackendMock.execute_command.return_value = MockTExecuteStatementResp + ThriftBackendMock.execute_command.return_value = mock_result_set return ThriftBackendMock @@ -82,25 +85,79 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch( - "%s.session.ThriftDatabricksClient" % PACKAGE_NAME, - ThriftDatabricksClientMockFactory.new(), - ) - @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_closing_connection_closes_commands(self, mock_result_set_class): + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_closing_connection_closes_commands(self, mock_thrift_client_class): + """Test that connection.close() properly closes result sets through the real close chain.""" # Test once with has_been_closed_server side, once without for closed in (True, False): with self.subTest(closed=closed): - mock_result_set_class.return_value = Mock() + # Mock the execute response with controlled state + mock_execute_response = Mock(spec=ExecuteResponse) + + mock_execute_response.command_id = Mock(spec=CommandId) + mock_execute_response.status = ( + CommandState.SUCCEEDED if not closed else CommandState.CLOSED + ) + mock_execute_response.has_been_closed_server_side = closed + mock_execute_response.is_staging_operation = False + + # Mock the backend that will be used by the real ThriftResultSet + mock_backend = Mock(spec=ThriftDatabricksClient) + mock_backend.staging_allowed_local_path = None + + # Configure the decorator's mock to return our specific mock_backend + mock_thrift_client_class.return_value = mock_backend + + # Create connection and cursor connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() - cursor.execute("SELECT 1;") - connection.close() - self.assertTrue( - mock_result_set_class.return_value.has_been_closed_server_side + # Create a REAL ThriftResultSet that will be returned by execute_command + real_result_set = ThriftResultSet( + connection=connection, + execute_response=mock_execute_response, + thrift_client=mock_backend, + ) + + # Verify initial state + self.assertEqual(real_result_set.has_been_closed_server_side, closed) + expected_op_state = ( + CommandState.CLOSED if closed else CommandState.SUCCEEDED + ) + self.assertEqual(real_result_set.op_state, expected_op_state) + + # Mock execute_command to return our real result set + cursor.backend.execute_command = Mock(return_value=real_result_set) + + # Execute a command - this should set cursor.active_result_set to our real result set + cursor.execute("SELECT 1") + + # Verify that cursor.execute() set up the result set correctly + self.assertIsInstance(cursor.active_result_set, ThriftResultSet) + self.assertEqual( + cursor.active_result_set.has_been_closed_server_side, closed ) - mock_result_set_class.return_value.close.assert_called_once_with() + + # Close the connection - this should trigger the real close chain: + # connection.close() -> cursor.close() -> result_set.close() + connection.close() + + # Verify the REAL close logic worked through the chain: + # 1. has_been_closed_server_side should always be True after close() + self.assertTrue(real_result_set.has_been_closed_server_side) + + # 2. op_state should always be CLOSED after close() + self.assertEqual(real_result_set.op_state, CommandState.CLOSED) + + # 3. Backend close_command should be called appropriately + if not closed: + # Should have called backend.close_command during the close chain + mock_backend.close_command.assert_called_once_with( + mock_execute_response.command_id + ) + else: + # Should NOT have called backend.close_command (already closed) + mock_backend.close_command.assert_not_called() @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_cant_open_cursor_on_closed_connection(self, mock_client_class): @@ -127,10 +184,11 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() - result_set = client.ResultSet( + + result_set = ThriftResultSet( connection=mock_connection, - backend=mock_backend, execute_response=Mock(), + thrift_client=mock_backend, ) # Setup session mock on the mock_connection mock_session = Mock() @@ -152,7 +210,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - result_set = client.ResultSet( + result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -162,17 +220,16 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response.command_id ) - @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_executing_multiple_commands_uses_the_most_recent_command( - self, mock_result_set_class - ): - + def test_executing_multiple_commands_uses_the_most_recent_command(self): mock_result_sets = [Mock(), Mock()] - mock_result_set_class.side_effect = mock_result_sets + # Set is_staging_operation to False to avoid _handle_staging_operation being called + for mock_rs in mock_result_sets: + mock_rs.is_staging_operation = False - cursor = client.Cursor( - connection=Mock(), backend=ThriftDatabricksClientMockFactory.new() - ) + mock_backend = ThriftDatabricksClientMockFactory.new() + mock_backend.execute_command.side_effect = mock_result_sets + + cursor = client.Cursor(connection=Mock(), backend=mock_backend) cursor.execute("SELECT 1;") cursor.execute("SELECT 1;") @@ -197,7 +254,7 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - result_set = client.ResultSet(Mock(), Mock(), Mock()) + result_set = ThriftResultSet(Mock(), Mock(), Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -349,14 +406,15 @@ def test_execute_parameter_passthrough(self): expected_query, ) - @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_executemany_parameter_passhthrough_and_uses_last_result_set( - self, mock_result_set_class - ): + def test_executemany_parameter_passhthrough_and_uses_last_result_set(self): # Create a new mock result set each time the class is instantiated mock_result_set_instances = [Mock(), Mock(), Mock()] - mock_result_set_class.side_effect = mock_result_set_instances + # Set is_staging_operation to False to avoid _handle_staging_operation being called + for mock_rs in mock_result_set_instances: + mock_rs.is_staging_operation = False + mock_backend = ThriftDatabricksClientMockFactory.new() + mock_backend.execute_command.side_effect = mock_result_set_instances cursor = client.Cursor(Mock(), mock_backend) @@ -509,8 +567,9 @@ def test_staging_operation_response_is_handled( ThriftDatabricksClientMockFactory.apply_property_to_mock( mock_execute_response, is_staging_operation=True ) - mock_client_class.execute_command.return_value = mock_execute_response - mock_client_class.return_value = mock_client_class + mock_client = mock_client_class.return_value + mock_client.execute_command.return_value = Mock(is_staging_operation=True) + mock_client_class.return_value = mock_client connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() @@ -617,9 +676,9 @@ def mock_close_normal(): def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" - result_set = client.ResultSet.__new__(client.ResultSet) - result_set.thrift_backend = Mock() - result_set.thrift_backend.CLOSED_OP_STATE = "CLOSED" + result_set = client.ThriftResultSet.__new__(client.ThriftResultSet) + result_set.backend = Mock() + result_set.backend.CLOSED_OP_STATE = "CLOSED" result_set.connection = Mock() result_set.connection.open = True result_set.op_state = "RUNNING" @@ -630,31 +689,31 @@ class MockRequestError(Exception): def __init__(self): self.args = ["Error message", CursorAlreadyClosedError()] - result_set.thrift_backend.close_command.side_effect = MockRequestError() + result_set.backend.close_command.side_effect = MockRequestError() original_close = client.ResultSet.close try: try: if ( - result_set.op_state != result_set.thrift_backend.CLOSED_OP_STATE + result_set.op_state != result_set.backend.CLOSED_OP_STATE and not result_set.has_been_closed_server_side and result_set.connection.open ): - result_set.thrift_backend.close_command(result_set.command_id) + result_set.backend.close_command(result_set.command_id) except MockRequestError as e: if isinstance(e.args[1], CursorAlreadyClosedError): pass finally: result_set.has_been_closed_server_side = True - result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE + result_set.op_state = result_set.backend.CLOSED_OP_STATE - result_set.thrift_backend.close_command.assert_called_once_with( + result_set.backend.close_command.assert_called_once_with( result_set.command_id ) assert result_set.has_been_closed_server_side is True - assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE + assert result_set.op_state == result_set.backend.CLOSED_OP_STATE finally: pass diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 1c6a1b18d..030510a64 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -10,6 +10,7 @@ import databricks.sql.client as client from databricks.sql.utils import ExecuteResponse, ArrowQueue from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.result_set import ThriftResultSet @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") @@ -38,9 +39,8 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) - rs = client.ResultSet( + rs = ThriftResultSet( connection=Mock(), - backend=None, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, @@ -52,6 +52,7 @@ def make_dummy_result_set_from_initial_results(initial_results): arrow_schema_bytes=schema.serialize().to_pybytes(), is_staging_operation=False, ), + thrift_client=None, ) num_cols = len(initial_results[0]) if initial_results else 0 rs.description = [ @@ -84,9 +85,8 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 - rs = client.ResultSet( + rs = ThriftResultSet( connection=Mock(), - backend=mock_thrift_backend, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=False, @@ -101,6 +101,7 @@ def fetch_results( arrow_schema_bytes=None, is_staging_operation=False, ), + thrift_client=mock_thrift_backend, ) return rs diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index 949230d1e..37e6cf1c9 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -64,13 +64,7 @@ class TestSessionHandleChecks(object): ], ) def test_get_protocol_version_fallback_behavior(self, test_input, expected): - properties = ( - {"serverProtocolVersion": test_input.serverProtocolVersion} - if test_input.serverProtocolVersion - else {} - ) - session_id = SessionId.from_thrift_handle(test_input.sessionHandle, properties) - assert Connection.get_protocol_version(session_id) == expected + assert Connection.get_protocol_version(test_input) == expected @pytest.mark.parametrize( "test_input,expected", diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 41a2a5800..57a2a61e3 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -18,7 +18,8 @@ from databricks.sql import * from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.backend.thrift_backend import ThriftDatabricksClient -from databricks.sql.backend.types import CommandId, SessionId, BackendType +from databricks.sql.result_set import ResultSet, ThriftResultSet +from databricks.sql.backend.types import CommandId, CommandState, SessionId, BackendType def retry_policy_factory(): @@ -882,7 +883,7 @@ def test_handle_execute_response_can_handle_without_direct_results( ) self.assertEqual( results_message_response.status, - ttypes.TOperationState.FINISHED_STATE, + CommandState.SUCCEEDED, ) def test_handle_execute_response_can_handle_with_direct_results(self): @@ -1152,7 +1153,12 @@ def test_execute_statement_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.execute_command("foo", Mock(), 100, 200, Mock(), cursor_mock) + result = thrift_backend.execute_command( + "foo", Mock(), 100, 200, Mock(), cursor_mock + ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1181,7 +1187,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) + result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1209,7 +1218,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_schemas( + result = thrift_backend.get_schemas( Mock(), 100, 200, @@ -1217,6 +1226,9 @@ def test_get_schemas_calls_client_and_handle_execute_response( catalog_name="catalog_pattern", schema_name="schema_pattern", ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1246,7 +1258,7 @@ def test_get_tables_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_tables( + result = thrift_backend.get_tables( Mock(), 100, 200, @@ -1256,6 +1268,9 @@ def test_get_tables_calls_client_and_handle_execute_response( table_name="table_pattern", table_types=["type1", "type2"], ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1287,7 +1302,7 @@ def test_get_columns_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_columns( + result = thrift_backend.get_columns( Mock(), 100, 200, @@ -1297,6 +1312,9 @@ def test_get_columns_calls_client_and_handle_execute_response( table_name="table_pattern", column_name="column_pattern", ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: