Skip to content

Commit d48163a

Browse files
committed
more optimizations
1 parent ba63201 commit d48163a

File tree

1 file changed

+59
-39
lines changed

1 file changed

+59
-39
lines changed

src/driver/common.rs

Lines changed: 59 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use crate::{
1616

1717
use bytes::BytesMut;
1818
use futures_util::pin_mut;
19-
use pyo3::{buffer::PyBuffer, PyErr, Python};
19+
use pyo3::{buffer::PyBuffer, Python};
2020
use tokio_postgres::binary_copy::BinaryCopyInWriter;
2121

2222
use crate::format_helpers::quote_ident;
@@ -249,50 +249,70 @@ macro_rules! impl_binary_copy_method {
249249
columns: Option<Vec<String>>,
250250
schema_name: Option<String>,
251251
) -> PSQLPyResult<u64> {
252-
let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).conn.clone());
253-
let mut table_name = quote_ident(&table_name);
254-
if let Some(schema_name) = schema_name {
255-
table_name = format!("{}.{}", quote_ident(&schema_name), table_name);
256-
}
257-
258-
let mut formated_columns = String::default();
259-
if let Some(columns) = columns {
260-
formated_columns = format!("({})", columns.join(", "));
261-
}
252+
let (db_client, mut bytes_mut) =
253+
Python::with_gil(|gil| -> PSQLPyResult<(Option<_>, BytesMut)> {
254+
let db_client = self_.borrow(gil).conn.clone();
255+
256+
let Some(db_client) = db_client else {
257+
return Ok((None, BytesMut::new()));
258+
};
259+
260+
let data_bytes_mut =
261+
if let Ok(py_buffer) = source.extract::<PyBuffer<u8>>(gil) {
262+
let buffer_len = py_buffer.len_bytes();
263+
let mut bytes_mut = BytesMut::zeroed(buffer_len);
264+
265+
py_buffer.copy_to_slice(gil, &mut bytes_mut[..])?;
266+
bytes_mut
267+
} else if let Ok(py_bytes) = source.call_method0(gil, "getvalue") {
268+
if let Ok(bytes_vec) = py_bytes.extract::<Vec<u8>>(gil) {
269+
let bytes_mut = BytesMut::from(&bytes_vec[..]);
270+
bytes_mut
271+
} else {
272+
return Err(RustPSQLDriverError::PyToRustValueConversionError(
273+
"source must be bytes or support Buffer protocol".into(),
274+
));
275+
}
276+
} else {
277+
return Err(RustPSQLDriverError::PyToRustValueConversionError(
278+
"source must be bytes or support Buffer protocol".into(),
279+
));
280+
};
281+
282+
Ok((Some(db_client), data_bytes_mut))
283+
})?;
262284

263-
let copy_qs =
264-
format!("COPY {table_name}{formated_columns} FROM STDIN (FORMAT binary)");
285+
let Some(db_client) = db_client else {
286+
return Ok(0);
287+
};
265288

266-
if let Some(db_client) = db_client {
267-
let mut psql_bytes: BytesMut = Python::with_gil(|gil| {
268-
let possible_py_buffer: Result<PyBuffer<u8>, PyErr> =
269-
source.extract::<PyBuffer<u8>>(gil);
270-
if let Ok(py_buffer) = possible_py_buffer {
271-
let vec_buf = py_buffer.to_vec(gil)?;
272-
return Ok(BytesMut::from(vec_buf.as_slice()));
273-
}
289+
let full_table_name = match schema_name {
290+
Some(schema) => {
291+
format!("{}.{}", quote_ident(&schema), quote_ident(&table_name))
292+
}
293+
None => quote_ident(&table_name),
294+
};
274295

275-
if let Ok(py_bytes) = source.call_method0(gil, "getvalue") {
276-
if let Ok(bytes) = py_bytes.extract::<Vec<u8>>(gil) {
277-
return Ok(BytesMut::from(bytes.as_slice()));
278-
}
279-
}
296+
let copy_qs = match columns {
297+
Some(ref cols) if !cols.is_empty() => {
298+
format!(
299+
"COPY {}({}) FROM STDIN (FORMAT binary)",
300+
full_table_name,
301+
cols.join(", ")
302+
)
303+
}
304+
_ => format!("COPY {} FROM STDIN (FORMAT binary)", full_table_name),
305+
};
280306

281-
Err(RustPSQLDriverError::PyToRustValueConversionError(
282-
"source must be bytes or support Buffer protocol".into(),
283-
))
284-
})?;
307+
let read_conn_g = db_client.read().await;
308+
let sink = read_conn_g.copy_in(&copy_qs).await?;
309+
let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]);
310+
pin_mut!(writer);
285311

286-
let read_conn_g = db_client.read().await;
287-
let sink = read_conn_g.copy_in(&copy_qs).await?;
288-
let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]);
289-
pin_mut!(writer);
290-
writer.as_mut().write_raw_bytes(&mut psql_bytes).await?;
291-
let rows_created = writer.as_mut().finish_empty().await?;
292-
return Ok(rows_created);
293-
}
312+
writer.as_mut().write_raw_bytes(&mut bytes_mut).await?;
313+
let rows_created = writer.as_mut().finish_empty().await?;
294314

295-
Ok(0)
315+
Ok(rows_created)
296316
}
297317
}
298318
};

0 commit comments

Comments
 (0)