diff --git a/examples/streaming_put.py b/examples/streaming_put.py new file mode 100644 index 00000000..4e769709 --- /dev/null +++ b/examples/streaming_put.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +""" +Simple example of streaming PUT operations. + +This demonstrates the basic usage of streaming PUT with the __input_stream__ token. +""" + +import io +import os +from databricks import sql + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + with connection.cursor() as cursor: + # Create a simple data stream + data = b"Hello, streaming world!" + stream = io.BytesIO(data) + + # Get catalog, schema, and volume from environment variables + catalog = os.getenv("DATABRICKS_CATALOG") + schema = os.getenv("DATABRICKS_SCHEMA") + volume = os.getenv("DATABRICKS_VOLUME") + + # Upload to Unity Catalog volume + cursor.execute( + f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/{volume}/hello.txt' OVERWRITE", + input_stream=stream + ) + + print("File uploaded successfully!") \ No newline at end of file diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index e4166f11..7680940a 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,5 +1,5 @@ import time -from typing import Dict, Tuple, List, Optional, Any, Union, Sequence +from typing import Dict, Tuple, List, Optional, Any, Union, Sequence, BinaryIO import pandas try: @@ -67,6 +67,7 @@ ) from databricks.sql.telemetry.latency_logger import log_latency from databricks.sql.telemetry.models.enums import StatementType +from databricks.sql.common.http import DatabricksHttpClient, HttpMethod logger = logging.getLogger(__name__) @@ -615,8 +616,34 @@ def _check_not_closed(self): session_id_hex=self.connection.get_session_id_hex(), ) + def _validate_staging_http_response( + self, response: requests.Response, operation_name: str = "staging operation" + ) -> None: + + # Check response codes + OK = requests.codes.ok # 200 + CREATED = requests.codes.created # 201 + ACCEPTED = requests.codes.accepted # 202 + NO_CONTENT = requests.codes.no_content # 204 + + if response.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: + raise OperationalError( + f"{operation_name} over HTTP was unsuccessful: {response.status_code}-{response.text}", + session_id_hex=self.connection.get_session_id_hex(), + ) + + if response.status_code == ACCEPTED: + logger.debug( + "Response code %s from server indicates %s was accepted " + "but not yet applied on the server. It's possible this command may fail later.", + ACCEPTED, + operation_name, + ) + def _handle_staging_operation( - self, staging_allowed_local_path: Union[None, str, List[str]] + self, + staging_allowed_local_path: Union[None, str, List[str]], + input_stream: Optional[BinaryIO] = None, ): """Fetch the HTTP request instruction from a staging ingestion command and call the designated handler. @@ -625,6 +652,28 @@ def _handle_staging_operation( is not descended from staging_allowed_local_path. """ + assert self.active_result_set is not None + row = self.active_result_set.fetchone() + assert row is not None + + # Parse headers + headers = ( + json.loads(row.headers) if isinstance(row.headers, str) else row.headers + ) + headers = dict(headers) if headers else {} + + # Handle __input_stream__ token for PUT operations + if ( + row.operation == "PUT" + and getattr(row, "localFile", None) == "__input_stream__" + ): + return self._handle_staging_put_stream( + presigned_url=row.presignedUrl, + stream=input_stream, + headers=headers, + ) + + # For non-streaming operations, validate staging_allowed_local_path if isinstance(staging_allowed_local_path, type(str())): _staging_allowed_local_paths = [staging_allowed_local_path] elif isinstance(staging_allowed_local_path, type(list())): @@ -639,10 +688,6 @@ def _handle_staging_operation( os.path.abspath(i) for i in _staging_allowed_local_paths ] - assert self.active_result_set is not None - row = self.active_result_set.fetchone() - assert row is not None - # Must set to None in cases where server response does not include localFile abs_localFile = None @@ -665,19 +710,16 @@ def _handle_staging_operation( session_id_hex=self.connection.get_session_id_hex(), ) - # May be real headers, or could be json string - headers = ( - json.loads(row.headers) if isinstance(row.headers, str) else row.headers - ) - handler_args = { "presigned_url": row.presignedUrl, "local_file": abs_localFile, - "headers": dict(headers) or {}, + "headers": headers, } logger.debug( - f"Attempting staging operation indicated by server: {row.operation} - {getattr(row, 'localFile', '')}" + "Attempting staging operation indicated by server: %s - %s", + row.operation, + getattr(row, "localFile", ""), ) # TODO: Create a retry loop here to re-attempt if the request times out or fails @@ -696,6 +738,43 @@ def _handle_staging_operation( session_id_hex=self.connection.get_session_id_hex(), ) + @log_latency(StatementType.SQL) + def _handle_staging_put_stream( + self, + presigned_url: str, + stream: BinaryIO, + headers: dict = {}, + ) -> None: + """Handle PUT operation with streaming data. + + Args: + presigned_url: The presigned URL for upload + stream: Binary stream to upload + headers: HTTP headers + + Raises: + ProgrammingError: If no input stream is provided + OperationalError: If the upload fails + """ + + if not stream: + raise ProgrammingError( + "No input stream provided for streaming operation", + session_id_hex=self.connection.get_session_id_hex(), + ) + + http_client = DatabricksHttpClient.get_instance() + + # Stream directly to presigned URL + with http_client.execute( + method=HttpMethod.PUT, + url=presigned_url, + data=stream, + headers=headers, + timeout=300, # 5 minute timeout + ) as response: + self._validate_staging_http_response(response, "stream upload") + @log_latency(StatementType.SQL) def _handle_staging_put( self, presigned_url: str, local_file: str, headers: Optional[dict] = None @@ -714,27 +793,7 @@ def _handle_staging_put( with open(local_file, "rb") as fh: r = requests.put(url=presigned_url, data=fh, headers=headers) - # fmt: off - # Design borrowed from: https://stackoverflow.com/a/2342589/5093960 - - OK = requests.codes.ok # 200 - CREATED = requests.codes.created # 201 - ACCEPTED = requests.codes.accepted # 202 - NO_CONTENT = requests.codes.no_content # 204 - - # fmt: on - - if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: - raise OperationalError( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", - session_id_hex=self.connection.get_session_id_hex(), - ) - - if r.status_code == ACCEPTED: - logger.debug( - f"Response code {ACCEPTED} from server indicates ingestion command was accepted " - + "but not yet applied on the server. It's possible this command may fail later." - ) + self._validate_staging_http_response(r, "file upload") @log_latency(StatementType.SQL) def _handle_staging_get( @@ -784,6 +843,7 @@ def execute( operation: str, parameters: Optional[TParameterCollection] = None, enforce_embedded_schema_correctness=False, + input_stream: Optional[BinaryIO] = None, ) -> "Cursor": """ Execute a query and wait for execution to complete. @@ -820,7 +880,6 @@ def execute( logger.debug( "Cursor.execute(operation=%s, parameters=%s)", operation, parameters ) - param_approach = self._determine_parameter_approach(parameters) if param_approach == ParameterApproach.NONE: prepared_params = NO_NATIVE_PARAMS @@ -857,7 +916,8 @@ def execute( if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( - staging_allowed_local_path=self.connection.staging_allowed_local_path + staging_allowed_local_path=self.connection.staging_allowed_local_path, + input_stream=input_stream, ) return self diff --git a/tests/e2e/common/streaming_put_tests.py b/tests/e2e/common/streaming_put_tests.py new file mode 100644 index 00000000..30e7c619 --- /dev/null +++ b/tests/e2e/common/streaming_put_tests.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +""" +E2E tests for streaming PUT operations. +""" + +import io +import logging +import pytest +from datetime import datetime + +logger = logging.getLogger(__name__) + + +class PySQLStreamingPutTestSuiteMixin: + """Test suite for streaming PUT operations.""" + + def test_streaming_put_basic(self, catalog, schema): + """Test basic streaming PUT functionality.""" + + # Create test data + test_data = b"Hello, streaming world! This is test data." + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"stream_test_{timestamp}.txt" + file_path = f"/Volumes/{catalog}/{schema}/e2etests/{filename}" + + try: + with self.connection() as conn: + with conn.cursor() as cursor: + with io.BytesIO(test_data) as stream: + cursor.execute( + f"PUT '__input_stream__' INTO '{file_path}'", + input_stream=stream + ) + + # Verify file exists + cursor.execute(f"LIST '/Volumes/{catalog}/{schema}/e2etests/'") + files = cursor.fetchall() + + # Check if our file is in the list + file_paths = [row[0] for row in files] + assert file_path in file_paths, f"File {file_path} not found in {file_paths}" + finally: + self._cleanup_test_file(file_path) + + def test_streaming_put_missing_stream(self, catalog, schema): + """Test that missing stream raises appropriate error.""" + + with self.connection() as conn: + with conn.cursor() as cursor: + # Test without providing stream + with pytest.raises(Exception): # Should fail + cursor.execute( + f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/e2etests/test.txt'" + # Note: No input_stream parameter + ) + + def _cleanup_test_file(self, file_path): + """Clean up a test file if it exists.""" + try: + with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with conn.cursor() as cursor: + cursor.execute(f"REMOVE '{file_path}'") + logger.info("Successfully cleaned up test file: %s", file_path) + except Exception as e: + logger.error("Cleanup failed for %s: %s", file_path, e) \ No newline at end of file diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 042fcc10..7a704109 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -47,8 +47,8 @@ ) from tests.e2e.common.staging_ingestion_tests import PySQLStagingIngestionTestSuiteMixin from tests.e2e.common.retry_test_mixins import PySQLRetryTestsMixin - from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin +from tests.e2e.common.streaming_put_tests import PySQLStreamingPutTestSuiteMixin from databricks.sql.exc import SessionAlreadyClosedError @@ -256,6 +256,7 @@ class TestPySQLCoreSuite( PySQLStagingIngestionTestSuiteMixin, PySQLRetryTestsMixin, PySQLUCVolumeTestSuiteMixin, + PySQLStreamingPutTestSuiteMixin, ): validate_row_value_type = True validate_result = True diff --git a/tests/unit/test_streaming_put.py b/tests/unit/test_streaming_put.py new file mode 100644 index 00000000..e1d3f27e --- /dev/null +++ b/tests/unit/test_streaming_put.py @@ -0,0 +1,171 @@ + +import io +import pytest +from unittest.mock import patch, Mock, MagicMock +import databricks.sql.client as client +from databricks.sql import ProgrammingError +import requests + + +class TestStreamingPut: + """Unit tests for streaming PUT functionality.""" + + @pytest.fixture + def mock_connection(self): + return Mock() + + @pytest.fixture + def mock_backend(self): + return Mock() + + @pytest.fixture + def cursor(self, mock_connection, mock_backend): + return client.Cursor( + connection=mock_connection, + backend=mock_backend + ) + + def _setup_mock_staging_put_stream_response(self, mock_backend): + """Helper method to set up mock staging PUT stream response.""" + mock_result_set = Mock() + mock_result_set.is_staging_operation = True + mock_backend.execute_command.return_value = mock_result_set + + mock_row = Mock() + mock_row.operation = "PUT" + mock_row.localFile = "__input_stream__" + mock_row.presignedUrl = "https://example.com/upload" + mock_row.headers = "{}" + mock_result_set.fetchone.return_value = mock_row + + return mock_result_set + + def test_execute_with_valid_stream(self, cursor, mock_backend): + """Test execute method with valid input stream.""" + + # Mock the backend response + self._setup_mock_staging_put_stream_response(mock_backend) + + # Test with valid stream + test_stream = io.BytesIO(b"test data") + + with patch.object(cursor, '_handle_staging_put_stream') as mock_handler: + cursor.execute( + "PUT '__input_stream__' INTO '/Volumes/test/cat/schema/vol/file.txt'", + input_stream=test_stream + ) + + # Verify staging handler was called + mock_handler.assert_called_once() + + def test_execute_with_invalid_stream_types(self, cursor, mock_backend): + + # Mock the backend response + self._setup_mock_staging_put_stream_response(mock_backend) + + # Test with None input stream + with pytest.raises(client.ProgrammingError) as excinfo: + cursor.execute( + "PUT '__input_stream__' INTO '/Volumes/test/cat/schema/vol/file.txt'", + input_stream=None + ) + assert "No input stream provided for streaming operation" in str(excinfo.value) + + def test_execute_with_none_stream_for_staging_put(self, cursor, mock_backend): + """Test execute method rejects None stream for streaming PUT operations.""" + + # Mock staging operation response for None case + self._setup_mock_staging_put_stream_response(mock_backend) + + # None with __input_stream__ raises ProgrammingError + with pytest.raises(client.ProgrammingError) as excinfo: + cursor.execute( + "PUT '__input_stream__' INTO '/Volumes/test/cat/schema/vol/file.txt'", + input_stream=None + ) + error_msg = str(excinfo.value) + assert "No input stream provided for streaming operation" in error_msg + + def test_handle_staging_put_stream_success(self, cursor): + """Test successful streaming PUT operation.""" + + test_stream = io.BytesIO(b"test data") + presigned_url = "https://example.com/upload" + headers = {"Content-Type": "text/plain"} + + with patch('databricks.sql.client.DatabricksHttpClient') as mock_client_class: + mock_client = Mock() + mock_client_class.get_instance.return_value = mock_client + + # Mock the context manager properly using MagicMock + mock_context = MagicMock() + mock_response = Mock() + mock_response.status_code = 200 + mock_context.__enter__.return_value = mock_response + mock_context.__exit__.return_value = None + mock_client.execute.return_value = mock_context + + cursor._handle_staging_put_stream( + presigned_url=presigned_url, + stream=test_stream, + headers=headers + ) + + # Verify the HTTP client was called correctly + mock_client.execute.assert_called_once() + call_args = mock_client.execute.call_args + assert call_args[1]['method'].value == 'PUT' + assert call_args[1]['url'] == presigned_url + assert call_args[1]['data'] == test_stream + assert call_args[1]['headers'] == headers + + def test_handle_staging_put_stream_http_error(self, cursor): + """Test streaming PUT operation with HTTP error.""" + + test_stream = io.BytesIO(b"test data") + presigned_url = "https://example.com/upload" + + with patch('databricks.sql.client.DatabricksHttpClient') as mock_client_class: + mock_client = Mock() + mock_client_class.get_instance.return_value = mock_client + + # Mock the context manager with error response + mock_context = MagicMock() + mock_response = Mock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_context.__enter__.return_value = mock_response + mock_context.__exit__.return_value = None + mock_client.execute.return_value = mock_context + + with pytest.raises(client.OperationalError) as excinfo: + cursor._handle_staging_put_stream( + presigned_url=presigned_url, + stream=test_stream + ) + + # Check for the actual error message format + assert "500" in str(excinfo.value) + + def test_handle_staging_put_stream_network_error(self, cursor): + """Test streaming PUT operation with network error.""" + + test_stream = io.BytesIO(b"test data") + presigned_url = "https://example.com/upload" + + with patch('databricks.sql.client.DatabricksHttpClient') as mock_client_class: + mock_client = Mock() + mock_client_class.get_instance.return_value = mock_client + + # Mock the context manager to raise an exception + mock_context = MagicMock() + mock_context.__enter__.side_effect = requests.exceptions.RequestException("Network error") + mock_client.execute.return_value = mock_context + + with pytest.raises(requests.exceptions.RequestException) as excinfo: + cursor._handle_staging_put_stream( + presigned_url=presigned_url, + stream=test_stream + ) + + assert "Network error" in str(excinfo.value)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: