diff --git a/docs/components/results.md b/docs/components/results.md index 7cea19a5..765571fa 100644 --- a/docs/components/results.md +++ b/docs/components/results.md @@ -14,8 +14,9 @@ Currently there are two results: ### Result #### Parameters + - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) -- `as_tuple`: return result as a tuple instead of dict. +- `as_tuple`: Headless tuple output Get the result as a list of dicts @@ -32,7 +33,7 @@ async def main() -> None: list_dict_result: List[Dict[str, Any]] = query_result.result() # Result as tuple - list_tuple_result: List[Tuple[Tuple[str, typing.Any], ...]] = query_result.result( + list_tuple_result: List[Tuple[str, typing.Any], ...] = query_result.result( as_tuple=True, ) ``` @@ -40,6 +41,7 @@ async def main() -> None: ### As class #### Parameters + - `as_class`: Custom class from Python. - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) @@ -68,6 +70,7 @@ async def main() -> None: ### Row Factory #### Parameters + - `row_factory`: custom callable object. - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) @@ -78,8 +81,9 @@ async def main() -> None: ### Result #### Parameters + - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) -- `as_tuple`: return result as a tuple instead of dict. +- `as_tuple`: Headless tuple output Get the result as a dict @@ -96,7 +100,7 @@ async def main() -> None: dict_result: Dict[str, Any] = query_result.result() # Result as tuple - tuple_result: Tuple[Tuple[str, typing.Any], ...] = query_result.result( + tuple_result: Tuple[str, typing.Any] = query_result.result( as_tuple=True, ) ``` @@ -104,6 +108,7 @@ async def main() -> None: ### As class #### Parameters + - `as_class`: Custom class from Python. - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) @@ -131,6 +136,7 @@ async def main() -> None: ### Row Factory #### Parameters + - `row_factory`: custom callable object. - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index 2665678b..17a2d482 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -36,7 +36,7 @@ class QueryResult: self: Self, as_tuple: typing.Literal[True], custom_decoders: dict[str, Callable[[bytes], Any]] | None = None, - ) -> list[tuple[tuple[str, typing.Any], ...]]: ... + ) -> list[tuple[typing.Any, ...]]: ... @typing.overload def result( self: Self, @@ -50,6 +50,7 @@ class QueryResult: `custom_decoders` must be used when you use PostgreSQL Type which isn't supported, read more in our docs. """ + def as_class( self: Self, as_class: Callable[..., _CustomClass], @@ -83,6 +84,7 @@ class QueryResult: ) ``` """ + def row_factory( self, row_factory: Callable[[dict[str, Any]], _RowFactoryRV], @@ -124,7 +126,7 @@ class SingleQueryResult: self: Self, as_tuple: typing.Literal[True], custom_decoders: dict[str, Callable[[bytes], Any]] | None = None, - ) -> tuple[tuple[str, typing.Any]]: ... + ) -> tuple[typing.Any, ...]: ... @typing.overload def result( self: Self, @@ -138,6 +140,7 @@ class SingleQueryResult: `custom_decoders` must be used when you use PostgreSQL Type which isn't supported, read more in our docs. """ + def as_class( self: Self, as_class: Callable[..., _CustomClass], @@ -174,6 +177,7 @@ class SingleQueryResult: ) ``` """ + def row_factory( self, row_factory: Callable[[dict[str, Any]], _RowFactoryRV], @@ -328,11 +332,13 @@ class Cursor: Execute DECLARE command for the cursor. """ + def close(self: Self) -> None: """Close the cursor. Execute CLOSE command for the cursor. """ + async def execute( self: Self, querystring: str, @@ -343,10 +349,13 @@ class Cursor: Method should be used instead of context manager and `start` method. """ + async def fetchone(self: Self) -> QueryResult: """Return next one row from the cursor.""" + async def fetchmany(self: Self, size: int | None = None) -> QueryResult: """Return rows from the cursor.""" + async def fetchall(self: Self, size: int | None = None) -> QueryResult: """Return all remaining rows from the cursor.""" @@ -379,6 +388,7 @@ class Transaction: `begin()` can be called only once per transaction. """ + async def commit(self: Self) -> None: """Commit the transaction. @@ -386,6 +396,7 @@ class Transaction: `commit()` can be called only once per transaction. """ + async def rollback(self: Self) -> None: """Rollback all queries in the transaction. @@ -406,6 +417,7 @@ class Transaction: await transaction.rollback() ``` """ + async def execute( self: Self, querystring: str, @@ -443,6 +455,7 @@ class Transaction: await transaction.commit() ``` """ + async def execute_batch( self: Self, querystring: str, @@ -458,6 +471,7 @@ class Transaction: ### Parameters: - `querystring`: querystrings separated by semicolons. """ + async def execute_many( self: Self, querystring: str, @@ -516,6 +530,7 @@ class Transaction: - `prepared`: should the querystring be prepared before the request. By default any querystring will be prepared. """ + async def fetch_row( self: Self, querystring: str, @@ -555,6 +570,7 @@ class Transaction: await transaction.commit() ``` """ + async def fetch_val( self: Self, querystring: str, @@ -595,6 +611,7 @@ class Transaction: ) ``` """ + async def pipeline( self, queries: list[tuple[str, list[Any] | None]], @@ -659,6 +676,7 @@ class Transaction: ) ``` """ + async def create_savepoint(self: Self, savepoint_name: str) -> None: """Create new savepoint. @@ -687,6 +705,7 @@ class Transaction: await transaction.rollback_savepoint("my_savepoint") ``` """ + async def rollback_savepoint(self: Self, savepoint_name: str) -> None: """ROLLBACK to the specified `savepoint_name`. @@ -712,6 +731,7 @@ class Transaction: await transaction.rollback_savepoint("my_savepoint") ``` """ + async def release_savepoint(self: Self, savepoint_name: str) -> None: """Execute ROLLBACK TO SAVEPOINT. @@ -736,6 +756,7 @@ class Transaction: await transaction.release_savepoint ``` """ + def cursor( self: Self, querystring: str, @@ -779,6 +800,7 @@ class Transaction: await cursor.close() ``` """ + async def binary_copy_to_table( self: Self, source: bytes | bytearray | Buffer | BytesIO, @@ -860,6 +882,7 @@ class Connection: Return representation of prepared statement. """ + async def execute( self: Self, querystring: str, @@ -896,6 +919,7 @@ class Connection: dict_result: List[Dict[Any, Any]] = query_result.result() ``` """ + async def execute_batch( self: Self, querystring: str, @@ -911,6 +935,7 @@ class Connection: ### Parameters: - `querystring`: querystrings separated by semicolons. """ + async def execute_many( self: Self, querystring: str, @@ -964,6 +989,7 @@ class Connection: - `prepared`: should the querystring be prepared before the request. By default any querystring will be prepared. """ + async def fetch_row( self: Self, querystring: str, @@ -1000,6 +1026,7 @@ class Connection: dict_result: Dict[Any, Any] = query_result.result() ``` """ + async def fetch_val( self: Self, querystring: str, @@ -1039,6 +1066,7 @@ class Connection: ) ``` """ + def transaction( self, isolation_level: IsolationLevel | None = None, @@ -1052,6 +1080,7 @@ class Connection: - `read_variant`: configure read variant of the transaction. - `deferrable`: configure deferrable of the transaction. """ + def cursor( self: Self, querystring: str, @@ -1090,6 +1119,7 @@ class Connection: ... # do something with this result. ``` """ + def close(self: Self) -> None: """Return connection back to the pool. @@ -1234,6 +1264,7 @@ class ConnectionPool: - `ca_file`: Loads trusted root certificates from a file. The file should contain a sequence of PEM-formatted CA certificates. """ + def __iter__(self: Self) -> Self: ... def __enter__(self: Self) -> Self: ... def __exit__( @@ -1248,6 +1279,7 @@ class ConnectionPool: ### Returns `ConnectionPoolStatus` """ + def resize(self: Self, new_max_size: int) -> None: """Resize the connection pool. @@ -1257,11 +1289,13 @@ class ConnectionPool: ### Parameters: - `new_max_size`: new size for the connection pool. """ + async def connection(self: Self) -> Connection: """Create new connection. It acquires new connection from the database pool. """ + def acquire(self: Self) -> Connection: """Create new connection for async context manager. @@ -1279,6 +1313,7 @@ class ConnectionPool: res = await connection.execute(...) ``` """ + def listener(self: Self) -> Listener: """Create new listener.""" @@ -1390,6 +1425,7 @@ class ConnectionPoolBuilder: def __init__(self: Self) -> None: """Initialize new instance of `ConnectionPoolBuilder`.""" + def build(self: Self) -> ConnectionPool: """ Build `ConnectionPool`. @@ -1397,6 +1433,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPool` """ + def max_pool_size(self: Self, pool_size: int) -> Self: """ Set maximum connection pool size. @@ -1407,6 +1444,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def conn_recycling_method( self: Self, conn_recycling_method: ConnRecyclingMethod, @@ -1422,6 +1460,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def user(self: Self, user: str) -> Self: """ Set username to `PostgreSQL`. @@ -1432,6 +1471,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def password(self: Self, password: str) -> Self: """ Set password for `PostgreSQL`. @@ -1442,6 +1482,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def dbname(self: Self, dbname: str) -> Self: """ Set database name for the `PostgreSQL`. @@ -1452,6 +1493,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def options(self: Self, options: str) -> Self: """ Set command line options used to configure the server. @@ -1462,6 +1504,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def application_name(self: Self, application_name: str) -> Self: """ Set the value of the `application_name` runtime parameter. @@ -1472,6 +1515,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def ssl_mode(self: Self, ssl_mode: SslMode) -> Self: """ Set the SSL configuration. @@ -1482,6 +1526,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def ca_file(self: Self, ca_file: str) -> Self: """ Set ca_file for SSL. @@ -1492,6 +1537,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def host(self: Self, host: str) -> Self: """ Add a host to the configuration. @@ -1509,6 +1555,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def hostaddr(self: Self, hostaddr: IPv4Address | IPv6Address) -> Self: """ Add a hostaddr to the configuration. @@ -1524,6 +1571,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def port(self: Self, port: int) -> Self: """ Add a port to the configuration. @@ -1540,6 +1588,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def connect_timeout(self: Self, connect_timeout: int) -> Self: """ Set the timeout applied to socket-level connection attempts. @@ -1554,6 +1603,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def tcp_user_timeout(self: Self, tcp_user_timeout: int) -> Self: """ Set the TCP user timeout. @@ -1569,6 +1619,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def target_session_attrs( self: Self, target_session_attrs: TargetSessionAttrs, @@ -1586,6 +1637,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def load_balance_hosts( self: Self, load_balance_hosts: LoadBalanceHosts, @@ -1601,6 +1653,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def keepalives( self: Self, keepalives: bool, @@ -1618,6 +1671,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def keepalives_idle( self: Self, keepalives_idle: int, @@ -1636,6 +1690,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def keepalives_interval( self: Self, keepalives_interval: int, @@ -1655,6 +1710,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def keepalives_retries( self: Self, keepalives_retries: int, @@ -1747,11 +1803,13 @@ class Listener: Each listener MUST be started up. """ + async def shutdown(self: Self) -> None: """Shutdown the listener. Abort listen and release underlying connection. """ + async def add_callback( self: Self, channel: str, @@ -1814,7 +1872,9 @@ class Column: class PreparedStatement: async def execute(self: Self) -> QueryResult: """Execute prepared statement.""" + def cursor(self: Self) -> Cursor: """Create new server-side cursor based on prepared statement.""" + def columns(self: Self) -> list[Column]: """Return information about statement columns.""" diff --git a/python/tests/test_query_results.py b/python/tests/test_query_results.py index 95de93c7..ff136fb4 100644 --- a/python/tests/test_query_results.py +++ b/python/tests/test_query_results.py @@ -39,7 +39,7 @@ async def test_result_as_tuple( assert isinstance(conn_result, QueryResult) assert isinstance(single_tuple_row, tuple) - assert single_tuple_row[0][0] == "id" + assert single_tuple_row[0] == 1 async def test_single_result_as_dict( @@ -73,4 +73,4 @@ async def test_single_result_as_tuple( assert isinstance(conn_result, SingleQueryResult) assert isinstance(result_tuple, tuple) - assert result_tuple[0][0] == "id" + assert result_tuple[0] == 1 diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index 07833848..122201ef 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -672,7 +672,7 @@ def point_encoder(point_bytes: bytes) -> str: # noqa: ARG001 as_tuple=True, ) - assert result[0][0][1] == "Just An Example" + assert result[0][0] == "Just An Example" async def test_row_factory_query_result( diff --git a/src/connection/impls.rs b/src/connection/impls.rs index 62d9a830..ebeb9bb0 100644 --- a/src/connection/impls.rs +++ b/src/connection/impls.rs @@ -409,42 +409,62 @@ impl PSQLPyConnection { parameters: Option>>, prepared: Option, ) -> PSQLPyResult<()> { - let mut statements: Vec = vec![]; - if let Some(parameters) = parameters { - for vec_of_py_any in parameters { - // TODO: Fix multiple qs creation - let statement = - StatementBuilder::new(&querystring, &Some(vec_of_py_any), self, prepared) - .build() - .await?; - - statements.push(statement); - } - } + let Some(parameters) = parameters else { + return Ok(()); + }; let prepared = prepared.unwrap_or(true); - for statement in statements { - let querystring_result = if prepared { - let prepared_stmt = &self.prepare(statement.raw_query(), true).await; - if let Err(error) = prepared_stmt { - return Err(RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot prepare statement in execute_many, operation rolled back {error}", - ))); - } - self.query( - &self.prepare(statement.raw_query(), true).await?, - &statement.params(), - ) - .await - } else { - self.query(statement.raw_query(), &statement.params()).await - }; + let mut statements: Vec = Vec::with_capacity(parameters.len()); + + for param_set in parameters { + let statement = + StatementBuilder::new(&querystring, &Some(param_set), self, Some(prepared)) + .build() + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot build statement in execute_many: {err}" + )) + })?; + statements.push(statement); + } + + if statements.is_empty() { + return Ok(()); + } - if let Err(error) = querystring_result { - return Err(RustPSQLDriverError::ConnectionExecuteError(format!( - "Error occured in `execute_many` statement: {error}" - ))); + if prepared { + let first_statement = &statements[0]; + let prepared_stmt = self + .prepare(first_statement.raw_query(), true) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot prepare statement in execute_many: {err}" + )) + })?; + + // Execute all statements using the same prepared statement + for statement in statements { + self.query(&prepared_stmt, &statement.params()) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Error occurred in `execute_many` statement: {err}" + )) + })?; + } + } else { + // Execute each statement without preparation + for statement in statements { + self.query(statement.raw_query(), &statement.params()) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Error occurred in `execute_many` statement: {err}" + )) + })?; } } diff --git a/src/driver/common.rs b/src/driver/common.rs index 3c22517a..b2ff6a52 100644 --- a/src/driver/common.rs +++ b/src/driver/common.rs @@ -16,7 +16,7 @@ use crate::{ use bytes::BytesMut; use futures_util::pin_mut; -use pyo3::{buffer::PyBuffer, PyErr, Python}; +use pyo3::{buffer::PyBuffer, Python}; use tokio_postgres::binary_copy::BinaryCopyInWriter; use crate::format_helpers::quote_ident; @@ -77,19 +77,13 @@ macro_rules! impl_config_py_methods { #[cfg(not(unix))] #[getter] fn hosts(&self) -> Vec { - let mut hosts_vec = vec![]; - - let hosts = self.pg_config.get_hosts(); - for host in hosts { - match host { - Host::Tcp(host) => { - hosts_vec.push(host.to_string()); - } - _ => unreachable!(), - } - } - - hosts_vec + self.pg_config + .get_hosts() + .iter() + .map(|host| match host { + Host::Tcp(host) => host.to_string(), + }) + .collect() } #[getter] @@ -255,50 +249,70 @@ macro_rules! impl_binary_copy_method { columns: Option>, schema_name: Option, ) -> PSQLPyResult { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).conn.clone()); - let mut table_name = quote_ident(&table_name); - if let Some(schema_name) = schema_name { - table_name = format!("{}.{}", quote_ident(&schema_name), table_name); - } - - let mut formated_columns = String::default(); - if let Some(columns) = columns { - formated_columns = format!("({})", columns.join(", ")); - } + let (db_client, mut bytes_mut) = + Python::with_gil(|gil| -> PSQLPyResult<(Option<_>, BytesMut)> { + let db_client = self_.borrow(gil).conn.clone(); + + let Some(db_client) = db_client else { + return Ok((None, BytesMut::new())); + }; + + let data_bytes_mut = + if let Ok(py_buffer) = source.extract::>(gil) { + let buffer_len = py_buffer.len_bytes(); + let mut bytes_mut = BytesMut::zeroed(buffer_len); + + py_buffer.copy_to_slice(gil, &mut bytes_mut[..])?; + bytes_mut + } else if let Ok(py_bytes) = source.call_method0(gil, "getvalue") { + if let Ok(bytes_vec) = py_bytes.extract::>(gil) { + let bytes_mut = BytesMut::from(&bytes_vec[..]); + bytes_mut + } else { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "source must be bytes or support Buffer protocol".into(), + )); + } + } else { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "source must be bytes or support Buffer protocol".into(), + )); + }; + + Ok((Some(db_client), data_bytes_mut)) + })?; - let copy_qs = - format!("COPY {table_name}{formated_columns} FROM STDIN (FORMAT binary)"); + let Some(db_client) = db_client else { + return Ok(0); + }; - if let Some(db_client) = db_client { - let mut psql_bytes: BytesMut = Python::with_gil(|gil| { - let possible_py_buffer: Result, PyErr> = - source.extract::>(gil); - if let Ok(py_buffer) = possible_py_buffer { - let vec_buf = py_buffer.to_vec(gil)?; - return Ok(BytesMut::from(vec_buf.as_slice())); - } + let full_table_name = match schema_name { + Some(schema) => { + format!("{}.{}", quote_ident(&schema), quote_ident(&table_name)) + } + None => quote_ident(&table_name), + }; - if let Ok(py_bytes) = source.call_method0(gil, "getvalue") { - if let Ok(bytes) = py_bytes.extract::>(gil) { - return Ok(BytesMut::from(bytes.as_slice())); - } - } + let copy_qs = match columns { + Some(ref cols) if !cols.is_empty() => { + format!( + "COPY {}({}) FROM STDIN (FORMAT binary)", + full_table_name, + cols.join(", ") + ) + } + _ => format!("COPY {} FROM STDIN (FORMAT binary)", full_table_name), + }; - Err(RustPSQLDriverError::PyToRustValueConversionError( - "source must be bytes or support Buffer protocol".into(), - )) - })?; + let read_conn_g = db_client.read().await; + let sink = read_conn_g.copy_in(©_qs).await?; + let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]); + pin_mut!(writer); - let read_conn_g = db_client.read().await; - let sink = read_conn_g.copy_in(©_qs).await?; - let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]); - pin_mut!(writer); - writer.as_mut().write_raw_bytes(&mut psql_bytes).await?; - let rows_created = writer.as_mut().finish_empty().await?; - return Ok(rows_created); - } + writer.as_mut().write_raw_bytes(&mut bytes_mut).await?; + let rows_created = writer.as_mut().finish_empty().await?; - Ok(0) + Ok(rows_created) } } }; diff --git a/src/query_result.rs b/src/query_result.rs index a5af132d..46047848 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -42,14 +42,15 @@ fn row_to_tuple<'a>( postgres_row: &'a Row, custom_decoders: &Option>, ) -> PSQLPyResult> { - let mut rows: Vec> = vec![]; + let columns = postgres_row.columns(); + let mut tuple_items = Vec::with_capacity(columns.len()); - for (column_idx, column) in postgres_row.columns().iter().enumerate() { - let python_type = postgres_to_py(py, postgres_row, column, column_idx, custom_decoders)?; - let timed_tuple = PyTuple::new(py, vec![column.name().into_py_any(py)?, python_type])?; - rows.push(timed_tuple); + for (column_idx, column) in columns.iter().enumerate() { + let python_value = postgres_to_py(py, postgres_row, column, column_idx, custom_decoders)?; + tuple_items.push(python_value); } - Ok(PyTuple::new(py, rows)?) + + Ok(PyTuple::new(py, tuple_items)?) } #[pyclass(name = "QueryResult")] diff --git a/src/value_converter/dto/funcs.rs b/src/value_converter/dto/funcs.rs index eec045e0..e869966c 100644 --- a/src/value_converter/dto/funcs.rs +++ b/src/value_converter/dto/funcs.rs @@ -4,7 +4,7 @@ use postgres_types::Type; pub fn array_type_to_single_type(array_type: &Type) -> Type { match *array_type { Type::BOOL_ARRAY => Type::BOOL, - Type::UUID_ARRAY => Type::UUID_ARRAY, + Type::UUID_ARRAY => Type::UUID, Type::VARCHAR_ARRAY => Type::VARCHAR, Type::TEXT_ARRAY => Type::TEXT, Type::INT2_ARRAY => Type::INT2, diff --git a/src/value_converter/models/serde_value.rs b/src/value_converter/models/serde_value.rs index 222ffe56..ebe95f58 100644 --- a/src/value_converter/models/serde_value.rs +++ b/src/value_converter/models/serde_value.rs @@ -3,8 +3,8 @@ use postgres_types::FromSql; use serde_json::{json, Map, Value}; use pyo3::{ - types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyTuple}, - Bound, FromPyObject, IntoPyObject, Py, PyAny, PyResult, Python, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods}, + Bound, FromPyObject, IntoPyObject, PyAny, PyResult, Python, }; use tokio_postgres::types::Type; @@ -37,7 +37,7 @@ impl<'py> IntoPyObject<'py> for InternalSerdeValue { type Error = RustPSQLDriverError; fn into_pyobject(self, py: Python<'py>) -> Result { - match build_python_from_serde_value(py, self.0.clone()) { + match build_python_from_serde_value(py, self.0) { Ok(ok_value) => Ok(ok_value.bind(py).clone()), Err(err) => Err(err), } @@ -57,25 +57,29 @@ impl<'a> FromSql<'a> for InternalSerdeValue { } } -fn serde_value_from_list(gil: Python<'_>, bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { - let mut result_vec: Vec = vec![]; +fn serde_value_from_list(_gil: Python<'_>, bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { + let py_list = bind_value.downcast::().map_err(|e| { + RustPSQLDriverError::PyToRustValueConversionError(format!( + "Parameter must be a list, but it's not: {e}" + )) + })?; - let params = bind_value.extract::>>()?; + let mut result_vec: Vec = Vec::with_capacity(py_list.len()); - for inner in params { - let inner_bind = inner.bind(gil); - if inner_bind.is_instance_of::() { - let python_dto = from_python_untyped(inner_bind)?; + for item in py_list.iter() { + if item.is_instance_of::() { + let python_dto = from_python_untyped(&item)?; result_vec.push(python_dto.to_serde_value()?); - } else if inner_bind.is_instance_of::() { - let serde_value = build_serde_value(inner.bind(gil))?; + } else if item.is_instance_of::() { + let serde_value = build_serde_value(&item)?; result_vec.push(serde_value); } else { return Err(RustPSQLDriverError::PyToRustValueConversionError( - "PyJSON must have dicts.".to_string(), + "Items in JSON array must be dicts or lists.".to_string(), )); } } + Ok(json!(result_vec)) } @@ -86,19 +90,18 @@ fn serde_value_from_dict(bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { )) })?; - let mut serde_map: Map = Map::new(); + let dict_len = dict.len(); + let mut serde_map: Map = Map::with_capacity(dict_len); - for dict_item in dict.items() { - let py_list = dict_item.downcast::().map_err(|error| { + for (key, value) in dict.iter() { + let key_str = key.extract::().map_err(|error| { RustPSQLDriverError::PyToRustValueConversionError(format!( - "Cannot cast to list: {error}" + "Cannot extract dict key as string: {error}" )) })?; - let key = py_list.get_item(0)?.extract::()?; - let value = from_python_untyped(&py_list.get_item(1)?)?; - - serde_map.insert(key, value.to_serde_value()?); + let value_dto = from_python_untyped(&value)?; + serde_map.insert(key_str, value_dto.to_serde_value()?); } Ok(Value::Object(serde_map)) @@ -131,12 +134,10 @@ pub fn build_serde_value(value: &Bound<'_, PyAny>) -> PSQLPyResult { /// May return error if cannot create serde value. pub fn pythondto_array_to_serde(array: Option>) -> PSQLPyResult { match array { - Some(array) => inner_pythondto_array_to_serde( - array.dimensions(), - array.iter().collect::>().as_slice(), - 0, - 0, - ), + Some(array) => { + let data: Vec = array.iter().cloned().collect(); + inner_pythondto_array_to_serde(array.dimensions(), &data, 0, 0) + } None => Ok(Value::Null), } } @@ -145,41 +146,49 @@ pub fn pythondto_array_to_serde(array: Option>) -> PSQLPyResult #[allow(clippy::cast_sign_loss)] fn inner_pythondto_array_to_serde( dimensions: &[Dimension], - data: &[&PythonDTO], + data: &[PythonDTO], dimension_index: usize, - mut lower_bound: usize, + data_offset: usize, ) -> PSQLPyResult { - let current_dimension = dimensions.get(dimension_index); - - if let Some(current_dimension) = current_dimension { - let possible_next_dimension = dimensions.get(dimension_index + 1); - match possible_next_dimension { - Some(next_dimension) => { - let mut final_list: Value = Value::Array(vec![]); - - for _ in 0..current_dimension.len as usize { - if dimensions.get(dimension_index + 1).is_some() { - let inner_pylist = inner_pythondto_array_to_serde( - dimensions, - &data[lower_bound..next_dimension.len as usize + lower_bound], - dimension_index + 1, - 0, - )?; - match final_list { - Value::Array(ref mut array) => array.push(inner_pylist), - _ => unreachable!(), - } - lower_bound += next_dimension.len as usize; - } - } - - return Ok(final_list); - } - None => { - return data.iter().map(|x| x.to_serde_value()).collect(); - } + if dimension_index >= dimensions.len() || data_offset >= data.len() { + return Ok(Value::Array(vec![])); + } + + let current_dimension = &dimensions[dimension_index]; + let current_len = current_dimension.len as usize; + + if dimension_index + 1 >= dimensions.len() { + let end_offset = (data_offset + current_len).min(data.len()); + let slice = &data[data_offset..end_offset]; + + let mut result_values = Vec::with_capacity(slice.len()); + for item in slice { + result_values.push(item.to_serde_value()?); } + + return Ok(Value::Array(result_values)); + } + + let mut final_array = Vec::with_capacity(current_len); + + let sub_array_size = dimensions[dimension_index + 1..] + .iter() + .map(|d| d.len as usize) + .product::(); + + let mut current_offset = data_offset; + + for _ in 0..current_len { + if current_offset >= data.len() { + break; + } + + let inner_value = + inner_pythondto_array_to_serde(dimensions, data, dimension_index + 1, current_offset)?; + + final_array.push(inner_value); + current_offset += sub_array_size; } - Ok(Value::Array(vec![])) + Ok(Value::Array(final_array)) } diff --git a/src/value_converter/to_python.rs b/src/value_converter/to_python.rs index abc734c8..cf0f6d35 100644 --- a/src/value_converter/to_python.rs +++ b/src/value_converter/to_python.rs @@ -95,13 +95,9 @@ fn postgres_array_to_py<'py, T: IntoPyObject<'py> + Clone>( array: Option>, ) -> Option> { array.map(|array| { - inner_postgres_array_to_py( - py, - array.dimensions(), - array.iter().cloned().collect::>(), - 0, - 0, - ) + // Collect data once instead of creating copies in recursion + let data: Vec = array.iter().cloned().collect(); + inner_postgres_array_to_py(py, array.dimensions(), &data, 0, 0) }) } @@ -110,44 +106,60 @@ fn postgres_array_to_py<'py, T: IntoPyObject<'py> + Clone>( fn inner_postgres_array_to_py<'py, T>( py: Python<'py>, dimensions: &[Dimension], - data: Vec, + data: &[T], dimension_index: usize, - mut lower_bound: usize, + data_offset: usize, ) -> Py where T: IntoPyObject<'py> + Clone, { - let current_dimension = dimensions.get(dimension_index); - - if let Some(current_dimension) = current_dimension { - let possible_next_dimension = dimensions.get(dimension_index + 1); - match possible_next_dimension { - Some(next_dimension) => { - let final_list = PyList::empty(py); - - for _ in 0..current_dimension.len as usize { - if dimensions.get(dimension_index + 1).is_some() { - let inner_pylist = inner_postgres_array_to_py( - py, - dimensions, - data[lower_bound..next_dimension.len as usize + lower_bound].to_vec(), - dimension_index + 1, - 0, - ); - final_list.append(inner_pylist).unwrap(); - lower_bound += next_dimension.len as usize; - } - } - - return final_list.unbind(); - } - None => { - return PyList::new(py, data).unwrap().unbind(); // TODO unwrap is unsafe - } + // Check bounds early + if dimension_index >= dimensions.len() || data_offset >= data.len() { + return PyList::empty(py).unbind(); + } + + let current_dimension = &dimensions[dimension_index]; + let current_len = current_dimension.len as usize; + + // If this is the last dimension, create a list with the actual data + if dimension_index + 1 >= dimensions.len() { + let end_offset = (data_offset + current_len).min(data.len()); + let slice = &data[data_offset..end_offset]; + + // Create Python list more efficiently + return match PyList::new(py, slice.iter().cloned()) { + Ok(list) => list.unbind(), + Err(_) => PyList::empty(py).unbind(), + }; + } + + // For multi-dimensional arrays, recursively create nested lists + let final_list = PyList::empty(py); + + // Calculate the size of each sub-array + let sub_array_size = dimensions[dimension_index + 1..] + .iter() + .map(|d| d.len as usize) + .product::(); + + let mut current_offset = data_offset; + + for _ in 0..current_len { + if current_offset >= data.len() { + break; } + + let inner_list = + inner_postgres_array_to_py(py, dimensions, data, dimension_index + 1, current_offset); + + if final_list.append(inner_list).is_err() { + break; + } + + current_offset += sub_array_size; } - PyList::empty(py).unbind() + final_list.unbind() } #[allow(clippy::too_many_lines)] @@ -604,7 +616,7 @@ pub fn raw_bytes_data_process( if let Ok(Some(py_encoder_func)) = py_encoder_func { return Ok(py_encoder_func - .call((raw_bytes_data.to_vec(),), None)? + .call1((PyBytes::new(py, raw_bytes_data),))? .unbind()); } } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy