@@ -16,7 +16,7 @@ use crate::{
16
16
17
17
use bytes:: BytesMut ;
18
18
use futures_util:: pin_mut;
19
- use pyo3:: { buffer:: PyBuffer , PyErr , Python } ;
19
+ use pyo3:: { buffer:: PyBuffer , Python } ;
20
20
use tokio_postgres:: binary_copy:: BinaryCopyInWriter ;
21
21
22
22
use crate :: format_helpers:: quote_ident;
@@ -249,50 +249,70 @@ macro_rules! impl_binary_copy_method {
249
249
columns: Option <Vec <String >>,
250
250
schema_name: Option <String >,
251
251
) -> PSQLPyResult <u64 > {
252
- let db_client = pyo3:: Python :: with_gil( |gil| self_. borrow( gil) . conn. clone( ) ) ;
253
- let mut table_name = quote_ident( & table_name) ;
254
- if let Some ( schema_name) = schema_name {
255
- table_name = format!( "{}.{}" , quote_ident( & schema_name) , table_name) ;
256
- }
257
-
258
- let mut formated_columns = String :: default ( ) ;
259
- if let Some ( columns) = columns {
260
- formated_columns = format!( "({})" , columns. join( ", " ) ) ;
261
- }
252
+ let ( db_client, mut bytes_mut) =
253
+ Python :: with_gil( |gil| -> PSQLPyResult <( Option <_>, BytesMut ) > {
254
+ let db_client = self_. borrow( gil) . conn. clone( ) ;
255
+
256
+ let Some ( db_client) = db_client else {
257
+ return Ok ( ( None , BytesMut :: new( ) ) ) ;
258
+ } ;
259
+
260
+ let data_bytes_mut =
261
+ if let Ok ( py_buffer) = source. extract:: <PyBuffer <u8 >>( gil) {
262
+ let buffer_len = py_buffer. len_bytes( ) ;
263
+ let mut bytes_mut = BytesMut :: zeroed( buffer_len) ;
264
+
265
+ py_buffer. copy_to_slice( gil, & mut bytes_mut[ ..] ) ?;
266
+ bytes_mut
267
+ } else if let Ok ( py_bytes) = source. call_method0( gil, "getvalue" ) {
268
+ if let Ok ( bytes_vec) = py_bytes. extract:: <Vec <u8 >>( gil) {
269
+ let bytes_mut = BytesMut :: from( & bytes_vec[ ..] ) ;
270
+ bytes_mut
271
+ } else {
272
+ return Err ( RustPSQLDriverError :: PyToRustValueConversionError (
273
+ "source must be bytes or support Buffer protocol" . into( ) ,
274
+ ) ) ;
275
+ }
276
+ } else {
277
+ return Err ( RustPSQLDriverError :: PyToRustValueConversionError (
278
+ "source must be bytes or support Buffer protocol" . into( ) ,
279
+ ) ) ;
280
+ } ;
281
+
282
+ Ok ( ( Some ( db_client) , data_bytes_mut) )
283
+ } ) ?;
262
284
263
- let copy_qs =
264
- format!( "COPY {table_name}{formated_columns} FROM STDIN (FORMAT binary)" ) ;
285
+ let Some ( db_client) = db_client else {
286
+ return Ok ( 0 ) ;
287
+ } ;
265
288
266
- if let Some ( db_client) = db_client {
267
- let mut psql_bytes: BytesMut = Python :: with_gil( |gil| {
268
- let possible_py_buffer: Result <PyBuffer <u8 >, PyErr > =
269
- source. extract:: <PyBuffer <u8 >>( gil) ;
270
- if let Ok ( py_buffer) = possible_py_buffer {
271
- let vec_buf = py_buffer. to_vec( gil) ?;
272
- return Ok ( BytesMut :: from( vec_buf. as_slice( ) ) ) ;
273
- }
289
+ let full_table_name = match schema_name {
290
+ Some ( schema) => {
291
+ format!( "{}.{}" , quote_ident( & schema) , quote_ident( & table_name) )
292
+ }
293
+ None => quote_ident( & table_name) ,
294
+ } ;
274
295
275
- if let Ok ( py_bytes) = source. call_method0( gil, "getvalue" ) {
276
- if let Ok ( bytes) = py_bytes. extract:: <Vec <u8 >>( gil) {
277
- return Ok ( BytesMut :: from( bytes. as_slice( ) ) ) ;
278
- }
279
- }
296
+ let copy_qs = match columns {
297
+ Some ( ref cols) if !cols. is_empty( ) => {
298
+ format!(
299
+ "COPY {}({}) FROM STDIN (FORMAT binary)" ,
300
+ full_table_name,
301
+ cols. join( ", " )
302
+ )
303
+ }
304
+ _ => format!( "COPY {} FROM STDIN (FORMAT binary)" , full_table_name) ,
305
+ } ;
280
306
281
- Err ( RustPSQLDriverError :: PyToRustValueConversionError (
282
- "source must be bytes or support Buffer protocol" . into ( ) ,
283
- ) )
284
- } ) ? ;
307
+ let read_conn_g = db_client . read ( ) . await ;
308
+ let sink = read_conn_g . copy_in ( & copy_qs ) . await ? ;
309
+ let writer = BinaryCopyInWriter :: new_empty_buffer ( sink , & [ ] ) ;
310
+ pin_mut! ( writer ) ;
285
311
286
- let read_conn_g = db_client. read( ) . await ;
287
- let sink = read_conn_g. copy_in( & copy_qs) . await ?;
288
- let writer = BinaryCopyInWriter :: new_empty_buffer( sink, & [ ] ) ;
289
- pin_mut!( writer) ;
290
- writer. as_mut( ) . write_raw_bytes( & mut psql_bytes) . await ?;
291
- let rows_created = writer. as_mut( ) . finish_empty( ) . await ?;
292
- return Ok ( rows_created) ;
293
- }
312
+ writer. as_mut( ) . write_raw_bytes( & mut bytes_mut) . await ?;
313
+ let rows_created = writer. as_mut( ) . finish_empty( ) . await ?;
294
314
295
- Ok ( 0 )
315
+ Ok ( rows_created )
296
316
}
297
317
}
298
318
} ;
0 commit comments