diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 6bcaabaec..02b8d4604 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -20,7 +20,12 @@ MetadataCommands, ) from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings +from databricks.sql.backend.sea.utils.metadata_transforms import ( + create_table_catalog_transform, +) from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift +from databricks.sql.backend.sea.utils.result_column import ResultColumn +from databricks.sql.backend.sea.utils.conversion import SqlType from databricks.sql.thrift_api.TCLIService import ttypes if TYPE_CHECKING: @@ -740,7 +745,23 @@ def get_schemas( assert isinstance( result, SeaResultSet ), "Expected SeaResultSet from SEA backend" - result.prepare_metadata_columns(MetadataColumnMappings.SCHEMA_COLUMNS) + + # Create dynamic schema columns with catalog name bound to TABLE_CATALOG + schema_columns = [] + for col in MetadataColumnMappings.SCHEMA_COLUMNS: + if col.thrift_col_name == "TABLE_CATALOG": + # Create a new column with the catalog transform bound + dynamic_col = ResultColumn( + col.thrift_col_name, + col.sea_col_name, + col.thrift_col_type, + create_table_catalog_transform(catalog_name), + ) + schema_columns.append(dynamic_col) + else: + schema_columns.append(col) + + result.prepare_metadata_columns(schema_columns) return result def get_tables( diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 09a8df1eb..af68721a1 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -320,7 +320,7 @@ def _prepare_column_mapping(self) -> None: None, None, None, - True, + None, ) # Set the mapping @@ -356,14 +356,20 @@ def _normalise_arrow_metadata_cols(self, table: "pyarrow.Table") -> "pyarrow.Tab if self._column_index_mapping else None ) - column = ( pyarrow.nulls(table.num_rows) if old_idx is None else table.column(old_idx) ) - new_columns.append(column) + # Apply transform if available + if result_column.transform_value: + # Convert to list, apply transform, and convert back + values = column.to_pylist() + transformed_values = [result_column.transform_value(v) for v in values] + column = pyarrow.array(transformed_values) + + new_columns.append(column) column_names.append(result_column.thrift_col_name) return pyarrow.Table.from_arrays(new_columns, names=column_names) @@ -382,8 +388,11 @@ def _normalise_json_metadata_cols(self, rows: List[List[str]]) -> List[List[Any] if self._column_index_mapping else None ) - value = None if old_idx is None else row[old_idx] + + # Apply transform if available + if result_column.transform_value: + value = result_column.transform_value(value) new_row.append(value) transformed_rows.append(new_row) return transformed_rows diff --git a/src/databricks/sql/backend/sea/utils/metadata_mappings.py b/src/databricks/sql/backend/sea/utils/metadata_mappings.py index 340c4c79e..ff5f2ab8b 100644 --- a/src/databricks/sql/backend/sea/utils/metadata_mappings.py +++ b/src/databricks/sql/backend/sea/utils/metadata_mappings.py @@ -1,5 +1,13 @@ from databricks.sql.backend.sea.utils.result_column import ResultColumn from databricks.sql.backend.sea.utils.conversion import SqlType +from databricks.sql.backend.sea.utils.metadata_transforms import ( + transform_remarks, + transform_is_autoincrement, + transform_is_nullable, + transform_nullable, + transform_data_type, + transform_ordinal_position, +) class MetadataColumnMappings: @@ -18,7 +26,9 @@ class MetadataColumnMappings: SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", SqlType.STRING) TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", SqlType.STRING) TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", SqlType.STRING) - REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", SqlType.STRING) + REMARKS_COLUMN = ResultColumn( + "REMARKS", "remarks", SqlType.STRING, transform_remarks + ) TYPE_CATALOG_COLUMN = ResultColumn("TYPE_CAT", None, SqlType.STRING) TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, SqlType.STRING) TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, SqlType.STRING) @@ -28,7 +38,9 @@ class MetadataColumnMappings: REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, SqlType.STRING) COL_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", SqlType.STRING) - DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", None, SqlType.INT) + DATA_TYPE_COLUMN = ResultColumn( + "DATA_TYPE", "columnType", SqlType.INT, transform_data_type + ) COLUMN_TYPE_COLUMN = ResultColumn("TYPE_NAME", "columnType", SqlType.STRING) COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", SqlType.INT) BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, SqlType.TINYINT) @@ -43,14 +55,19 @@ class MetadataColumnMappings: "ORDINAL_POSITION", "ordinalPosition", SqlType.INT, + transform_ordinal_position, ) - NULLABLE_COLUMN = ResultColumn("NULLABLE", None, SqlType.INT) + NULLABLE_COLUMN = ResultColumn( + "NULLABLE", "isNullable", SqlType.INT, transform_nullable + ) COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", None, SqlType.STRING) SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, SqlType.INT) SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, SqlType.INT) CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, SqlType.INT) - IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", SqlType.STRING) + IS_NULLABLE_COLUMN = ResultColumn( + "IS_NULLABLE", "isNullable", SqlType.STRING, transform_is_nullable + ) SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, SqlType.STRING) SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, SqlType.STRING) @@ -58,7 +75,10 @@ class MetadataColumnMappings: SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, SqlType.SMALLINT) IS_AUTO_INCREMENT_COLUMN = ResultColumn( - "IS_AUTOINCREMENT", "isAutoIncrement", SqlType.STRING + "IS_AUTO_INCREMENT", + "isAutoIncrement", + SqlType.STRING, + transform_is_autoincrement, ) IS_GENERATED_COLUMN = ResultColumn( "IS_GENERATEDCOLUMN", "isGenerated", SqlType.STRING diff --git a/src/databricks/sql/backend/sea/utils/metadata_transforms.py b/src/databricks/sql/backend/sea/utils/metadata_transforms.py new file mode 100644 index 000000000..efff2236a --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/metadata_transforms.py @@ -0,0 +1,83 @@ +"""Simple transformation functions for metadata value normalization.""" + + +def transform_is_autoincrement(value): + """Transform IS_AUTOINCREMENT: boolean to YES/NO string.""" + if isinstance(value, bool) or value is None: + return "YES" if value else "NO" + return value + + +def transform_is_nullable(value): + """Transform IS_NULLABLE: true/false to YES/NO string.""" + if value is True or value == "true": + return "YES" + elif value is False or value == "false": + return "NO" + return value + + +def transform_remarks(value): + if value is None: + return "" + return value + + +def transform_nullable(value): + """Transform NULLABLE column: boolean/string to integer.""" + if value is True or value == "true" or value == "YES": + return 1 + elif value is False or value == "false" or value == "NO": + return 0 + return value + + +# Type code mapping based on JDBC specification +TYPE_CODE_MAP = { + "STRING": 12, # VARCHAR + "VARCHAR": 12, # VARCHAR + "CHAR": 1, # CHAR + "INT": 4, # INTEGER + "INTEGER": 4, # INTEGER + "BIGINT": -5, # BIGINT + "SMALLINT": 5, # SMALLINT + "TINYINT": -6, # TINYINT + "DOUBLE": 8, # DOUBLE + "FLOAT": 6, # FLOAT + "REAL": 7, # REAL + "DECIMAL": 3, # DECIMAL + "NUMERIC": 2, # NUMERIC + "BOOLEAN": 16, # BOOLEAN + "DATE": 91, # DATE + "TIMESTAMP": 93, # TIMESTAMP + "BINARY": -2, # BINARY + "ARRAY": 2003, # ARRAY + "MAP": 2002, # JAVA_OBJECT + "STRUCT": 2002, # JAVA_OBJECT +} + + +def transform_data_type(value): + """Transform DATA_TYPE: type name to JDBC type code.""" + if isinstance(value, str): + # Handle parameterized types like DECIMAL(10,2) + base_type = value.split("(")[0].upper() + return TYPE_CODE_MAP.get(base_type, value) + return value + + +def transform_ordinal_position(value): + """Transform ORDINAL_POSITION: decrement by 1 (1-based to 0-based).""" + if isinstance(value, int): + return value - 1 + return value + + +def create_table_catalog_transform(catalog_name): + """Factory function to create TABLE_CATALOG transform with bound catalog name.""" + + def transform_table_catalog(value): + """Transform TABLE_CATALOG: return the catalog name for all rows.""" + return catalog_name + + return transform_table_catalog diff --git a/src/databricks/sql/backend/sea/utils/result_column.py b/src/databricks/sql/backend/sea/utils/result_column.py index a4c1f619b..2980bd8d9 100644 --- a/src/databricks/sql/backend/sea/utils/result_column.py +++ b/src/databricks/sql/backend/sea/utils/result_column.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Callable, Any @dataclass(frozen=True) @@ -11,8 +11,10 @@ class ResultColumn: thrift_col_name: Column name as returned by Thrift (e.g., "TABLE_CAT") sea_col_name: Server result column name from SEA (e.g., "catalog") thrift_col_type: SQL type name + transform_value: Optional callback to transform values for this column """ thrift_col_name: str sea_col_name: Optional[str] # None if SEA doesn't return this column thrift_col_type: str + transform_value: Optional[Callable[[Any], Any]] = None diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 3fa87b1af..7b86cfbe8 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -562,8 +562,17 @@ def test_get_schemas(self): finally: cursor.execute("DROP DATABASE IF EXISTS {}".format(database_name)) - def test_get_catalogs(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "backend_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) + def test_get_catalogs(self, backend_params): + with self.cursor(backend_params) as cursor: cursor.catalogs() cursor.fetchall() catalogs_desc = cursor.description diff --git a/tests/unit/test_metadata_mappings.py b/tests/unit/test_metadata_mappings.py index 6ab749e0f..f0ffef067 100644 --- a/tests/unit/test_metadata_mappings.py +++ b/tests/unit/test_metadata_mappings.py @@ -89,7 +89,7 @@ def test_column_columns_mapping(self): "TABLE_SCHEM": ("namespace", SqlType.STRING), "TABLE_NAME": ("tableName", SqlType.STRING), "COLUMN_NAME": ("col_name", SqlType.STRING), - "DATA_TYPE": (None, SqlType.INT), + "DATA_TYPE": ("columnType", SqlType.INT), "TYPE_NAME": ("columnType", SqlType.STRING), "COLUMN_SIZE": ("columnSize", SqlType.INT), "DECIMAL_DIGITS": ("decimalDigits", SqlType.INT), diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 1a2621f06..5f2df8887 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -826,9 +826,6 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): # Verify prepare_metadata_columns was called for successful cases assert mock_result_set.prepare_metadata_columns.call_count == 2 - mock_result_set.prepare_metadata_columns.assert_called_with( - MetadataColumnMappings.SCHEMA_COLUMNS - ) def test_get_tables(self, sea_client, sea_session_id, mock_cursor): """Test the get_tables method with various parameter combinations.""" pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy