diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index 869d242d066e..18d9d0f9c615 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -244,32 +244,24 @@ impl Queryable for Mssql { } /// Statement to begin a transaction - fn begin_statement(&self, depth: u32) -> Cow<'static, str> { - if depth > 1 { - Cow::Owned(format!("SAVE TRANSACTION savepoint{depth}")) - } else { - Cow::Borrowed("BEGIN TRAN") - } + fn begin_statement(&self) -> &'static str { + "BEGIN TRAN" } - /// Statement to commit a transaction - fn commit_statement(&self, depth: u32) -> Cow<'static, str> { - // MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested - // transaction we just continue onwards - if depth > 1 { - Cow::Owned("".to_string()) - } else { - Cow::Borrowed("COMMIT") - } + /// Statement to create a savepoint + fn create_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("SAVE TRANSACTION savepoint{depth}")) } - /// Statement to rollback a transaction - fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { - if depth > 1 { - Cow::Owned(format!("ROLLBACK TRANSACTION savepoint{depth}")) - } else { - Cow::Borrowed("ROLLBACK") - } + // MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested + // transaction we just continue onwards + fn release_savepoint_statement(&self, _depth: u32) -> Cow<'static, str> { + Cow::Owned("".to_string()) + } + + /// Statement to rollback to a savepoint + fn rollback_to_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("ROLLBACK TRANSACTION savepoint{depth}")) } fn requires_isolation_first(&self) -> bool { diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index 94dd25f5c0e2..ccd403c5499c 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -350,29 +350,22 @@ impl Queryable for Mysql { } /// Statement to begin a transaction - fn begin_statement(&self, depth: u32) -> Cow<'static, str> { - if depth > 1 { - Cow::Owned(format!("SAVEPOINT savepoint{depth}")) - } else { - Cow::Borrowed("BEGIN") - } + fn begin_statement(&self) -> &'static str { + "BEGIN" } - /// Statement to commit a transaction - fn commit_statement(&self, depth: u32) -> Cow<'static, str> { - if depth > 1 { - Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}")) - } else { - Cow::Borrowed("COMMIT") - } + /// Statement to create a savepoint + fn create_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("SAVEPOINT savepoint{depth}")) } - /// Statement to rollback a transaction - fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { - if depth > 1 { - Cow::Owned(format!("ROLLBACK TO savepoint{depth}")) - } else { - Cow::Borrowed("ROLLBACK") - } + /// Statement to release a savepoint + fn release_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}")) + } + + /// Statement to rollback to a savepoint + fn rollback_to_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("ROLLBACK TO savepoint{depth}")) } } diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 8cd1e1a07577..7e31f19fd085 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -809,30 +809,23 @@ impl Queryable for PostgreSql { } /// Statement to begin a transaction - fn begin_statement(&self, depth: u32) -> Cow<'static, str> { - if depth > 1 { - Cow::Owned(format!("SAVEPOINT savepoint{depth}")) - } else { - Cow::Borrowed("BEGIN") - } + fn begin_statement(&self) -> &'static str { + "BEGIN" } - /// Statement to commit a transaction - fn commit_statement(&self, depth: u32) -> Cow<'static, str> { - if depth > 1 { - Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}")) - } else { - Cow::Borrowed("COMMIT") - } + /// Statement to create a savepoint + fn create_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("SAVEPOINT savepoint{depth}")) } - /// Statement to rollback a transaction - fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { - if depth > 1 { - Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}")) - } else { - Cow::Borrowed("ROLLBACK") - } + /// Statement to release a savepoint + fn release_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}")) + } + + /// Statement to rollback to a savepoint + fn rollback_to_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}")) } } diff --git a/quaint/src/connector/queryable.rs b/quaint/src/connector/queryable.rs index b34e42866d85..1023884921e2 100644 --- a/quaint/src/connector/queryable.rs +++ b/quaint/src/connector/queryable.rs @@ -92,30 +92,23 @@ pub trait Queryable: Send + Sync { } /// Statement to begin a transaction - fn begin_statement(&self, depth: u32) -> Cow<'static, str> { - if depth > 1 { - Cow::Owned(format!("SAVEPOINT savepoint{depth}")) - } else { - Cow::Borrowed("BEGIN") - } + fn begin_statement(&self) -> &'static str { + "BEGIN" } - /// Statement to commit a transaction - fn commit_statement(&self, depth: u32) -> Cow<'static, str> { - if depth > 1 { - Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}")) - } else { - Cow::Borrowed("COMMIT") - } + /// Statement to create a savepoint in a transaction + fn create_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("SAVEPOINT savepoint{depth}")) } - /// Statement to rollback a transaction - fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { - if depth > 1 { - Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}")) - } else { - Cow::Borrowed("ROLLBACK") - } + /// Statement to release a savepoint in a transaction + fn release_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}")) + } + + /// Statement to rollback to a savepoint in a transaction + fn rollback_to_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}")) } /// Sets the transaction isolation level to given value. diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index d4cf610af6f5..037a98cf6d8e 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -185,34 +185,27 @@ impl Queryable for Sqlite { } /// Statement to begin a transaction - fn begin_statement(&self, depth: u32) -> Cow<'static, str> { + fn begin_statement(&self) -> &'static str { // From https://sqlite.org/isolation.html: // `BEGIN IMMEDIATE` avoids possible `SQLITE_BUSY_SNAPSHOT` that arise when another connection jumps ahead in line. // The BEGIN IMMEDIATE command goes ahead and starts a write transaction, and thus blocks all other writers. // If the BEGIN IMMEDIATE operation succeeds, then no subsequent operations in that transaction will ever fail with an SQLITE_BUSY error. - if depth > 1 { - Cow::Owned(format!("SAVEPOINT savepoint{depth}")) - } else { - Cow::Borrowed("BEGIN IMMEDIATE") - } + "BEGIN IMMEDIATE" } - /// Statement to commit a transaction - fn commit_statement(&self, depth: u32) -> Cow<'static, str> { - if depth > 1 { - Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}")) - } else { - Cow::Borrowed("COMMIT") - } + /// Statement to create a savepoint + fn create_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("SAVEPOINT savepoint{depth}")) } - /// Statement to rollback a transaction - fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { - if depth > 1 { - Cow::Owned(format!("ROLLBACK TO savepoint{depth}")) - } else { - Cow::Borrowed("ROLLBACK") - } + /// Statement to release a savepoint + fn release_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}")) + } + + /// Statement to rollback to a savepoint + fn rollback_to_savepoint_statement(&self, depth: u32) -> Cow<'static, str> { + Cow::Owned(format!("ROLLBACK TO savepoint{depth}")) } } diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index 461ce6610ee2..0a3fbaa50a9c 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -14,14 +14,25 @@ use std::{ #[async_trait] pub trait Transaction: Queryable { + fn depth(&self) -> u32; + /// Start a new transaction or nested transaction via savepoint. async fn begin(&mut self) -> crate::Result<()>; /// Commit the changes to the database and consume the transaction. - async fn commit(&mut self) -> crate::Result; + async fn commit(&mut self) -> crate::Result<()>; /// Rolls back the changes to the database. - async fn rollback(&mut self) -> crate::Result; + async fn rollback(&mut self) -> crate::Result<()>; + + /// Creates a savepoint in the transaction. + async fn create_savepoint(&mut self) -> crate::Result<()>; + + /// Releases a savepoint in the transaction. + async fn release_savepoint(&mut self) -> crate::Result<()>; + + /// Rolls back to a savepoint in the transaction. + async fn rollback_to_savepoint(&mut self) -> crate::Result<()>; /// workaround for lack of upcasting between traits https://github.com/rust-lang/rust/issues/65991 fn as_queryable(&self) -> &dyn Queryable; @@ -79,71 +90,110 @@ impl<'a> DefaultTransaction<'a> { #[async_trait] impl<'a> Transaction for DefaultTransaction<'a> { + fn depth(&self) -> u32 { + *self.depth.lock().unwrap() + } + async fn begin(&mut self) -> crate::Result<()> { - let current_depth = { + // Lock the mutex in it's own scope to ensure its dropped before the await + { let mut depth = self.depth.lock().unwrap(); *depth += 1; - *depth - }; + } - let begin_statement = self.inner.begin_statement(current_depth); + let begin_statement = self.inner.begin_statement(); - self.inner.raw_cmd(&begin_statement).await?; + self.inner.raw_cmd(begin_statement).await?; Ok(()) } /// Commit the changes to the database and consume the transaction. - async fn commit(&mut self) -> crate::Result { - // Lock the mutex and get the depth value - let depth_val = { - let depth = self.depth.lock().unwrap(); - *depth - }; - + async fn commit(&mut self) -> crate::Result<()> { // Perform the asynchronous operation without holding the lock - let commit_statement = self.inner.commit_statement(depth_val); - self.inner.raw_cmd(&commit_statement).await?; + self.inner.raw_cmd("COMMIT").await?; - // Lock the mutex again to modify the depth - let new_depth = { + // Lock the mutex to modify the depth + let mut depth = self.depth.lock().unwrap(); + *depth -= 1; + + self.gauge.decrement(); + + Ok(()) + } + + /// Rolls back the changes to the database. + async fn rollback(&mut self) -> crate::Result<()> { + self.inner.raw_cmd("ROLLBACK").await?; + + // Lock the mutex to modify the depth + let mut depth = self.depth.lock().unwrap(); + *depth -= 1; + + self.gauge.decrement(); + + Ok(()) + } + + /// Creates a savepoint in the transaction + async fn create_savepoint(&mut self) -> crate::Result<()> { + let current_depth = { let mut depth = self.depth.lock().unwrap(); - *depth -= 1; + *depth += 1; *depth }; - if new_depth == 0 { - self.gauge.decrement(); - } - - Ok(new_depth) + let create_savepoint_statement = self.inner.create_savepoint_statement(current_depth); + self.inner.raw_cmd(&create_savepoint_statement).await?; + Ok(()) } - /// Rolls back the changes to the database. - async fn rollback(&mut self) -> crate::Result { + /// Releases a savepoint in the transaction + async fn release_savepoint(&mut self) -> crate::Result<()> { // Lock the mutex and get the depth value let depth_val = { let depth = self.depth.lock().unwrap(); *depth }; - // Perform the asynchronous operation without holding the lock - let rollback_statement = self.inner.rollback_statement(depth_val); + if depth_val == 0 { + panic!( + "No savepoint to release in transaction, make sure to call create_savepoint before release_savepoint" + ); + } - self.inner.raw_cmd(&rollback_statement).await?; + // Perform the asynchronous operation without holding the lock + let release_savepoint_statement = self.inner.release_savepoint_statement(depth_val); + self.inner.raw_cmd(&release_savepoint_statement).await?; // Lock the mutex again to modify the depth - let new_depth = { - let mut depth = self.depth.lock().unwrap(); - *depth -= 1; + let mut depth = self.depth.lock().unwrap(); + *depth -= 1; + + Ok(()) + } + + /// Rollback to savepoint in the transaction + async fn rollback_to_savepoint(&mut self) -> crate::Result<()> { + let depth_val = { + let depth = self.depth.lock().unwrap(); *depth }; - if new_depth == 0 { - self.gauge.decrement(); + if depth_val == 0 { + panic!( + "No savepoint to rollback to in transaction, make sure to call create_savepoint before rollback_to_savepoint" + ); } - Ok(new_depth) + let rollback_to_savepoint_statement = self.inner.rollback_to_savepoint_statement(depth_val); + self.inner.raw_cmd(&rollback_to_savepoint_statement).await?; + + // Lock the mutex again to modify the depth + let mut depth = self.depth.lock().unwrap(); + *depth -= 1; + + Ok(()) } fn as_queryable(&self) -> &dyn Queryable { diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index 22bea71778a3..bf4d50eeea87 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -3,7 +3,6 @@ use std::future::Future; use async_trait::async_trait; use mobc::{Connection as MobcPooled, Manager}; use prisma_metrics::WithMetricsInstrumentation; -use std::borrow::Cow; use tracing_futures::WithSubscriber; #[cfg(feature = "mssql-native")] @@ -72,16 +71,8 @@ impl Queryable for PooledConnection { self.inner.server_reset_query(tx).await } - fn begin_statement(&self, depth: u32) -> Cow<'static, str> { - self.inner.begin_statement(depth) - } - - fn commit_statement(&self, depth: u32) -> Cow<'static, str> { - self.inner.commit_statement(depth) - } - - fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { - self.inner.rollback_statement(depth) + fn begin_statement(&self) -> &'static str { + self.inner.begin_statement() } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 9953d021ae02..13be8c4bc857 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -5,7 +5,7 @@ use crate::{ connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, }; use async_trait::async_trait; -use std::{borrow::Cow, fmt, sync::Arc}; +use std::{fmt, sync::Arc}; #[cfg(feature = "sqlite-native")] use std::convert::TryFrom; @@ -238,16 +238,8 @@ impl Queryable for Quaint { self.inner.is_healthy() } - fn begin_statement(&self, depth: u32) -> Cow<'static, str> { - self.inner.begin_statement(depth) - } - - fn commit_statement(&self, depth: u32) -> Cow<'static, str> { - self.inner.commit_statement(depth) - } - - fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { - self.inner.rollback_statement(depth) + fn begin_statement(&self) -> &'static str { + self.inner.begin_statement() } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/tests/query.rs b/quaint/src/tests/query.rs index 5c7c96360529..db65691e9c09 100644 --- a/quaint/src/tests/query.rs +++ b/quaint/src/tests/query.rs @@ -76,19 +76,19 @@ async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { assert_eq!(Value::int32(10), res[0]); // Check that nested transactions are also rolled back, even at multiple levels deep - tx.begin().await?; + tx.create_savepoint().await?; let inner_insert1 = Insert::single_into(&table).value("value", 20); let inner_rows_affected1 = tx.execute(inner_insert1.into()).await?; assert_eq!(1, inner_rows_affected1); // Open another nested transaction - tx.begin().await?; + tx.create_savepoint().await?; let inner_insert2 = Insert::single_into(&table).value("value", 20); let inner_rows_affected2 = tx.execute(inner_insert2.into()).await?; assert_eq!(1, inner_rows_affected2); - tx.commit().await?; + tx.release_savepoint().await?; - tx.commit().await?; + tx.release_savepoint().await?; tx.rollback().await?; diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs index 163c78c8a393..a1b6c3c4e3af 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs @@ -48,17 +48,17 @@ impl<'conn> Transaction for MongoDbTransaction<'conn> { Ok(()) } - async fn commit(&mut self) -> connector_interface::Result { + async fn commit(&mut self) -> connector_interface::Result<()> { self.gauge.decrement(); utils::commit_with_retry(&mut self.connection.session) .await .map_err(|err| MongoError::from(err).into_connector_error())?; - Ok(0) + Ok(()) } - async fn rollback(&mut self) -> connector_interface::Result { + async fn rollback(&mut self) -> connector_interface::Result<()> { self.gauge.decrement(); self.connection @@ -67,7 +67,24 @@ impl<'conn> Transaction for MongoDbTransaction<'conn> { .await .map_err(|err| MongoError::from(err).into_connector_error())?; - Ok(0) + Ok(()) + } + + fn depth(&self) -> u32 { + 0 + } + + /// MongoDB does not support savepoints/nested transactions. + async fn create_savepoint(&mut self) -> connector_interface::Result<()> { + Err(MongoError::Unsupported("MongoDB does not support savepoints".into()).into_connector_error()) + } + + async fn release_savepoint(&mut self) -> connector_interface::Result<()> { + Err(MongoError::Unsupported("MongoDB does not support savepoints".into()).into_connector_error()) + } + + async fn rollback_to_savepoint(&mut self) -> connector_interface::Result<()> { + Err(MongoError::Unsupported("MongoDB does not support savepoints".into()).into_connector_error()) } async fn version(&self) -> Option { diff --git a/query-engine/connectors/query-connector/src/interface.rs b/query-engine/connectors/query-connector/src/interface.rs index fccb9aaaccfa..60e816962541 100644 --- a/query-engine/connectors/query-connector/src/interface.rs +++ b/query-engine/connectors/query-connector/src/interface.rs @@ -37,8 +37,12 @@ pub trait Connection: ConnectionLike { #[async_trait] pub trait Transaction: ConnectionLike { async fn begin(&mut self) -> crate::Result<()>; - async fn commit(&mut self) -> crate::Result; - async fn rollback(&mut self) -> crate::Result; + async fn commit(&mut self) -> crate::Result<()>; + async fn rollback(&mut self) -> crate::Result<()>; + async fn create_savepoint(&mut self) -> crate::Result<()>; + async fn release_savepoint(&mut self) -> crate::Result<()>; + async fn rollback_to_savepoint(&mut self) -> crate::Result<()>; + fn depth(&self) -> u32; async fn version(&self) -> Option; diff --git a/query-engine/connectors/sql-query-connector/src/database/transaction.rs b/query-engine/connectors/sql-query-connector/src/database/transaction.rs index ce7e102b92ad..07a043c80b9d 100644 --- a/query-engine/connectors/sql-query-connector/src/database/transaction.rs +++ b/query-engine/connectors/sql-query-connector/src/database/transaction.rs @@ -38,6 +38,10 @@ impl<'tx> ConnectionLike for SqlConnectorTransaction<'tx> {} #[async_trait] impl<'tx> Transaction for SqlConnectorTransaction<'tx> { + fn depth(&self) -> u32 { + self.inner.depth() + } + async fn begin(&mut self) -> connector::Result<()> { catch(&self.connection_info, async { self.inner.begin().await.map_err(SqlError::from) @@ -45,25 +49,46 @@ impl<'tx> Transaction for SqlConnectorTransaction<'tx> { .await } - async fn commit(&mut self) -> connector::Result { + async fn commit(&mut self) -> connector::Result<()> { catch(&self.connection_info, async { self.inner.commit().await.map_err(SqlError::from) }) .await } - async fn rollback(&mut self) -> connector::Result { + async fn rollback(&mut self) -> connector::Result<()> { catch(&self.connection_info, async { let res = self.inner.rollback().await.map_err(SqlError::from); match res { - Err(SqlError::TransactionAlreadyClosed(_)) | Err(SqlError::RollbackWithoutBegin) => Ok(0), + Err(SqlError::TransactionAlreadyClosed(_)) | Err(SqlError::RollbackWithoutBegin) => Ok(()), _ => res, } }) .await } + async fn create_savepoint(&mut self) -> connector::Result<()> { + catch(&self.connection_info, async { + self.inner.create_savepoint().await.map_err(SqlError::from) + }) + .await + } + + async fn release_savepoint(&mut self) -> connector::Result<()> { + catch(&self.connection_info, async { + self.inner.release_savepoint().await.map_err(SqlError::from) + }) + .await + } + + async fn rollback_to_savepoint(&mut self) -> connector::Result<()> { + catch(&self.connection_info, async { + self.inner.rollback_to_savepoint().await.map_err(SqlError::from) + }) + .await + } + async fn version(&self) -> Option { self.connection_info.version().map(|v| v.to_string()) } diff --git a/query-engine/core/src/executor/interpreting_executor.rs b/query-engine/core/src/executor/interpreting_executor.rs index 0f73c99d2ddc..2e391461c718 100644 --- a/query-engine/core/src/executor/interpreting_executor.rs +++ b/query-engine/core/src/executor/interpreting_executor.rs @@ -189,7 +189,7 @@ where self.itx_manager.commit_tx(&tx_id).await } - async fn rollback_tx(&self, tx_id: TxId) -> crate::Result { + async fn rollback_tx(&self, tx_id: TxId) -> crate::Result<()> { self.itx_manager.rollback_tx(&tx_id).await } } diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index 0146a94eb93e..25933d46e2ee 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -118,5 +118,5 @@ pub trait TransactionManager { async fn commit_tx(&self, tx_id: TxId) -> crate::Result<()>; /// Rolls back a transaction. - async fn rollback_tx(&self, tx_id: TxId) -> crate::Result; + async fn rollback_tx(&self, tx_id: TxId) -> crate::Result<()>; } diff --git a/query-engine/core/src/interactive_transactions/manager.rs b/query-engine/core/src/interactive_transactions/manager.rs index 654a9696e8f5..823e45367ea7 100644 --- a/query-engine/core/src/interactive_transactions/manager.rs +++ b/query-engine/core/src/interactive_transactions/manager.rs @@ -112,7 +112,14 @@ impl ItxManager { // Only create a client if there is no client for this transaction yet. // otherwise, begin a new transaction/savepoint for the existing client. if self.transactions.read().await.contains_key(&tx_id) { - let _ = self.get_transaction(&tx_id, "begin").await?.lock().await.begin().await; + let transaction_entry = self.get_transaction(&tx_id, "begin").await?; + let mut tx = transaction_entry.lock().await; + // If the transaction is already open, we need to create a savepoint. + if tx.depth() > 0 { + tx.create_savepoint().await?; + } else { + tx.begin().await?; + } } else { // This task notifies the task spawned in `new()` method that the timeout for this // transaction has expired. @@ -184,15 +191,24 @@ impl ItxManager { } pub async fn commit_tx(&self, tx_id: &TxId) -> crate::Result<()> { - self.get_transaction(tx_id, "commit").await?.lock().await.commit().await + let transaction_entry = self.get_transaction(tx_id, "commit").await?; + let mut tx = transaction_entry.lock().await; + let depth = tx.depth(); + if depth > 1 { + tx.release_savepoint().await + } else { + tx.commit().await + } } - pub async fn rollback_tx(&self, tx_id: &TxId) -> crate::Result { - self.get_transaction(tx_id, "rollback") - .await? - .lock() - .await - .rollback(false) - .await + pub async fn rollback_tx(&self, tx_id: &TxId) -> crate::Result<()> { + let transaction_entry = self.get_transaction(tx_id, "rollback").await?; + let mut tx = transaction_entry.lock().await; + let depth = tx.depth(); + if depth > 1 { + tx.rollback_to_savepoint().await + } else { + tx.rollback(false).await + } } } diff --git a/query-engine/core/src/interactive_transactions/transaction.rs b/query-engine/core/src/interactive_transactions/transaction.rs index 505ff1b622ee..27bdd8828b4c 100644 --- a/query-engine/core/src/interactive_transactions/transaction.rs +++ b/query-engine/core/src/interactive_transactions/transaction.rs @@ -197,6 +197,13 @@ impl InteractiveTransaction { }) } + pub fn depth(&mut self) -> u32 { + match self.state.as_open("depth") { + Ok(state) => state.depth(), + Err(_) => 0, + } + } + pub async fn begin(&mut self) -> crate::Result<()> { tx_timeout!(self, "begin", async { let name = self.name(); @@ -223,9 +230,7 @@ impl InteractiveTransaction { match conn.commit().instrument(span).await { Ok(depth) => { debug!(?depth, ?name, "transaction committed"); - if depth == 0 { - self.state = TransactionState::Committed; - } + self.state = TransactionState::Committed; Ok(()) } Err(err) => { @@ -240,7 +245,7 @@ impl InteractiveTransaction { }) } - pub async fn rollback(&mut self, was_timeout: bool) -> crate::Result { + pub async fn rollback(&mut self, was_timeout: bool) -> crate::Result<()> { let name = self.name(); let conn = self.state.as_open("rollback")?; let span = info_span!("prisma:engine:itx_rollback", user_facing = true); @@ -265,6 +270,58 @@ impl InteractiveTransaction { result.map_err(<_>::into) } + pub async fn create_savepoint(&mut self) -> crate::Result<()> { + tx_timeout!(self, "create savepoint", async { + let name = self.name(); + let conn = self.state.as_open("create_savepoint")?; + let span = info_span!("prisma:engine:itx_create_savepoint", user_facing = true); + + if let Err(err) = conn.create_savepoint().instrument(span).await { + error!(?err, ?name, "transaction failed to create savepoint"); + let _ = self.rollback(false).await; + Err(err.into()) + } else { + debug!(?name, "savepoint created"); + Ok(()) + } + }) + } + + pub async fn release_savepoint(&mut self) -> crate::Result<()> { + tx_timeout!(self, "release savepoint", async { + let name = self.name(); + let conn = self.state.as_open("release_savepoint")?; + let span = info_span!("prisma:engine:itx_release_savepoint", user_facing = true); + + match conn.release_savepoint().instrument(span).await { + Ok(()) => { + debug!(?name, "savepoint released"); + Ok(()) + } + Err(err) => { + error!(?err, ?name, "transaction failed to release savepoint"); + let _ = self.rollback(false).await; + Err(err.into()) + } + } + }) + } + + pub async fn rollback_to_savepoint(&mut self) -> crate::Result<()> { + let name = self.name(); + let conn = self.state.as_open("rollback_to_savepoint")?; + let span = info_span!("prisma:engine:itx_rollback_to_savepoint", user_facing = true); + + let result = conn.rollback_to_savepoint().instrument(span).await; + if let Err(err) = &result { + error!(?err, ?name, "transaction failed to roll back to savepoint"); + } else { + debug!(?name, "transaction rolled back to savepoint"); + } + + result.map_err(<_>::into) + } + pub fn as_closed(&self) -> Option { self.state.as_closed() } diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index d159a2f0de6c..acf52668d8ce 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -332,14 +332,14 @@ impl JsQueryable { tx.depth += 1; - let begin_stmt = tx.begin_statement(tx.depth); + let begin_stmt = tx.begin_statement(); let tx_opts = tx.options(); if tx_opts.use_phantom_query { - let begin_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); + let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); tx.raw_phantom_cmd(begin_stmt.as_str()).await?; } else { - tx.raw_cmd(&begin_stmt).await?; + tx.raw_cmd(begin_stmt).await?; } // 4. Set the isolation level (if specified) if we didn't do it before. diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs index dfc3a920c90a..ccfc4339c129 100644 --- a/query-engine/driver-adapters/src/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, future::Future}; +use std::future::Future; use async_trait::async_trait; use prisma_metrics::gauge; @@ -117,34 +117,39 @@ impl JsTransaction { #[async_trait] impl QuaintTransaction for JsTransaction { + fn depth(&self) -> u32 { + self.depth + } + async fn begin(&mut self) -> quaint::Result<()> { // increment of this gauge is done in DriverProxy::startTransaction gauge!("prisma_client_queries_active").decrement(1.0); self.depth += 1; - let begin_stmt = self.begin_statement(self.depth); + let begin_stmt = self.begin_statement(); if self.options().use_phantom_query { - let commit_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); + let commit_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); self.raw_phantom_cmd(commit_stmt.as_str()).await?; } else { - self.inner.raw_cmd(&begin_stmt).await?; + self.inner.raw_cmd(begin_stmt).await?; } UnsafeFuture(self.tx_proxy.begin()).await } - async fn commit(&mut self) -> quaint::Result { + + async fn commit(&mut self) -> quaint::Result<()> { // increment of this gauge is done in DriverProxy::startTransaction gauge!("prisma_client_queries_active").decrement(1.0); - let commit_stmt = self.commit_statement(self.depth); + let commit_stmt = "COMMIT"; if self.options().use_phantom_query { - let commit_stmt = JsBaseQueryable::phantom_query_message(&commit_stmt); + let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); self.raw_phantom_cmd(commit_stmt.as_str()).await?; } else { - self.inner.raw_cmd(&commit_stmt).await?; + self.inner.raw_cmd(commit_stmt).await?; } let _ = UnsafeFuture(self.tx_proxy.commit()).await; @@ -152,17 +157,20 @@ impl QuaintTransaction for JsTransaction { // Modify the depth value self.depth -= 1; - Ok(self.depth) + Ok(()) } - async fn rollback(&mut self) -> quaint::Result { - let rollback_stmt = self.rollback_statement(self.depth); + async fn rollback(&mut self) -> quaint::Result<()> { + // increment of this gauge is done in DriverProxy::startTransaction + gauge!("prisma_client_queries_active").decrement(1.0); + + let rollback_stmt = "ROLLBACK"; if self.options().use_phantom_query { - let rollback_stmt = JsBaseQueryable::phantom_query_message(&rollback_stmt); + let rollback_stmt = JsBaseQueryable::phantom_query_message(rollback_stmt); self.raw_phantom_cmd(rollback_stmt.as_str()).await?; } else { - self.inner.raw_cmd(&rollback_stmt).await?; + self.inner.raw_cmd(rollback_stmt).await?; } let _ = UnsafeFuture(self.tx_proxy.rollback()).await; @@ -170,7 +178,44 @@ impl QuaintTransaction for JsTransaction { // Modify the depth value self.depth -= 1; - Ok(self.depth) + Ok(()) + } + + async fn create_savepoint(&mut self) -> quaint::Result<()> { + let create_savepoint_statement = self.create_savepoint_statement(self.depth); + if self.options().use_phantom_query { + let create_savepoint_statement = JsBaseQueryable::phantom_query_message(&create_savepoint_statement); + self.raw_phantom_cmd(create_savepoint_statement.as_str()).await?; + } else { + self.inner.raw_cmd(&create_savepoint_statement).await?; + } + + Ok(()) + } + + async fn release_savepoint(&mut self) -> quaint::Result<()> { + let release_savepoint_statement = self.release_savepoint_statement(self.depth); + if self.options().use_phantom_query { + let release_savepoint_statement = JsBaseQueryable::phantom_query_message(&release_savepoint_statement); + self.raw_phantom_cmd(release_savepoint_statement.as_str()).await?; + } else { + self.inner.raw_cmd(&release_savepoint_statement).await?; + } + + Ok(()) + } + + async fn rollback_to_savepoint(&mut self) -> quaint::Result<()> { + let rollback_to_savepoint_statement = self.rollback_to_savepoint_statement(self.depth); + if self.options().use_phantom_query { + let rollback_to_savepoint_statement = + JsBaseQueryable::phantom_query_message(&rollback_to_savepoint_statement); + self.raw_phantom_cmd(rollback_to_savepoint_statement.as_str()).await?; + } else { + self.inner.raw_cmd(&rollback_to_savepoint_statement).await?; + } + + Ok(()) } fn as_queryable(&self) -> &dyn Queryable { @@ -228,16 +273,8 @@ impl Queryable for JsTransaction { self.inner.requires_isolation_first() } - fn begin_statement(&self, depth: u32) -> Cow<'static, str> { - self.inner.begin_statement(depth) - } - - fn commit_statement(&self, depth: u32) -> Cow<'static, str> { - self.inner.commit_statement(depth) - } - - fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { - self.inner.rollback_statement(depth) + fn begin_statement(&self) -> &'static str { + self.inner.begin_statement() } }