diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index b6db61a3c..462d22369 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -1,7 +1,15 @@ name: Code Quality Checks - -on: [pull_request] - +on: + push: + branches: + - main + - sea-migration + - telemetry + pull_request: + branches: + - main + - sea-migration + - telemetry jobs: run-unit-tests: runs-on: ubuntu-latest diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 127c8ff4f..ccd3a580d 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -1,10 +1,14 @@ name: Integration Tests - on: - push: + push: + paths-ignore: + - "**.MD" + - "**.md" + pull_request: branches: - main - pull_request: + - sea-migration + - telemetry jobs: run-e2e-tests: diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py new file mode 100644 index 000000000..dcfcd475f --- /dev/null +++ b/examples/experimental/sea_connector_test.py @@ -0,0 +1,65 @@ +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + sys.exit(1) + + logger.info(f"Connecting to {server_hostname}") + logger.info(f"HTTP Path: {http_path}") + if catalog: + logger.info(f"Using catalog: {catalog}") + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent + ) + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") + logger.info(f"backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + sys.exit(1) + + logger.info("SEA session test completed successfully") + +if __name__ == "__main__": + test_sea_session() \ No newline at end of file diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py new file mode 100644 index 000000000..29b03bd1a --- /dev/null +++ b/src/databricks/sql/backend/databricks_client.py @@ -0,0 +1,212 @@ +""" +Abstract client interface for interacting with Databricks SQL services. + +Implementations of this class are responsible for: +- Managing connections to Databricks SQL services +- Handling authentication +- Executing SQL queries and commands +- Retrieving query results +- Fetching metadata about catalogs, schemas, tables, and columns +- Managing error handling and retries +""" + +from abc import ABC, abstractmethod +from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from databricks.sql.client import Cursor + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.backend.types import SessionId, CommandId, CommandState +from databricks.sql.utils import ExecuteResponse +from databricks.sql.types import SSLOptions + +# Forward reference for type hints +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + + +class DatabricksClient(ABC): + # == Connection and Session Management == + @abstractmethod + def open_session( + self, + session_configuration: Optional[Dict[str, Any]], + catalog: Optional[str], + schema: Optional[str], + ) -> SessionId: + """ + Opens a new session with the Databricks SQL service. + + This method establishes a new session with the server and returns a session + identifier that can be used for subsequent operations. + + Args: + session_configuration: Optional dictionary of configuration parameters for the session + catalog: Optional catalog name to use as the initial catalog for the session + schema: Optional schema name to use as the initial schema for the session + + Returns: + SessionId: A session identifier object that can be used for subsequent operations + + Raises: + Error: If the session configuration is invalid + OperationalError: If there's an error establishing the session + InvalidServerResponseError: If the server response is invalid or unexpected + """ + pass + + @abstractmethod + def close_session(self, session_id: SessionId) -> None: + """ + Closes an existing session with the Databricks SQL service. + + This method terminates the session identified by the given session ID and + releases any resources associated with it. + + Args: + session_id: The session identifier returned by open_session() + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error closing the session + """ + pass + + # == Query Execution, Command Management == + @abstractmethod + def execute_command( + self, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: "Cursor", + use_cloud_fetch: bool, + parameters: List[ttypes.TSparkParameter], + async_op: bool, + enforce_embedded_schema_correctness: bool, + ) -> Union["ResultSet", None]: + pass + + @abstractmethod + def cancel_command(self, command_id: CommandId) -> None: + """ + Cancels a running command or query. + + This method attempts to cancel a command that is currently being executed. + It can be called from a different thread than the one executing the command. + + Args: + command_id: The command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + OperationalError: If there's an error canceling the command + """ + pass + + @abstractmethod + def close_command(self, command_id: CommandId) -> None: + pass + + @abstractmethod + def get_query_state(self, command_id: CommandId) -> CommandState: + pass + + @abstractmethod + def get_execution_result( + self, + command_id: CommandId, + cursor: "Cursor", + ) -> "ResultSet": + pass + + # == Metadata Operations == + @abstractmethod + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + ) -> "ResultSet": + pass + + @abstractmethod + def get_schemas( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + ) -> "ResultSet": + pass + + @abstractmethod + def get_tables( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + table_types: Optional[List[str]] = None, + ) -> "ResultSet": + pass + + @abstractmethod + def get_columns( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + column_name: Optional[str] = None, + ) -> "ResultSet": + pass + + # == Properties == + @property + @abstractmethod + def staging_allowed_local_path(self) -> Union[None, str, List[str]]: + """ + Gets the allowed local paths for staging operations. + + Returns: + Union[None, str, List[str]]: The allowed local paths for staging operations, + or None if staging is not allowed + """ + pass + + @property + @abstractmethod + def ssl_options(self) -> SSLOptions: + """ + Gets the SSL options for this client. + + Returns: + SSLOptions: The SSL configuration options + """ + pass + + @property + @abstractmethod + def max_download_threads(self) -> int: + """ + Gets the maximum number of download threads for cloud fetch operations. + + Returns: + int: The maximum number of download threads + """ + pass diff --git a/src/databricks/sql/backend/sea_backend.py b/src/databricks/sql/backend/sea_backend.py new file mode 100644 index 000000000..e5dc721ac --- /dev/null +++ b/src/databricks/sql/backend/sea_backend.py @@ -0,0 +1,301 @@ +import logging +import uuid +from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from databricks.sql.client import Cursor + +from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.backend.utils.http_client import CustomHttpClient +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import SSLOptions + +logger = logging.getLogger(__name__) + + +class SeaDatabricksClient(DatabricksClient): + """ + Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. + """ + + # SEA API paths + BASE_PATH = "/api/2.0/sql/" + SESSION_PATH = BASE_PATH + "sessions" + SESSION_PATH_WITH_ID = SESSION_PATH + "/{}" + STATEMENT_PATH = BASE_PATH + "statements" + STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" + CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" + + def __init__( + self, + server_hostname: str, + port: int, + http_path: str, + http_headers: List[Tuple[str, str]], + auth_provider, + ssl_options: SSLOptions, + staging_allowed_local_path: Union[None, str, List[str]] = None, + **kwargs, + ): + """ + Initialize the SEA backend client. + + Args: + server_hostname: Hostname of the Databricks server + port: Port number for the connection + http_path: HTTP path for the connection + http_headers: List of HTTP headers to include in requests + auth_provider: Authentication provider + ssl_options: SSL configuration options + staging_allowed_local_path: Allowed local paths for staging operations + **kwargs: Additional keyword arguments + """ + logger.debug( + "SEADatabricksClient.__init__(server_hostname=%s, port=%s, http_path=%s)", + server_hostname, + port, + http_path, + ) + + self._staging_allowed_local_path = staging_allowed_local_path + self._ssl_options = ssl_options + self._max_download_threads = kwargs.get("max_download_threads", 10) + + # Extract warehouse ID from http_path + self.warehouse_id = self._extract_warehouse_id(http_path) + + # Initialize HTTP client + self.http_client = CustomHttpClient( + server_hostname=server_hostname, + port=port, + http_path=http_path, + http_headers=http_headers, + auth_provider=auth_provider, + ssl_options=ssl_options, + **kwargs, + ) + + def _extract_warehouse_id(self, http_path: str) -> str: + """ + Extract the warehouse ID from the HTTP path. + + The warehouse ID is expected to be the last segment of the path when the + second-to-last segment is either 'warehouses' or 'endpoints'. + This matches the JDBC implementation which supports both formats. + + Args: + http_path: The HTTP path from which to extract the warehouse ID + + Returns: + The extracted warehouse ID + + Raises: + Error: If the warehouse ID cannot be extracted from the path + """ + path_parts = http_path.strip("/").split("/") + warehouse_id = None + + if len(path_parts) >= 3 and path_parts[-2] in ["warehouses", "endpoints"]: + warehouse_id = path_parts[-1] + logger.debug( + f"Extracted warehouse ID: {warehouse_id} from path: {http_path}" + ) + + if not warehouse_id: + error_message = ( + f"Could not extract warehouse ID from http_path: {http_path}. " + f"Expected format: /path/to/warehouses/{{warehouse_id}} or " + f"/path/to/endpoints/{{warehouse_id}}" + ) + logger.error(error_message) + raise ValueError(error_message) + + return warehouse_id + + @property + def staging_allowed_local_path(self) -> Union[None, str, List[str]]: + """Get the allowed local paths for staging operations.""" + return self._staging_allowed_local_path + + @property + def ssl_options(self) -> SSLOptions: + """Get the SSL options for this client.""" + return self._ssl_options + + @property + def max_download_threads(self) -> int: + """Get the maximum number of download threads for cloud fetch operations.""" + return self._max_download_threads + + def open_session( + self, + session_configuration: Optional[Dict[str, str]], + catalog: Optional[str], + schema: Optional[str], + ) -> SessionId: + """ + Opens a new session with the Databricks SQL service using SEA. + + Args: + session_configuration: Optional dictionary of configuration parameters for the session + catalog: Optional catalog name to use as the initial catalog for the session + schema: Optional schema name to use as the initial schema for the session + + Returns: + SessionId: A session identifier object that can be used for subsequent operations + + Raises: + Error: If the session configuration is invalid + OperationalError: If there's an error establishing the session + """ + logger.debug( + "SEADatabricksClient.open_session(session_configuration=%s, catalog=%s, schema=%s)", + session_configuration, + catalog, + schema, + ) + + request_data: Dict[str, Any] = {"warehouse_id": self.warehouse_id} + if session_configuration: + request_data["session_confs"] = session_configuration + if catalog: + request_data["catalog"] = catalog + if schema: + request_data["schema"] = schema + + response = self.http_client._make_request( + method="POST", path=self.SESSION_PATH, data=request_data + ) + + session_id = response.get("session_id") + if not session_id: + raise Error("Failed to create session: No session ID returned") + + return SessionId.from_sea_session_id(session_id) + + def close_session(self, session_id: SessionId) -> None: + """ + Closes an existing session with the Databricks SQL service. + + Args: + session_id: The session identifier returned by open_session() + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error closing the session + """ + logger.debug("SEADatabricksClient.close_session(session_id=%s)", session_id) + + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + sea_session_id = session_id.to_sea_session_id() + + request_data = {"warehouse_id": self.warehouse_id} + + self.http_client._make_request( + method="DELETE", + path=self.SESSION_PATH_WITH_ID.format(sea_session_id), + data=request_data, + ) + + # == Not Implemented Operations == + # These methods will be implemented in future iterations + + def execute_command( + self, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: "Cursor", + use_cloud_fetch: bool, + parameters: List[ttypes.TSparkParameter], + async_op: bool, + enforce_embedded_schema_correctness: bool, + ): + """Not implemented yet.""" + raise NotSupportedError( + "execute_command is not yet implemented for SEA backend" + ) + + def cancel_command(self, command_id: CommandId) -> None: + """Not implemented yet.""" + raise NotSupportedError("cancel_command is not yet implemented for SEA backend") + + def close_command(self, command_id: CommandId) -> None: + """Not implemented yet.""" + raise NotSupportedError("close_command is not yet implemented for SEA backend") + + def get_query_state(self, command_id: CommandId) -> CommandState: + """Not implemented yet.""" + raise NotSupportedError( + "get_query_state is not yet implemented for SEA backend" + ) + + def get_execution_result( + self, + command_id: CommandId, + cursor: "Cursor", + ): + """Not implemented yet.""" + raise NotSupportedError( + "get_execution_result is not yet implemented for SEA backend" + ) + + # == Metadata Operations == + + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + ): + """Not implemented yet.""" + raise NotSupportedError("get_catalogs is not yet implemented for SEA backend") + + def get_schemas( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + ): + """Not implemented yet.""" + raise NotSupportedError("get_schemas is not yet implemented for SEA backend") + + def get_tables( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + table_types: Optional[List[str]] = None, + ): + """Not implemented yet.""" + raise NotSupportedError("get_tables is not yet implemented for SEA backend") + + def get_columns( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + column_name: Optional[str] = None, + ): + """Not implemented yet.""" + raise NotSupportedError("get_columns is not yet implemented for SEA backend") diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py similarity index 82% rename from src/databricks/sql/thrift_backend.py rename to src/databricks/sql/backend/thrift_backend.py index e3dc38ad5..2a9653cab 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -5,9 +5,19 @@ import time import uuid import threading -from typing import List, Union +from typing import List, Union, Any, TYPE_CHECKING + +if TYPE_CHECKING: + from databricks.sql.client import Cursor from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState +from databricks.sql.backend.types import ( + CommandState, + SessionId, + CommandId, + BackendType, + guid_to_hex_id, +) try: import pyarrow @@ -41,6 +51,8 @@ convert_column_based_set_to_arrow_table, ) from databricks.sql.types import SSLOptions +from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.result_set import ResultSet, ThriftResultSet logger = logging.getLogger(__name__) @@ -73,7 +85,7 @@ } -class ThriftBackend: +class ThriftDatabricksClient(DatabricksClient): CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE @@ -150,7 +162,7 @@ def __init__( else: raise ValueError("No valid connection settings.") - self.staging_allowed_local_path = staging_allowed_local_path + self._staging_allowed_local_path = staging_allowed_local_path self._initialize_retry_args(kwargs) self._use_arrow_native_complex_types = kwargs.get( "_use_arrow_native_complex_types", True @@ -161,7 +173,7 @@ def __init__( ) # Cloud fetch - self.max_download_threads = kwargs.get("max_download_threads", 10) + self._max_download_threads = kwargs.get("max_download_threads", 10) self._ssl_options = ssl_options @@ -224,6 +236,18 @@ def __init__( self._request_lock = threading.RLock() + @property + def staging_allowed_local_path(self) -> Union[None, str, List[str]]: + return self._staging_allowed_local_path + + @property + def ssl_options(self) -> SSLOptions: + return self._ssl_options + + @property + def max_download_threads(self) -> int: + return self._max_download_threads + # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): # Configure retries & timing: use user-settings or defaults, and bound @@ -446,8 +470,10 @@ def attempt_request(attempt): logger.error("ThriftBackend.attempt_request: Exception: %s", err) error = err retry_delay = extract_retry_delay(attempt) - error_message = ThriftBackend._extract_error_message_from_headers( - getattr(self._transport, "headers", {}) + error_message = ( + ThriftDatabricksClient._extract_error_message_from_headers( + getattr(self._transport, "headers", {}) + ) ) finally: # Calling `close()` here releases the active HTTP connection back to the pool @@ -483,7 +509,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response) + ThriftDatabricksClient._check_response_for_error(response) return response error_info = response_or_error_info @@ -534,7 +560,7 @@ def _check_session_configuration(self, session_configuration): ) ) - def open_session(self, session_configuration, catalog, schema): + def open_session(self, session_configuration, catalog, schema) -> SessionId: try: self._transport.open() session_configuration = { @@ -562,13 +588,22 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) - return response + properties = ( + {"serverProtocolVersion": response.serverProtocolVersion} + if response.serverProtocolVersion + else {} + ) + return SessionId.from_thrift_handle(response.sessionHandle, properties) except: self._transport.close() raise - def close_session(self, session_handle) -> None: - req = ttypes.TCloseSessionReq(sessionHandle=session_handle) + def close_session(self, session_id: SessionId) -> None: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") + + req = ttypes.TCloseSessionReq(sessionHandle=thrift_handle) try: self.make_request(self._client.CloseSession, req) finally: @@ -583,7 +618,7 @@ def _check_command_not_in_error_or_closed_state( get_operations_resp.displayMessage, { "operation-id": op_handle - and self.guid_to_hex_id(op_handle.operationId.guid), + and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, ) @@ -592,18 +627,18 @@ def _check_command_not_in_error_or_closed_state( get_operations_resp.errorMessage, { "operation-id": op_handle - and self.guid_to_hex_id(op_handle.operationId.guid), + and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( "Command {} unexpectedly closed server side".format( - op_handle and self.guid_to_hex_id(op_handle.operationId.guid) + op_handle and guid_to_hex_id(op_handle.operationId.guid) ), { "operation-id": op_handle - and self.guid_to_hex_id(op_handle.operationId.guid) + and guid_to_hex_id(op_handle.operationId.guid) }, ) @@ -707,7 +742,8 @@ def _col_to_description(col): @staticmethod def _hive_schema_to_description(t_table_schema): return [ - ThriftBackend._col_to_description(col) for col in t_table_schema.columns + ThriftDatabricksClient._col_to_description(col) + for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -767,6 +803,9 @@ def _results_message_to_execute_response(self, resp, operation_state): ) else: arrow_queue_opt = None + + command_id = CommandId.from_thrift_handle(resp.operationHandle) + return ExecuteResponse( arrow_queue=arrow_queue_opt, status=operation_state, @@ -774,21 +813,24 @@ def _results_message_to_execute_response(self, resp, operation_state): has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_handle=resp.operationHandle, + command_id=command_id, description=description, arrow_schema_bytes=schema_bytes, ) - def get_execution_result(self, op_handle, cursor): - - assert op_handle is not None + def get_execution_result( + self, command_id: CommandId, cursor: "Cursor" + ) -> "ResultSet": + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") req = ttypes.TFetchResultsReq( operationHandle=ttypes.TOperationHandle( - op_handle.operationId, - op_handle.operationType, + thrift_handle.operationId, + thrift_handle.operationType, False, - op_handle.modifiedRowCount, + thrift_handle.modifiedRowCount, ), maxRows=cursor.arraysize, maxBytes=cursor.buffer_size_bytes, @@ -827,18 +869,27 @@ def get_execution_result(self, op_handle, cursor): ssl_options=self._ssl_options, ) - return ExecuteResponse( + execute_response = ExecuteResponse( arrow_queue=queue, status=resp.status, has_been_closed_server_side=False, has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_handle=op_handle, + command_id=command_id, description=description, arrow_schema_bytes=schema_bytes, ) + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) + def _wait_until_command_done(self, op_handle, initial_operation_status_resp): if initial_operation_status_resp: self._check_command_not_in_error_or_closed_state( @@ -857,51 +908,57 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) return operation_state - def get_query_state(self, op_handle) -> "TOperationState": - poll_resp = self._poll_for_status(op_handle) + def get_query_state(self, command_id: CommandId) -> CommandState: + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") + + poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState - self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) - return operation_state + self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) + return CommandState.from_thrift_state(operation_state) @staticmethod def _check_direct_results_for_error(t_spark_direct_results): if t_spark_direct_results: if t_spark_direct_results.operationStatus: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.operationStatus ) if t_spark_direct_results.resultSetMetadata: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSetMetadata ) if t_spark_direct_results.resultSet: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSet ) if t_spark_direct_results.closeOperation: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.closeOperation ) def execute_command( self, - operation, - session_handle, - max_rows, - max_bytes, - lz4_compression, - cursor, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: "Cursor", use_cloud_fetch=True, parameters=[], async_op=False, enforce_embedded_schema_correctness=False, - ): - assert session_handle is not None + ) -> Union["ResultSet", None]: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") logger.debug( "ThriftBackend.execute_command(operation=%s, session_handle=%s)", operation, - session_handle, + thrift_handle, ) spark_arrow_types = ttypes.TSparkArrowTypes( @@ -913,7 +970,7 @@ def execute_command( intervalTypesAsArrow=False, ) req = ttypes.TExecuteStatementReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, statement=operation, runAsync=True, # For async operation we don't want the direct results @@ -938,34 +995,64 @@ def execute_command( if async_op: self._handle_execute_response_async(resp, cursor) + return None else: - return self._handle_execute_response(resp, cursor) + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=use_cloud_fetch, + ) - def get_catalogs(self, session_handle, max_rows, max_bytes, cursor): - assert session_handle is not None + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + ) -> "ResultSet": + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetCatalogsReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), ) resp = self.make_request(self._client.GetCatalogs, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_schemas( self, - session_handle, - max_rows, - max_bytes, - cursor, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", catalog_name=None, schema_name=None, - ): - assert session_handle is not None + ) -> "ResultSet": + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetSchemasReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -973,23 +1060,35 @@ def get_schemas( schemaName=schema_name, ) resp = self.make_request(self._client.GetSchemas, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_tables( self, - session_handle, - max_rows, - max_bytes, - cursor, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", catalog_name=None, schema_name=None, table_name=None, table_types=None, - ): - assert session_handle is not None + ) -> "ResultSet": + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetTablesReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -999,23 +1098,35 @@ def get_tables( tableTypes=table_types, ) resp = self.make_request(self._client.GetTables, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_columns( self, - session_handle, - max_rows, - max_bytes, - cursor, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", catalog_name=None, schema_name=None, table_name=None, column_name=None, - ): - assert session_handle is not None + ) -> "ResultSet": + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetColumnsReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -1025,10 +1136,22 @@ def get_columns( columnName=column_name, ) resp = self.make_request(self._client.GetColumns, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def _handle_execute_response(self, resp, cursor): - cursor.active_op_handle = resp.operationHandle + command_id = CommandId.from_thrift_handle(resp.operationHandle) + + cursor.active_command_id = command_id self._check_direct_results_for_error(resp.directResults) final_operation_state = self._wait_until_command_done( @@ -1036,31 +1159,38 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - return self._results_message_to_execute_response(resp, final_operation_state) + execute_response = self._results_message_to_execute_response( + resp, final_operation_state + ) + execute_response = execute_response._replace(command_id=command_id) + return execute_response def _handle_execute_response_async(self, resp, cursor): - cursor.active_op_handle = resp.operationHandle + command_id = CommandId.from_thrift_handle(resp.operationHandle) + cursor.active_command_id = command_id self._check_direct_results_for_error(resp.directResults) def fetch_results( self, - op_handle, - max_rows, - max_bytes, - expected_row_start_offset, - lz4_compressed, + command_id: CommandId, + max_rows: int, + max_bytes: int, + expected_row_start_offset: int, + lz4_compressed: bool, arrow_schema_bytes, description, use_cloud_fetch=True, ): - assert op_handle is not None + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") req = ttypes.TFetchResultsReq( operationHandle=ttypes.TOperationHandle( - op_handle.operationId, - op_handle.operationType, + thrift_handle.operationId, + thrift_handle.operationType, False, - op_handle.modifiedRowCount, + thrift_handle.modifiedRowCount, ), maxRows=max_rows, maxBytes=max_bytes, @@ -1089,46 +1219,21 @@ def fetch_results( return queue, resp.hasMoreRows - def close_command(self, op_handle): - logger.debug("ThriftBackend.close_command(op_handle=%s)", op_handle) - req = ttypes.TCloseOperationReq(operationHandle=op_handle) - resp = self.make_request(self._client.CloseOperation, req) - return resp.status + def cancel_command(self, command_id: CommandId) -> None: + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") - def cancel_command(self, active_op_handle): - logger.debug( - "Cancelling command {}".format( - self.guid_to_hex_id(active_op_handle.operationId.guid) - ) - ) - req = ttypes.TCancelOperationReq(active_op_handle) + logger.debug("Cancelling command {}".format(command_id.guid)) + req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) - @staticmethod - def handle_to_id(session_handle): - return session_handle.sessionId.guid - - @staticmethod - def handle_to_hex_id(session_handle: TCLIService.TSessionHandle): - this_uuid = uuid.UUID(bytes=session_handle.sessionId.guid) - return str(this_uuid) - - @staticmethod - def guid_to_hex_id(guid: bytes) -> str: - """Return a hexadecimal string instead of bytes - - Example: - IN b'\x01\xee\x1d)\xa4\x19\x1d\xb6\xa9\xc0\x8d\xf1\xfe\xbaB\xdd' - OUT '01ee1d29-a419-1db6-a9c0-8df1feba42dd' + def close_command(self, command_id: CommandId) -> None: + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") - If conversion to hexadecimal fails, the original bytes are returned - """ - - this_uuid: Union[bytes, uuid.UUID] - - try: - this_uuid = uuid.UUID(bytes=guid) - except Exception as e: - logger.debug(f"Unable to convert bytes to UUID: {bytes} -- {str(e)}") - this_uuid = guid - return str(this_uuid) + logger.debug("ThriftBackend.close_command(command_id=%s)", command_id) + req = ttypes.TCloseOperationReq(operationHandle=thrift_handle) + resp = self.make_request(self._client.CloseOperation, req) + return resp.status diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py new file mode 100644 index 000000000..afab46e28 --- /dev/null +++ b/src/databricks/sql/backend/types.py @@ -0,0 +1,329 @@ +from enum import Enum +from typing import Dict, Optional, Any, Union +import uuid +import logging + +from databricks.sql.thrift_api.TCLIService import ttypes + +logger = logging.getLogger(__name__) + + +class CommandState(Enum): + PENDING = "PENDING" + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + CLOSED = "CLOSED" + CANCELLED = "CANCELLED" + + @classmethod + def from_thrift_state(cls, state: ttypes.TOperationState) -> "CommandState": + if state in ( + ttypes.TOperationState.INITIALIZED_STATE, + ttypes.TOperationState.PENDING_STATE, + ): + return cls.PENDING + elif state == ttypes.TOperationState.RUNNING_STATE: + return cls.RUNNING + elif state == ttypes.TOperationState.FINISHED_STATE: + return cls.SUCCEEDED + elif state in ( + ttypes.TOperationState.ERROR_STATE, + ttypes.TOperationState.TIMEDOUT_STATE, + ttypes.TOperationState.UKNOWN_STATE, + ): + return cls.FAILED + elif state == ttypes.TOperationState.CLOSED_STATE: + return cls.CLOSED + elif state == ttypes.TOperationState.CANCELED_STATE: + return cls.CANCELLED + else: + raise ValueError(f"Unknown command state: {state}") + + +def guid_to_hex_id(guid: bytes) -> str: + """Return a hexadecimal string instead of bytes + + Example: + IN b'\x01\xee\x1d)\xa4\x19\x1d\xb6\xa9\xc0\x8d\xf1\xfe\xbaB\xdd' + OUT '01ee1d29-a419-1db6-a9c0-8df1feba42dd' + + If conversion to hexadecimal fails, the original bytes are returned + """ + try: + this_uuid = uuid.UUID(bytes=guid) + except Exception as e: + logger.debug(f"Unable to convert bytes to UUID: {guid!r} -- {str(e)}") + return str(guid) + return str(this_uuid) + + +class BackendType(Enum): + """Enum representing the type of backend.""" + + THRIFT = "thrift" + SEA = "sea" + + +class SessionId: + """ + A normalized session identifier that works with both Thrift and SEA backends. + + This class abstracts away the differences between Thrift's TSessionHandle and + SEA's session ID string, providing a consistent interface for the connector. + """ + + def __init__( + self, + backend_type: BackendType, + guid: Any, + secret: Optional[Any] = None, + properties: Optional[Dict[str, Any]] = None, + ): + """ + Initialize a SessionId. + + Args: + backend_type: The type of backend (THRIFT or SEA) + guid: The primary identifier for the session + secret: The secret part of the identifier (only used for Thrift) + info: Additional information about the session + """ + self.backend_type = backend_type + self.guid = guid + self.secret = secret + self.properties = properties or {} + + def __str__(self) -> str: + """ + Return a string representation of the SessionId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the session ID + """ + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + return f"{self.get_hex_id()}|{guid_to_hex_id(self.secret) if isinstance(self.secret, bytes) else str(self.secret)}" + return str(self.guid) + + @classmethod + def from_thrift_handle( + cls, session_handle, properties: Optional[Dict[str, Any]] = None + ): + """ + Create a SessionId from a Thrift session handle. + + Args: + session_handle: A TSessionHandle object from the Thrift API + + Returns: + A SessionId instance + """ + if session_handle is None: + return None + + guid_bytes = session_handle.sessionId.guid + secret_bytes = session_handle.sessionId.secret + + if session_handle.serverProtocolVersion is not None: + if properties is None: + properties = {} + properties["serverProtocolVersion"] = session_handle.serverProtocolVersion + + return cls(BackendType.THRIFT, guid_bytes, secret_bytes, properties) + + @classmethod + def from_sea_session_id( + cls, session_id: str, properties: Optional[Dict[str, Any]] = None + ): + """ + Create a SessionId from a SEA session ID. + + Args: + session_id: The SEA session ID string + + Returns: + A SessionId instance + """ + return cls(BackendType.SEA, session_id, properties=properties) + + def to_thrift_handle(self): + """ + Convert this SessionId to a Thrift TSessionHandle. + + Returns: + A TSessionHandle object or None if this is not a Thrift session ID + """ + if self.backend_type != BackendType.THRIFT: + return None + + from databricks.sql.thrift_api.TCLIService import ttypes + + handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) + server_protocol_version = self.properties.get("serverProtocolVersion") + return ttypes.TSessionHandle( + sessionId=handle_identifier, serverProtocolVersion=server_protocol_version + ) + + def to_sea_session_id(self): + """ + Get the SEA session ID string. + + Returns: + The session ID string or None if this is not a SEA session ID + """ + if self.backend_type != BackendType.SEA: + return None + + return self.guid + + def get_id(self) -> Any: + """ + Get the ID of the session. + """ + return self.guid + + def get_hex_id(self) -> str: + """ + Get a hexadecimal string representation of the session ID. + + Returns: + A hexadecimal string representation + """ + if isinstance(self.guid, bytes): + return guid_to_hex_id(self.guid) + else: + return str(self.guid) + + def get_protocol_version(self): + """ + Get the server protocol version for this session. + + Returns: + The server protocol version or None if it does not exist + It is not expected to exist for SEA sessions. + """ + return self.properties.get("serverProtocolVersion") + + +class CommandId: + """ + A normalized command identifier that works with both Thrift and SEA backends. + + This class abstracts away the differences between Thrift's TOperationHandle and + SEA's statement ID string, providing a consistent interface for the connector. + """ + + def __init__( + self, + backend_type: BackendType, + guid: Any, + secret: Optional[Any] = None, + operation_type: Optional[int] = None, + has_result_set: bool = False, + modified_row_count: Optional[int] = None, + ): + """ + Initialize a CommandId. + + Args: + backend_type: The type of backend (THRIFT or SEA) + guid: The primary identifier for the command + secret: The secret part of the identifier (only used for Thrift) + operation_type: The operation type (only used for Thrift) + has_result_set: Whether the command has a result set + modified_row_count: The number of rows modified by the command + """ + self.backend_type = backend_type + self.guid = guid + self.secret = secret + self.operation_type = operation_type + self.has_result_set = has_result_set + self.modified_row_count = modified_row_count + + @classmethod + def from_thrift_handle(cls, operation_handle): + """ + Create a CommandId from a Thrift operation handle. + + Args: + operation_handle: A TOperationHandle object from the Thrift API + + Returns: + A CommandId instance + """ + if operation_handle is None: + return None + + guid_bytes = operation_handle.operationId.guid + secret_bytes = operation_handle.operationId.secret + + return cls( + BackendType.THRIFT, + guid_bytes, + secret_bytes, + operation_handle.operationType, + operation_handle.hasResultSet, + operation_handle.modifiedRowCount, + ) + + @classmethod + def from_sea_statement_id(cls, statement_id: str): + """ + Create a CommandId from a SEA statement ID. + + Args: + statement_id: The SEA statement ID string + + Returns: + A CommandId instance + """ + return cls(BackendType.SEA, statement_id) + + def to_thrift_handle(self): + """ + Convert this CommandId to a Thrift TOperationHandle. + + Returns: + A TOperationHandle object or None if this is not a Thrift command ID + """ + if self.backend_type != BackendType.THRIFT: + return None + + from databricks.sql.thrift_api.TCLIService import ttypes + + handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) + return ttypes.TOperationHandle( + operationId=handle_identifier, + operationType=self.operation_type, + hasResultSet=self.has_result_set, + modifiedRowCount=self.modified_row_count, + ) + + def to_sea_statement_id(self): + """ + Get the SEA statement ID string. + + Returns: + The statement ID string or None if this is not a SEA statement ID + """ + if self.backend_type != BackendType.SEA: + return None + + return self.guid + + def to_hex_id(self) -> str: + """ + Get a hexadecimal string representation of the command ID. + + Returns: + A hexadecimal string representation + """ + if isinstance(self.guid, bytes): + return guid_to_hex_id(self.guid) + else: + return str(self.guid) diff --git a/src/databricks/sql/backend/utils/http_client.py b/src/databricks/sql/backend/utils/http_client.py new file mode 100644 index 000000000..8cc229850 --- /dev/null +++ b/src/databricks/sql/backend/utils/http_client.py @@ -0,0 +1,172 @@ +import json +import logging +import requests +from typing import Dict, Any, Optional, Union, List +from urllib.parse import urljoin + +from databricks.sql.auth.authenticators import AuthProvider +from databricks.sql.types import SSLOptions + +logger = logging.getLogger(__name__) + + +class CustomHttpClient: + """ + HTTP client for Statement Execution API (SEA). + + This client handles the HTTP communication with the SEA endpoints, + including authentication, request formatting, and response parsing. + """ + + def __init__( + self, + server_hostname: str, + port: int, + http_path: str, + http_headers: List[tuple], + auth_provider: AuthProvider, + ssl_options: SSLOptions, + **kwargs, + ): + """ + Initialize the SEA HTTP client. + + Args: + server_hostname: Hostname of the Databricks server + port: Port number for the connection + http_path: HTTP path for the connection + http_headers: List of HTTP headers to include in requests + auth_provider: Authentication provider + ssl_options: SSL configuration options + **kwargs: Additional keyword arguments + """ + self.server_hostname = server_hostname + self.port = port + self.http_path = http_path + self.auth_provider = auth_provider + self.ssl_options = ssl_options + + self.base_url = f"https://{server_hostname}:{port}" + + self.headers = dict(http_headers) + self.headers.update({"Content-Type": "application/json"}) + + self.max_retries = kwargs.get("_retry_stop_after_attempts_count", 30) + + # Create a session for connection pooling + self.session = requests.Session() + + # Configure SSL verification + if ssl_options.tls_verify: + self.session.verify = ssl_options.tls_trusted_ca_file or True + else: + self.session.verify = False + + # Configure client certificates if provided + if ssl_options.tls_client_cert_file: + client_cert = ssl_options.tls_client_cert_file + client_key = ssl_options.tls_client_cert_key_file + client_key_password = ssl_options.tls_client_cert_key_password + + if client_key: + self.session.cert = (client_cert, client_key) + else: + self.session.cert = client_cert + + if client_key_password: + # Note: requests doesn't directly support key passwords + # This would require more complex handling with libraries like pyOpenSSL + logger.warning( + "Client key password provided but not supported by requests library" + ) + + def _get_auth_headers(self) -> Dict[str, str]: + """Get authentication headers from the auth provider.""" + headers: Dict[str, str] = {} + self.auth_provider.add_headers(headers) + return headers + + def _make_request( + self, method: str, path: str, data: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Make an HTTP request to the SEA endpoint. + + Args: + method: HTTP method (GET, POST, DELETE) + path: API endpoint path + data: Request payload data + + Returns: + Dict[str, Any]: Response data parsed from JSON + + Raises: + RequestError: If the request fails + """ + url = urljoin(self.base_url, path) + headers = {**self.headers, **self._get_auth_headers()} + + logger.debug(f"making {method} request to {url}") + + try: + if method.upper() == "GET": + response = self.session.get(url, headers=headers, params=data) + elif method.upper() == "POST": + response = self.session.post(url, headers=headers, json=data) + elif method.upper() == "DELETE": + # For DELETE requests, use params for data (query parameters) + response = self.session.delete(url, headers=headers, params=data) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + # Check for HTTP errors + response.raise_for_status() + + # Log response details + logger.debug(f"Response status: {response.status_code}") + + # Parse JSON response + if response.content: + result = response.json() + # Log response content (but limit it for large responses) + content_str = json.dumps(result) + if len(content_str) > 1000: + logger.debug( + f"Response content (truncated): {content_str[:1000]}..." + ) + else: + logger.debug(f"Response content: {content_str}") + return result + return {} + + except requests.exceptions.RequestException as e: + # Handle request errors + error_message = f"SEA HTTP request failed: {str(e)}" + logger.error(error_message) + + # Extract error details from response if available + if hasattr(e, "response") and e.response is not None: + try: + error_details = e.response.json() + error_message = ( + f"{error_message}: {error_details.get('message', '')}" + ) + logger.error( + f"Response status: {e.response.status_code}, Error details: {error_details}" + ) + except (ValueError, KeyError): + # If we can't parse the JSON, just log the raw content + content_str = ( + e.response.content.decode("utf-8", errors="replace") + if isinstance(e.response.content, bytes) + else str(e.response.content) + ) + logger.error( + f"Response status: {e.response.status_code}, Raw content: {content_str}" + ) + pass + + # Re-raise as a RequestError + from databricks.sql.exc import RequestError + + raise RequestError(error_message, e) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index d6a9e6b08..ab52896c9 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -21,7 +21,8 @@ CursorAlreadyClosedError, ) from databricks.sql.thrift_api.TCLIService import ttypes -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.utils import ( ExecuteResponse, ParamEscaper, @@ -41,11 +42,12 @@ ParameterApproach, ) - +from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.types import Row, SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence from databricks.sql.session import Session +from databricks.sql.backend.types import CommandId, BackendType, CommandState, SessionId from databricks.sql.thrift_api.TCLIService.ttypes import ( TSparkParameter, @@ -230,7 +232,6 @@ def read(self) -> Optional[OAuthToken]: self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) self._cursors = [] # type: List[Cursor] - # Create the session self.session = Session( server_hostname, http_path, @@ -243,11 +244,6 @@ def read(self) -> Optional[OAuthToken]: ) self.session.open() - logger.info( - "Successfully opened connection with session " - + str(self.get_session_id_hex()) - ) - self.use_inline_params = self._set_use_inline_params_with_warning( kwargs.get("use_inline_params", False) ) @@ -305,11 +301,11 @@ def __del__(self): logger.debug("Couldn't close unclosed connection: {}".format(e.message)) def get_session_id(self): - """Get the session ID from the Session object""" + """Get the raw session ID (backend-specific)""" return self.session.get_id() def get_session_id_hex(self): - """Get the session ID in hex format from the Session object""" + """Get the session ID in hex format""" return self.session.get_id_hex() @staticmethod @@ -323,9 +319,9 @@ def protocol_version(self): return self.session.protocol_version @staticmethod - def get_protocol_version(openSessionResp): + def get_protocol_version(session_id: SessionId): """Delegate to Session class static method""" - return Session.get_protocol_version(openSessionResp) + return Session.get_protocol_version(session_id) @property def open(self) -> bool: @@ -347,7 +343,7 @@ def cursor( cursor = Cursor( self, - self.session.thrift_backend, + self.session.backend, arraysize=arraysize, result_buffer_size_bytes=buffer_size_bytes, ) @@ -380,7 +376,7 @@ class Cursor: def __init__( self, connection: Connection, - thrift_backend: ThriftBackend, + backend: DatabricksClient, result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, arraysize: int = DEFAULT_ARRAY_SIZE, ) -> None: @@ -399,8 +395,8 @@ def __init__( # Note that Cursor closed => active result set closed, but not vice versa self.open = True self.executing_command_id = None - self.thrift_backend = thrift_backend - self.active_op_handle = None + self.backend = backend + self.active_command_id = None self.escaper = ParamEscaper() self.lastrowid = None @@ -774,9 +770,9 @@ def execute( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.execute_command( + self.active_result_set = self.backend.execute_command( operation=prepared_operation, - session_handle=self.connection.session.get_handle(), + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -786,18 +782,10 @@ def execute( async_op=False, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.thrift_backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) - if execute_response.is_staging_operation: + if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( - staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path + staging_allowed_local_path=self.backend.staging_allowed_local_path ) return self @@ -837,9 +825,9 @@ def execute_async( self._check_not_closed() self._close_and_clear_active_result_set() - self.thrift_backend.execute_command( + self.backend.execute_command( operation=prepared_operation, - session_handle=self.connection.session.get_handle(), + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -852,14 +840,16 @@ def execute_async( return self - def get_query_state(self) -> "TOperationState": + def get_query_state(self) -> CommandState: """ Get the state of the async executing query or basically poll the status of the query :return: """ self._check_not_closed() - return self.thrift_backend.get_query_state(self.active_op_handle) + if self.active_command_id is None: + raise Error("No active command to get state for") + return self.backend.get_query_state(self.active_command_id) def is_query_pending(self): """ @@ -868,11 +858,7 @@ def is_query_pending(self): :return: """ operation_state = self.get_query_state() - - return not operation_state or operation_state in [ - ttypes.TOperationState.RUNNING_STATE, - ttypes.TOperationState.PENDING_STATE, - ] + return operation_state in [CommandState.PENDING, CommandState.RUNNING] def get_async_execution_result(self): """ @@ -888,21 +874,14 @@ def get_async_execution_result(self): time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL) operation_state = self.get_query_state() - if operation_state == ttypes.TOperationState.FINISHED_STATE: - execute_response = self.thrift_backend.get_execution_result( + if operation_state == CommandState.SUCCEEDED: + self.active_result_set = self.backend.get_execution_result( self.active_op_handle, self ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.thrift_backend, - self.buffer_size_bytes, - self.arraysize, - ) - if execute_response.is_staging_operation: + if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( - staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path + staging_allowed_local_path=self.backend.staging_allowed_local_path ) return self @@ -934,19 +913,12 @@ def catalogs(self) -> "Cursor": """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_catalogs( - session_handle=self.connection.session.get_handle(), + self.active_result_set = self.backend.get_catalogs( + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.thrift_backend, - self.buffer_size_bytes, - self.arraysize, - ) return self def schemas( @@ -960,21 +932,14 @@ def schemas( """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_schemas( - session_handle=self.connection.session.get_handle(), + self.active_result_set = self.backend.get_schemas( + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, catalog_name=catalog_name, schema_name=schema_name, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.thrift_backend, - self.buffer_size_bytes, - self.arraysize, - ) return self def tables( @@ -993,8 +958,8 @@ def tables( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_tables( - session_handle=self.connection.session.get_handle(), + self.active_result_set = self.backend.get_tables( + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1003,13 +968,6 @@ def tables( table_name=table_name, table_types=table_types, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.thrift_backend, - self.buffer_size_bytes, - self.arraysize, - ) return self def columns( @@ -1028,8 +986,8 @@ def columns( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_columns( - session_handle=self.connection.session.get_handle(), + self.active_result_set = self.backend.get_columns( + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1038,13 +996,6 @@ def columns( table_name=table_name, column_name=column_name, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.thrift_backend, - self.buffer_size_bytes, - self.arraysize, - ) return self def fetchall(self) -> List[Row]: @@ -1117,8 +1068,8 @@ def cancel(self) -> None: The command should be closed to free resources from the server. This method can be called from another thread. """ - if self.active_op_handle is not None: - self.thrift_backend.cancel_command(self.active_op_handle) + if self.active_command_id is not None: + self.backend.cancel_command(self.active_command_id) else: logger.warning( "Attempting to cancel a command, but there is no " @@ -1130,9 +1081,9 @@ def close(self) -> None: self.open = False # Close active operation handle if it exists - if self.active_op_handle: + if self.active_command_id: try: - self.thrift_backend.close_command(self.active_op_handle) + self.backend.close_command(self.active_command_id) except RequestError as e: if isinstance(e.args[1], CursorAlreadyClosedError): logger.info("Operation was canceled by a prior request") @@ -1141,7 +1092,7 @@ def close(self) -> None: except Exception as e: logging.warning(f"Error closing operation handle: {e}") finally: - self.active_op_handle = None + self.active_command_id = None if self.active_result_set: self._close_and_clear_active_result_set() @@ -1154,8 +1105,8 @@ def query_id(self) -> Optional[str]: This attribute will be ``None`` if the cursor has not had an operation invoked via the execute method yet, or if cursor was closed. """ - if self.active_op_handle is not None: - return str(UUID(bytes=self.active_op_handle.operationId.guid)) + if self.active_command_id is not None: + return self.active_command_id.to_hex_id() return None @property @@ -1200,301 +1151,3 @@ def setinputsizes(self, sizes): def setoutputsize(self, size, column=None): """Does nothing by default""" pass - - -class ResultSet: - def __init__( - self, - connection: Connection, - execute_response: ExecuteResponse, - thrift_backend: ThriftBackend, - result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, - arraysize: int = 10000, - use_cloud_fetch: bool = True, - ): - """ - A ResultSet manages the results of a single command. - - :param connection: The parent connection that was used to execute this command - :param execute_response: A `ExecuteResponse` class returned by a command execution - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - amount :param arraysize: The max number of rows to fetch at a time (PEP-249) - """ - self.connection = connection - self.command_id = execute_response.command_handle - self.op_state = execute_response.status - self.has_been_closed_server_side = execute_response.has_been_closed_server_side - self.has_more_rows = execute_response.has_more_rows - self.buffer_size_bytes = result_buffer_size_bytes - self.lz4_compressed = execute_response.lz4_compressed - self.arraysize = arraysize - self.thrift_backend = thrift_backend - self.description = execute_response.description - self._arrow_schema_bytes = execute_response.arrow_schema_bytes - self._next_row_index = 0 - self._use_cloud_fetch = use_cloud_fetch - - if execute_response.arrow_queue: - # In this case the server has taken the fast path and returned an initial batch of - # results - self.results = execute_response.arrow_queue - else: - # In this case, there are results waiting on the server so we fetch now for simplicity - self._fill_results_buffer() - - def __iter__(self): - while True: - row = self.fetchone() - if row: - yield row - else: - break - - def _fill_results_buffer(self): - # At initialization or if the server does not have cloud fetch result links available - results, has_more_rows = self.thrift_backend.fetch_results( - op_handle=self.command_id, - max_rows=self.arraysize, - max_bytes=self.buffer_size_bytes, - expected_row_start_offset=self._next_row_index, - lz4_compressed=self.lz4_compressed, - arrow_schema_bytes=self._arrow_schema_bytes, - description=self.description, - use_cloud_fetch=self._use_cloud_fetch, - ) - self.results = results - self.has_more_rows = has_more_rows - - def _convert_columnar_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - result = [] - for row_index in range(table.num_rows): - curr_row = [] - for col_index in range(table.num_columns): - curr_row.append(table.get_item(col_index, row_index)) - result.append(ResultRow(*curr_row)) - - return result - - def _convert_arrow_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - - if self.connection.disable_pandas is True: - return [ - ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) - ] - - # Need to use nullable types, as otherwise type can change when there are missing values. - # See https://arrow.apache.org/docs/python/pandas.html#nullable-types - # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { - pyarrow.int8(): pandas.Int8Dtype(), - pyarrow.int16(): pandas.Int16Dtype(), - pyarrow.int32(): pandas.Int32Dtype(), - pyarrow.int64(): pandas.Int64Dtype(), - pyarrow.uint8(): pandas.UInt8Dtype(), - pyarrow.uint16(): pandas.UInt16Dtype(), - pyarrow.uint32(): pandas.UInt32Dtype(), - pyarrow.uint64(): pandas.UInt64Dtype(), - pyarrow.bool_(): pandas.BooleanDtype(), - pyarrow.float32(): pandas.Float32Dtype(), - pyarrow.float64(): pandas.Float64Dtype(), - pyarrow.string(): pandas.StringDtype(), - } - - # Need to rename columns, as the to_pandas function cannot handle duplicate column names - table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) - - res = df.to_numpy(na_value=None, dtype="object") - return [ResultRow(*v) for v in res] - - @property - def rownumber(self): - return self._next_row_index - - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": - """ - Fetch the next set of rows of a query result, returning a PyArrow table. - - An empty sequence is returned when no more rows are available. - """ - if size < 0: - raise ValueError("size argument for fetchmany is %s but must be >= 0", size) - results = self.results.next_n_rows(size) - n_remaining_rows = size - results.num_rows - self._next_row_index += results.num_rows - - while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows - ): - self._fill_results_buffer() - partial_results = self.results.next_n_rows(n_remaining_rows) - results = pyarrow.concat_tables([results, partial_results]) - n_remaining_rows -= partial_results.num_rows - self._next_row_index += partial_results.num_rows - - return results - - def merge_columnar(self, result1, result2): - """ - Function to merge / combining the columnar results into a single result - :param result1: - :param result2: - :return: - """ - - if result1.column_names != result2.column_names: - raise ValueError("The columns in the results don't match") - - merged_result = [ - result1.column_table[i] + result2.column_table[i] - for i in range(result1.num_columns) - ] - return ColumnTable(merged_result, result1.column_names) - - def fetchmany_columnar(self, size: int): - """ - Fetch the next set of rows of a query result, returning a Columnar Table. - An empty sequence is returned when no more rows are available. - """ - if size < 0: - raise ValueError("size argument for fetchmany is %s but must be >= 0", size) - - results = self.results.next_n_rows(size) - n_remaining_rows = size - results.num_rows - self._next_row_index += results.num_rows - - while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows - ): - self._fill_results_buffer() - partial_results = self.results.next_n_rows(n_remaining_rows) - results = self.merge_columnar(results, partial_results) - n_remaining_rows -= partial_results.num_rows - self._next_row_index += partial_results.num_rows - - return results - - def fetchall_arrow(self) -> "pyarrow.Table": - """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" - results = self.results.remaining_rows() - self._next_row_index += results.num_rows - - while not self.has_been_closed_server_side and self.has_more_rows: - self._fill_results_buffer() - partial_results = self.results.remaining_rows() - if isinstance(results, ColumnTable) and isinstance( - partial_results, ColumnTable - ): - results = self.merge_columnar(results, partial_results) - else: - results = pyarrow.concat_tables([results, partial_results]) - self._next_row_index += partial_results.num_rows - - # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table - # Valid only for metadata commands result set - if isinstance(results, ColumnTable) and pyarrow: - data = { - name: col - for name, col in zip(results.column_names, results.column_table) - } - return pyarrow.Table.from_pydict(data) - return results - - def fetchall_columnar(self): - """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" - results = self.results.remaining_rows() - self._next_row_index += results.num_rows - - while not self.has_been_closed_server_side and self.has_more_rows: - self._fill_results_buffer() - partial_results = self.results.remaining_rows() - results = self.merge_columnar(results, partial_results) - self._next_row_index += partial_results.num_rows - - return results - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - if isinstance(self.results, ColumnQueue): - res = self._convert_columnar_table(self.fetchmany_columnar(1)) - else: - res = self._convert_arrow_table(self.fetchmany_arrow(1)) - - if len(res) > 0: - return res[0] - else: - return None - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - if isinstance(self.results, ColumnQueue): - return self._convert_columnar_table(self.fetchall_columnar()) - else: - return self._convert_arrow_table(self.fetchall_arrow()) - - def fetchmany(self, size: int) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - if isinstance(self.results, ColumnQueue): - return self._convert_columnar_table(self.fetchmany_columnar(size)) - else: - return self._convert_arrow_table(self.fetchmany_arrow(size)) - - def close(self) -> None: - """ - Close the cursor. - - If the connection has not been closed, and the cursor has not already - been closed on the server for some other reason, issue a request to the server to close it. - """ - try: - if ( - self.op_state != self.thrift_backend.CLOSED_OP_STATE - and not self.has_been_closed_server_side - and self.connection.open - ): - self.thrift_backend.close_command(self.command_id) - except RequestError as e: - if isinstance(e.args[1], CursorAlreadyClosedError): - logger.info("Operation was canceled by a prior request") - finally: - self.has_been_closed_server_side = True - self.op_state = self.thrift_backend.CLOSED_OP_STATE - - @staticmethod - def _get_schema_description(table_schema_message): - """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 - """ - - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ - - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py new file mode 100644 index 000000000..5d3c12920 --- /dev/null +++ b/src/databricks/sql/result_set.py @@ -0,0 +1,363 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Any, Union + +import logging +import time +import pandas + +try: + import pyarrow +except ImportError: + pyarrow = None + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import Row +from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError +from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue + +logger = logging.getLogger(__name__) + + +class ResultSet(ABC): + """ + Abstract base class for result sets returned by different backend implementations. + + This class defines the interface that all concrete result set implementations must follow. + """ + + def __init__(self, connection, backend, arraysize: int, buffer_size_bytes: int): + """Initialize the base ResultSet with common properties.""" + self.connection = connection + self.backend = backend # Store the backend client directly + self.arraysize = arraysize + self.buffer_size_bytes = buffer_size_bytes + self._next_row_index = 0 + self.description = None + + def __iter__(self): + while True: + row = self.fetchone() + if row: + yield row + else: + break + + @property + def rownumber(self): + return self._next_row_index + + @property + @abstractmethod + def is_staging_operation(self) -> bool: + """Whether this result set represents a staging operation.""" + pass + + # Define abstract methods that concrete implementations must implement + @abstractmethod + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + pass + + @abstractmethod + def fetchone(self) -> Optional[Row]: + """Fetch the next row of a query result set.""" + pass + + @abstractmethod + def fetchmany(self, size: int) -> List[Row]: + """Fetch the next set of rows of a query result.""" + pass + + @abstractmethod + def fetchall(self) -> List[Row]: + """Fetch all remaining rows of a query result.""" + pass + + @abstractmethod + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + pass + + @abstractmethod + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + pass + + @abstractmethod + def close(self) -> None: + """Close the result set and release any resources.""" + pass + + +class ThriftResultSet(ResultSet): + """ResultSet implementation for the Thrift backend.""" + + def __init__( + self, + connection, + execute_response: ExecuteResponse, + thrift_client, # Pass the specific ThriftDatabricksClient instance + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + use_cloud_fetch: bool = True, + ): + """ + Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. + + Args: + connection: The parent connection + execute_response: Response from the execute command + thrift_client: The ThriftDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + use_cloud_fetch: Whether to use cloud fetch for retrieving results + """ + super().__init__(connection, thrift_client, arraysize, buffer_size_bytes) + + # Initialize ThriftResultSet-specific attributes + self.command_id = execute_response.command_id + self.op_state = execute_response.status + self.has_been_closed_server_side = execute_response.has_been_closed_server_side + self.has_more_rows = execute_response.has_more_rows + self.lz4_compressed = execute_response.lz4_compressed + self.description = execute_response.description + self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._use_cloud_fetch = use_cloud_fetch + self._is_staging_operation = execute_response.is_staging_operation + + # Initialize results queue + if execute_response.arrow_queue: + self.results = execute_response.arrow_queue + else: + self._fill_results_buffer() + + def _fill_results_buffer(self): + results, has_more_rows = self.backend.fetch_results( + command_id=self.command_id, + max_rows=self.arraysize, + max_bytes=self.buffer_size_bytes, + expected_row_start_offset=self._next_row_index, + lz4_compressed=self.lz4_compressed, + arrow_schema_bytes=self._arrow_schema_bytes, + description=self.description, + use_cloud_fetch=self._use_cloud_fetch, + ) + self.results = results + self.has_more_rows = has_more_rows + + def _convert_columnar_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + result = [] + for row_index in range(table.num_rows): + curr_row = [] + for col_index in range(table.num_columns): + curr_row.append(table.get_item(col_index, row_index)) + result.append(ResultRow(*curr_row)) + + return result + + def _convert_arrow_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + + if self.connection.disable_pandas is True: + return [ + ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) + ] + + # Need to use nullable types, as otherwise type can change when there are missing values. + # See https://arrow.apache.org/docs/python/pandas.html#nullable-types + # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html + dtype_mapping = { + pyarrow.int8(): pandas.Int8Dtype(), + pyarrow.int16(): pandas.Int16Dtype(), + pyarrow.int32(): pandas.Int32Dtype(), + pyarrow.int64(): pandas.Int64Dtype(), + pyarrow.uint8(): pandas.UInt8Dtype(), + pyarrow.uint16(): pandas.UInt16Dtype(), + pyarrow.uint32(): pandas.UInt32Dtype(), + pyarrow.uint64(): pandas.UInt64Dtype(), + pyarrow.bool_(): pandas.BooleanDtype(), + pyarrow.float32(): pandas.Float32Dtype(), + pyarrow.float64(): pandas.Float64Dtype(), + pyarrow.string(): pandas.StringDtype(), + } + + # Need to rename columns, as the to_pandas function cannot handle duplicate column names + table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) + df = table_renamed.to_pandas( + types_mapper=dtype_mapping.get, + date_as_object=True, + timestamp_as_object=True, + ) + + res = df.to_numpy(na_value=None, dtype="object") + return [ResultRow(*v) for v in res] + + def merge_columnar(self, result1, result2): + """ + Function to merge / combining the columnar results into a single result + :param result1: + :param result2: + :return: + """ + + if result1.column_names != result2.column_names: + raise ValueError("The columns in the results don't match") + + merged_result = [ + result1.column_table[i] + result2.column_table[i] + for i in range(result1.num_columns) + ] + return ColumnTable(merged_result, result1.column_names) + + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """ + Fetch the next set of rows of a query result, returning a PyArrow table. + + An empty sequence is returned when no more rows are available. + """ + if size < 0: + raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + results = self.results.next_n_rows(size) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows + + while ( + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows + ): + self._fill_results_buffer() + partial_results = self.results.next_n_rows(n_remaining_rows) + results = pyarrow.concat_tables([results, partial_results]) + n_remaining_rows -= partial_results.num_rows + self._next_row_index += partial_results.num_rows + + return results + + def fetchmany_columnar(self, size: int): + """ + Fetch the next set of rows of a query result, returning a Columnar Table. + An empty sequence is returned when no more rows are available. + """ + if size < 0: + raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + results = self.results.next_n_rows(size) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows + + while ( + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows + ): + self._fill_results_buffer() + partial_results = self.results.next_n_rows(n_remaining_rows) + results = self.merge_columnar(results, partial_results) + n_remaining_rows -= partial_results.num_rows + self._next_row_index += partial_results.num_rows + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": + """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + while not self.has_been_closed_server_side and self.has_more_rows: + self._fill_results_buffer() + partial_results = self.results.remaining_rows() + if isinstance(results, ColumnTable) and isinstance( + partial_results, ColumnTable + ): + results = self.merge_columnar(results, partial_results) + else: + results = pyarrow.concat_tables([results, partial_results]) + self._next_row_index += partial_results.num_rows + + # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table + # Valid only for metadata commands result set + if isinstance(results, ColumnTable) and pyarrow: + data = { + name: col + for name, col in zip(results.column_names, results.column_table) + } + return pyarrow.Table.from_pydict(data) + return results + + def fetchall_columnar(self): + """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + while not self.has_been_closed_server_side and self.has_more_rows: + self._fill_results_buffer() + partial_results = self.results.remaining_rows() + results = self.merge_columnar(results, partial_results) + self._next_row_index += partial_results.num_rows + + return results + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + if isinstance(self.results, ColumnQueue): + res = self._convert_columnar_table(self.fetchmany_columnar(1)) + else: + res = self._convert_arrow_table(self.fetchmany_arrow(1)) + + if len(res) > 0: + return res[0] + else: + return None + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchall_columnar()) + else: + return self._convert_arrow_table(self.fetchall_arrow()) + + def fetchmany(self, size: int) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchmany_columnar(size)) + else: + return self._convert_arrow_table(self.fetchmany_arrow(size)) + + def close(self) -> None: + """ + Close the cursor. + + If the connection has not been closed, and the cursor has not already + been closed on the server for some other reason, issue a request to the server to close it. + """ + try: + if ( + self.op_state != ttypes.TOperationState.CLOSED_STATE + and not self.has_been_closed_server_side + and self.connection.open + ): + self.backend.close_command(self.command_id) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + finally: + self.has_been_closed_server_side = True + self.op_state = ttypes.TOperationState.CLOSED_STATE + + @property + def is_staging_operation(self) -> bool: + """Whether this result set represents a staging operation.""" + return self._is_staging_operation diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index f2f38d572..571a61fcd 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -7,7 +7,10 @@ from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.sea_backend import SeaDatabricksClient +from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.backend.types import SessionId, BackendType logger = logging.getLogger(__name__) @@ -71,42 +74,47 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - self.thrift_backend = ThriftBackend( - self.host, - self.port, - http_path, - (http_headers or []) + base_headers, - auth_provider, - ssl_options=self._ssl_options, - _use_arrow_native_complex_types=_use_arrow_native_complex_types, - **kwargs, - ) + # Determine which backend to use + use_sea = kwargs.get("use_sea", False) + + if use_sea: + self.backend: DatabricksClient = SeaDatabricksClient( + self.host, + self.port, + http_path, + (http_headers or []) + base_headers, + auth_provider, + ssl_options=self._ssl_options, + _use_arrow_native_complex_types=_use_arrow_native_complex_types, + **kwargs, + ) + else: + self.backend = ThriftDatabricksClient( + self.host, + self.port, + http_path, + (http_headers or []) + base_headers, + auth_provider, + ssl_options=self._ssl_options, + _use_arrow_native_complex_types=_use_arrow_native_complex_types, + **kwargs, + ) - self._handle = None self.protocol_version = None - def open(self) -> None: - self._open_session_resp = self.thrift_backend.open_session( - self.session_configuration, self.catalog, self.schema + def open(self): + self._session_id = self.backend.open_session( + session_configuration=self.session_configuration, + catalog=self.catalog, + schema=self.schema, ) - self._handle = self._open_session_resp.sessionHandle - self.protocol_version = self.get_protocol_version(self._open_session_resp) + self.protocol_version = self.get_protocol_version(self._session_id) self.is_open = True logger.info("Successfully opened session " + str(self.get_id_hex())) @staticmethod - def get_protocol_version(openSessionResp): - """ - Since the sessionHandle will sometimes have a serverProtocolVersion, it takes - precedence over the serverProtocolVersion defined in the OpenSessionResponse. - """ - if ( - openSessionResp.sessionHandle - and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") - and openSessionResp.sessionHandle.serverProtocolVersion - ): - return openSessionResp.sessionHandle.serverProtocolVersion - return openSessionResp.serverProtocolVersion + def get_protocol_version(session_id: SessionId): + return session_id.get_protocol_version() @staticmethod def server_parameterized_queries_enabled(protocolVersion): @@ -118,20 +126,17 @@ def server_parameterized_queries_enabled(protocolVersion): else: return False - def get_handle(self): - return self._handle + def get_session_id(self) -> SessionId: + """Get the normalized session ID""" + return self._session_id def get_id(self): - handle = self.get_handle() - if handle is None: - return None - return self.thrift_backend.handle_to_id(handle) + """Get the raw session ID (backend-specific)""" + return self._session_id.get_id() - def get_id_hex(self): - handle = self.get_handle() - if handle is None: - return None - return self.thrift_backend.handle_to_hex_id(handle) + def get_id_hex(self) -> str: + """Get the session ID in hex format""" + return self._session_id.get_hex_id() def close(self) -> None: """Close the underlying session.""" @@ -141,7 +146,7 @@ def close(self) -> None: return try: - self.thrift_backend.close_session(self.get_handle()) + self.backend.close_session(self._session_id) except RequestError as e: if isinstance(e.args[1], SessionAlreadyClosedError): logger.info("Session was closed by a prior request") diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 186f13dd6..c541ad3fd 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -26,6 +26,7 @@ TSparkRowSetType, ) from databricks.sql.types import SSLOptions +from databricks.sql.backend.types import CommandId from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter @@ -345,7 +346,7 @@ def _create_empty_table(self) -> "pyarrow.Table": ExecuteResponse = namedtuple( "ExecuteResponse", "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " - "command_handle arrow_queue arrow_schema_bytes", + "command_id arrow_queue arrow_schema_bytes", ) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index abe0e22d2..c446b6715 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -822,11 +822,10 @@ def test_close_connection_closes_cursors(self): # We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True # Cursor op state should be open before connection is closed status_request = ttypes.TGetOperationStatusReq( - operationHandle=ars.command_id, getProgressUpdate=False - ) - op_status_at_server = ars.thrift_backend._client.GetOperationStatus( - status_request + operationHandle=ars.command_id.to_thrift_handle(), + getProgressUpdate=False, ) + op_status_at_server = ars.backend._client.GetOperationStatus(status_request) assert ( op_status_at_server.operationState != ttypes.TOperationState.CLOSED_STATE @@ -836,7 +835,7 @@ def test_close_connection_closes_cursors(self): # When connection closes, any cursor operations should no longer exist at the server with pytest.raises(SessionAlreadyClosedError) as cm: - op_status_at_server = ars.thrift_backend._client.GetOperationStatus( + op_status_at_server = ars.backend._client.GetOperationStatus( status_request ) @@ -866,9 +865,9 @@ def test_cursor_close_properly_closes_operation(self): cursor = conn.cursor() try: cursor.execute("SELECT 1 AS test") - assert cursor.active_op_handle is not None + assert cursor.active_command_id is not None cursor.close() - assert cursor.active_op_handle is None + assert cursor.active_command_id is None assert not cursor.open finally: if cursor.open: @@ -894,19 +893,19 @@ def test_nested_cursor_context_managers(self): with self.connection() as conn: with conn.cursor() as cursor1: cursor1.execute("SELECT 1 AS test1") - assert cursor1.active_op_handle is not None + assert cursor1.active_command_id is not None with conn.cursor() as cursor2: cursor2.execute("SELECT 2 AS test2") - assert cursor2.active_op_handle is not None + assert cursor2.active_command_id is not None # After inner context manager exit, cursor2 should be not open assert not cursor2.open - assert cursor2.active_op_handle is None + assert cursor2.active_command_id is None # After outer context manager exit, cursor1 should be not open assert not cursor1.open - assert cursor1.active_op_handle is None + assert cursor1.active_command_id is None def test_cursor_error_handling(self): """Test that cursor close handles errors properly to prevent orphaned operations.""" @@ -915,12 +914,12 @@ def test_cursor_error_handling(self): cursor.execute("SELECT 1 AS test") - op_handle = cursor.active_op_handle + op_handle = cursor.active_command_id assert op_handle is not None # Manually close the operation to simulate server-side closure - conn.session.thrift_backend.close_command(op_handle) + conn.session.backend.close_command(op_handle) cursor.close() @@ -940,7 +939,7 @@ def test_result_set_close(self): result_set.close() - assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE + assert result_set.op_state == result_set.backend.CLOSED_OP_STATE assert result_set.op_state != initial_op_state # Closing the result set again should be a no-op and not raise exceptions diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index a9c7a43a9..ce6fd0e93 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -15,41 +15,42 @@ THandleIdentifier, TOperationType, ) -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient import databricks.sql import databricks.sql.client as client from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError from databricks.sql.exc import RequestError, CursorAlreadyClosedError from databricks.sql.types import Row +from databricks.sql.result_set import ResultSet, ThriftResultSet +from databricks.sql.backend.types import CommandId from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite from tests.unit.test_arrow_queue import ArrowQueueSuite -class ThriftBackendMockFactory: +class ThriftDatabricksClientMockFactory: @classmethod def new(cls): - ThriftBackendMock = Mock(spec=ThriftBackend) + ThriftBackendMock = Mock(spec=ThriftDatabricksClient) ThriftBackendMock.return_value = ThriftBackendMock cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None) - MockTExecuteStatementResp = MagicMock(spec=TExecuteStatementResp()) + mock_result_set = Mock(spec=ThriftResultSet) cls.apply_property_to_mock( - MockTExecuteStatementResp, + mock_result_set, description=None, - arrow_queue=None, is_staging_operation=False, - command_handle=b"\x22", + command_id=None, has_been_closed_server_side=True, has_more_rows=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) - ThriftBackendMock.execute_command.return_value = MockTExecuteStatementResp + ThriftBackendMock.execute_command.return_value = mock_result_set return ThriftBackendMock @@ -81,24 +82,30 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch( + "%s.session.ThriftDatabricksClient" % PACKAGE_NAME, + ThriftDatabricksClientMockFactory.new(), + ) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_closing_connection_closes_commands(self, mock_result_set_class): # Test once with has_been_closed_server side, once without for closed in (True, False): with self.subTest(closed=closed): - mock_result_set_class.return_value = Mock() connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() - cursor.execute("SELECT 1;") + + # Create a mock result set and set it as the active result set + mock_result_set = Mock() + mock_result_set.has_been_closed_server_side = closed + cursor.active_result_set = mock_result_set + + # Close the connection connection.close() - self.assertTrue( - mock_result_set_class.return_value.has_been_closed_server_side - ) - mock_result_set_class.return_value.close.assert_called_once_with() + # Check that the manually created mock result set's close method was called + mock_result_set.close.assert_called_once_with() - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) self.assertTrue(connection.open) @@ -108,7 +115,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection.cursor() self.assertIn("closed", str(cm.exception)) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) @patch("%s.client.Cursor" % PACKAGE_NAME) def test_arraysize_buffer_size_passthrough( self, mock_cursor_class, mock_client_class @@ -123,10 +130,11 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() - result_set = client.ResultSet( + + result_set = ThriftResultSet( connection=mock_connection, - thrift_backend=mock_backend, execute_response=Mock(), + thrift_client=mock_backend, ) # Setup session mock on the mock_connection mock_session = Mock() @@ -148,27 +156,31 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - result_set = client.ResultSet( + result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) result_set.close() mock_thrift_backend.close_command.assert_called_once_with( - mock_results_response.command_handle + mock_results_response.command_id ) - @patch("%s.client.ResultSet" % PACKAGE_NAME) + @patch("%s.result_set.ThriftResultSet" % PACKAGE_NAME) def test_executing_multiple_commands_uses_the_most_recent_command( self, mock_result_set_class ): - mock_result_sets = [Mock(), Mock()] + # Set is_staging_operation to False to avoid _handle_staging_operation being called + for mock_rs in mock_result_sets: + mock_rs.is_staging_operation = False + mock_result_set_class.side_effect = mock_result_sets - cursor = client.Cursor( - connection=Mock(), thrift_backend=ThriftBackendMockFactory.new() - ) + mock_backend = ThriftDatabricksClientMockFactory.new() + mock_backend.execute_command.side_effect = mock_result_sets + + cursor = client.Cursor(connection=Mock(), backend=mock_backend) cursor.execute("SELECT 1;") cursor.execute("SELECT 1;") @@ -193,7 +205,7 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - result_set = client.ResultSet(Mock(), Mock(), Mock()) + result_set = ThriftResultSet(Mock(), Mock(), Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -205,11 +217,11 @@ def test_context_manager_closes_cursor(self): mock_close.assert_called_once_with() cursor = client.Cursor(Mock(), Mock()) - cursor.close = Mock() + cursor.close = Mock() try: with self.assertRaises(KeyboardInterrupt): - with cursor: + with cursor: raise KeyboardInterrupt("Simulated interrupt") finally: cursor.close.assert_called() @@ -226,7 +238,7 @@ def dict_product(self, dicts): """ return (dict(zip(dicts.keys(), x)) for x in itertools.product(*dicts.values())) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backend): req_args_combinations = self.dict_product( dict( @@ -247,7 +259,7 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe for k, v in req_args.items(): self.assertEqual(v, call_args[k]) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backend): req_args_combinations = self.dict_product( dict( @@ -270,7 +282,7 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen for k, v in req_args.items(): self.assertEqual(v, call_args[k]) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backend): req_args_combinations = self.dict_product( dict( @@ -296,10 +308,10 @@ def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backe def test_cancel_command_calls_the_backend(self): mock_thrift_backend = Mock() cursor = client.Cursor(Mock(), mock_thrift_backend) - mock_op_handle = Mock() - cursor.active_op_handle = mock_op_handle + mock_command_id = Mock() + cursor.active_command_id = mock_command_id cursor.cancel() - mock_thrift_backend.cancel_command.assert_called_with(mock_op_handle) + mock_thrift_backend.cancel_command.assert_called_with(mock_command_id) @patch("databricks.sql.client.logger") def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( @@ -321,7 +333,7 @@ def test_version_is_canonical(self): self.assertIsNotNone(re.match(canonical_version_re, version)) def test_execute_parameter_passthrough(self): - mock_thrift_backend = ThriftBackendMockFactory.new() + mock_thrift_backend = ThriftDatabricksClientMockFactory.new() cursor = client.Cursor(Mock(), mock_thrift_backend) tests = [ @@ -345,16 +357,21 @@ def test_execute_parameter_passthrough(self): expected_query, ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - @patch("%s.client.ResultSet" % PACKAGE_NAME) + @patch("%s.result_set.ThriftResultSet" % PACKAGE_NAME) def test_executemany_parameter_passhthrough_and_uses_last_result_set( - self, mock_result_set_class, mock_thrift_backend + self, mock_result_set_class ): # Create a new mock result set each time the class is instantiated mock_result_set_instances = [Mock(), Mock(), Mock()] + # Set is_staging_operation to False to avoid _handle_staging_operation being called + for mock_rs in mock_result_set_instances: + mock_rs.is_staging_operation = False + mock_result_set_class.side_effect = mock_result_set_instances - mock_thrift_backend = ThriftBackendMockFactory.new() - cursor = client.Cursor(Mock(), mock_thrift_backend()) + mock_backend = ThriftDatabricksClientMockFactory.new() + mock_backend.execute_command.side_effect = mock_result_set_instances + + cursor = client.Cursor(Mock(), mock_backend) params = [{"x": None}, {"x": "foo1"}, {"x": "bar2"}] expected_queries = ["SELECT NULL", "SELECT 'foo1'", "SELECT 'bar2'"] @@ -362,13 +379,13 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( cursor.executemany("SELECT %(x)s", seq_of_parameters=params) self.assertEqual( - len(mock_thrift_backend.execute_command.call_args_list), + len(mock_backend.execute_command.call_args_list), len(expected_queries), "Expected execute_command to be called the same number of times as params were passed", ) for expected_query, call_args in zip( - expected_queries, mock_thrift_backend.execute_command.call_args_list + expected_queries, mock_backend.execute_command.call_args_list ): self.assertEqual(call_args[1]["operation"], expected_query) @@ -379,7 +396,7 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( "last operation", ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_commit_a_noop(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) c.commit() @@ -392,14 +409,14 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_rollback_not_supported(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) with self.assertRaises(NotSupportedError): c.rollback() @unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_row_number_respected(self, mock_thrift_backend_class): def make_fake_row_slice(n_rows): mock_slice = Mock() @@ -424,7 +441,7 @@ def make_fake_row_slice(n_rows): self.assertEqual(cursor.rownumber, 29) @unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_disable_pandas_respected(self, mock_thrift_backend_class): mock_thrift_backend = mock_thrift_backend_class.return_value mock_table = Mock() @@ -477,7 +494,7 @@ def test_column_name_api(self): }, ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value @@ -496,17 +513,18 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_staging_operation_response_is_handled( self, mock_client_class, mock_handle_staging_operation, mock_execute_response ): # If server sets ExecuteResponse.is_staging_operation True then _handle_staging_operation should be called - ThriftBackendMockFactory.apply_property_to_mock( + ThriftDatabricksClientMockFactory.apply_property_to_mock( mock_execute_response, is_staging_operation=True ) - mock_client_class.execute_command.return_value = mock_execute_response - mock_client_class.return_value = mock_client_class + mock_client = mock_client_class.return_value + mock_client.execute_command.return_value = Mock(is_staging_operation=True) + mock_client_class.return_value = mock_client connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() @@ -515,7 +533,10 @@ def test_staging_operation_response_is_handled( mock_handle_staging_operation.call_count == 1 - @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch( + "%s.session.ThriftDatabricksClient" % PACKAGE_NAME, + ThriftDatabricksClientMockFactory.new(), + ) def test_access_current_query_id(self): operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821" @@ -524,9 +545,13 @@ def test_access_current_query_id(self): self.assertIsNone(cursor.query_id) - cursor.active_op_handle = TOperationHandle( - operationId=THandleIdentifier(guid=UUID(operation_id).bytes, secret=0x00), - operationType=TOperationType.EXECUTE_STATEMENT, + cursor.active_command_id = CommandId.from_thrift_handle( + TOperationHandle( + operationId=THandleIdentifier( + guid=UUID(operation_id).bytes, secret=0x00 + ), + operationType=TOperationType.EXECUTE_STATEMENT, + ) ) self.assertEqual(cursor.query_id.upper(), operation_id.upper()) @@ -537,18 +562,18 @@ def test_cursor_close_handles_exception(self): """Test that Cursor.close() handles exceptions from close_command properly.""" mock_backend = Mock() mock_connection = Mock() - mock_op_handle = Mock() + mock_command_id = Mock() mock_backend.close_command.side_effect = Exception("Test error") cursor = client.Cursor(mock_connection, mock_backend) - cursor.active_op_handle = mock_op_handle + cursor.active_command_id = mock_command_id cursor.close() - mock_backend.close_command.assert_called_once_with(mock_op_handle) + mock_backend.close_command.assert_called_once_with(mock_command_id) - self.assertIsNone(cursor.active_op_handle) + self.assertIsNone(cursor.active_command_id) self.assertFalse(cursor.open) @@ -606,9 +631,9 @@ def mock_close_normal(): def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" - result_set = client.ResultSet.__new__(client.ResultSet) - result_set.thrift_backend = Mock() - result_set.thrift_backend.CLOSED_OP_STATE = "CLOSED" + result_set = client.ThriftResultSet.__new__(client.ThriftResultSet) + result_set.backend = Mock() + result_set.backend.CLOSED_OP_STATE = "CLOSED" result_set.connection = Mock() result_set.connection.open = True result_set.op_state = "RUNNING" @@ -619,31 +644,31 @@ class MockRequestError(Exception): def __init__(self): self.args = ["Error message", CursorAlreadyClosedError()] - result_set.thrift_backend.close_command.side_effect = MockRequestError() + result_set.backend.close_command.side_effect = MockRequestError() original_close = client.ResultSet.close try: try: if ( - result_set.op_state != result_set.thrift_backend.CLOSED_OP_STATE + result_set.op_state != result_set.backend.CLOSED_OP_STATE and not result_set.has_been_closed_server_side and result_set.connection.open ): - result_set.thrift_backend.close_command(result_set.command_id) + result_set.backend.close_command(result_set.command_id) except MockRequestError as e: if isinstance(e.args[1], CursorAlreadyClosedError): pass finally: result_set.has_been_closed_server_side = True - result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE + result_set.op_state = result_set.backend.CLOSED_OP_STATE - result_set.thrift_backend.close_command.assert_called_once_with( + result_set.backend.close_command.assert_called_once_with( result_set.command_id ) assert result_set.has_been_closed_server_side is True - assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE + assert result_set.op_state == result_set.backend.CLOSED_OP_STATE finally: pass diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 71766f2cb..030510a64 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -9,6 +9,8 @@ import databricks.sql.client as client from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.result_set import ThriftResultSet @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") @@ -37,20 +39,20 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) - rs = client.ResultSet( + rs = ThriftResultSet( connection=Mock(), - thrift_backend=None, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), lz4_compressed=Mock(), - command_handle=None, + command_id=None, arrow_queue=arrow_queue, arrow_schema_bytes=schema.serialize().to_pybytes(), is_staging_operation=False, ), + thrift_client=None, ) num_cols = len(initial_results[0]) if initial_results else 0 rs.description = [ @@ -64,7 +66,7 @@ def make_dummy_result_set_from_batch_list(batch_list): batch_index = 0 def fetch_results( - op_handle, + command_id, max_rows, max_bytes, expected_row_start_offset, @@ -79,13 +81,12 @@ def fetch_results( return results, batch_index < len(batch_list) - mock_thrift_backend = Mock() + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 - rs = client.ResultSet( + rs = ThriftResultSet( connection=Mock(), - thrift_backend=mock_thrift_backend, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=False, @@ -95,11 +96,12 @@ def fetch_results( for col_id in range(num_cols) ], lz4_compressed=Mock(), - command_handle=None, + command_id=None, arrow_queue=None, arrow_schema_bytes=None, is_staging_operation=False, ), + thrift_client=mock_thrift_backend, ) return rs diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index 552872221..b302c00da 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -31,13 +31,13 @@ def make_dummy_result_set_from_initial_results(arrow_table): arrow_queue = ArrowQueue(arrow_table, arrow_table.num_rows, 0) rs = client.ResultSet( connection=None, - thrift_backend=None, + backend=None, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), - command_handle=None, + command_id=None, arrow_queue=arrow_queue, arrow_schema=arrow_table.schema, ), diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index eec921e4d..949230d1e 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -22,6 +22,7 @@ TinyIntParameter, VoidParameter, ) +from databricks.sql.backend.types import SessionId from databricks.sql.parameters.native import ( TDbsqlParameter, TSparkParameterValue, @@ -42,7 +43,10 @@ class TestSessionHandleChecks(object): ( TOpenSessionResp( serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, - sessionHandle=TSessionHandle(1, None), + sessionHandle=TSessionHandle( + sessionId=ttypes.THandleIdentifier(guid=0x36, secret=0x37), + serverProtocolVersion=None, + ), ), ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, ), @@ -51,7 +55,8 @@ class TestSessionHandleChecks(object): TOpenSessionResp( serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, sessionHandle=TSessionHandle( - 1, ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 + sessionId=ttypes.THandleIdentifier(guid=0x36, secret=0x37), + serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, ), ), ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, @@ -59,7 +64,13 @@ class TestSessionHandleChecks(object): ], ) def test_get_protocol_version_fallback_behavior(self, test_input, expected): - assert Connection.get_protocol_version(test_input) == expected + properties = ( + {"serverProtocolVersion": test_input.serverProtocolVersion} + if test_input.serverProtocolVersion + else {} + ) + session_id = SessionId.from_thrift_handle(test_input.sessionHandle, properties) + assert Connection.get_protocol_version(session_id) == expected @pytest.mark.parametrize( "test_input,expected", diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py new file mode 100644 index 000000000..72009d6cf --- /dev/null +++ b/tests/unit/test_sea_backend.py @@ -0,0 +1,168 @@ +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.sea_backend import SeaDatabricksClient +from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.types import SSLOptions +from databricks.sql.auth.authenticators import AuthProvider +from databricks.sql.exc import Error, NotSupportedError + + +class TestSeaBackend: + """Test suite for the SeaDatabricksClient class.""" + + @pytest.fixture + def mock_http_client(self): + """Create a mock HTTP client.""" + with patch( + "databricks.sql.backend.sea_backend.CustomHttpClient" + ) as mock_client_class: + mock_client = mock_client_class.return_value + yield mock_client + + @pytest.fixture + def sea_client(self, mock_http_client): + """Create a SeaDatabricksClient instance with mocked dependencies.""" + server_hostname = "test-server.databricks.com" + port = 443 + http_path = "/sql/warehouses/abc123" + http_headers = [("header1", "value1"), ("header2", "value2")] + auth_provider = AuthProvider() + ssl_options = SSLOptions() + + client = SeaDatabricksClient( + server_hostname=server_hostname, + port=port, + http_path=http_path, + http_headers=http_headers, + auth_provider=auth_provider, + ssl_options=ssl_options, + ) + + return client + + def test_init_extracts_warehouse_id(self, mock_http_client): + """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" + # Test with warehouses format + client1 = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client1.warehouse_id == "abc123" + + # Test with endpoints format + client2 = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/endpoints/def456", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client2.warehouse_id == "def456" + + def test_init_raises_error_for_invalid_http_path(self, mock_http_client): + """Test that the constructor raises an error for invalid HTTP paths.""" + with pytest.raises(ValueError) as excinfo: + SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/invalid/path", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert "Could not extract warehouse ID" in str(excinfo.value) + + def test_open_session_basic(self, sea_client, mock_http_client): + """Test the open_session method with minimal parameters.""" + # Set up mock response + mock_http_client._make_request.return_value = {"session_id": "test-session-123"} + + # Call the method + session_id = sea_client.open_session(None, None, None) + + # Verify the result + assert isinstance(session_id, SessionId) + assert session_id.backend_type == BackendType.SEA + assert session_id.guid == "test-session-123" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once_with( + method="POST", path=sea_client.SESSION_PATH, data={"warehouse_id": "abc123"} + ) + + def test_open_session_with_all_parameters(self, sea_client, mock_http_client): + """Test the open_session method with all parameters.""" + # Set up mock response + mock_http_client._make_request.return_value = {"session_id": "test-session-456"} + + # Call the method with all parameters + session_config = {"spark.sql.shuffle.partitions": "10"} + catalog = "test_catalog" + schema = "test_schema" + + session_id = sea_client.open_session(session_config, catalog, schema) + + # Verify the result + assert isinstance(session_id, SessionId) + assert session_id.backend_type == BackendType.SEA + assert session_id.guid == "test-session-456" + + # Verify the HTTP request + expected_data = { + "warehouse_id": "abc123", + "session_confs": session_config, + "catalog": catalog, + "schema": schema, + } + mock_http_client._make_request.assert_called_once_with( + method="POST", path=sea_client.SESSION_PATH, data=expected_data + ) + + def test_open_session_error_handling(self, sea_client, mock_http_client): + """Test error handling in the open_session method.""" + # Set up mock response without session_id + mock_http_client._make_request.return_value = {} + + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.open_session(None, None, None) + + assert "Failed to create session" in str(excinfo.value) + + def test_close_session_valid_id(self, sea_client, mock_http_client): + """Test closing a session with a valid session ID.""" + # Create a valid SEA session ID + session_id = SessionId.from_sea_session_id("test-session-789") + + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_session(session_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once_with( + method="DELETE", + path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"), + data={"warehouse_id": "abc123"}, + ) + + def test_close_session_invalid_id_type(self, sea_client): + """Test closing a session with an invalid session ID type.""" + # Create a Thrift session ID (not SEA) + mock_thrift_handle = MagicMock() + mock_thrift_handle.sessionId.guid = b"guid" + mock_thrift_handle.sessionId.secret = b"secret" + session_id = SessionId.from_thrift_handle(mock_thrift_handle) + + # Call the method and expect an error + with pytest.raises(ValueError) as excinfo: + sea_client.close_session(session_id) + + assert "Not a valid SEA session ID" in str(excinfo.value) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index eb392a229..858119f92 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -4,7 +4,10 @@ from databricks.sql.thrift_api.TCLIService.ttypes import ( TOpenSessionResp, + TSessionHandle, + THandleIdentifier, ) +from databricks.sql.backend.types import SessionId, BackendType import databricks.sql @@ -21,22 +24,23 @@ class SessionTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_close_uses_the_correct_session_id(self, mock_client_class): instance = mock_client_class.return_value - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + instance.open_session.return_value = mock_session_id connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) connection.close() - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") + # Check that close_session was called with the correct SessionId + close_session_call_args = instance.close_session.call_args[0][0] + self.assertEqual(close_session_call_args.guid, b"\x22") + self.assertEqual(close_session_call_args.secret, b"\x33") - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_auth_args(self, mock_client_class): # Test that the following auth args work: # token = foo, @@ -63,7 +67,7 @@ def test_auth_args(self, mock_client_class): self.assertEqual(args["http_path"], http_path) connection.close() - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_http_header_passthrough(self, mock_client_class): http_headers = [("foo", "bar")] databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) @@ -71,7 +75,7 @@ def test_http_header_passthrough(self, mock_client_class): call_args = mock_client_class.call_args[0][3] self.assertIn(("foo", "bar"), call_args) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_tls_arg_passthrough(self, mock_client_class): databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, @@ -87,7 +91,7 @@ def test_tls_arg_passthrough(self, mock_client_class): self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_useragent_header(self, mock_client_class): databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) @@ -108,22 +112,23 @@ def test_useragent_header(self, mock_client_class): http_headers = mock_client_class.call_args[0][3] self.assertIn(user_agent_header_with_entry, http_headers) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_context_manager_closes_connection(self, mock_client_class): instance = mock_client_class.return_value - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + instance.open_session.return_value = mock_session_id with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: pass - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") + # Check that close_session was called with the correct SessionId + close_session_call_args = instance.close_session.call_args[0][0] + self.assertEqual(close_session_call_args.guid, b"\x22") + self.assertEqual(close_session_call_args.secret, b"\x33") - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_max_number_of_retries_passthrough(self, mock_client_class): databricks.sql.connect( _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS @@ -133,54 +138,62 @@ def test_max_number_of_retries_passthrough(self, mock_client_class): mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_socket_timeout_passthrough(self, mock_client_class): databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): mock_session_config = Mock() + + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + mock_client_class.return_value.open_session.return_value = mock_session_id + databricks.sql.connect( session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][0], - mock_session_config, - ) + # Check that open_session was called with the correct session_configuration as keyword argument + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + self.assertEqual(call_kwargs["session_configuration"], mock_session_config) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): mock_cat = Mock() mock_schem = Mock() + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + mock_client_class.return_value.open_session.return_value = mock_session_id + databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][1], mock_cat - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][2], mock_schem - ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + # Check that open_session was called with the correct catalog and schema as keyword arguments + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + self.assertEqual(call_kwargs["catalog"], mock_cat) + self.assertEqual(call_kwargs["schema"], mock_schem) + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_finalizer_closes_abandoned_connection(self, mock_client_class): instance = mock_client_class.return_value - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + instance.open_session.return_value = mock_session_id databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) # not strictly necessary as the refcount is 0, but just to be sure gc.collect() - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") + # Check that close_session was called with the correct SessionId + close_session_call_args = instance.close_session.call_args[0][0] + self.assertEqual(close_session_call_args.guid, b"\x22") + self.assertEqual(close_session_call_args.secret, b"\x33") if __name__ == "__main__": diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 458ea9a82..de9016dfa 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -17,7 +17,9 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql import * from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.result_set import ResultSet, ThriftResultSet +from databricks.sql.backend.types import CommandId, SessionId, BackendType def retry_policy_factory(): @@ -73,7 +75,7 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -92,7 +94,7 @@ def _make_type_desc(self, type): ) def _make_fake_thrift_backend(self): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -126,14 +128,16 @@ def test_hive_schema_to_arrow_schema_preserves_column_names(self): ] t_table_schema = ttypes.TTableSchema(columns) - arrow_schema = ThriftBackend._hive_schema_to_arrow_schema(t_table_schema) + arrow_schema = ThriftDatabricksClient._hive_schema_to_arrow_schema( + t_table_schema + ) self.assertEqual(arrow_schema.field(0).name, "column 1") self.assertEqual(arrow_schema.field(1).name, "column 2") self.assertEqual(arrow_schema.field(2).name, "column 2") self.assertEqual(arrow_schema.field(3).name, "") - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): t_http_client_instance = tcli_service_client_cass.return_value bad_protocol_versions = [ @@ -163,7 +167,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): "expected server to use a protocol version", str(cm.exception) ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): t_http_client_instance = tcli_service_client_cass.return_value good_protocol_versions = [ @@ -174,7 +178,9 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): for protocol_version in good_protocol_versions: t_http_client_instance.OpenSession.return_value = ttypes.TOpenSessionResp( - status=self.okay_status, serverProtocolVersion=protocol_version + status=self.okay_status, + serverProtocolVersion=protocol_version, + sessionHandle=self.session_handle, ) thrift_backend = self._make_fake_thrift_backend() @@ -182,7 +188,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_headers_are_set(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -229,7 +235,7 @@ def test_tls_cert_args_are_propagated( mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called_once_with(cafile=mock_trusted_ca_file) - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -315,7 +321,7 @@ def test_tls_no_verify_is_respected( mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called() - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -339,7 +345,7 @@ def test_tls_verify_hostname_is_respected( mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called() - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -356,7 +362,7 @@ def test_tls_verify_hostname_is_respected( @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_port_and_host_are_respected(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -371,7 +377,7 @@ def test_port_and_host_are_respected(self, t_http_client_class): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_https_does_not_duplicate(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "https://hostname", 123, "path_value", @@ -386,7 +392,7 @@ def test_host_with_https_does_not_duplicate(self, t_http_client_class): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "https://hostname/", 123, "path_value", @@ -401,7 +407,7 @@ def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_cla @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_socket_timeout_is_propagated(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -413,7 +419,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): self.assertEqual( t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000 ) - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -423,7 +429,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): _socket_timeout=0, ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0) - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -434,7 +440,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): self.assertEqual( t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000 ) - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -467,9 +473,9 @@ def test_non_primitive_types_raise_error(self): t_table_schema = ttypes.TTableSchema(columns) with self.assertRaises(OperationalError): - ThriftBackend._hive_schema_to_arrow_schema(t_table_schema) + ThriftDatabricksClient._hive_schema_to_arrow_schema(t_table_schema) with self.assertRaises(OperationalError): - ThriftBackend._hive_schema_to_description(t_table_schema) + ThriftDatabricksClient._hive_schema_to_description(t_table_schema) def test_hive_schema_to_description_preserves_column_names_and_types(self): # Full coverage of all types is done in integration tests, this is just a @@ -493,7 +499,7 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self): ] t_table_schema = ttypes.TTableSchema(columns) - description = ThriftBackend._hive_schema_to_description(t_table_schema) + description = ThriftDatabricksClient._hive_schema_to_description(t_table_schema) self.assertEqual( description, @@ -532,7 +538,7 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): ] t_table_schema = ttypes.TTableSchema(columns) - description = ThriftBackend._hive_schema_to_description(t_table_schema) + description = ThriftDatabricksClient._hive_schema_to_description(t_table_schema) self.assertEqual( description, [ @@ -545,7 +551,7 @@ def test_make_request_checks_status_code(self): ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ] - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -589,7 +595,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): closeOperation=None, ), ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -628,7 +634,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( closeOperation=None, ), ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -642,7 +648,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_checks_operation_state_in_polls( self, tcli_service_class ): @@ -672,7 +678,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( ) tcli_service_instance.GetOperationStatus.return_value = op_state_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -686,7 +692,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( if op_state_resp.errorMessage: self.assertIn(op_state_resp.errorMessage, str(cm.exception)) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_status_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value @@ -710,7 +716,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): ) tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -724,7 +730,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): self.assertEqual(display_message, str(cm.exception)) self.assertIn(diagnostic_info, str(cm.exception.message_with_context())) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_direct_results_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value @@ -750,7 +756,7 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -812,7 +818,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -825,7 +831,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( self, tcli_service_class ): @@ -863,7 +869,7 @@ def test_handle_execute_response_can_handle_without_direct_results( op_state_2, op_state_3, ] - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -900,7 +906,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): operationHandle=self.operation_handle, ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -917,7 +923,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): ttypes.TOperationState.FINISHED_STATE, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_use_arrow_schema_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value arrow_schema_mock = MagicMock(name="Arrow schema mock") @@ -946,7 +952,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value hive_schema_mock = MagicMock(name="Hive schema mock") @@ -976,7 +982,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): @@ -1020,7 +1026,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): @@ -1064,7 +1070,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( thrift_backend._handle_execute_response(execute_resp, Mock()) _, has_more_rows_resp = thrift_backend.fetch_results( - op_handle=Mock(), + command_id=Mock(), max_rows=1, max_bytes=1, expected_row_start_offset=0, @@ -1075,7 +1081,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( self.assertEqual(has_more_rows, has_more_rows_resp) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): # make some semi-real arrow batches and check the number of rows is correct in the queue tcli_service_instance = tcli_service_class.return_value @@ -1108,7 +1114,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): .to_pybytes() ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1117,7 +1123,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): ssl_options=SSLOptions(), ) arrow_queue, has_more_results = thrift_backend.fetch_results( - op_handle=Mock(), + command_id=Mock(), max_rows=1, max_bytes=1, expected_row_start_offset=0, @@ -1128,14 +1134,14 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.ExecuteStatement.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1146,7 +1152,12 @@ def test_execute_statement_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.execute_command("foo", Mock(), 100, 200, Mock(), cursor_mock) + result = thrift_backend.execute_command( + "foo", Mock(), 100, 200, Mock(), cursor_mock + ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1157,14 +1168,14 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1175,7 +1186,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) + result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1185,14 +1199,14 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1203,7 +1217,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_schemas( + result = thrift_backend.get_schemas( Mock(), 100, 200, @@ -1211,6 +1225,9 @@ def test_get_schemas_calls_client_and_handle_execute_response( catalog_name="catalog_pattern", schema_name="schema_pattern", ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1222,14 +1239,14 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1240,7 +1257,7 @@ def test_get_tables_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_tables( + result = thrift_backend.get_tables( Mock(), 100, 200, @@ -1250,6 +1267,9 @@ def test_get_tables_calls_client_and_handle_execute_response( table_name="table_pattern", table_types=["type1", "type2"], ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1263,14 +1283,14 @@ def test_get_tables_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1281,7 +1301,7 @@ def test_get_columns_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_columns( + result = thrift_backend.get_columns( Mock(), 100, 200, @@ -1291,6 +1311,9 @@ def test_get_columns_calls_client_and_handle_execute_response( table_name="table_pattern", column_name="column_pattern", ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1304,12 +1327,12 @@ def test_get_columns_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_open_session_user_provided_session_id_optional(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1320,10 +1343,10 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class thrift_backend.open_session({}, None, None) self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1331,16 +1354,17 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend.close_command(self.operation_handle) + command_id = CommandId.from_thrift_handle(self.operation_handle) + thrift_backend.close_command(command_id) self.assertEqual( tcli_service_instance.CloseOperation.call_args[0][0].operationHandle, self.operation_handle, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1348,13 +1372,14 @@ def test_session_handle_respected_in_close_session(self, tcli_service_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend.close_session(self.session_handle) + session_id = SessionId.from_thrift_handle(self.session_handle) + thrift_backend.close_session(session_id) self.assertEqual( tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_non_arrow_non_column_based_set_triggers_exception( self, tcli_service_class ): @@ -1392,7 +1417,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1403,12 +1428,16 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) - @patch("databricks.sql.thrift_backend.convert_arrow_based_set_to_arrow_table") - @patch("databricks.sql.thrift_backend.convert_column_based_set_to_arrow_table") + @patch( + "databricks.sql.backend.thrift_backend.convert_arrow_based_set_to_arrow_table" + ) + @patch( + "databricks.sql.backend.thrift_backend.convert_column_based_set_to_arrow_table" + ) def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1443,7 +1472,7 @@ def test_create_arrow_table_calls_correct_conversion_method( def test_convert_arrow_based_set_to_arrow_table( self, open_stream_mock, lz4_decompress_mock ): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1597,17 +1626,18 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): self.assertEqual(arrow_table.column(2).to_pylist(), [1.15, 2.2, 3.3]) self.assertEqual(arrow_table.column(3).to_pylist(), [b"\x11", b"\x22", b"\x33"]) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_cancel_command_uses_active_op_handle(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value thrift_backend = self._make_fake_thrift_backend() - active_op_handle_mock = Mock() - thrift_backend.cancel_command(active_op_handle_mock) + # Create a proper CommandId from the existing operation_handle + command_id = CommandId.from_thrift_handle(self.operation_handle) + thrift_backend.cancel_command(command_id) self.assertEqual( tcli_service_instance.CancelOperation.call_args[0][0].operationHandle, - active_op_handle_mock, + self.operation_handle, ) def test_handle_execute_response_sets_active_op_handle(self): @@ -1615,19 +1645,27 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() thrift_backend._results_message_to_execute_response = Mock() + + # Create a mock response with a real operation handle mock_resp = Mock() + mock_resp.operationHandle = ( + self.operation_handle + ) # Use the real operation handle from the test class mock_cursor = Mock() thrift_backend._handle_execute_response(mock_resp, mock_cursor) - self.assertEqual(mock_resp.operationHandle, mock_cursor.active_op_handle) + self.assertEqual( + mock_resp.operationHandle, mock_cursor.active_command_id.to_thrift_handle() + ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch( "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" ) @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + "databricks.sql.backend.thrift_backend._retry_policy", + new_callable=retry_policy_factory, ) def test_make_request_will_retry_GetOperationStatus( self, mock_retry_policy, mock_GetOperationStatus, t_transport_class @@ -1654,7 +1692,7 @@ def test_make_request_will_retry_GetOperationStatus( EXPECTED_RETRIES = 2 - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1681,7 +1719,7 @@ def test_make_request_will_retry_GetOperationStatus( ) with self.assertLogs( - "databricks.sql.thrift_backend", level=logging.WARNING + "databricks.sql.backend.thrift_backend", level=logging.WARNING ) as cm: with self.assertRaises(RequestError): thrift_backend.make_request(client.GetOperationStatus, req) @@ -1702,7 +1740,8 @@ def test_make_request_will_retry_GetOperationStatus( "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" ) @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + "databricks.sql.backend.thrift_backend._retry_policy", + new_callable=retry_policy_factory, ) def test_make_request_will_retry_GetOperationStatus_for_http_error( self, mock_retry_policy, mock_gos @@ -1731,7 +1770,7 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( EXPECTED_RETRIES = 2 - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1763,7 +1802,7 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503( mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1779,7 +1818,8 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503( @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + "databricks.sql.backend.thrift_backend._retry_policy", + new_callable=retry_policy_factory, ) def test_make_request_will_retry_stop_after_attempts_count_if_retryable( self, mock_retry_policy, t_transport_class @@ -1791,7 +1831,7 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1820,7 +1860,7 @@ def test_make_request_will_read_error_message_headers_if_set( mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1944,7 +1984,7 @@ def test_retry_args_passthrough(self, mock_http_client): "_retry_stop_after_attempts_count": 1, "_retry_stop_after_attempts_duration": 100, } - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1959,7 +1999,12 @@ def test_retry_args_passthrough(self, mock_http_client): @patch("thrift.transport.THttpClient.THttpClient") def test_retry_args_bounding(self, mock_http_client): retry_delay_test_args_and_expected_values = {} - for k, (_, _, min, max) in databricks.sql.thrift_backend._retry_policy.items(): + for k, ( + _, + _, + min, + max, + ) in databricks.sql.backend.thrift_backend._retry_policy.items(): retry_delay_test_args_and_expected_values[k] = ( (min - 1, min), (max + 1, max), @@ -1970,7 +2015,7 @@ def test_retry_args_bounding(self, mock_http_client): k: v[i][0] for (k, v) in retry_delay_test_args_and_expected_values.items() } - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1986,7 +2031,7 @@ def test_retry_args_bounding(self, mock_http_client): for arg, val in retry_delay_expected_vals.items(): self.assertEqual(getattr(backend, arg), val) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_configuration_passthrough(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp @@ -1998,7 +2043,7 @@ def test_configuration_passthrough(self, tcli_client_class): "42": "42", } - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2011,12 +2056,12 @@ def test_configuration_passthrough(self, tcli_client_class): open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] self.assertEqual(open_session_req.configuration, expected_config) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp mock_config = {"spark.thriftserver.arrowBasedRowSet.timestampAsString": True} - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2036,13 +2081,14 @@ def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, sch serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4, canUseMultipleCatalogs=can_use_multiple_cats, initialNamespace=ttypes.TNamespace(catalogName=cat, schemaName=schem), + sessionHandle=self.session_handle, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2066,14 +2112,14 @@ def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): self.assertEqual(open_session_req.initialNamespace.catalogName, cat) self.assertEqual(open_session_req.initialNamespace.schemaName, schem) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_can_use_multiple_catalogs_is_set_in_open_session_req( self, tcli_client_class ): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2086,13 +2132,13 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req( open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] self.assertTrue(open_session_req.canUseMultipleCatalogs) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog( self, tcli_client_class ): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2126,7 +2172,7 @@ def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog( ) backend.open_session({}, cat, schem) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value @@ -2135,9 +2181,10 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3, canUseMultipleCatalogs=True, initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem"), + sessionHandle=self.session_handle, ) - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2154,8 +2201,10 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - @patch("databricks.sql.thrift_backend.ThriftBackend._handle_execute_response") + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" + ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class ): @@ -2172,7 +2221,7 @@ def test_execute_command_sets_complex_type_fields_correctly( if decimals is not None: complex_arg_types["_use_arrow_native_decimals"] = decimals - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path",
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: