Skip to content

Commit d48163a

Browse files
committed
more optimizations
1 parent ba63201 commit d48163a

File tree

1 file changed

+59
-39
lines changed

1 file changed

+59
-39
lines changed

src/driver/common.rs

Lines changed: 59 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use crate::{
1616

1717
use bytes::BytesMut;
1818
use futures_util::pin_mut;
19-
use pyo3::{buffer::PyBuffer, PyErr, Python};
19+
use pyo3::{buffer::PyBuffer, Python};
2020
use tokio_postgres::binary_copy::BinaryCopyInWriter;
2121

2222
use crate::format_helpers::quote_ident;
@@ -249,50 +249,70 @@ macro_rules! impl_binary_copy_method {
249249
columns: Option<Vec<String>>,
250250
schema_name: Option<String>,
251251
) -> PSQLPyResult<u64> {
252-
let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).conn.clone());
253-
let mut table_name = quote_ident(&table_name);
254-
if let Some(schema_name) = schema_name {
255-
table_name = format!("{}.{}", quote_ident(&schema_name), table_name);
256-
}
257-
258-
let mut formated_columns = String::default();
259-
if let Some(columns) = columns {
260-
formated_columns = format!("({})", columns.join(", "));
261-
}
252+
let (db_client, mut bytes_mut) =
253+
Python::with_gil(|gil| -> PSQLPyResult<(Option<_>, BytesMut)> {
254+
let db_client = self_.borrow(gil).conn.clone();
255+
256+
let Some(db_client) = db_client else {
257+
return Ok((None, BytesMut::new()));
258+
};
259+
260+
let data_bytes_mut =
261+
if let Ok(py_buffer) = source.extract::<PyBuffer<u8>>(gil) {
262+
let buffer_len = py_buffer.len_bytes();
263+
let mut bytes_mut = BytesMut::zeroed(buffer_len);
264+
265+
py_buffer.copy_to_slice(gil, &mut bytes_mut[..])?;
266+
bytes_mut
267+
} else if let Ok(py_bytes) = source.call_method0(gil, "getvalue") {
268+
if let Ok(bytes_vec) = py_bytes.extract::<Vec<u8>>(gil) {
269+
let bytes_mut = BytesMut::from(&bytes_vec[..]);
270+
bytes_mut
271+
} else {
272+
return Err(RustPSQLDriverError::PyToRustValueConversionError(
273+
"source must be bytes or support Buffer protocol".into(),
274+
));
275+
}
276+
} else {
277+
return Err(RustPSQLDriverError::PyToRustValueConversionError(
278+
"source must be bytes or support Buffer protocol".into(),
279+
));
280+
};
281+
282+
Ok((Some(db_client), data_bytes_mut))
283+
})?;
262284

263-
let copy_qs =
264-
format!("COPY {table_name}{formated_columns} FROM STDIN (FORMAT binary)");
285+
let Some(db_client) = db_client else {
286+
return Ok(0);
287+
};
265288

266-
if let Some(db_client) = db_client {
267-
let mut psql_bytes: BytesMut = Python::with_gil(|gil| {
268-
let possible_py_buffer: Result<PyBuffer<u8>, PyErr> =
269-
source.extract::<PyBuffer<u8>>(gil);
270-
if let Ok(py_buffer) = possible_py_buffer {
271-
let vec_buf = py_buffer.to_vec(gil)?;
272-
return Ok(BytesMut::from(vec_buf.as_slice()));
273-
}
289+
let full_table_name = match schema_name {
290+
Some(schema) => {
291+
format!("{}.{}", quote_ident(&schema), quote_ident(&table_name))
292+
}
293+
None => quote_ident(&table_name),
294+
};
274295

275-
if let Ok(py_bytes) = source.call_method0(gil, "getvalue") {
276-
if let Ok(bytes) = py_bytes.extract::<Vec<u8>>(gil) {
277-
return Ok(BytesMut::from(bytes.as_slice()));
278-
}
279-
}
296+
let copy_qs = match columns {
297+
Some(ref cols) if !cols.is_empty() => {
298+
format!(
299+
"COPY {}({}) FROM STDIN (FORMAT binary)",
300+
full_table_name,
301+
cols.join(", ")
302+
)
303+
}
304+
_ => format!("COPY {} FROM STDIN (FORMAT binary)", full_table_name),
305+
};
280306

281-
Err(RustPSQLDriverError::PyToRustValueConversionError(
282-
"source must be bytes or support Buffer protocol".into(),
283-
))
284-
})?;
307+
let read_conn_g = db_client.read().await;
308+
let sink = read_conn_g.copy_in(&copy_qs).await?;
309+
let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]);
310+
pin_mut!(writer);
285311

286-
let read_conn_g = db_client.read().await;
287-
let sink = read_conn_g.copy_in(&copy_qs).await?;
288-
let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]);
289-
pin_mut!(writer);
290-
writer.as_mut().write_raw_bytes(&mut psql_bytes).await?;
291-
let rows_created = writer.as_mut().finish_empty().await?;
292-
return Ok(rows_created);
293-
}
312+
writer.as_mut().write_raw_bytes(&mut bytes_mut).await?;
313+
let rows_created = writer.as_mut().finish_empty().await?;
294314

295-
Ok(0)
315+
Ok(rows_created)
296316
}
297317
}
298318
};

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy