Skip to content

Commit d80e558

Browse files
authored
Merge pull request #137 from psqlpy-python/support_dbapi
Added python Portal class and logic
2 parents 61a272e + e60aad9 commit d80e558

File tree

8 files changed

+337
-113
lines changed

8 files changed

+337
-113
lines changed

src/connection/impls.rs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
1+
use std::sync::{Arc, RwLock};
2+
13
use bytes::Buf;
24
use pyo3::{PyAny, Python};
3-
use tokio_postgres::{CopyInSink, Row, Statement, ToStatement};
5+
use tokio_postgres::{CopyInSink, Portal as tp_Portal, Row, Statement, ToStatement};
46

57
use crate::{
8+
driver::portal::Portal,
69
exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError},
710
options::{IsolationLevel, ReadVariant},
811
query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult},
912
statement::{statement::PsqlpyStatement, statement_builder::StatementBuilder},
13+
transaction::structs::PSQLPyTransaction,
1014
value_converter::to_python::postgres_to_py,
1115
};
1216

17+
use deadpool_postgres::Transaction as dp_Transaction;
18+
use tokio_postgres::Transaction as tp_Transaction;
19+
1320
use super::{
1421
structs::{PSQLPyConnection, PoolConnection, SingleConnection},
1522
traits::{CloseTransaction, Connection, Cursor, StartTransaction, Transaction},
@@ -516,4 +523,42 @@ impl PSQLPyConnection {
516523
}
517524
}
518525
}
526+
527+
pub async fn transaction(&mut self) -> PSQLPyResult<PSQLPyTransaction> {
528+
match self {
529+
PSQLPyConnection::PoolConn(conn) => {
530+
let transaction = unsafe {
531+
std::mem::transmute::<dp_Transaction<'_>, dp_Transaction<'static>>(
532+
conn.connection.transaction().await?,
533+
)
534+
};
535+
Ok(PSQLPyTransaction::PoolTransaction(transaction))
536+
}
537+
PSQLPyConnection::SingleConnection(conn) => {
538+
let transaction = unsafe {
539+
std::mem::transmute::<tp_Transaction<'_>, tp_Transaction<'static>>(
540+
conn.connection.transaction().await?,
541+
)
542+
};
543+
Ok(PSQLPyTransaction::SingleTransaction(transaction))
544+
}
545+
}
546+
}
547+
548+
pub async fn portal(
549+
&mut self,
550+
querystring: String,
551+
parameters: Option<pyo3::Py<PyAny>>,
552+
) -> PSQLPyResult<(PSQLPyTransaction, tp_Portal)> {
553+
let statement = StatementBuilder::new(querystring, parameters, self, Some(false))
554+
.build()
555+
.await?;
556+
557+
let transaction = self.transaction().await?;
558+
let inner_portal = transaction
559+
.portal(statement.raw_query(), &statement.params())
560+
.await?;
561+
562+
Ok((transaction, inner_portal))
563+
}
519564
}

src/driver/connection.rs

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ use crate::{
1818
runtime::tokio_runtime,
1919
};
2020

21-
use super::{connection_pool::connect_pool, cursor::Cursor, transaction::Transaction};
21+
use super::{
22+
connection_pool::connect_pool, cursor::Cursor, portal::Portal, transaction::Transaction,
23+
};
2224

2325
/// Make new connection pool.
2426
///
@@ -396,17 +398,16 @@ impl Connection {
396398
read_variant: Option<ReadVariant>,
397399
deferrable: Option<bool>,
398400
) -> PSQLPyResult<Transaction> {
399-
if let Some(db_client) = &self.db_client {
400-
return Ok(Transaction::new(
401-
Some(db_client.clone()),
402-
self.pg_config.clone(),
403-
isolation_level,
404-
read_variant,
405-
deferrable,
406-
));
407-
}
408-
409-
Err(RustPSQLDriverError::ConnectionClosedError)
401+
let Some(conn) = &self.db_client else {
402+
return Err(RustPSQLDriverError::ConnectionClosedError);
403+
};
404+
Ok(Transaction::new(
405+
Some(conn.clone()),
406+
self.pg_config.clone(),
407+
isolation_level,
408+
read_variant,
409+
deferrable,
410+
))
410411
}
411412

412413
/// Create new cursor object.
@@ -428,19 +429,39 @@ impl Connection {
428429
scroll: Option<bool>,
429430
prepared: Option<bool>,
430431
) -> PSQLPyResult<Cursor> {
431-
if let Some(db_client) = &self.db_client {
432-
return Ok(Cursor::new(
433-
db_client.clone(),
434-
self.pg_config.clone(),
435-
querystring,
436-
parameters,
437-
fetch_number.unwrap_or(10),
438-
scroll,
439-
prepared,
440-
));
441-
}
432+
let Some(conn) = &self.db_client else {
433+
return Err(RustPSQLDriverError::ConnectionClosedError);
434+
};
435+
436+
Ok(Cursor::new(
437+
conn.clone(),
438+
self.pg_config.clone(),
439+
querystring,
440+
parameters,
441+
fetch_number.unwrap_or(10),
442+
scroll,
443+
prepared,
444+
))
445+
}
442446

443-
Err(RustPSQLDriverError::ConnectionClosedError)
447+
#[pyo3(signature = (
448+
querystring,
449+
parameters=None,
450+
fetch_number=None,
451+
))]
452+
pub fn portal(
453+
&self,
454+
querystring: String,
455+
parameters: Option<Py<PyAny>>,
456+
fetch_number: Option<i32>,
457+
) -> PSQLPyResult<Portal> {
458+
println!("{:?}", fetch_number);
459+
Ok(Portal::new(
460+
self.db_client.clone(),
461+
querystring,
462+
parameters,
463+
fetch_number,
464+
))
444465
}
445466

446467
#[allow(clippy::needless_pass_by_value)]

0 commit comments

Comments
 (0)