diff --git a/Makefile b/Makefile index ec16c50b9dc2..9f16ba8ebb70 100644 --- a/Makefile +++ b/Makefile @@ -407,8 +407,8 @@ ensure-prisma-present: echo "⚠️ ../prisma diverges from prisma/prisma main branch. Test results might diverge from those in CI ⚠️ "; \ fi \ else \ - echo "git clone --depth=1 https://github.com/prisma/prisma.git --branch=$(DRIVER_ADAPTERS_BRANCH) ../prisma"; \ - git clone --depth=1 https://github.com/prisma/prisma.git --branch=$(DRIVER_ADAPTERS_BRANCH) "../prisma" && echo "Prisma repository has been cloned to ../prisma"; \ + echo "git clone --depth=1 https://github.com/LucianBuzzo/prisma.git --branch=lucianbuzzo/nested-rollbacks ../prisma"; \ + git clone --depth=1 https://github.com/LucianBuzzo/prisma.git --branch=lucianbuzzo/nested-rollbacks "../prisma" && echo "Prisma repository has been cloned to ../prisma"; \ fi; # Quick schema validation of whatever you have in the dev_datamodel.prisma file. diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index 7383e503d0ab..3579114a3648 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -18,7 +18,10 @@ use futures::lock::Mutex; use std::{ convert::TryFrom, future::Future, - sync::atomic::{AtomicBool, Ordering}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, time::Duration, }; use tiberius::*; @@ -45,11 +48,13 @@ impl TransactionCapable for Mssql { .or(self.url.query_params.transaction_isolation_level) .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); - let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); + let opts = TransactionOptions::new( + isolation, + self.requires_isolation_first(), + self.transaction_depth.clone(), + ); - Ok(Box::new( - DefaultTransaction::new(self, self.begin_statement(), opts).await?, - )) + Ok(Box::new(DefaultTransaction::new(self, opts).await?)) } } @@ -60,6 +65,7 @@ pub struct Mssql { url: MssqlUrl, socket_timeout: Option, is_healthy: AtomicBool, + transaction_depth: Arc>, } impl Mssql { @@ -91,6 +97,7 @@ impl Mssql { url, socket_timeout, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(Mutex::new(0)), }; if let Some(isolation) = this.url.transaction_isolation_level() { @@ -243,8 +250,41 @@ impl Queryable for Mssql { Ok(()) } - fn begin_statement(&self) -> &'static str { - "BEGIN TRAN" + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVE TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN TRAN".to_string() + }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + // MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested + // transaction we just continue onwards + let ret = if depth > 1 { + " ".to_string() + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; } fn requires_isolation_first(&self) -> bool { diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index b4b23ab94cb8..6465db6684ef 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -23,7 +23,10 @@ use mysql_async::{ }; use std::{ future::Future, - sync::atomic::{AtomicBool, Ordering}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, time::Duration, }; use tokio::sync::Mutex; @@ -76,6 +79,7 @@ pub struct Mysql { socket_timeout: Option, is_healthy: AtomicBool, statement_cache: Mutex>, + transaction_depth: Arc>, } impl Mysql { @@ -89,6 +93,7 @@ impl Mysql { statement_cache: Mutex::new(url.cache()), url, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -345,4 +350,36 @@ impl Queryable for Mysql { fn requires_isolation_first(&self) -> bool { true } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; + } } diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 805ba13a6021..2a5fee9450f4 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -33,7 +33,10 @@ use std::{ fmt::{Debug, Display}, fs, future::Future, - sync::atomic::{AtomicBool, Ordering}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, time::Duration, }; use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; @@ -61,6 +64,7 @@ pub struct PostgreSql { is_healthy: AtomicBool, is_cockroachdb: bool, is_materialize: bool, + transaction_depth: Arc>, } /// Key uniquely representing an SQL statement in the prepared statements cache. @@ -289,6 +293,7 @@ impl PostgreSql { is_healthy: AtomicBool::new(true), is_cockroachdb, is_materialize, + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -763,6 +768,39 @@ impl Queryable for PostgreSql { fn requires_isolation_first(&self) -> bool { false } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + println!("pg connector: Transaction depth: {}", depth); + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; + } } /// Sorted list of CockroachDB's reserved keywords. diff --git a/quaint/src/connector/queryable.rs b/quaint/src/connector/queryable.rs index 5f0fd54dad6b..8aed583a7a0e 100644 --- a/quaint/src/connector/queryable.rs +++ b/quaint/src/connector/queryable.rs @@ -90,8 +90,36 @@ pub trait Queryable: Send + Sync { } /// Statement to begin a transaction - fn begin_statement(&self) -> &'static str { - "BEGIN" + async fn begin_statement(&self, depth: i32) -> String { + println!("connector: Transaction depth: {}", depth); + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; } /// Sets the transaction isolation level to given value. @@ -120,10 +148,14 @@ macro_rules! impl_default_TransactionCapable { &'a self, isolation: Option, ) -> crate::Result> { - let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first()); + let opts = crate::connector::TransactionOptions::new( + isolation, + self.requires_isolation_first(), + self.transaction_depth.clone(), + ); Ok(Box::new( - crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?, + crate::connector::DefaultTransaction::new(self, opts).await?, )) } } diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index abcec7410a67..902c5fb91cbc 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -17,7 +17,7 @@ use crate::{ visitor::{self, Visitor}, }; use async_trait::async_trait; -use std::convert::TryFrom; +use std::{convert::TryFrom, sync::Arc}; use tokio::sync::Mutex; /// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. @@ -27,6 +27,7 @@ pub use rusqlite; /// A connector interface for the SQLite database pub struct Sqlite { pub(crate) client: Mutex, + transaction_depth: Arc>, } impl TryFrom<&str> for Sqlite { @@ -64,7 +65,10 @@ impl TryFrom<&str> for Sqlite { let client = Mutex::new(conn); - Ok(Sqlite { client }) + Ok(Sqlite { + client, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), + }) } } @@ -79,6 +83,7 @@ impl Sqlite { Ok(Sqlite { client: Mutex::new(client), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -181,12 +186,44 @@ impl Queryable for Sqlite { false } - fn begin_statement(&self) -> &'static str { + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); // From https://sqlite.org/isolation.html: // `BEGIN IMMEDIATE` avoids possible `SQLITE_BUSY_SNAPSHOT` that arise when another connection jumps ahead in line. // The BEGIN IMMEDIATE command goes ahead and starts a write transaction, and thus blocks all other writers. // If the BEGIN IMMEDIATE operation succeeds, then no subsequent operations in that transaction will ever fail with an SQLITE_BUSY error. - "BEGIN IMMEDIATE" + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN IMMEDIATE".to_string() + }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; } } diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index df4084883e80..20e1ee6b0298 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -4,18 +4,22 @@ use crate::{ error::{Error, ErrorKind}, }; use async_trait::async_trait; +use futures::lock::Mutex; use metrics::{decrement_gauge, increment_gauge}; -use std::{fmt, str::FromStr}; +use std::{fmt, str::FromStr, sync::Arc}; extern crate metrics as metrics; #[async_trait] pub trait Transaction: Queryable { /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()>; + async fn begin(&mut self) -> crate::Result<()>; + + /// Commit the changes to the database and consume the transaction. + async fn commit(&mut self) -> crate::Result; /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()>; + async fn rollback(&mut self) -> crate::Result; /// workaround for lack of upcasting between traits https://github.com/rust-lang/rust/issues/65991 fn as_queryable(&self) -> &dyn Queryable; @@ -27,6 +31,9 @@ pub(crate) struct TransactionOptions { /// Whether or not to put the isolation level `SET` before or after the `BEGIN`. pub(crate) isolation_first: bool, + + /// The depth of the transaction, used to determine the nested transaction statements. + pub depth: Arc>, } /// A default representation of an SQL database transaction. If not commited, a @@ -36,15 +43,18 @@ pub(crate) struct TransactionOptions { /// transaction object will panic. pub struct DefaultTransaction<'a> { pub inner: &'a dyn Queryable, + pub depth: Arc>, } impl<'a> DefaultTransaction<'a> { pub(crate) async fn new( inner: &'a dyn Queryable, - begin_stmt: &str, tx_opts: TransactionOptions, ) -> crate::Result> { - let this = Self { inner }; + let mut this = Self { + inner, + depth: tx_opts.depth, + }; if tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -52,7 +62,7 @@ impl<'a> DefaultTransaction<'a> { } } - inner.raw_cmd(begin_stmt).await?; + this.begin().await?; if !tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -62,27 +72,63 @@ impl<'a> DefaultTransaction<'a> { inner.server_reset_query(&this).await?; - increment_gauge!("prisma_client_queries_active", 1.0); Ok(this) } } #[async_trait] impl<'a> Transaction for DefaultTransaction<'a> { + async fn begin(&mut self) -> crate::Result<()> { + increment_gauge!("prisma_client_queries_active", 1.0); + + let mut depth_guard = self.depth.lock().await; + + // Modify the depth value through the MutexGuard + *depth_guard += 1; + + let st_depth = *depth_guard; + + let begin_statement = self.inner.begin_statement(st_depth).await; + + self.inner.raw_cmd(&begin_statement).await?; + + Ok(()) + } + /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()> { + async fn commit(&mut self) -> crate::Result { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("COMMIT").await?; - Ok(()) + let mut depth_guard = self.depth.lock().await; + + let st_depth = *depth_guard; + + let commit_statement = self.inner.commit_statement(st_depth).await; + + self.inner.raw_cmd(&commit_statement).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + Ok(*depth_guard) } /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()> { + async fn rollback(&mut self) -> crate::Result { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("ROLLBACK").await?; - Ok(()) + let mut depth_guard = self.depth.lock().await; + + let st_depth = *depth_guard; + + let rollback_statement = self.inner.rollback_statement(st_depth).await; + + self.inner.raw_cmd(&rollback_statement).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + Ok(*depth_guard) } fn as_queryable(&self) -> &dyn Queryable { @@ -194,10 +240,11 @@ impl FromStr for IsolationLevel { } } impl TransactionOptions { - pub fn new(isolation_level: Option, isolation_first: bool) -> Self { + pub fn new(isolation_level: Option, isolation_first: bool, depth: Arc>) -> Self { Self { isolation_level, isolation_first, + depth, } } } diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 381f0c824149..9bacf46d4214 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -507,7 +507,10 @@ impl Quaint { } }; - Ok(PooledConnection { inner }) + Ok(PooledConnection { + inner, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), + }) } /// Info about the connection and underlying database. diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index 7533dffcfcc5..e20bf7a341a8 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -10,12 +10,15 @@ use crate::{ error::Error, }; use async_trait::async_trait; +use futures::lock::Mutex; use mobc::{Connection as MobcPooled, Manager}; +use std::sync::Arc; /// A connection from the pool. Implements /// [Queryable](connector/trait.Queryable.html). pub struct PooledConnection { pub(crate) inner: MobcPooled, + pub transaction_depth: Arc>, } impl_default_TransactionCapable!(PooledConnection); @@ -66,8 +69,16 @@ impl Queryable for PooledConnection { self.inner.server_reset_query(tx).await } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/single.rs b/quaint/src/single.rs index cbf460c41509..968066538674 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -5,6 +5,7 @@ use crate::{ connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, }; use async_trait::async_trait; +use futures::lock::Mutex; use std::{fmt, sync::Arc}; #[cfg(feature = "sqlite-native")] @@ -18,6 +19,7 @@ use crate::connector::NativeConnectionInfo; pub struct Quaint { inner: Arc, connection_info: Arc, + transaction_depth: Arc>, } impl fmt::Debug for Quaint { @@ -165,7 +167,11 @@ impl Quaint { let connection_info = Arc::new(ConnectionInfo::from_url(url_str)?); Self::log_start(&connection_info); - Ok(Self { inner, connection_info }) + Ok(Self { + inner, + connection_info, + transaction_depth: Arc::new(Mutex::new(0)), + }) } #[cfg(feature = "sqlite-native")] @@ -178,6 +184,7 @@ impl Quaint { connection_info: Arc::new(ConnectionInfo::Native(NativeConnectionInfo::InMemorySqlite { db_name: DEFAULT_SQLITE_DATABASE.to_owned(), })), + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -237,8 +244,16 @@ impl Queryable for Quaint { self.inner.is_healthy() } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/tests/query.rs b/quaint/src/tests/query.rs index 6e83297a9a75..914921f1c990 100644 --- a/quaint/src/tests/query.rs +++ b/quaint/src/tests/query.rs @@ -64,7 +64,7 @@ async fn select_star_from(api: &mut dyn TestApi) -> crate::Result<()> { async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { let table = api.create_temp_table("value int").await?; - let tx = api.conn().start_transaction(None).await?; + let mut tx = api.conn().start_transaction(None).await?; let insert = Insert::single_into(&table).value("value", 10); let rows_affected = tx.execute(insert.into()).await?; @@ -75,6 +75,20 @@ async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { assert_eq!(Value::int32(10), res[0]); + // Check that nested transactions are also rolled back, even at multiple levels deep + let mut tx_inner = api.conn().start_transaction(None).await?; + let inner_insert1 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected1 = tx.execute(inner_insert1.into()).await?; + assert_eq!(1, inner_rows_affected1); + + let mut tx_inner2 = api.conn().start_transaction(None).await?; + let inner_insert2 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected2 = tx.execute(inner_insert2.into()).await?; + assert_eq!(1, inner_rows_affected2); + tx_inner2.commit().await?; + + tx_inner.commit().await?; + tx.rollback().await?; let select = Select::from_table(&table).column("value"); diff --git a/quaint/src/tests/query/error.rs b/quaint/src/tests/query/error.rs index 399866bd4a3b..424d2a0348ea 100644 --- a/quaint/src/tests/query/error.rs +++ b/quaint/src/tests/query/error.rs @@ -456,7 +456,7 @@ async fn concurrent_transaction_conflict(api: &mut dyn TestApi) -> crate::Result let conn1 = api.create_additional_connection().await?; let conn2 = api.create_additional_connection().await?; - let tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; + let mut tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; let tx2 = conn2.start_transaction(Some(IsolationLevel::Serializable)).await?; tx1.query(Select::from_table(&table).into()).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs index b8827a3bf009..d156f5a50845 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs @@ -8,7 +8,7 @@ mod interactive_tx { #[connector_test] async fn basic_commit_workflow(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -35,7 +35,7 @@ mod interactive_tx { #[connector_test] async fn basic_rollback_workflow(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -63,7 +63,7 @@ mod interactive_tx { #[connector_test] async fn tx_expiration_cycle(mut runner: Runner) -> TestResult<()> { // Tx expires after one second. - let tx_id = runner.start_tx(5000, 1000, None).await?; + let tx_id = runner.start_tx(5000, 1000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -85,7 +85,6 @@ mod interactive_tx { let error = res.err().unwrap(); let known_err = error.as_known().unwrap(); - println!("KNOWN ERROR {known_err:?}"); assert_eq!(known_err.error_code, Cow::Borrowed("P2028")); assert!(known_err @@ -108,7 +107,7 @@ mod interactive_tx { #[connector_test] async fn no_auto_rollback(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); // Row is created @@ -135,7 +134,7 @@ mod interactive_tx { #[connector_test(only(Postgres))] async fn raw_queries(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -164,7 +163,7 @@ mod interactive_tx { #[connector_test] async fn batch_queries_success(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); let queries = vec![ @@ -190,7 +189,7 @@ mod interactive_tx { #[connector_test] async fn batch_queries_rollback(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); let queries = vec![ @@ -216,7 +215,7 @@ mod interactive_tx { #[connector_test(exclude(Sqlite("cfd1")))] async fn batch_queries_failure(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); // One dup key, will cause failure of the batch. @@ -259,7 +258,7 @@ mod interactive_tx { #[connector_test] async fn tx_expiration_failure_cycle(mut runner: Runner) -> TestResult<()> { // Tx expires after one seconds. - let tx_id = runner.start_tx(5000, 1000, None).await?; + let tx_id = runner.start_tx(5000, 1000, None, None).await?; runner.set_active_tx(tx_id.clone()); // Row is created @@ -328,10 +327,10 @@ mod interactive_tx { #[connector_test(exclude(Sqlite))] async fn multiple_tx(mut runner: Runner) -> TestResult<()> { // First transaction. - let tx_id_a = runner.start_tx(2000, 2000, None).await?; + let tx_id_a = runner.start_tx(2000, 2000, None, None).await?; // Second transaction. - let tx_id_b = runner.start_tx(2000, 2000, None).await?; + let tx_id_b = runner.start_tx(2000, 2000, None, None).await?; // Execute on first transaction. runner.set_active_tx(tx_id_a.clone()); @@ -379,10 +378,10 @@ mod interactive_tx { ); // First transaction. - let tx_id_a = runner.start_tx(5000, 5000, Some("Serializable".into())).await?; + let tx_id_a = runner.start_tx(5000, 5000, Some("Serializable".into()), None).await?; // Second transaction. - let tx_id_b = runner.start_tx(5000, 5000, Some("Serializable".into())).await?; + let tx_id_b = runner.start_tx(5000, 5000, Some("Serializable".into()), None).await?; // Read on first transaction. runner.set_active_tx(tx_id_a.clone()); @@ -421,7 +420,7 @@ mod interactive_tx { #[connector_test] async fn double_commit(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -456,9 +455,81 @@ mod interactive_tx { Ok(()) } + #[connector_test(only(Postgres))] + async fn nested_commit_workflow(mut runner: Runner) -> TestResult<()> { + // Start the outer transaction + let outer_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(outer_tx_id.clone()); + + // Start the inner transaction + let inner_tx_id = runner.start_tx(5000, 5000, None, Some(outer_tx_id.clone())).await?; + runner.set_active_tx(inner_tx_id.clone()); + + // Perform operations in the inner transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 1 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":1}}}"### + ); + + let res = runner.commit_tx(inner_tx_id).await?; + assert!(res.is_ok()); + + // Perform operations in the outer transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 2 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":2}}}"### + ); + + let res = runner.commit_tx(outer_tx_id).await?; + assert!(res.is_ok()); + + Ok(()) + } + + #[connector_test(only(Postgres))] + async fn nested_commit_rollback_workflow(mut runner: Runner) -> TestResult<()> { + // Start the outer transaction + let outer_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(outer_tx_id.clone()); + + // Start the inner transaction + let inner_tx_id = runner.start_tx(5000, 5000, None, Some(outer_tx_id.clone())).await?; + runner.set_active_tx(inner_tx_id.clone()); + + // Perform operations in the inner transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 1 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":1}}}"### + ); + + let res = runner.commit_tx(inner_tx_id).await?; + assert!(res.is_ok()); + + // Perform operations in the outer transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 2 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":2}}}"### + ); + + // Now rollback the outer transaction + let res = runner.rollback_tx(outer_tx_id).await?; + assert!(res.is_ok()); + + // Assert that no records were written to the DB + let result_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(result_tx_id.clone()); + insta::assert_snapshot!( + run_query!(&runner, r#"query { findManyTestModel { id field }}"#), + @r###"{"data":{"findManyTestModel":[]}}"### + ); + let _ = runner.commit_tx(result_tx_id).await?; + + Ok(()) + } + #[connector_test] async fn double_rollback(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -495,7 +566,7 @@ mod interactive_tx { #[connector_test] async fn commit_after_rollback(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -532,7 +603,7 @@ mod interactive_tx { #[connector_test] async fn rollback_after_commit(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -575,7 +646,9 @@ mod itx_isolation { // All (SQL) connectors support serializable. #[connector_test(exclude(MongoDb, Sqlite("cfd1")))] async fn basic_serializable(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("Serializable".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -597,7 +670,9 @@ mod itx_isolation { #[connector_test(exclude(MongoDb, Sqlite("cfd1")))] async fn casing_doesnt_matter(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; @@ -608,13 +683,17 @@ mod itx_isolation { #[connector_test(only(Postgres))] async fn spacing_doesnt_matter(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Repeatable Read".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("Repeatable Read".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; assert!(res.is_ok()); - let tx_id = runner.start_tx(5000, 5000, Some("RepeatableRead".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("RepeatableRead".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; @@ -625,7 +704,7 @@ mod itx_isolation { #[connector_test(exclude(MongoDb))] async fn invalid_isolation(runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("test".to_owned())).await; + let tx_id = runner.start_tx(5000, 5000, Some("test".to_owned()), None).await; match tx_id { Ok(_) => panic!("Expected invalid isolation level string to throw an error, but it succeeded instead."), @@ -638,7 +717,7 @@ mod itx_isolation { // Mongo doesn't support isolation levels. #[connector_test(only(MongoDb))] async fn mongo_failure(runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await; + let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned()), None).await; match tx_id { Ok(_) => panic!("Expected mongo to throw an unsupported error, but it succeeded instead."), diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs index 2a1cf89e9d3b..18cabcf7cca9 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs @@ -48,7 +48,7 @@ mod metrics { #[connector_test] async fn metrics_tx_do_not_go_negative(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -65,7 +65,7 @@ mod metrics { let active_transactions = utils::metrics::get_gauge(&json, PRISMA_CLIENT_QUERIES_ACTIVE); assert_eq!(active_transactions, 0.0); - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs index a9b6c4395760..49ea6597ff6b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs @@ -90,7 +90,7 @@ mod mongodb { } async fn start_itx(runner: &mut Runner) -> TestResult { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); Ok(tx_id) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs index e026a90016bd..733a5b12e01a 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs @@ -83,7 +83,7 @@ impl Actor { response_sender.send(Response::Query(result)).await.unwrap(); } Message::BeginTransaction => { - let response = with_logs(runner.start_tx(10000, 10000, None), log_tx.clone()).await; + let response = with_logs(runner.start_tx(10000, 10000, None, None), log_tx.clone()).await; response_sender.send(Response::Tx(response)).await.unwrap(); } Message::RollbackTransaction(tx_id) => { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs index 313a29cdacf4..0b1e3244e420 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs @@ -5,7 +5,9 @@ mod sqlite { #[connector_test] async fn close_tx_on_error(runner: Runner) -> TestResult<()> { // Try to open a transaction with unsupported isolation error in SQLite. - let result = runner.start_tx(2000, 5000, Some("ReadUncommitted".to_owned())).await; + let result = runner + .start_tx(2000, 5000, Some("ReadUncommitted".to_owned()), None) + .await; assert!(result.is_err()); // Without the changes from https://github.com/prisma/prisma-engines/pull/4286 or @@ -16,7 +18,7 @@ mod sqlite { // IMMEDIATE if we had control over SQLite transaction type here, as that would not rely on // both transactions using the same connection if we were to pool multiple SQLite // connections in the future. - let tx = runner.start_tx(2000, 5000, None).await?; + let tx = runner.start_tx(2000, 5000, None, None).await?; runner.rollback_tx(tx).await?.unwrap(); Ok(()) diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs index e5808ace7fcc..77861b80d7ac 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs @@ -509,8 +509,9 @@ impl Runner { max_acquisition_millis: u64, valid_for_millis: u64, isolation_level: Option, + new_tx_id: Option, ) -> TestResult { - let tx_opts = TransactionOptions::new(max_acquisition_millis, valid_for_millis, isolation_level); + let tx_opts = TransactionOptions::new(max_acquisition_millis, valid_for_millis, isolation_level, new_tx_id); match &self.executor { RunnerExecutor::Builtin(executor) => { let id = executor diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs index 2fe5d840fa1f..d2baa1bf5970 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs @@ -39,17 +39,21 @@ impl<'conn> MongoDbTransaction<'conn> { #[async_trait] impl<'conn> Transaction for MongoDbTransaction<'conn> { - async fn commit(&mut self) -> connector_interface::Result<()> { + async fn begin(&mut self) -> connector_interface::Result<()> { + Ok(()) + } + + async fn commit(&mut self) -> connector_interface::Result { decrement_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); utils::commit_with_retry(&mut self.connection.session) .await .map_err(|err| MongoError::from(err).into_connector_error())?; - Ok(()) + Ok(0) } - async fn rollback(&mut self) -> connector_interface::Result<()> { + async fn rollback(&mut self) -> connector_interface::Result { decrement_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); self.connection @@ -58,7 +62,7 @@ impl<'conn> Transaction for MongoDbTransaction<'conn> { .await .map_err(|err| MongoError::from(err).into_connector_error())?; - Ok(()) + Ok(0) } async fn version(&self) -> Option { diff --git a/query-engine/connectors/query-connector/src/interface.rs b/query-engine/connectors/query-connector/src/interface.rs index cbdafcaeeee3..2c41e674c3ce 100644 --- a/query-engine/connectors/query-connector/src/interface.rs +++ b/query-engine/connectors/query-connector/src/interface.rs @@ -32,8 +32,9 @@ pub trait Connection: ConnectionLike { #[async_trait] pub trait Transaction: ConnectionLike { - async fn commit(&mut self) -> crate::Result<()>; - async fn rollback(&mut self) -> crate::Result<()>; + async fn begin(&mut self) -> crate::Result<()>; + async fn commit(&mut self) -> crate::Result; + async fn rollback(&mut self) -> crate::Result; async fn version(&self) -> Option; diff --git a/query-engine/connectors/sql-query-connector/src/database/transaction.rs b/query-engine/connectors/sql-query-connector/src/database/transaction.rs index 263c541f6b42..75e58483d12e 100644 --- a/query-engine/connectors/sql-query-connector/src/database/transaction.rs +++ b/query-engine/connectors/sql-query-connector/src/database/transaction.rs @@ -37,19 +37,26 @@ impl<'tx> ConnectionLike for SqlConnectorTransaction<'tx> {} #[async_trait] impl<'tx> Transaction for SqlConnectorTransaction<'tx> { - async fn commit(&mut self) -> connector::Result<()> { + async fn begin(&mut self) -> connector::Result<()> { + catch(&self.connection_info, async { + self.inner.begin().await.map_err(SqlError::from) + }) + .await + } + + async fn commit(&mut self) -> connector::Result { catch(&self.connection_info, async { self.inner.commit().await.map_err(SqlError::from) }) .await } - async fn rollback(&mut self) -> connector::Result<()> { + async fn rollback(&mut self) -> connector::Result { catch(&self.connection_info, async { let res = self.inner.rollback().await.map_err(SqlError::from); match res { - Err(SqlError::TransactionAlreadyClosed(_)) | Err(SqlError::RollbackWithoutBegin) => Ok(()), + Err(SqlError::TransactionAlreadyClosed(_)) | Err(SqlError::RollbackWithoutBegin) => Ok(0), _ => res, } }) diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index fee7bc68fe7b..2316267c2345 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -73,17 +73,21 @@ pub struct TransactionOptions { /// An optional pre-defined transaction id. Some value might be provided in case we want to generate /// a new id at the beginning of the transaction - #[serde(skip)] pub new_tx_id: Option, } impl TransactionOptions { - pub fn new(max_acquisition_millis: u64, valid_for_millis: u64, isolation_level: Option) -> Self { + pub fn new( + max_acquisition_millis: u64, + valid_for_millis: u64, + isolation_level: Option, + new_tx_id: Option, + ) -> Self { Self { max_acquisition_millis, valid_for_millis, isolation_level, - new_tx_id: None, + new_tx_id, } } diff --git a/query-engine/core/src/interactive_transactions/actor_manager.rs b/query-engine/core/src/interactive_transactions/actor_manager.rs index e6c1c7fbd1dc..37dae7e57332 100644 --- a/query-engine/core/src/interactive_transactions/actor_manager.rs +++ b/query-engine/core/src/interactive_transactions/actor_manager.rs @@ -72,19 +72,27 @@ impl TransactionActorManager { timeout: Duration, engine_protocol: EngineProtocol, ) -> crate::Result<()> { - let client = spawn_itx_actor( - query_schema.clone(), - tx_id.clone(), - conn, - isolation_level, - timeout, - CHANNEL_SIZE, - self.send_done.clone(), - engine_protocol, - ) - .await?; - - self.clients.write().await.insert(tx_id, client); + // Only create a client if there is no client for this transaction yet. + // otherwise, begin a new transaction/savepoint for the existing client. + if !self.clients.read().await.contains_key(&tx_id) { + let client = spawn_itx_actor( + query_schema.clone(), + tx_id.clone(), + conn, + isolation_level, + timeout, + CHANNEL_SIZE, + self.send_done.clone(), + engine_protocol, + ) + .await?; + + self.clients.write().await.insert(tx_id, client); + } else { + let client = self.get_client(&tx_id, "begin").await?; + client.begin().await?; + } + Ok(()) } diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 86ebd5c13b84..4038e62e5a68 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -66,15 +66,39 @@ impl<'a> ITXServer<'a> { let _ = op.respond_to.send(TxOpResponse::Batch(result)); RunState::Continue } + TxOpRequestMsg::Begin => { + let _result = self.begin().await; + let _ = op.respond_to.send(TxOpResponse::Begin(())); + RunState::Continue + } TxOpRequestMsg::Commit => { let resp = self.commit().await; + let resp_value = match &resp { + Ok(val) => *val, + Err(_) => 0, + }; + let _ = op.respond_to.send(TxOpResponse::Committed(resp)); - RunState::Finished + + if resp_value > 0 { + RunState::Continue + } else { + RunState::Finished + } } TxOpRequestMsg::Rollback => { let resp = self.rollback(false).await; + let resp_value = match &resp { + Ok(val) => *val, + Err(_) => 0, + }; let _ = op.respond_to.send(TxOpResponse::RolledBack(resp)); - RunState::Finished + + if resp_value > 0 { + RunState::Continue + } else { + RunState::Finished + } } } } @@ -118,32 +142,46 @@ impl<'a> ITXServer<'a> { .await } - pub(crate) async fn commit(&mut self) -> crate::Result<()> { + pub(crate) async fn begin(&mut self) -> crate::Result<()> { if let CachedTx::Open(_) = self.cached_tx { let open_tx = self.cached_tx.as_open()?; - trace!("[{}] committing.", self.id.to_string()); - open_tx.commit().await?; - self.cached_tx = CachedTx::Committed; + trace!("[{}] beginning.", self.id.to_string()); + open_tx.begin().await?; } Ok(()) } - pub(crate) async fn rollback(&mut self, was_timeout: bool) -> crate::Result<()> { + pub(crate) async fn commit(&mut self) -> crate::Result { + if let CachedTx::Open(_) = self.cached_tx { + let open_tx = self.cached_tx.as_open()?; + trace!("[{}] committing.", self.id.to_string()); + let depth = open_tx.commit().await?; + if depth == 0 { + self.cached_tx = CachedTx::Committed; + } + return Ok(depth); + } + + Ok(0) + } + + pub(crate) async fn rollback(&mut self, was_timeout: bool) -> crate::Result { debug!("[{}] rolling back, was timed out = {was_timeout}", self.name()); if let CachedTx::Open(_) = self.cached_tx { let open_tx = self.cached_tx.as_open()?; - open_tx.rollback().await?; + let depth = open_tx.rollback().await?; if was_timeout { trace!("[{}] Expired Rolling back", self.id.to_string()); self.cached_tx = CachedTx::Expired; - } else { + } else if depth == 0 { self.cached_tx = CachedTx::RolledBack; trace!("[{}] Rolling back", self.id.to_string()); } + return Ok(depth); } - Ok(()) + Ok(0) } pub(crate) fn name(&self) -> String { @@ -158,7 +196,12 @@ pub struct ITXClient { } impl ITXClient { - pub(crate) async fn commit(&self) -> crate::Result<()> { + pub async fn begin(&self) -> crate::Result<()> { + self.send_and_receive(TxOpRequestMsg::Begin).await?; + Ok(()) + } + + pub(crate) async fn commit(&self) -> crate::Result { let msg = self.send_and_receive(TxOpRequestMsg::Commit).await?; if let TxOpResponse::Committed(resp) = msg { @@ -169,7 +212,7 @@ impl ITXClient { } } - pub(crate) async fn rollback(&self) -> crate::Result<()> { + pub(crate) async fn rollback(&self) -> crate::Result { let msg = self.send_and_receive(TxOpRequestMsg::Rollback).await?; if let TxOpResponse::RolledBack(resp) = msg { diff --git a/query-engine/core/src/interactive_transactions/messages.rs b/query-engine/core/src/interactive_transactions/messages.rs index 0dba2c096a8a..a61a2887ef8c 100644 --- a/query-engine/core/src/interactive_transactions/messages.rs +++ b/query-engine/core/src/interactive_transactions/messages.rs @@ -6,6 +6,7 @@ use tokio::sync::oneshot; pub enum TxOpRequestMsg { Commit, Rollback, + Begin, Single(Operation, Option), Batch(Vec, Option), } @@ -18,6 +19,7 @@ pub struct TxOpRequest { impl Display for TxOpRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.msg { + TxOpRequestMsg::Begin => write!(f, "Begin"), TxOpRequestMsg::Commit => write!(f, "Commit"), TxOpRequestMsg::Rollback => write!(f, "Rollback"), TxOpRequestMsg::Single(..) => write!(f, "Single"), @@ -28,8 +30,9 @@ impl Display for TxOpRequest { #[derive(Debug)] pub enum TxOpResponse { - Committed(crate::Result<()>), - RolledBack(crate::Result<()>), + Begin(()), + Committed(crate::Result), + RolledBack(crate::Result), Single(crate::Result), Batch(crate::Result>>), } @@ -37,6 +40,7 @@ pub enum TxOpResponse { impl Display for TxOpResponse { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + Self::Begin(..) => write!(f, "Begin"), Self::Committed(..) => write!(f, "Committed"), Self::RolledBack(..) => write!(f, "RolledBack"), Self::Single(..) => write!(f, "Single"), diff --git a/query-engine/core/src/interactive_transactions/mod.rs b/query-engine/core/src/interactive_transactions/mod.rs index a0aed069a879..a430e2db0edb 100644 --- a/query-engine/core/src/interactive_transactions/mod.rs +++ b/query-engine/core/src/interactive_transactions/mod.rs @@ -1,7 +1,7 @@ use crate::CoreError; use connector::Transaction; use crosstarget_utils::time::ElapsedTimeCounter; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::fmt::Display; use tokio::time::Duration; @@ -39,7 +39,7 @@ pub(crate) use messages::*; /// the TransactionActorManager can reply with a helpful error message which explains that no operation can be performed on a closed transaction /// rather than an error message stating that the transaction does not exist. -#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize)] +#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize, Serialize)] pub struct TxId(String); const MINIMUM_TX_ID_LENGTH: usize = 24; diff --git a/query-engine/driver-adapters/executor/src/bench.ts b/query-engine/driver-adapters/executor/src/bench.ts index 14923f69cf9e..82dff0604943 100644 --- a/query-engine/driver-adapters/executor/src/bench.ts +++ b/query-engine/driver-adapters/executor/src/bench.ts @@ -47,6 +47,7 @@ async function main(): Promise { const { recorder, replayer, recordings } = recording(withErrorCapturing); // We exercise the queries recording them + // @ts-ignore await recordQueries(recorder, datamodel, prismaQueries); // Dump recordings if requested @@ -59,6 +60,7 @@ async function main(): Promise { // Then we benchmark the execution of the queries but instead of hitting the DB // we fetch results from the recordings, thus isolating the performance // of the engine + driver adapter code from that of the DB IO. + // @ts-ignore await benchMarkQueries(replayer, datamodel, prismaQueries); } diff --git a/query-engine/driver-adapters/src/proxy.rs b/query-engine/driver-adapters/src/proxy.rs index 1c2a5a68240b..99415984ebbb 100644 --- a/query-engine/driver-adapters/src/proxy.rs +++ b/query-engine/driver-adapters/src/proxy.rs @@ -49,6 +49,9 @@ pub(crate) struct TransactionProxy { /// transaction options options: TransactionOptions, + /// begin transaction + pub begin: AdapterMethod<(), ()>, + /// commit transaction commit: AdapterMethod<(), ()>, @@ -133,11 +136,13 @@ impl TransactionContextProxy { impl TransactionProxy { pub fn new(js_transaction: &JsObject) -> JsResult { let commit = get_named_property(js_transaction, "commit")?; + let begin = get_named_property(js_transaction, "begin")?; let rollback = get_named_property(js_transaction, "rollback")?; let options = get_named_property(js_transaction, "options")?; let options = from_js_value::(options); Ok(Self { + begin, commit, rollback, options, @@ -149,6 +154,11 @@ impl TransactionProxy { &self.options } + pub fn begin(&self) -> UnsafeFuture> + '_> { + self.closed.store(true, Ordering::Relaxed); + UnsafeFuture(self.begin.call_as_async(())) + } + /// Commits the transaction via the driver adapter. /// /// ## Cancellation safety diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index dde42c5fb420..06ad4b499abd 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -5,7 +5,7 @@ use crate::JsObject; use super::conversion; use crate::send_future::UnsafeFuture; use async_trait::async_trait; -use futures::Future; +use futures::{lock::Mutex, Future}; use quaint::connector::{DescribedQuery, ExternalConnectionInfo, ExternalConnector}; use quaint::{ connector::{metrics, IsolationLevel, Transaction}, @@ -13,6 +13,7 @@ use quaint::{ prelude::{Query as QuaintQuery, Queryable as QuaintQueryable, ResultSet, TransactionCapable}, visitor::{self, Visitor}, }; +use std::sync::Arc; use tracing::{info_span, Instrument}; /// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the @@ -227,6 +228,7 @@ impl JsBaseQueryable { pub struct JsQueryable { inner: JsBaseQueryable, driver_proxy: DriverProxy, + pub transaction_depth: Arc>, } impl std::fmt::Display for JsQueryable { @@ -324,24 +326,31 @@ impl JsQueryable { // 3. Spawn a transaction from the context. let tx = tx_ctx.start_transaction().await?; - let begin_stmt = tx.begin_statement(); - let tx_opts = tx.options(); + { + let mut depth_guard = tx.depth.lock().await; + *depth_guard += 1; - if tx_opts.use_phantom_query { - let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); - tx.raw_phantom_cmd(begin_stmt.as_str()).await?; - } else { - tx.raw_cmd(begin_stmt).await?; - } + let st_depth = *depth_guard; - // 4. Set the isolation level (if specified) if we didn't do it before. - if !requires_isolation_first { - if let Some(isolation) = isolation { - tx.set_tx_isolation_level(isolation).await?; + let begin_stmt = tx.begin_statement(st_depth).await; + let tx_opts = tx.options(); + + if tx_opts.use_phantom_query { + let begin_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); + tx.raw_phantom_cmd(begin_stmt.as_str()).await?; + } else { + tx.raw_cmd(&begin_stmt).await?; } - } - self.server_reset_query(tx.as_ref()).await?; + // 4. Set the isolation level (if specified) if we didn't do it before. + if !requires_isolation_first { + if let Some(isolation) = isolation { + tx.set_tx_isolation_level(isolation).await?; + } + } + + self.server_reset_query(tx.as_ref()).await?; + } Ok(tx) } @@ -364,5 +373,6 @@ pub fn from_js(driver: JsObject) -> JsQueryable { JsQueryable { inner: JsBaseQueryable::new(common), driver_proxy, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), } } diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs index 8d124bd4da0a..4d0ddb389a8d 100644 --- a/query-engine/driver-adapters/src/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -1,12 +1,14 @@ use std::future::Future; use async_trait::async_trait; +use futures::lock::Mutex; use metrics::decrement_gauge; use quaint::{ connector::{DescribedQuery, IsolationLevel, Transaction as QuaintTransaction}, prelude::{Query as QuaintQuery, Queryable, ResultSet}, Value, }; +use std::sync::Arc; use crate::proxy::{TransactionContextProxy, TransactionOptions, TransactionProxy}; use crate::{proxy::CommonProxy, queryable::JsBaseQueryable, send_future::UnsafeFuture}; @@ -86,11 +88,16 @@ impl Queryable for JsTransactionContext { pub(crate) struct JsTransaction { tx_proxy: TransactionProxy, inner: JsBaseQueryable, + pub depth: Arc>, } impl JsTransaction { pub(crate) fn new(inner: JsBaseQueryable, tx_proxy: TransactionProxy) -> Self { - Self { inner, tx_proxy } + Self { + inner, + tx_proxy, + depth: Arc::new(futures::lock::Mutex::new(0)), + } } pub fn options(&self) -> &TransactionOptions { @@ -105,36 +112,69 @@ impl JsTransaction { #[async_trait] impl QuaintTransaction for JsTransaction { - async fn commit(&self) -> quaint::Result<()> { + async fn begin(&mut self) -> quaint::Result<()> { // increment of this gauge is done in DriverProxy::startTransaction decrement_gauge!("prisma_client_queries_active", 1.0); - let commit_stmt = "COMMIT"; + let mut depth_guard = self.depth.lock().await; + // Modify the depth value through the MutexGuard + *depth_guard += 1; + + let begin_stmt = self.begin_statement(*depth_guard).await; if self.options().use_phantom_query { - let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); + let commit_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); self.raw_phantom_cmd(commit_stmt.as_str()).await?; } else { - self.inner.raw_cmd(commit_stmt).await?; + self.inner.raw_cmd(&begin_stmt).await?; } - UnsafeFuture(self.tx_proxy.commit()).await + println!("JsTransaction begin: incrementing depth_guard to: {}", *depth_guard); + + UnsafeFuture(self.tx_proxy.begin()).await } + async fn commit(&mut self) -> quaint::Result { + // increment of this gauge is done in DriverProxy::startTransaction + decrement_gauge!("prisma_client_queries_active", 1.0); + + let mut depth_guard = self.depth.lock().await; + let commit_stmt = self.commit_statement(*depth_guard).await; + + if self.options().use_phantom_query { + let commit_stmt = JsBaseQueryable::phantom_query_message(&commit_stmt); + self.raw_phantom_cmd(commit_stmt.as_str()).await?; + } else { + self.inner.raw_cmd(&commit_stmt).await?; + } + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; - async fn rollback(&self) -> quaint::Result<()> { + let _ = UnsafeFuture(self.tx_proxy.commit()).await; + + Ok(*depth_guard) + } + + async fn rollback(&mut self) -> quaint::Result { // increment of this gauge is done in DriverProxy::startTransaction decrement_gauge!("prisma_client_queries_active", 1.0); - let rollback_stmt = "ROLLBACK"; + let mut depth_guard = self.depth.lock().await; + let rollback_stmt = self.rollback_statement(*depth_guard).await; if self.options().use_phantom_query { - let rollback_stmt = JsBaseQueryable::phantom_query_message(rollback_stmt); + let rollback_stmt = JsBaseQueryable::phantom_query_message(&rollback_stmt); self.raw_phantom_cmd(rollback_stmt.as_str()).await?; } else { - self.inner.raw_cmd(rollback_stmt).await?; + self.inner.raw_cmd(&rollback_stmt).await?; } - UnsafeFuture(self.tx_proxy.rollback()).await + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + let _ = UnsafeFuture(self.tx_proxy.rollback()).await; + + Ok(*depth_guard) } fn as_queryable(&self) -> &dyn Queryable { @@ -191,6 +231,18 @@ impl Queryable for JsTransaction { fn requires_isolation_first(&self) -> bool { self.inner.requires_isolation_first() } + + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await + } } #[cfg(target_arch = "wasm32")] diff --git a/query-engine/query-engine/src/server/mod.rs b/query-engine/query-engine/src/server/mod.rs index 01b61a07b6b4..1fc9789f865b 100644 --- a/query-engine/query-engine/src/server/mod.rs +++ b/query-engine/query-engine/src/server/mod.rs @@ -282,7 +282,11 @@ async fn transaction_start_handler(cx: Arc, req: Request) - let body_start = req.into_body(); let full_body = hyper::body::to_bytes(body_start).await?; let mut tx_opts: TransactionOptions = serde_json::from_slice(full_body.as_ref()).unwrap(); - let tx_id = tx_opts.with_new_transaction_id(); + let tx_id = if tx_opts.new_tx_id.is_none() { + tx_opts.with_new_transaction_id() + } else { + tx_opts.new_tx_id.clone().unwrap() + }; // This is the span we use to instrument the execution of a transaction. This span will be open // during the tx execution, and held in the ITXServer for that transaction (see ITXServer])