diff --git a/docs/components/results.md b/docs/components/results.md index 15bc3690..8a28a72e 100644 --- a/docs/components/results.md +++ b/docs/components/results.md @@ -14,7 +14,9 @@ Currently there are two results: ### Result #### Parameters + - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) +- `as_tuple`: Headless tuple output Get the result as a list of dicts @@ -27,12 +29,19 @@ async def main() -> None: [], ) - result: List[Dict[str, Any]] = query_result.result() + # Result as dict + list_dict_result: List[Dict[str, Any]] = query_result.result() + + # Result as tuple + list_tuple_result: List[Tuple[t.Any, ...]] = query_result.result( + as_tuple=True, + ) ``` ### As class #### Parameters + - `as_class`: Custom class from Python. - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) @@ -61,6 +70,7 @@ async def main() -> None: ### Row Factory #### Parameters + - `row_factory`: custom callable object. - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) @@ -71,7 +81,9 @@ async def main() -> None: ### Result #### Parameters + - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) +- `as_tuple`: Headless tuple output Get the result as a dict @@ -84,12 +96,19 @@ async def main() -> None: [100], ) - result: Dict[str, Any] = query_result.result() + # Result as dict + dict_result: Dict[str, Any] = query_result.result() + + # Result as tuple + tuple_result: Tuple[typing.Any, ...] = query_result.result( + as_tuple=True, + ) ``` ### As class #### Parameters + - `as_class`: Custom class from Python. - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) @@ -117,6 +136,7 @@ async def main() -> None: ### Row Factory #### Parameters + - `row_factory`: custom callable object. - `custom_decoders`: custom decoders for unsupported types. [Read more](/usage/types/advanced_type_usage.md) diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index ddb74de1..17a2d482 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -1,4 +1,5 @@ import types +import typing from enum import Enum from io import BytesIO from ipaddress import IPv4Address, IPv6Address @@ -18,15 +19,38 @@ ParamsT: TypeAlias = Sequence[Any] | Mapping[str, Any] | None class QueryResult: """Result.""" + @typing.overload def result( self: Self, + as_tuple: typing.Literal[None] = None, custom_decoders: dict[str, Callable[[bytes], Any]] | None = None, - ) -> list[dict[Any, Any]]: - """Return result from database as a list of dicts. + ) -> list[dict[str, Any]]: ... + @typing.overload + def result( + self: Self, + as_tuple: typing.Literal[False], + custom_decoders: dict[str, Callable[[bytes], Any]] | None = None, + ) -> list[dict[str, Any]]: ... + @typing.overload + def result( + self: Self, + as_tuple: typing.Literal[True], + custom_decoders: dict[str, Callable[[bytes], Any]] | None = None, + ) -> list[tuple[typing.Any, ...]]: ... + @typing.overload + def result( + self: Self, + custom_decoders: dict[str, Callable[[bytes], Any]] | None = None, + as_tuple: bool | None = None, + ) -> list[dict[str, Any]]: + """Return result from database. + + By default it returns result as a list of dicts. `custom_decoders` must be used when you use PostgreSQL Type which isn't supported, read more in our docs. """ + def as_class( self: Self, as_class: Callable[..., _CustomClass], @@ -60,6 +84,7 @@ class QueryResult: ) ``` """ + def row_factory( self, row_factory: Callable[[dict[str, Any]], _RowFactoryRV], @@ -84,15 +109,38 @@ class QueryResult: class SingleQueryResult: """Single result.""" + @typing.overload + def result( + self: Self, + as_tuple: typing.Literal[None] = None, + custom_decoders: dict[str, Callable[[bytes], Any]] | None = None, + ) -> dict[str, Any]: ... + @typing.overload + def result( + self: Self, + as_tuple: typing.Literal[False], + custom_decoders: dict[str, Callable[[bytes], Any]] | None = None, + ) -> dict[str, Any]: ... + @typing.overload + def result( + self: Self, + as_tuple: typing.Literal[True], + custom_decoders: dict[str, Callable[[bytes], Any]] | None = None, + ) -> tuple[typing.Any, ...]: ... + @typing.overload def result( self: Self, custom_decoders: dict[str, Callable[[bytes], Any]] | None = None, + as_tuple: bool | None = None, ) -> dict[Any, Any]: - """Return result from database as a dict. + """Return result from database. + + By default it returns result as a dict. `custom_decoders` must be used when you use PostgreSQL Type which isn't supported, read more in our docs. """ + def as_class( self: Self, as_class: Callable[..., _CustomClass], @@ -129,6 +177,7 @@ class SingleQueryResult: ) ``` """ + def row_factory( self, row_factory: Callable[[dict[str, Any]], _RowFactoryRV], @@ -283,11 +332,13 @@ class Cursor: Execute DECLARE command for the cursor. """ + def close(self: Self) -> None: """Close the cursor. Execute CLOSE command for the cursor. """ + async def execute( self: Self, querystring: str, @@ -298,10 +349,13 @@ class Cursor: Method should be used instead of context manager and `start` method. """ + async def fetchone(self: Self) -> QueryResult: """Return next one row from the cursor.""" + async def fetchmany(self: Self, size: int | None = None) -> QueryResult: """Return rows from the cursor.""" + async def fetchall(self: Self, size: int | None = None) -> QueryResult: """Return all remaining rows from the cursor.""" @@ -334,6 +388,7 @@ class Transaction: `begin()` can be called only once per transaction. """ + async def commit(self: Self) -> None: """Commit the transaction. @@ -341,6 +396,7 @@ class Transaction: `commit()` can be called only once per transaction. """ + async def rollback(self: Self) -> None: """Rollback all queries in the transaction. @@ -361,6 +417,7 @@ class Transaction: await transaction.rollback() ``` """ + async def execute( self: Self, querystring: str, @@ -398,6 +455,7 @@ class Transaction: await transaction.commit() ``` """ + async def execute_batch( self: Self, querystring: str, @@ -413,6 +471,7 @@ class Transaction: ### Parameters: - `querystring`: querystrings separated by semicolons. """ + async def execute_many( self: Self, querystring: str, @@ -471,6 +530,7 @@ class Transaction: - `prepared`: should the querystring be prepared before the request. By default any querystring will be prepared. """ + async def fetch_row( self: Self, querystring: str, @@ -510,6 +570,7 @@ class Transaction: await transaction.commit() ``` """ + async def fetch_val( self: Self, querystring: str, @@ -550,6 +611,7 @@ class Transaction: ) ``` """ + async def pipeline( self, queries: list[tuple[str, list[Any] | None]], @@ -614,6 +676,7 @@ class Transaction: ) ``` """ + async def create_savepoint(self: Self, savepoint_name: str) -> None: """Create new savepoint. @@ -642,6 +705,7 @@ class Transaction: await transaction.rollback_savepoint("my_savepoint") ``` """ + async def rollback_savepoint(self: Self, savepoint_name: str) -> None: """ROLLBACK to the specified `savepoint_name`. @@ -667,6 +731,7 @@ class Transaction: await transaction.rollback_savepoint("my_savepoint") ``` """ + async def release_savepoint(self: Self, savepoint_name: str) -> None: """Execute ROLLBACK TO SAVEPOINT. @@ -691,6 +756,7 @@ class Transaction: await transaction.release_savepoint ``` """ + def cursor( self: Self, querystring: str, @@ -734,6 +800,7 @@ class Transaction: await cursor.close() ``` """ + async def binary_copy_to_table( self: Self, source: bytes | bytearray | Buffer | BytesIO, @@ -815,6 +882,7 @@ class Connection: Return representation of prepared statement. """ + async def execute( self: Self, querystring: str, @@ -851,6 +919,7 @@ class Connection: dict_result: List[Dict[Any, Any]] = query_result.result() ``` """ + async def execute_batch( self: Self, querystring: str, @@ -866,6 +935,7 @@ class Connection: ### Parameters: - `querystring`: querystrings separated by semicolons. """ + async def execute_many( self: Self, querystring: str, @@ -919,6 +989,7 @@ class Connection: - `prepared`: should the querystring be prepared before the request. By default any querystring will be prepared. """ + async def fetch_row( self: Self, querystring: str, @@ -955,6 +1026,7 @@ class Connection: dict_result: Dict[Any, Any] = query_result.result() ``` """ + async def fetch_val( self: Self, querystring: str, @@ -994,6 +1066,7 @@ class Connection: ) ``` """ + def transaction( self, isolation_level: IsolationLevel | None = None, @@ -1007,6 +1080,7 @@ class Connection: - `read_variant`: configure read variant of the transaction. - `deferrable`: configure deferrable of the transaction. """ + def cursor( self: Self, querystring: str, @@ -1045,6 +1119,7 @@ class Connection: ... # do something with this result. ``` """ + def close(self: Self) -> None: """Return connection back to the pool. @@ -1189,6 +1264,7 @@ class ConnectionPool: - `ca_file`: Loads trusted root certificates from a file. The file should contain a sequence of PEM-formatted CA certificates. """ + def __iter__(self: Self) -> Self: ... def __enter__(self: Self) -> Self: ... def __exit__( @@ -1203,6 +1279,7 @@ class ConnectionPool: ### Returns `ConnectionPoolStatus` """ + def resize(self: Self, new_max_size: int) -> None: """Resize the connection pool. @@ -1212,11 +1289,13 @@ class ConnectionPool: ### Parameters: - `new_max_size`: new size for the connection pool. """ + async def connection(self: Self) -> Connection: """Create new connection. It acquires new connection from the database pool. """ + def acquire(self: Self) -> Connection: """Create new connection for async context manager. @@ -1234,6 +1313,7 @@ class ConnectionPool: res = await connection.execute(...) ``` """ + def listener(self: Self) -> Listener: """Create new listener.""" @@ -1345,6 +1425,7 @@ class ConnectionPoolBuilder: def __init__(self: Self) -> None: """Initialize new instance of `ConnectionPoolBuilder`.""" + def build(self: Self) -> ConnectionPool: """ Build `ConnectionPool`. @@ -1352,6 +1433,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPool` """ + def max_pool_size(self: Self, pool_size: int) -> Self: """ Set maximum connection pool size. @@ -1362,6 +1444,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def conn_recycling_method( self: Self, conn_recycling_method: ConnRecyclingMethod, @@ -1377,6 +1460,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def user(self: Self, user: str) -> Self: """ Set username to `PostgreSQL`. @@ -1387,6 +1471,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def password(self: Self, password: str) -> Self: """ Set password for `PostgreSQL`. @@ -1397,6 +1482,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def dbname(self: Self, dbname: str) -> Self: """ Set database name for the `PostgreSQL`. @@ -1407,6 +1493,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def options(self: Self, options: str) -> Self: """ Set command line options used to configure the server. @@ -1417,6 +1504,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def application_name(self: Self, application_name: str) -> Self: """ Set the value of the `application_name` runtime parameter. @@ -1427,6 +1515,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def ssl_mode(self: Self, ssl_mode: SslMode) -> Self: """ Set the SSL configuration. @@ -1437,6 +1526,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def ca_file(self: Self, ca_file: str) -> Self: """ Set ca_file for SSL. @@ -1447,6 +1537,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def host(self: Self, host: str) -> Self: """ Add a host to the configuration. @@ -1464,6 +1555,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def hostaddr(self: Self, hostaddr: IPv4Address | IPv6Address) -> Self: """ Add a hostaddr to the configuration. @@ -1479,6 +1571,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def port(self: Self, port: int) -> Self: """ Add a port to the configuration. @@ -1495,6 +1588,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def connect_timeout(self: Self, connect_timeout: int) -> Self: """ Set the timeout applied to socket-level connection attempts. @@ -1509,6 +1603,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def tcp_user_timeout(self: Self, tcp_user_timeout: int) -> Self: """ Set the TCP user timeout. @@ -1524,6 +1619,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def target_session_attrs( self: Self, target_session_attrs: TargetSessionAttrs, @@ -1541,6 +1637,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def load_balance_hosts( self: Self, load_balance_hosts: LoadBalanceHosts, @@ -1556,6 +1653,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def keepalives( self: Self, keepalives: bool, @@ -1573,6 +1671,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def keepalives_idle( self: Self, keepalives_idle: int, @@ -1591,6 +1690,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def keepalives_interval( self: Self, keepalives_interval: int, @@ -1610,6 +1710,7 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + def keepalives_retries( self: Self, keepalives_retries: int, @@ -1702,11 +1803,13 @@ class Listener: Each listener MUST be started up. """ + async def shutdown(self: Self) -> None: """Shutdown the listener. Abort listen and release underlying connection. """ + async def add_callback( self: Self, channel: str, @@ -1769,7 +1872,9 @@ class Column: class PreparedStatement: async def execute(self: Self) -> QueryResult: """Execute prepared statement.""" + def cursor(self: Self) -> Cursor: """Create new server-side cursor based on prepared statement.""" + def columns(self: Self) -> list[Column]: """Return information about statement columns.""" diff --git a/python/tests/conftest.py b/python/tests/conftest.py index efb7f6e3..31cb31e1 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -134,21 +134,24 @@ async def create_default_data_for_tests( table_name: str, number_database_records: int, ) -> AsyncGenerator[None, None]: - connection = await psql_pool.connection() - await connection.execute( - f"CREATE TABLE {table_name} (id SERIAL, name VARCHAR(255))", - ) - - for table_id in range(1, number_database_records + 1): - new_name = random_string() + async with psql_pool.acquire() as connection: await connection.execute( - querystring=f"INSERT INTO {table_name} VALUES ($1, $2)", - parameters=[table_id, new_name], + f"CREATE TABLE {table_name} (id SERIAL, name VARCHAR(255))", ) + + for table_id in range(1, number_database_records + 1): + new_name = random_string() + await connection.execute( + querystring=f"INSERT INTO {table_name} VALUES ($1, $2)", + parameters=[table_id, new_name], + ) + yield - await connection.execute( - f"DROP TABLE {table_name}", - ) + + async with psql_pool.acquire() as connection: + await connection.execute( + f"DROP TABLE {table_name}", + ) @pytest.fixture @@ -156,17 +159,19 @@ async def create_table_for_listener_tests( psql_pool: ConnectionPool, listener_table_name: str, ) -> AsyncGenerator[None, None]: - connection = await psql_pool.connection() - await connection.execute( - f"CREATE TABLE {listener_table_name}" - f"(id SERIAL, payload VARCHAR(255)," - f"channel VARCHAR(255), process_id INT)", - ) + async with psql_pool.acquire() as connection: + await connection.execute( + f"CREATE TABLE {listener_table_name}" + f"(id SERIAL, payload VARCHAR(255)," + f"channel VARCHAR(255), process_id INT)", + ) yield - await connection.execute( - f"DROP TABLE {listener_table_name}", - ) + + async with psql_pool.acquire() as connection: + await connection.execute( + f"DROP TABLE {listener_table_name}", + ) @pytest.fixture @@ -174,16 +179,18 @@ async def create_table_for_map_parameters_test( psql_pool: ConnectionPool, map_parameters_table_name: str, ) -> AsyncGenerator[None, None]: - connection = await psql_pool.connection() - await connection.execute( - f"CREATE TABLE {map_parameters_table_name}" - "(id SERIAL, name VARCHAR(255),surname VARCHAR(255), age SMALLINT)", - ) + async with psql_pool.acquire() as connection: + await connection.execute( + f"CREATE TABLE {map_parameters_table_name}" + "(id SERIAL, name VARCHAR(255),surname VARCHAR(255), age SMALLINT)", + ) yield - await connection.execute( - f"DROP TABLE {map_parameters_table_name}", - ) + + async with psql_pool.acquire() as connection: + await connection.execute( + f"DROP TABLE {map_parameters_table_name}", + ) @pytest.fixture @@ -191,12 +198,12 @@ async def test_cursor( psql_pool: ConnectionPool, table_name: str, ) -> AsyncGenerator[Cursor, None]: - connection = await psql_pool.connection() - transaction = connection.transaction() - await transaction.begin() - cursor = transaction.cursor( - querystring=f"SELECT * FROM {table_name}", - ) - await cursor.start() - yield cursor - await transaction.commit() + async with psql_pool.acquire() as connection: + transaction = connection.transaction() + await transaction.begin() + cursor = transaction.cursor( + querystring=f"SELECT * FROM {table_name}", + ) + await cursor.start() + yield cursor + await transaction.commit() diff --git a/python/tests/test_query_results.py b/python/tests/test_query_results.py new file mode 100644 index 00000000..ff136fb4 --- /dev/null +++ b/python/tests/test_query_results.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import pytest +from psqlpy import ConnectionPool, QueryResult, SingleQueryResult + +pytestmark = pytest.mark.anyio + + +async def test_result_as_dict( + psql_pool: ConnectionPool, + table_name: str, +) -> None: + """Test that single connection can execute queries.""" + connection = await psql_pool.connection() + + conn_result = await connection.execute( + querystring=f"SELECT * FROM {table_name}", + ) + result_list_dicts = conn_result.result() + single_dict_row = result_list_dicts[0] + + assert isinstance(conn_result, QueryResult) + assert isinstance(single_dict_row, dict) + assert single_dict_row.get("id") + + +async def test_result_as_tuple( + psql_pool: ConnectionPool, + table_name: str, +) -> None: + """Test that single connection can execute queries.""" + connection = await psql_pool.connection() + + conn_result = await connection.execute( + querystring=f"SELECT * FROM {table_name}", + ) + result_tuple = conn_result.result(as_tuple=True) + single_tuple_row = result_tuple[0] + + assert isinstance(conn_result, QueryResult) + assert isinstance(single_tuple_row, tuple) + assert single_tuple_row[0] == 1 + + +async def test_single_result_as_dict( + psql_pool: ConnectionPool, + table_name: str, +) -> None: + """Test that single connection can execute queries.""" + connection = await psql_pool.connection() + + conn_result = await connection.fetch_row( + querystring=f"SELECT * FROM {table_name} LIMIT 1", + ) + result_dict = conn_result.result() + + assert isinstance(conn_result, SingleQueryResult) + assert isinstance(result_dict, dict) + assert result_dict.get("id") + + +async def test_single_result_as_tuple( + psql_pool: ConnectionPool, + table_name: str, +) -> None: + """Test that single connection can execute queries.""" + connection = await psql_pool.connection() + + conn_result = await connection.fetch_row( + querystring=f"SELECT * FROM {table_name} LIMIT 1", + ) + result_tuple = conn_result.result(as_tuple=True) + + assert isinstance(conn_result, SingleQueryResult) + assert isinstance(result_tuple, tuple) + assert result_tuple[0] == 1 diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index ce2f05ed..122201ef 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -646,6 +646,35 @@ def point_encoder(point_bytes: bytes) -> str: # noqa: ARG001 assert result[0]["geo_point"] == "Just An Example" +async def test_custom_decoder_as_tuple_result( + psql_pool: ConnectionPool, +) -> None: + def point_encoder(point_bytes: bytes) -> str: # noqa: ARG001 + return "Just An Example" + + async with psql_pool.acquire() as conn: + await conn.execute("DROP TABLE IF EXISTS for_test") + await conn.execute( + "CREATE TABLE for_test (geo_point POINT)", + ) + + await conn.execute( + "INSERT INTO for_test VALUES ('(1, 1)')", + ) + + qs_result = await conn.execute( + "SELECT * FROM for_test", + ) + result = qs_result.result( + custom_decoders={ + "geo_point": point_encoder, + }, + as_tuple=True, + ) + + assert result[0][0] == "Just An Example" + + async def test_row_factory_query_result( psql_pool: ConnectionPool, table_name: str, diff --git a/src/connection/impls.rs b/src/connection/impls.rs index 62d9a830..fd6dabec 100644 --- a/src/connection/impls.rs +++ b/src/connection/impls.rs @@ -367,6 +367,31 @@ impl PSQLPyConnection { Ok(PSQLDriverPyQueryResult::new(result)) } + /// Execute raw querystring without parameters. + /// + /// # Errors + /// May return error if there is some problem with DB communication. + pub async fn execute_no_params( + &self, + querystring: String, + prepared: Option, + ) -> PSQLPyResult { + let prepared = prepared.unwrap_or(true); + let result = if prepared { + self.query(&querystring, &[]).await + } else { + self.query_typed(&querystring, &[]).await + }; + + let return_result = result.map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot execute query, error - {err}" + )) + })?; + + Ok(PSQLDriverPyQueryResult::new(return_result)) + } + /// Execute raw query with parameters. /// /// # Errors @@ -409,42 +434,62 @@ impl PSQLPyConnection { parameters: Option>>, prepared: Option, ) -> PSQLPyResult<()> { - let mut statements: Vec = vec![]; - if let Some(parameters) = parameters { - for vec_of_py_any in parameters { - // TODO: Fix multiple qs creation - let statement = - StatementBuilder::new(&querystring, &Some(vec_of_py_any), self, prepared) - .build() - .await?; - - statements.push(statement); - } - } + let Some(parameters) = parameters else { + return Ok(()); + }; let prepared = prepared.unwrap_or(true); - for statement in statements { - let querystring_result = if prepared { - let prepared_stmt = &self.prepare(statement.raw_query(), true).await; - if let Err(error) = prepared_stmt { - return Err(RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot prepare statement in execute_many, operation rolled back {error}", - ))); - } - self.query( - &self.prepare(statement.raw_query(), true).await?, - &statement.params(), - ) - .await - } else { - self.query(statement.raw_query(), &statement.params()).await - }; + let mut statements: Vec = Vec::with_capacity(parameters.len()); + + for param_set in parameters { + let statement = + StatementBuilder::new(&querystring, &Some(param_set), self, Some(prepared)) + .build() + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot build statement in execute_many: {err}" + )) + })?; + statements.push(statement); + } - if let Err(error) = querystring_result { - return Err(RustPSQLDriverError::ConnectionExecuteError(format!( - "Error occured in `execute_many` statement: {error}" - ))); + if statements.is_empty() { + return Ok(()); + } + + if prepared { + let first_statement = &statements[0]; + let prepared_stmt = self + .prepare(first_statement.raw_query(), true) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot prepare statement in execute_many: {err}" + )) + })?; + + // Execute all statements using the same prepared statement + for statement in statements { + self.query(&prepared_stmt, &statement.params()) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Error occurred in `execute_many` statement: {err}" + )) + })?; + } + } else { + // Execute each statement without preparation + for statement in statements { + self.query(statement.raw_query(), &statement.params()) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Error occurred in `execute_many` statement: {err}" + )) + })?; } } diff --git a/src/driver/common.rs b/src/driver/common.rs index 3c22517a..b2ff6a52 100644 --- a/src/driver/common.rs +++ b/src/driver/common.rs @@ -16,7 +16,7 @@ use crate::{ use bytes::BytesMut; use futures_util::pin_mut; -use pyo3::{buffer::PyBuffer, PyErr, Python}; +use pyo3::{buffer::PyBuffer, Python}; use tokio_postgres::binary_copy::BinaryCopyInWriter; use crate::format_helpers::quote_ident; @@ -77,19 +77,13 @@ macro_rules! impl_config_py_methods { #[cfg(not(unix))] #[getter] fn hosts(&self) -> Vec { - let mut hosts_vec = vec![]; - - let hosts = self.pg_config.get_hosts(); - for host in hosts { - match host { - Host::Tcp(host) => { - hosts_vec.push(host.to_string()); - } - _ => unreachable!(), - } - } - - hosts_vec + self.pg_config + .get_hosts() + .iter() + .map(|host| match host { + Host::Tcp(host) => host.to_string(), + }) + .collect() } #[getter] @@ -255,50 +249,70 @@ macro_rules! impl_binary_copy_method { columns: Option>, schema_name: Option, ) -> PSQLPyResult { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).conn.clone()); - let mut table_name = quote_ident(&table_name); - if let Some(schema_name) = schema_name { - table_name = format!("{}.{}", quote_ident(&schema_name), table_name); - } - - let mut formated_columns = String::default(); - if let Some(columns) = columns { - formated_columns = format!("({})", columns.join(", ")); - } + let (db_client, mut bytes_mut) = + Python::with_gil(|gil| -> PSQLPyResult<(Option<_>, BytesMut)> { + let db_client = self_.borrow(gil).conn.clone(); + + let Some(db_client) = db_client else { + return Ok((None, BytesMut::new())); + }; + + let data_bytes_mut = + if let Ok(py_buffer) = source.extract::>(gil) { + let buffer_len = py_buffer.len_bytes(); + let mut bytes_mut = BytesMut::zeroed(buffer_len); + + py_buffer.copy_to_slice(gil, &mut bytes_mut[..])?; + bytes_mut + } else if let Ok(py_bytes) = source.call_method0(gil, "getvalue") { + if let Ok(bytes_vec) = py_bytes.extract::>(gil) { + let bytes_mut = BytesMut::from(&bytes_vec[..]); + bytes_mut + } else { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "source must be bytes or support Buffer protocol".into(), + )); + } + } else { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "source must be bytes or support Buffer protocol".into(), + )); + }; + + Ok((Some(db_client), data_bytes_mut)) + })?; - let copy_qs = - format!("COPY {table_name}{formated_columns} FROM STDIN (FORMAT binary)"); + let Some(db_client) = db_client else { + return Ok(0); + }; - if let Some(db_client) = db_client { - let mut psql_bytes: BytesMut = Python::with_gil(|gil| { - let possible_py_buffer: Result, PyErr> = - source.extract::>(gil); - if let Ok(py_buffer) = possible_py_buffer { - let vec_buf = py_buffer.to_vec(gil)?; - return Ok(BytesMut::from(vec_buf.as_slice())); - } + let full_table_name = match schema_name { + Some(schema) => { + format!("{}.{}", quote_ident(&schema), quote_ident(&table_name)) + } + None => quote_ident(&table_name), + }; - if let Ok(py_bytes) = source.call_method0(gil, "getvalue") { - if let Ok(bytes) = py_bytes.extract::>(gil) { - return Ok(BytesMut::from(bytes.as_slice())); - } - } + let copy_qs = match columns { + Some(ref cols) if !cols.is_empty() => { + format!( + "COPY {}({}) FROM STDIN (FORMAT binary)", + full_table_name, + cols.join(", ") + ) + } + _ => format!("COPY {} FROM STDIN (FORMAT binary)", full_table_name), + }; - Err(RustPSQLDriverError::PyToRustValueConversionError( - "source must be bytes or support Buffer protocol".into(), - )) - })?; + let read_conn_g = db_client.read().await; + let sink = read_conn_g.copy_in(©_qs).await?; + let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]); + pin_mut!(writer); - let read_conn_g = db_client.read().await; - let sink = read_conn_g.copy_in(©_qs).await?; - let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]); - pin_mut!(writer); - writer.as_mut().write_raw_bytes(&mut psql_bytes).await?; - let rows_created = writer.as_mut().finish_empty().await?; - return Ok(rows_created); - } + writer.as_mut().write_raw_bytes(&mut bytes_mut).await?; + let rows_created = writer.as_mut().finish_empty().await?; - Ok(0) + Ok(rows_created) } } }; diff --git a/src/driver/connection.rs b/src/driver/connection.rs index a89f3edd..915c9667 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -237,8 +237,13 @@ impl Connection { if let Some(db_client) = db_client { let read_conn_g = db_client.read().await; - let res = read_conn_g.execute(querystring, parameters, prepared).await; - return res; + return { + if parameters.is_some() { + read_conn_g.execute(querystring, parameters, prepared).await + } else { + read_conn_g.execute_no_params(querystring, prepared).await + } + }; } Err(RustPSQLDriverError::ConnectionClosedError) @@ -318,7 +323,13 @@ impl Connection { if let Some(db_client) = db_client { let read_conn_g = db_client.read().await; - return read_conn_g.execute(querystring, parameters, prepared).await; + return { + if parameters.is_some() { + read_conn_g.execute(querystring, parameters, prepared).await + } else { + read_conn_g.execute_no_params(querystring, prepared).await + } + }; } Err(RustPSQLDriverError::ConnectionClosedError) diff --git a/src/query_result.rs b/src/query_result.rs index b17acad9..46047848 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -1,4 +1,9 @@ -use pyo3::{prelude::*, pyclass, pymethods, types::PyDict, IntoPyObjectExt, Py, PyAny, Python}; +use pyo3::{ + prelude::*, + pyclass, pymethods, + types::{PyDict, PyTuple}, + IntoPyObjectExt, Py, PyAny, Python, +}; use tokio_postgres::Row; use crate::{exceptions::rust_errors::PSQLPyResult, value_converter::to_python::postgres_to_py}; @@ -15,7 +20,7 @@ fn row_to_dict<'a>( py: Python<'a>, postgres_row: &'a Row, custom_decoders: &Option>, -) -> PSQLPyResult> { +) -> PSQLPyResult> { let python_dict = PyDict::new(py); for (column_idx, column) in postgres_row.columns().iter().enumerate() { let python_type = postgres_to_py(py, postgres_row, column, column_idx, custom_decoders)?; @@ -24,6 +29,30 @@ fn row_to_dict<'a>( Ok(python_dict) } +/// Convert postgres `Row` into Python Tuple. +/// +/// # Errors +/// +/// May return Err Result if can not convert +/// postgres type to python or set new key-value pair +/// in python dict. +#[allow(clippy::ref_option)] +fn row_to_tuple<'a>( + py: Python<'a>, + postgres_row: &'a Row, + custom_decoders: &Option>, +) -> PSQLPyResult> { + let columns = postgres_row.columns(); + let mut tuple_items = Vec::with_capacity(columns.len()); + + for (column_idx, column) in columns.iter().enumerate() { + let python_value = postgres_to_py(py, postgres_row, column, column_idx, custom_decoders)?; + tuple_items.push(python_value); + } + + Ok(PyTuple::new(py, tuple_items)?) +} + #[pyclass(name = "QueryResult")] #[allow(clippy::module_name_repetitions)] pub struct PSQLDriverPyQueryResult { @@ -56,18 +85,29 @@ impl PSQLDriverPyQueryResult { /// May return Err Result if can not convert /// postgres type to python or set new key-value pair /// in python dict. - #[pyo3(signature = (custom_decoders=None))] + #[pyo3(signature = (custom_decoders=None, as_tuple=None))] #[allow(clippy::needless_pass_by_value)] pub fn result( &self, py: Python<'_>, custom_decoders: Option>, + as_tuple: Option, ) -> PSQLPyResult> { - let mut result: Vec> = vec![]; + let as_tuple = as_tuple.unwrap_or(false); + + if as_tuple { + let mut tuple_rows: Vec> = vec![]; + for row in &self.inner { + tuple_rows.push(row_to_tuple(py, row, &custom_decoders)?); + } + return Ok(tuple_rows.into_py_any(py)?); + } + + let mut dict_rows: Vec> = vec![]; for row in &self.inner { - result.push(row_to_dict(py, row, &custom_decoders)?); + dict_rows.push(row_to_dict(py, row, &custom_decoders)?); } - Ok(result.into_py_any(py)?) + Ok(dict_rows.into_py_any(py)?) } /// Convert result from database to any class passed from Python. @@ -143,12 +183,19 @@ impl PSQLDriverSinglePyQueryResult { /// postgres type to python, can not set new key-value pair /// in python dict or there are no result. #[allow(clippy::needless_pass_by_value)] - #[pyo3(signature = (custom_decoders=None))] + #[pyo3(signature = (custom_decoders=None, as_tuple=None))] pub fn result( &self, py: Python<'_>, custom_decoders: Option>, + as_tuple: Option, ) -> PSQLPyResult> { + let as_tuple = as_tuple.unwrap_or(false); + + if as_tuple { + return Ok(row_to_tuple(py, &self.inner, &custom_decoders)?.into_py_any(py)?); + } + Ok(row_to_dict(py, &self.inner, &custom_decoders)?.into_py_any(py)?) } diff --git a/src/value_converter/dto/funcs.rs b/src/value_converter/dto/funcs.rs index eec045e0..e869966c 100644 --- a/src/value_converter/dto/funcs.rs +++ b/src/value_converter/dto/funcs.rs @@ -4,7 +4,7 @@ use postgres_types::Type; pub fn array_type_to_single_type(array_type: &Type) -> Type { match *array_type { Type::BOOL_ARRAY => Type::BOOL, - Type::UUID_ARRAY => Type::UUID_ARRAY, + Type::UUID_ARRAY => Type::UUID, Type::VARCHAR_ARRAY => Type::VARCHAR, Type::TEXT_ARRAY => Type::TEXT, Type::INT2_ARRAY => Type::INT2, diff --git a/src/value_converter/models/serde_value.rs b/src/value_converter/models/serde_value.rs index 222ffe56..ebe95f58 100644 --- a/src/value_converter/models/serde_value.rs +++ b/src/value_converter/models/serde_value.rs @@ -3,8 +3,8 @@ use postgres_types::FromSql; use serde_json::{json, Map, Value}; use pyo3::{ - types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyTuple}, - Bound, FromPyObject, IntoPyObject, Py, PyAny, PyResult, Python, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods}, + Bound, FromPyObject, IntoPyObject, PyAny, PyResult, Python, }; use tokio_postgres::types::Type; @@ -37,7 +37,7 @@ impl<'py> IntoPyObject<'py> for InternalSerdeValue { type Error = RustPSQLDriverError; fn into_pyobject(self, py: Python<'py>) -> Result { - match build_python_from_serde_value(py, self.0.clone()) { + match build_python_from_serde_value(py, self.0) { Ok(ok_value) => Ok(ok_value.bind(py).clone()), Err(err) => Err(err), } @@ -57,25 +57,29 @@ impl<'a> FromSql<'a> for InternalSerdeValue { } } -fn serde_value_from_list(gil: Python<'_>, bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { - let mut result_vec: Vec = vec![]; +fn serde_value_from_list(_gil: Python<'_>, bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { + let py_list = bind_value.downcast::().map_err(|e| { + RustPSQLDriverError::PyToRustValueConversionError(format!( + "Parameter must be a list, but it's not: {e}" + )) + })?; - let params = bind_value.extract::>>()?; + let mut result_vec: Vec = Vec::with_capacity(py_list.len()); - for inner in params { - let inner_bind = inner.bind(gil); - if inner_bind.is_instance_of::() { - let python_dto = from_python_untyped(inner_bind)?; + for item in py_list.iter() { + if item.is_instance_of::() { + let python_dto = from_python_untyped(&item)?; result_vec.push(python_dto.to_serde_value()?); - } else if inner_bind.is_instance_of::() { - let serde_value = build_serde_value(inner.bind(gil))?; + } else if item.is_instance_of::() { + let serde_value = build_serde_value(&item)?; result_vec.push(serde_value); } else { return Err(RustPSQLDriverError::PyToRustValueConversionError( - "PyJSON must have dicts.".to_string(), + "Items in JSON array must be dicts or lists.".to_string(), )); } } + Ok(json!(result_vec)) } @@ -86,19 +90,18 @@ fn serde_value_from_dict(bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { )) })?; - let mut serde_map: Map = Map::new(); + let dict_len = dict.len(); + let mut serde_map: Map = Map::with_capacity(dict_len); - for dict_item in dict.items() { - let py_list = dict_item.downcast::().map_err(|error| { + for (key, value) in dict.iter() { + let key_str = key.extract::().map_err(|error| { RustPSQLDriverError::PyToRustValueConversionError(format!( - "Cannot cast to list: {error}" + "Cannot extract dict key as string: {error}" )) })?; - let key = py_list.get_item(0)?.extract::()?; - let value = from_python_untyped(&py_list.get_item(1)?)?; - - serde_map.insert(key, value.to_serde_value()?); + let value_dto = from_python_untyped(&value)?; + serde_map.insert(key_str, value_dto.to_serde_value()?); } Ok(Value::Object(serde_map)) @@ -131,12 +134,10 @@ pub fn build_serde_value(value: &Bound<'_, PyAny>) -> PSQLPyResult { /// May return error if cannot create serde value. pub fn pythondto_array_to_serde(array: Option>) -> PSQLPyResult { match array { - Some(array) => inner_pythondto_array_to_serde( - array.dimensions(), - array.iter().collect::>().as_slice(), - 0, - 0, - ), + Some(array) => { + let data: Vec = array.iter().cloned().collect(); + inner_pythondto_array_to_serde(array.dimensions(), &data, 0, 0) + } None => Ok(Value::Null), } } @@ -145,41 +146,49 @@ pub fn pythondto_array_to_serde(array: Option>) -> PSQLPyResult #[allow(clippy::cast_sign_loss)] fn inner_pythondto_array_to_serde( dimensions: &[Dimension], - data: &[&PythonDTO], + data: &[PythonDTO], dimension_index: usize, - mut lower_bound: usize, + data_offset: usize, ) -> PSQLPyResult { - let current_dimension = dimensions.get(dimension_index); - - if let Some(current_dimension) = current_dimension { - let possible_next_dimension = dimensions.get(dimension_index + 1); - match possible_next_dimension { - Some(next_dimension) => { - let mut final_list: Value = Value::Array(vec![]); - - for _ in 0..current_dimension.len as usize { - if dimensions.get(dimension_index + 1).is_some() { - let inner_pylist = inner_pythondto_array_to_serde( - dimensions, - &data[lower_bound..next_dimension.len as usize + lower_bound], - dimension_index + 1, - 0, - )?; - match final_list { - Value::Array(ref mut array) => array.push(inner_pylist), - _ => unreachable!(), - } - lower_bound += next_dimension.len as usize; - } - } - - return Ok(final_list); - } - None => { - return data.iter().map(|x| x.to_serde_value()).collect(); - } + if dimension_index >= dimensions.len() || data_offset >= data.len() { + return Ok(Value::Array(vec![])); + } + + let current_dimension = &dimensions[dimension_index]; + let current_len = current_dimension.len as usize; + + if dimension_index + 1 >= dimensions.len() { + let end_offset = (data_offset + current_len).min(data.len()); + let slice = &data[data_offset..end_offset]; + + let mut result_values = Vec::with_capacity(slice.len()); + for item in slice { + result_values.push(item.to_serde_value()?); } + + return Ok(Value::Array(result_values)); + } + + let mut final_array = Vec::with_capacity(current_len); + + let sub_array_size = dimensions[dimension_index + 1..] + .iter() + .map(|d| d.len as usize) + .product::(); + + let mut current_offset = data_offset; + + for _ in 0..current_len { + if current_offset >= data.len() { + break; + } + + let inner_value = + inner_pythondto_array_to_serde(dimensions, data, dimension_index + 1, current_offset)?; + + final_array.push(inner_value); + current_offset += sub_array_size; } - Ok(Value::Array(vec![])) + Ok(Value::Array(final_array)) } diff --git a/src/value_converter/to_python.rs b/src/value_converter/to_python.rs index abc734c8..cf0f6d35 100644 --- a/src/value_converter/to_python.rs +++ b/src/value_converter/to_python.rs @@ -95,13 +95,9 @@ fn postgres_array_to_py<'py, T: IntoPyObject<'py> + Clone>( array: Option>, ) -> Option> { array.map(|array| { - inner_postgres_array_to_py( - py, - array.dimensions(), - array.iter().cloned().collect::>(), - 0, - 0, - ) + // Collect data once instead of creating copies in recursion + let data: Vec = array.iter().cloned().collect(); + inner_postgres_array_to_py(py, array.dimensions(), &data, 0, 0) }) } @@ -110,44 +106,60 @@ fn postgres_array_to_py<'py, T: IntoPyObject<'py> + Clone>( fn inner_postgres_array_to_py<'py, T>( py: Python<'py>, dimensions: &[Dimension], - data: Vec, + data: &[T], dimension_index: usize, - mut lower_bound: usize, + data_offset: usize, ) -> Py where T: IntoPyObject<'py> + Clone, { - let current_dimension = dimensions.get(dimension_index); - - if let Some(current_dimension) = current_dimension { - let possible_next_dimension = dimensions.get(dimension_index + 1); - match possible_next_dimension { - Some(next_dimension) => { - let final_list = PyList::empty(py); - - for _ in 0..current_dimension.len as usize { - if dimensions.get(dimension_index + 1).is_some() { - let inner_pylist = inner_postgres_array_to_py( - py, - dimensions, - data[lower_bound..next_dimension.len as usize + lower_bound].to_vec(), - dimension_index + 1, - 0, - ); - final_list.append(inner_pylist).unwrap(); - lower_bound += next_dimension.len as usize; - } - } - - return final_list.unbind(); - } - None => { - return PyList::new(py, data).unwrap().unbind(); // TODO unwrap is unsafe - } + // Check bounds early + if dimension_index >= dimensions.len() || data_offset >= data.len() { + return PyList::empty(py).unbind(); + } + + let current_dimension = &dimensions[dimension_index]; + let current_len = current_dimension.len as usize; + + // If this is the last dimension, create a list with the actual data + if dimension_index + 1 >= dimensions.len() { + let end_offset = (data_offset + current_len).min(data.len()); + let slice = &data[data_offset..end_offset]; + + // Create Python list more efficiently + return match PyList::new(py, slice.iter().cloned()) { + Ok(list) => list.unbind(), + Err(_) => PyList::empty(py).unbind(), + }; + } + + // For multi-dimensional arrays, recursively create nested lists + let final_list = PyList::empty(py); + + // Calculate the size of each sub-array + let sub_array_size = dimensions[dimension_index + 1..] + .iter() + .map(|d| d.len as usize) + .product::(); + + let mut current_offset = data_offset; + + for _ in 0..current_len { + if current_offset >= data.len() { + break; } + + let inner_list = + inner_postgres_array_to_py(py, dimensions, data, dimension_index + 1, current_offset); + + if final_list.append(inner_list).is_err() { + break; + } + + current_offset += sub_array_size; } - PyList::empty(py).unbind() + final_list.unbind() } #[allow(clippy::too_many_lines)] @@ -604,7 +616,7 @@ pub fn raw_bytes_data_process( if let Ok(Some(py_encoder_func)) = py_encoder_func { return Ok(py_encoder_func - .call((raw_bytes_data.to_vec(),), None)? + .call1((PyBytes::new(py, raw_bytes_data),))? .unbind()); } } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy