From 71d3bd8f0c20b95a6c3a0772e04f53ec649daf20 Mon Sep 17 00:00:00 2001 From: Lucian Buzzo Date: Sun, 15 Oct 2023 13:58:28 +0100 Subject: [PATCH] feat: add support for nested transaction rollbacks via savepoints in sql This is my first OSS contribution for a Rust project, so I'm sure I've made some stupid mistakes, but I think it should mostly work :) This change adds a mutable depth counter, that can track how many levels deep a transaction is, and uses savepoints to implement correct rollback behaviour. Previously, once a nested transaction was complete, it would be saved with `COMMIT`, meaning that even if the outer transaction was rolled back, the operations in the inner transaction would persist. With this change, if the outer transaction gets rolled back, then all inner transactions will also be rolled back. Different flavours of SQL servers have different syntax for handling savepoints, so I've had to add new methods to the `Queryable` trait for getting the commit and rollback statements. These are both parameterized by the current depth. I've additionally had to modify the `begin_statement` method to accept a depth parameter, as it will need to conditionally create a savepoint. When opening a transaction via the transaction server, you can now pass the prior transaction ID to re-use the existing transaction, incrementing the depth. Signed-off-by: Lucian Buzzo --- Makefile | 4 +- quaint/src/connector/mssql/native/mod.rs | 54 +++++++- quaint/src/connector/mysql/native/mod.rs | 39 +++++- quaint/src/connector/postgres/native/mod.rs | 40 +++++- quaint/src/connector/queryable.rs | 40 +++++- quaint/src/connector/sqlite/native/mod.rs | 45 ++++++- quaint/src/connector/transaction.rs | 75 +++++++++-- quaint/src/pooled.rs | 5 +- quaint/src/pooled/manager.rs | 15 ++- quaint/src/single.rs | 21 ++- quaint/src/tests/query.rs | 16 ++- quaint/src/tests/query/error.rs | 2 +- .../tests/new/interactive_tx.rs | 127 ++++++++++++++---- .../query-engine-tests/tests/new/metrics.rs | 4 +- .../tests/new/regressions/prisma_13405.rs | 2 +- .../tests/new/regressions/prisma_15607.rs | 2 +- .../new/regressions/prisma_engines_4286.rs | 6 +- .../query-tests-setup/src/runner/mod.rs | 3 +- .../src/interface/transaction.rs | 12 +- .../query-connector/src/interface.rs | 5 +- .../src/database/transaction.rs | 13 +- query-engine/core/src/executor/mod.rs | 10 +- .../interactive_transactions/actor_manager.rs | 34 +++-- .../src/interactive_transactions/actors.rs | 67 +++++++-- .../src/interactive_transactions/messages.rs | 8 +- .../core/src/interactive_transactions/mod.rs | 4 +- .../driver-adapters/executor/src/bench.ts | 2 + query-engine/driver-adapters/src/proxy.rs | 10 ++ query-engine/driver-adapters/src/queryable.rs | 40 +++--- .../driver-adapters/src/transaction.rs | 74 ++++++++-- query-engine/query-engine/src/server/mod.rs | 6 +- 31 files changed, 645 insertions(+), 140 deletions(-) diff --git a/Makefile b/Makefile index ec16c50b9dc2..9f16ba8ebb70 100644 --- a/Makefile +++ b/Makefile @@ -407,8 +407,8 @@ ensure-prisma-present: echo "⚠️ ../prisma diverges from prisma/prisma main branch. Test results might diverge from those in CI ⚠️ "; \ fi \ else \ - echo "git clone --depth=1 https://github.com/prisma/prisma.git --branch=$(DRIVER_ADAPTERS_BRANCH) ../prisma"; \ - git clone --depth=1 https://github.com/prisma/prisma.git --branch=$(DRIVER_ADAPTERS_BRANCH) "../prisma" && echo "Prisma repository has been cloned to ../prisma"; \ + echo "git clone --depth=1 https://github.com/LucianBuzzo/prisma.git --branch=lucianbuzzo/nested-rollbacks ../prisma"; \ + git clone --depth=1 https://github.com/LucianBuzzo/prisma.git --branch=lucianbuzzo/nested-rollbacks "../prisma" && echo "Prisma repository has been cloned to ../prisma"; \ fi; # Quick schema validation of whatever you have in the dev_datamodel.prisma file. diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index 7383e503d0ab..3579114a3648 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -18,7 +18,10 @@ use futures::lock::Mutex; use std::{ convert::TryFrom, future::Future, - sync::atomic::{AtomicBool, Ordering}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, time::Duration, }; use tiberius::*; @@ -45,11 +48,13 @@ impl TransactionCapable for Mssql { .or(self.url.query_params.transaction_isolation_level) .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); - let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); + let opts = TransactionOptions::new( + isolation, + self.requires_isolation_first(), + self.transaction_depth.clone(), + ); - Ok(Box::new( - DefaultTransaction::new(self, self.begin_statement(), opts).await?, - )) + Ok(Box::new(DefaultTransaction::new(self, opts).await?)) } } @@ -60,6 +65,7 @@ pub struct Mssql { url: MssqlUrl, socket_timeout: Option, is_healthy: AtomicBool, + transaction_depth: Arc>, } impl Mssql { @@ -91,6 +97,7 @@ impl Mssql { url, socket_timeout, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(Mutex::new(0)), }; if let Some(isolation) = this.url.transaction_isolation_level() { @@ -243,8 +250,41 @@ impl Queryable for Mssql { Ok(()) } - fn begin_statement(&self) -> &'static str { - "BEGIN TRAN" + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVE TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN TRAN".to_string() + }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + // MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested + // transaction we just continue onwards + let ret = if depth > 1 { + " ".to_string() + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; } fn requires_isolation_first(&self) -> bool { diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index b4b23ab94cb8..6465db6684ef 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -23,7 +23,10 @@ use mysql_async::{ }; use std::{ future::Future, - sync::atomic::{AtomicBool, Ordering}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, time::Duration, }; use tokio::sync::Mutex; @@ -76,6 +79,7 @@ pub struct Mysql { socket_timeout: Option, is_healthy: AtomicBool, statement_cache: Mutex>, + transaction_depth: Arc>, } impl Mysql { @@ -89,6 +93,7 @@ impl Mysql { statement_cache: Mutex::new(url.cache()), url, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -345,4 +350,36 @@ impl Queryable for Mysql { fn requires_isolation_first(&self) -> bool { true } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; + } } diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 805ba13a6021..2a5fee9450f4 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -33,7 +33,10 @@ use std::{ fmt::{Debug, Display}, fs, future::Future, - sync::atomic::{AtomicBool, Ordering}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, time::Duration, }; use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; @@ -61,6 +64,7 @@ pub struct PostgreSql { is_healthy: AtomicBool, is_cockroachdb: bool, is_materialize: bool, + transaction_depth: Arc>, } /// Key uniquely representing an SQL statement in the prepared statements cache. @@ -289,6 +293,7 @@ impl PostgreSql { is_healthy: AtomicBool::new(true), is_cockroachdb, is_materialize, + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -763,6 +768,39 @@ impl Queryable for PostgreSql { fn requires_isolation_first(&self) -> bool { false } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + println!("pg connector: Transaction depth: {}", depth); + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; + } } /// Sorted list of CockroachDB's reserved keywords. diff --git a/quaint/src/connector/queryable.rs b/quaint/src/connector/queryable.rs index 5f0fd54dad6b..8aed583a7a0e 100644 --- a/quaint/src/connector/queryable.rs +++ b/quaint/src/connector/queryable.rs @@ -90,8 +90,36 @@ pub trait Queryable: Send + Sync { } /// Statement to begin a transaction - fn begin_statement(&self) -> &'static str { - "BEGIN" + async fn begin_statement(&self, depth: i32) -> String { + println!("connector: Transaction depth: {}", depth); + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; } /// Sets the transaction isolation level to given value. @@ -120,10 +148,14 @@ macro_rules! impl_default_TransactionCapable { &'a self, isolation: Option, ) -> crate::Result> { - let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first()); + let opts = crate::connector::TransactionOptions::new( + isolation, + self.requires_isolation_first(), + self.transaction_depth.clone(), + ); Ok(Box::new( - crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?, + crate::connector::DefaultTransaction::new(self, opts).await?, )) } } diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index abcec7410a67..902c5fb91cbc 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -17,7 +17,7 @@ use crate::{ visitor::{self, Visitor}, }; use async_trait::async_trait; -use std::convert::TryFrom; +use std::{convert::TryFrom, sync::Arc}; use tokio::sync::Mutex; /// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. @@ -27,6 +27,7 @@ pub use rusqlite; /// A connector interface for the SQLite database pub struct Sqlite { pub(crate) client: Mutex, + transaction_depth: Arc>, } impl TryFrom<&str> for Sqlite { @@ -64,7 +65,10 @@ impl TryFrom<&str> for Sqlite { let client = Mutex::new(conn); - Ok(Sqlite { client }) + Ok(Sqlite { + client, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), + }) } } @@ -79,6 +83,7 @@ impl Sqlite { Ok(Sqlite { client: Mutex::new(client), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -181,12 +186,44 @@ impl Queryable for Sqlite { false } - fn begin_statement(&self) -> &'static str { + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); // From https://sqlite.org/isolation.html: // `BEGIN IMMEDIATE` avoids possible `SQLITE_BUSY_SNAPSHOT` that arise when another connection jumps ahead in line. // The BEGIN IMMEDIATE command goes ahead and starts a write transaction, and thus blocks all other writers. // If the BEGIN IMMEDIATE operation succeeds, then no subsequent operations in that transaction will ever fail with an SQLITE_BUSY error. - "BEGIN IMMEDIATE" + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN IMMEDIATE".to_string() + }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; } } diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index df4084883e80..20e1ee6b0298 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -4,18 +4,22 @@ use crate::{ error::{Error, ErrorKind}, }; use async_trait::async_trait; +use futures::lock::Mutex; use metrics::{decrement_gauge, increment_gauge}; -use std::{fmt, str::FromStr}; +use std::{fmt, str::FromStr, sync::Arc}; extern crate metrics as metrics; #[async_trait] pub trait Transaction: Queryable { /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()>; + async fn begin(&mut self) -> crate::Result<()>; + + /// Commit the changes to the database and consume the transaction. + async fn commit(&mut self) -> crate::Result; /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()>; + async fn rollback(&mut self) -> crate::Result; /// workaround for lack of upcasting between traits https://github.com/rust-lang/rust/issues/65991 fn as_queryable(&self) -> &dyn Queryable; @@ -27,6 +31,9 @@ pub(crate) struct TransactionOptions { /// Whether or not to put the isolation level `SET` before or after the `BEGIN`. pub(crate) isolation_first: bool, + + /// The depth of the transaction, used to determine the nested transaction statements. + pub depth: Arc>, } /// A default representation of an SQL database transaction. If not commited, a @@ -36,15 +43,18 @@ pub(crate) struct TransactionOptions { /// transaction object will panic. pub struct DefaultTransaction<'a> { pub inner: &'a dyn Queryable, + pub depth: Arc>, } impl<'a> DefaultTransaction<'a> { pub(crate) async fn new( inner: &'a dyn Queryable, - begin_stmt: &str, tx_opts: TransactionOptions, ) -> crate::Result> { - let this = Self { inner }; + let mut this = Self { + inner, + depth: tx_opts.depth, + }; if tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -52,7 +62,7 @@ impl<'a> DefaultTransaction<'a> { } } - inner.raw_cmd(begin_stmt).await?; + this.begin().await?; if !tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -62,27 +72,63 @@ impl<'a> DefaultTransaction<'a> { inner.server_reset_query(&this).await?; - increment_gauge!("prisma_client_queries_active", 1.0); Ok(this) } } #[async_trait] impl<'a> Transaction for DefaultTransaction<'a> { + async fn begin(&mut self) -> crate::Result<()> { + increment_gauge!("prisma_client_queries_active", 1.0); + + let mut depth_guard = self.depth.lock().await; + + // Modify the depth value through the MutexGuard + *depth_guard += 1; + + let st_depth = *depth_guard; + + let begin_statement = self.inner.begin_statement(st_depth).await; + + self.inner.raw_cmd(&begin_statement).await?; + + Ok(()) + } + /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()> { + async fn commit(&mut self) -> crate::Result { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("COMMIT").await?; - Ok(()) + let mut depth_guard = self.depth.lock().await; + + let st_depth = *depth_guard; + + let commit_statement = self.inner.commit_statement(st_depth).await; + + self.inner.raw_cmd(&commit_statement).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + Ok(*depth_guard) } /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()> { + async fn rollback(&mut self) -> crate::Result { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("ROLLBACK").await?; - Ok(()) + let mut depth_guard = self.depth.lock().await; + + let st_depth = *depth_guard; + + let rollback_statement = self.inner.rollback_statement(st_depth).await; + + self.inner.raw_cmd(&rollback_statement).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + Ok(*depth_guard) } fn as_queryable(&self) -> &dyn Queryable { @@ -194,10 +240,11 @@ impl FromStr for IsolationLevel { } } impl TransactionOptions { - pub fn new(isolation_level: Option, isolation_first: bool) -> Self { + pub fn new(isolation_level: Option, isolation_first: bool, depth: Arc>) -> Self { Self { isolation_level, isolation_first, + depth, } } } diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 381f0c824149..9bacf46d4214 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -507,7 +507,10 @@ impl Quaint { } }; - Ok(PooledConnection { inner }) + Ok(PooledConnection { + inner, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), + }) } /// Info about the connection and underlying database. diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index 7533dffcfcc5..e20bf7a341a8 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -10,12 +10,15 @@ use crate::{ error::Error, }; use async_trait::async_trait; +use futures::lock::Mutex; use mobc::{Connection as MobcPooled, Manager}; +use std::sync::Arc; /// A connection from the pool. Implements /// [Queryable](connector/trait.Queryable.html). pub struct PooledConnection { pub(crate) inner: MobcPooled, + pub transaction_depth: Arc>, } impl_default_TransactionCapable!(PooledConnection); @@ -66,8 +69,16 @@ impl Queryable for PooledConnection { self.inner.server_reset_query(tx).await } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/single.rs b/quaint/src/single.rs index cbf460c41509..968066538674 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -5,6 +5,7 @@ use crate::{ connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, }; use async_trait::async_trait; +use futures::lock::Mutex; use std::{fmt, sync::Arc}; #[cfg(feature = "sqlite-native")] @@ -18,6 +19,7 @@ use crate::connector::NativeConnectionInfo; pub struct Quaint { inner: Arc, connection_info: Arc, + transaction_depth: Arc>, } impl fmt::Debug for Quaint { @@ -165,7 +167,11 @@ impl Quaint { let connection_info = Arc::new(ConnectionInfo::from_url(url_str)?); Self::log_start(&connection_info); - Ok(Self { inner, connection_info }) + Ok(Self { + inner, + connection_info, + transaction_depth: Arc::new(Mutex::new(0)), + }) } #[cfg(feature = "sqlite-native")] @@ -178,6 +184,7 @@ impl Quaint { connection_info: Arc::new(ConnectionInfo::Native(NativeConnectionInfo::InMemorySqlite { db_name: DEFAULT_SQLITE_DATABASE.to_owned(), })), + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -237,8 +244,16 @@ impl Queryable for Quaint { self.inner.is_healthy() } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/tests/query.rs b/quaint/src/tests/query.rs index 6e83297a9a75..914921f1c990 100644 --- a/quaint/src/tests/query.rs +++ b/quaint/src/tests/query.rs @@ -64,7 +64,7 @@ async fn select_star_from(api: &mut dyn TestApi) -> crate::Result<()> { async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { let table = api.create_temp_table("value int").await?; - let tx = api.conn().start_transaction(None).await?; + let mut tx = api.conn().start_transaction(None).await?; let insert = Insert::single_into(&table).value("value", 10); let rows_affected = tx.execute(insert.into()).await?; @@ -75,6 +75,20 @@ async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { assert_eq!(Value::int32(10), res[0]); + // Check that nested transactions are also rolled back, even at multiple levels deep + let mut tx_inner = api.conn().start_transaction(None).await?; + let inner_insert1 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected1 = tx.execute(inner_insert1.into()).await?; + assert_eq!(1, inner_rows_affected1); + + let mut tx_inner2 = api.conn().start_transaction(None).await?; + let inner_insert2 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected2 = tx.execute(inner_insert2.into()).await?; + assert_eq!(1, inner_rows_affected2); + tx_inner2.commit().await?; + + tx_inner.commit().await?; + tx.rollback().await?; let select = Select::from_table(&table).column("value"); diff --git a/quaint/src/tests/query/error.rs b/quaint/src/tests/query/error.rs index 399866bd4a3b..424d2a0348ea 100644 --- a/quaint/src/tests/query/error.rs +++ b/quaint/src/tests/query/error.rs @@ -456,7 +456,7 @@ async fn concurrent_transaction_conflict(api: &mut dyn TestApi) -> crate::Result let conn1 = api.create_additional_connection().await?; let conn2 = api.create_additional_connection().await?; - let tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; + let mut tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; let tx2 = conn2.start_transaction(Some(IsolationLevel::Serializable)).await?; tx1.query(Select::from_table(&table).into()).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs index b8827a3bf009..d156f5a50845 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs @@ -8,7 +8,7 @@ mod interactive_tx { #[connector_test] async fn basic_commit_workflow(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -35,7 +35,7 @@ mod interactive_tx { #[connector_test] async fn basic_rollback_workflow(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -63,7 +63,7 @@ mod interactive_tx { #[connector_test] async fn tx_expiration_cycle(mut runner: Runner) -> TestResult<()> { // Tx expires after one second. - let tx_id = runner.start_tx(5000, 1000, None).await?; + let tx_id = runner.start_tx(5000, 1000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -85,7 +85,6 @@ mod interactive_tx { let error = res.err().unwrap(); let known_err = error.as_known().unwrap(); - println!("KNOWN ERROR {known_err:?}"); assert_eq!(known_err.error_code, Cow::Borrowed("P2028")); assert!(known_err @@ -108,7 +107,7 @@ mod interactive_tx { #[connector_test] async fn no_auto_rollback(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); // Row is created @@ -135,7 +134,7 @@ mod interactive_tx { #[connector_test(only(Postgres))] async fn raw_queries(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -164,7 +163,7 @@ mod interactive_tx { #[connector_test] async fn batch_queries_success(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); let queries = vec![ @@ -190,7 +189,7 @@ mod interactive_tx { #[connector_test] async fn batch_queries_rollback(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); let queries = vec![ @@ -216,7 +215,7 @@ mod interactive_tx { #[connector_test(exclude(Sqlite("cfd1")))] async fn batch_queries_failure(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); // One dup key, will cause failure of the batch. @@ -259,7 +258,7 @@ mod interactive_tx { #[connector_test] async fn tx_expiration_failure_cycle(mut runner: Runner) -> TestResult<()> { // Tx expires after one seconds. - let tx_id = runner.start_tx(5000, 1000, None).await?; + let tx_id = runner.start_tx(5000, 1000, None, None).await?; runner.set_active_tx(tx_id.clone()); // Row is created @@ -328,10 +327,10 @@ mod interactive_tx { #[connector_test(exclude(Sqlite))] async fn multiple_tx(mut runner: Runner) -> TestResult<()> { // First transaction. - let tx_id_a = runner.start_tx(2000, 2000, None).await?; + let tx_id_a = runner.start_tx(2000, 2000, None, None).await?; // Second transaction. - let tx_id_b = runner.start_tx(2000, 2000, None).await?; + let tx_id_b = runner.start_tx(2000, 2000, None, None).await?; // Execute on first transaction. runner.set_active_tx(tx_id_a.clone()); @@ -379,10 +378,10 @@ mod interactive_tx { ); // First transaction. - let tx_id_a = runner.start_tx(5000, 5000, Some("Serializable".into())).await?; + let tx_id_a = runner.start_tx(5000, 5000, Some("Serializable".into()), None).await?; // Second transaction. - let tx_id_b = runner.start_tx(5000, 5000, Some("Serializable".into())).await?; + let tx_id_b = runner.start_tx(5000, 5000, Some("Serializable".into()), None).await?; // Read on first transaction. runner.set_active_tx(tx_id_a.clone()); @@ -421,7 +420,7 @@ mod interactive_tx { #[connector_test] async fn double_commit(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -456,9 +455,81 @@ mod interactive_tx { Ok(()) } + #[connector_test(only(Postgres))] + async fn nested_commit_workflow(mut runner: Runner) -> TestResult<()> { + // Start the outer transaction + let outer_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(outer_tx_id.clone()); + + // Start the inner transaction + let inner_tx_id = runner.start_tx(5000, 5000, None, Some(outer_tx_id.clone())).await?; + runner.set_active_tx(inner_tx_id.clone()); + + // Perform operations in the inner transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 1 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":1}}}"### + ); + + let res = runner.commit_tx(inner_tx_id).await?; + assert!(res.is_ok()); + + // Perform operations in the outer transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 2 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":2}}}"### + ); + + let res = runner.commit_tx(outer_tx_id).await?; + assert!(res.is_ok()); + + Ok(()) + } + + #[connector_test(only(Postgres))] + async fn nested_commit_rollback_workflow(mut runner: Runner) -> TestResult<()> { + // Start the outer transaction + let outer_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(outer_tx_id.clone()); + + // Start the inner transaction + let inner_tx_id = runner.start_tx(5000, 5000, None, Some(outer_tx_id.clone())).await?; + runner.set_active_tx(inner_tx_id.clone()); + + // Perform operations in the inner transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 1 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":1}}}"### + ); + + let res = runner.commit_tx(inner_tx_id).await?; + assert!(res.is_ok()); + + // Perform operations in the outer transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 2 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":2}}}"### + ); + + // Now rollback the outer transaction + let res = runner.rollback_tx(outer_tx_id).await?; + assert!(res.is_ok()); + + // Assert that no records were written to the DB + let result_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(result_tx_id.clone()); + insta::assert_snapshot!( + run_query!(&runner, r#"query { findManyTestModel { id field }}"#), + @r###"{"data":{"findManyTestModel":[]}}"### + ); + let _ = runner.commit_tx(result_tx_id).await?; + + Ok(()) + } + #[connector_test] async fn double_rollback(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -495,7 +566,7 @@ mod interactive_tx { #[connector_test] async fn commit_after_rollback(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -532,7 +603,7 @@ mod interactive_tx { #[connector_test] async fn rollback_after_commit(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -575,7 +646,9 @@ mod itx_isolation { // All (SQL) connectors support serializable. #[connector_test(exclude(MongoDb, Sqlite("cfd1")))] async fn basic_serializable(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("Serializable".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -597,7 +670,9 @@ mod itx_isolation { #[connector_test(exclude(MongoDb, Sqlite("cfd1")))] async fn casing_doesnt_matter(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; @@ -608,13 +683,17 @@ mod itx_isolation { #[connector_test(only(Postgres))] async fn spacing_doesnt_matter(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Repeatable Read".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("Repeatable Read".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; assert!(res.is_ok()); - let tx_id = runner.start_tx(5000, 5000, Some("RepeatableRead".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("RepeatableRead".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; @@ -625,7 +704,7 @@ mod itx_isolation { #[connector_test(exclude(MongoDb))] async fn invalid_isolation(runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("test".to_owned())).await; + let tx_id = runner.start_tx(5000, 5000, Some("test".to_owned()), None).await; match tx_id { Ok(_) => panic!("Expected invalid isolation level string to throw an error, but it succeeded instead."), @@ -638,7 +717,7 @@ mod itx_isolation { // Mongo doesn't support isolation levels. #[connector_test(only(MongoDb))] async fn mongo_failure(runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await; + let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned()), None).await; match tx_id { Ok(_) => panic!("Expected mongo to throw an unsupported error, but it succeeded instead."), diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs index 2a1cf89e9d3b..18cabcf7cca9 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs @@ -48,7 +48,7 @@ mod metrics { #[connector_test] async fn metrics_tx_do_not_go_negative(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -65,7 +65,7 @@ mod metrics { let active_transactions = utils::metrics::get_gauge(&json, PRISMA_CLIENT_QUERIES_ACTIVE); assert_eq!(active_transactions, 0.0); - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs index a9b6c4395760..49ea6597ff6b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs @@ -90,7 +90,7 @@ mod mongodb { } async fn start_itx(runner: &mut Runner) -> TestResult { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); Ok(tx_id) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs index e026a90016bd..733a5b12e01a 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs @@ -83,7 +83,7 @@ impl Actor { response_sender.send(Response::Query(result)).await.unwrap(); } Message::BeginTransaction => { - let response = with_logs(runner.start_tx(10000, 10000, None), log_tx.clone()).await; + let response = with_logs(runner.start_tx(10000, 10000, None, None), log_tx.clone()).await; response_sender.send(Response::Tx(response)).await.unwrap(); } Message::RollbackTransaction(tx_id) => { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs index 313a29cdacf4..0b1e3244e420 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs @@ -5,7 +5,9 @@ mod sqlite { #[connector_test] async fn close_tx_on_error(runner: Runner) -> TestResult<()> { // Try to open a transaction with unsupported isolation error in SQLite. - let result = runner.start_tx(2000, 5000, Some("ReadUncommitted".to_owned())).await; + let result = runner + .start_tx(2000, 5000, Some("ReadUncommitted".to_owned()), None) + .await; assert!(result.is_err()); // Without the changes from https://github.com/prisma/prisma-engines/pull/4286 or @@ -16,7 +18,7 @@ mod sqlite { // IMMEDIATE if we had control over SQLite transaction type here, as that would not rely on // both transactions using the same connection if we were to pool multiple SQLite // connections in the future. - let tx = runner.start_tx(2000, 5000, None).await?; + let tx = runner.start_tx(2000, 5000, None, None).await?; runner.rollback_tx(tx).await?.unwrap(); Ok(()) diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs index e5808ace7fcc..77861b80d7ac 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs @@ -509,8 +509,9 @@ impl Runner { max_acquisition_millis: u64, valid_for_millis: u64, isolation_level: Option, + new_tx_id: Option, ) -> TestResult { - let tx_opts = TransactionOptions::new(max_acquisition_millis, valid_for_millis, isolation_level); + let tx_opts = TransactionOptions::new(max_acquisition_millis, valid_for_millis, isolation_level, new_tx_id); match &self.executor { RunnerExecutor::Builtin(executor) => { let id = executor diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs index 2fe5d840fa1f..d2baa1bf5970 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs @@ -39,17 +39,21 @@ impl<'conn> MongoDbTransaction<'conn> { #[async_trait] impl<'conn> Transaction for MongoDbTransaction<'conn> { - async fn commit(&mut self) -> connector_interface::Result<()> { + async fn begin(&mut self) -> connector_interface::Result<()> { + Ok(()) + } + + async fn commit(&mut self) -> connector_interface::Result { decrement_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); utils::commit_with_retry(&mut self.connection.session) .await .map_err(|err| MongoError::from(err).into_connector_error())?; - Ok(()) + Ok(0) } - async fn rollback(&mut self) -> connector_interface::Result<()> { + async fn rollback(&mut self) -> connector_interface::Result { decrement_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); self.connection @@ -58,7 +62,7 @@ impl<'conn> Transaction for MongoDbTransaction<'conn> { .await .map_err(|err| MongoError::from(err).into_connector_error())?; - Ok(()) + Ok(0) } async fn version(&self) -> Option { diff --git a/query-engine/connectors/query-connector/src/interface.rs b/query-engine/connectors/query-connector/src/interface.rs index cbdafcaeeee3..2c41e674c3ce 100644 --- a/query-engine/connectors/query-connector/src/interface.rs +++ b/query-engine/connectors/query-connector/src/interface.rs @@ -32,8 +32,9 @@ pub trait Connection: ConnectionLike { #[async_trait] pub trait Transaction: ConnectionLike { - async fn commit(&mut self) -> crate::Result<()>; - async fn rollback(&mut self) -> crate::Result<()>; + async fn begin(&mut self) -> crate::Result<()>; + async fn commit(&mut self) -> crate::Result; + async fn rollback(&mut self) -> crate::Result; async fn version(&self) -> Option; diff --git a/query-engine/connectors/sql-query-connector/src/database/transaction.rs b/query-engine/connectors/sql-query-connector/src/database/transaction.rs index 263c541f6b42..75e58483d12e 100644 --- a/query-engine/connectors/sql-query-connector/src/database/transaction.rs +++ b/query-engine/connectors/sql-query-connector/src/database/transaction.rs @@ -37,19 +37,26 @@ impl<'tx> ConnectionLike for SqlConnectorTransaction<'tx> {} #[async_trait] impl<'tx> Transaction for SqlConnectorTransaction<'tx> { - async fn commit(&mut self) -> connector::Result<()> { + async fn begin(&mut self) -> connector::Result<()> { + catch(&self.connection_info, async { + self.inner.begin().await.map_err(SqlError::from) + }) + .await + } + + async fn commit(&mut self) -> connector::Result { catch(&self.connection_info, async { self.inner.commit().await.map_err(SqlError::from) }) .await } - async fn rollback(&mut self) -> connector::Result<()> { + async fn rollback(&mut self) -> connector::Result { catch(&self.connection_info, async { let res = self.inner.rollback().await.map_err(SqlError::from); match res { - Err(SqlError::TransactionAlreadyClosed(_)) | Err(SqlError::RollbackWithoutBegin) => Ok(()), + Err(SqlError::TransactionAlreadyClosed(_)) | Err(SqlError::RollbackWithoutBegin) => Ok(0), _ => res, } }) diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index fee7bc68fe7b..2316267c2345 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -73,17 +73,21 @@ pub struct TransactionOptions { /// An optional pre-defined transaction id. Some value might be provided in case we want to generate /// a new id at the beginning of the transaction - #[serde(skip)] pub new_tx_id: Option, } impl TransactionOptions { - pub fn new(max_acquisition_millis: u64, valid_for_millis: u64, isolation_level: Option) -> Self { + pub fn new( + max_acquisition_millis: u64, + valid_for_millis: u64, + isolation_level: Option, + new_tx_id: Option, + ) -> Self { Self { max_acquisition_millis, valid_for_millis, isolation_level, - new_tx_id: None, + new_tx_id, } } diff --git a/query-engine/core/src/interactive_transactions/actor_manager.rs b/query-engine/core/src/interactive_transactions/actor_manager.rs index e6c1c7fbd1dc..37dae7e57332 100644 --- a/query-engine/core/src/interactive_transactions/actor_manager.rs +++ b/query-engine/core/src/interactive_transactions/actor_manager.rs @@ -72,19 +72,27 @@ impl TransactionActorManager { timeout: Duration, engine_protocol: EngineProtocol, ) -> crate::Result<()> { - let client = spawn_itx_actor( - query_schema.clone(), - tx_id.clone(), - conn, - isolation_level, - timeout, - CHANNEL_SIZE, - self.send_done.clone(), - engine_protocol, - ) - .await?; - - self.clients.write().await.insert(tx_id, client); + // Only create a client if there is no client for this transaction yet. + // otherwise, begin a new transaction/savepoint for the existing client. + if !self.clients.read().await.contains_key(&tx_id) { + let client = spawn_itx_actor( + query_schema.clone(), + tx_id.clone(), + conn, + isolation_level, + timeout, + CHANNEL_SIZE, + self.send_done.clone(), + engine_protocol, + ) + .await?; + + self.clients.write().await.insert(tx_id, client); + } else { + let client = self.get_client(&tx_id, "begin").await?; + client.begin().await?; + } + Ok(()) } diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 86ebd5c13b84..4038e62e5a68 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -66,15 +66,39 @@ impl<'a> ITXServer<'a> { let _ = op.respond_to.send(TxOpResponse::Batch(result)); RunState::Continue } + TxOpRequestMsg::Begin => { + let _result = self.begin().await; + let _ = op.respond_to.send(TxOpResponse::Begin(())); + RunState::Continue + } TxOpRequestMsg::Commit => { let resp = self.commit().await; + let resp_value = match &resp { + Ok(val) => *val, + Err(_) => 0, + }; + let _ = op.respond_to.send(TxOpResponse::Committed(resp)); - RunState::Finished + + if resp_value > 0 { + RunState::Continue + } else { + RunState::Finished + } } TxOpRequestMsg::Rollback => { let resp = self.rollback(false).await; + let resp_value = match &resp { + Ok(val) => *val, + Err(_) => 0, + }; let _ = op.respond_to.send(TxOpResponse::RolledBack(resp)); - RunState::Finished + + if resp_value > 0 { + RunState::Continue + } else { + RunState::Finished + } } } } @@ -118,32 +142,46 @@ impl<'a> ITXServer<'a> { .await } - pub(crate) async fn commit(&mut self) -> crate::Result<()> { + pub(crate) async fn begin(&mut self) -> crate::Result<()> { if let CachedTx::Open(_) = self.cached_tx { let open_tx = self.cached_tx.as_open()?; - trace!("[{}] committing.", self.id.to_string()); - open_tx.commit().await?; - self.cached_tx = CachedTx::Committed; + trace!("[{}] beginning.", self.id.to_string()); + open_tx.begin().await?; } Ok(()) } - pub(crate) async fn rollback(&mut self, was_timeout: bool) -> crate::Result<()> { + pub(crate) async fn commit(&mut self) -> crate::Result { + if let CachedTx::Open(_) = self.cached_tx { + let open_tx = self.cached_tx.as_open()?; + trace!("[{}] committing.", self.id.to_string()); + let depth = open_tx.commit().await?; + if depth == 0 { + self.cached_tx = CachedTx::Committed; + } + return Ok(depth); + } + + Ok(0) + } + + pub(crate) async fn rollback(&mut self, was_timeout: bool) -> crate::Result { debug!("[{}] rolling back, was timed out = {was_timeout}", self.name()); if let CachedTx::Open(_) = self.cached_tx { let open_tx = self.cached_tx.as_open()?; - open_tx.rollback().await?; + let depth = open_tx.rollback().await?; if was_timeout { trace!("[{}] Expired Rolling back", self.id.to_string()); self.cached_tx = CachedTx::Expired; - } else { + } else if depth == 0 { self.cached_tx = CachedTx::RolledBack; trace!("[{}] Rolling back", self.id.to_string()); } + return Ok(depth); } - Ok(()) + Ok(0) } pub(crate) fn name(&self) -> String { @@ -158,7 +196,12 @@ pub struct ITXClient { } impl ITXClient { - pub(crate) async fn commit(&self) -> crate::Result<()> { + pub async fn begin(&self) -> crate::Result<()> { + self.send_and_receive(TxOpRequestMsg::Begin).await?; + Ok(()) + } + + pub(crate) async fn commit(&self) -> crate::Result { let msg = self.send_and_receive(TxOpRequestMsg::Commit).await?; if let TxOpResponse::Committed(resp) = msg { @@ -169,7 +212,7 @@ impl ITXClient { } } - pub(crate) async fn rollback(&self) -> crate::Result<()> { + pub(crate) async fn rollback(&self) -> crate::Result { let msg = self.send_and_receive(TxOpRequestMsg::Rollback).await?; if let TxOpResponse::RolledBack(resp) = msg { diff --git a/query-engine/core/src/interactive_transactions/messages.rs b/query-engine/core/src/interactive_transactions/messages.rs index 0dba2c096a8a..a61a2887ef8c 100644 --- a/query-engine/core/src/interactive_transactions/messages.rs +++ b/query-engine/core/src/interactive_transactions/messages.rs @@ -6,6 +6,7 @@ use tokio::sync::oneshot; pub enum TxOpRequestMsg { Commit, Rollback, + Begin, Single(Operation, Option), Batch(Vec, Option), } @@ -18,6 +19,7 @@ pub struct TxOpRequest { impl Display for TxOpRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.msg { + TxOpRequestMsg::Begin => write!(f, "Begin"), TxOpRequestMsg::Commit => write!(f, "Commit"), TxOpRequestMsg::Rollback => write!(f, "Rollback"), TxOpRequestMsg::Single(..) => write!(f, "Single"), @@ -28,8 +30,9 @@ impl Display for TxOpRequest { #[derive(Debug)] pub enum TxOpResponse { - Committed(crate::Result<()>), - RolledBack(crate::Result<()>), + Begin(()), + Committed(crate::Result), + RolledBack(crate::Result), Single(crate::Result), Batch(crate::Result>>), } @@ -37,6 +40,7 @@ pub enum TxOpResponse { impl Display for TxOpResponse { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + Self::Begin(..) => write!(f, "Begin"), Self::Committed(..) => write!(f, "Committed"), Self::RolledBack(..) => write!(f, "RolledBack"), Self::Single(..) => write!(f, "Single"), diff --git a/query-engine/core/src/interactive_transactions/mod.rs b/query-engine/core/src/interactive_transactions/mod.rs index a0aed069a879..a430e2db0edb 100644 --- a/query-engine/core/src/interactive_transactions/mod.rs +++ b/query-engine/core/src/interactive_transactions/mod.rs @@ -1,7 +1,7 @@ use crate::CoreError; use connector::Transaction; use crosstarget_utils::time::ElapsedTimeCounter; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::fmt::Display; use tokio::time::Duration; @@ -39,7 +39,7 @@ pub(crate) use messages::*; /// the TransactionActorManager can reply with a helpful error message which explains that no operation can be performed on a closed transaction /// rather than an error message stating that the transaction does not exist. -#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize)] +#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize, Serialize)] pub struct TxId(String); const MINIMUM_TX_ID_LENGTH: usize = 24; diff --git a/query-engine/driver-adapters/executor/src/bench.ts b/query-engine/driver-adapters/executor/src/bench.ts index 14923f69cf9e..82dff0604943 100644 --- a/query-engine/driver-adapters/executor/src/bench.ts +++ b/query-engine/driver-adapters/executor/src/bench.ts @@ -47,6 +47,7 @@ async function main(): Promise { const { recorder, replayer, recordings } = recording(withErrorCapturing); // We exercise the queries recording them + // @ts-ignore await recordQueries(recorder, datamodel, prismaQueries); // Dump recordings if requested @@ -59,6 +60,7 @@ async function main(): Promise { // Then we benchmark the execution of the queries but instead of hitting the DB // we fetch results from the recordings, thus isolating the performance // of the engine + driver adapter code from that of the DB IO. + // @ts-ignore await benchMarkQueries(replayer, datamodel, prismaQueries); } diff --git a/query-engine/driver-adapters/src/proxy.rs b/query-engine/driver-adapters/src/proxy.rs index 1c2a5a68240b..99415984ebbb 100644 --- a/query-engine/driver-adapters/src/proxy.rs +++ b/query-engine/driver-adapters/src/proxy.rs @@ -49,6 +49,9 @@ pub(crate) struct TransactionProxy { /// transaction options options: TransactionOptions, + /// begin transaction + pub begin: AdapterMethod<(), ()>, + /// commit transaction commit: AdapterMethod<(), ()>, @@ -133,11 +136,13 @@ impl TransactionContextProxy { impl TransactionProxy { pub fn new(js_transaction: &JsObject) -> JsResult { let commit = get_named_property(js_transaction, "commit")?; + let begin = get_named_property(js_transaction, "begin")?; let rollback = get_named_property(js_transaction, "rollback")?; let options = get_named_property(js_transaction, "options")?; let options = from_js_value::(options); Ok(Self { + begin, commit, rollback, options, @@ -149,6 +154,11 @@ impl TransactionProxy { &self.options } + pub fn begin(&self) -> UnsafeFuture> + '_> { + self.closed.store(true, Ordering::Relaxed); + UnsafeFuture(self.begin.call_as_async(())) + } + /// Commits the transaction via the driver adapter. /// /// ## Cancellation safety diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index dde42c5fb420..06ad4b499abd 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -5,7 +5,7 @@ use crate::JsObject; use super::conversion; use crate::send_future::UnsafeFuture; use async_trait::async_trait; -use futures::Future; +use futures::{lock::Mutex, Future}; use quaint::connector::{DescribedQuery, ExternalConnectionInfo, ExternalConnector}; use quaint::{ connector::{metrics, IsolationLevel, Transaction}, @@ -13,6 +13,7 @@ use quaint::{ prelude::{Query as QuaintQuery, Queryable as QuaintQueryable, ResultSet, TransactionCapable}, visitor::{self, Visitor}, }; +use std::sync::Arc; use tracing::{info_span, Instrument}; /// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the @@ -227,6 +228,7 @@ impl JsBaseQueryable { pub struct JsQueryable { inner: JsBaseQueryable, driver_proxy: DriverProxy, + pub transaction_depth: Arc>, } impl std::fmt::Display for JsQueryable { @@ -324,24 +326,31 @@ impl JsQueryable { // 3. Spawn a transaction from the context. let tx = tx_ctx.start_transaction().await?; - let begin_stmt = tx.begin_statement(); - let tx_opts = tx.options(); + { + let mut depth_guard = tx.depth.lock().await; + *depth_guard += 1; - if tx_opts.use_phantom_query { - let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); - tx.raw_phantom_cmd(begin_stmt.as_str()).await?; - } else { - tx.raw_cmd(begin_stmt).await?; - } + let st_depth = *depth_guard; - // 4. Set the isolation level (if specified) if we didn't do it before. - if !requires_isolation_first { - if let Some(isolation) = isolation { - tx.set_tx_isolation_level(isolation).await?; + let begin_stmt = tx.begin_statement(st_depth).await; + let tx_opts = tx.options(); + + if tx_opts.use_phantom_query { + let begin_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); + tx.raw_phantom_cmd(begin_stmt.as_str()).await?; + } else { + tx.raw_cmd(&begin_stmt).await?; } - } - self.server_reset_query(tx.as_ref()).await?; + // 4. Set the isolation level (if specified) if we didn't do it before. + if !requires_isolation_first { + if let Some(isolation) = isolation { + tx.set_tx_isolation_level(isolation).await?; + } + } + + self.server_reset_query(tx.as_ref()).await?; + } Ok(tx) } @@ -364,5 +373,6 @@ pub fn from_js(driver: JsObject) -> JsQueryable { JsQueryable { inner: JsBaseQueryable::new(common), driver_proxy, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), } } diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs index 8d124bd4da0a..4d0ddb389a8d 100644 --- a/query-engine/driver-adapters/src/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -1,12 +1,14 @@ use std::future::Future; use async_trait::async_trait; +use futures::lock::Mutex; use metrics::decrement_gauge; use quaint::{ connector::{DescribedQuery, IsolationLevel, Transaction as QuaintTransaction}, prelude::{Query as QuaintQuery, Queryable, ResultSet}, Value, }; +use std::sync::Arc; use crate::proxy::{TransactionContextProxy, TransactionOptions, TransactionProxy}; use crate::{proxy::CommonProxy, queryable::JsBaseQueryable, send_future::UnsafeFuture}; @@ -86,11 +88,16 @@ impl Queryable for JsTransactionContext { pub(crate) struct JsTransaction { tx_proxy: TransactionProxy, inner: JsBaseQueryable, + pub depth: Arc>, } impl JsTransaction { pub(crate) fn new(inner: JsBaseQueryable, tx_proxy: TransactionProxy) -> Self { - Self { inner, tx_proxy } + Self { + inner, + tx_proxy, + depth: Arc::new(futures::lock::Mutex::new(0)), + } } pub fn options(&self) -> &TransactionOptions { @@ -105,36 +112,69 @@ impl JsTransaction { #[async_trait] impl QuaintTransaction for JsTransaction { - async fn commit(&self) -> quaint::Result<()> { + async fn begin(&mut self) -> quaint::Result<()> { // increment of this gauge is done in DriverProxy::startTransaction decrement_gauge!("prisma_client_queries_active", 1.0); - let commit_stmt = "COMMIT"; + let mut depth_guard = self.depth.lock().await; + // Modify the depth value through the MutexGuard + *depth_guard += 1; + + let begin_stmt = self.begin_statement(*depth_guard).await; if self.options().use_phantom_query { - let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); + let commit_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); self.raw_phantom_cmd(commit_stmt.as_str()).await?; } else { - self.inner.raw_cmd(commit_stmt).await?; + self.inner.raw_cmd(&begin_stmt).await?; } - UnsafeFuture(self.tx_proxy.commit()).await + println!("JsTransaction begin: incrementing depth_guard to: {}", *depth_guard); + + UnsafeFuture(self.tx_proxy.begin()).await } + async fn commit(&mut self) -> quaint::Result { + // increment of this gauge is done in DriverProxy::startTransaction + decrement_gauge!("prisma_client_queries_active", 1.0); + + let mut depth_guard = self.depth.lock().await; + let commit_stmt = self.commit_statement(*depth_guard).await; + + if self.options().use_phantom_query { + let commit_stmt = JsBaseQueryable::phantom_query_message(&commit_stmt); + self.raw_phantom_cmd(commit_stmt.as_str()).await?; + } else { + self.inner.raw_cmd(&commit_stmt).await?; + } + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; - async fn rollback(&self) -> quaint::Result<()> { + let _ = UnsafeFuture(self.tx_proxy.commit()).await; + + Ok(*depth_guard) + } + + async fn rollback(&mut self) -> quaint::Result { // increment of this gauge is done in DriverProxy::startTransaction decrement_gauge!("prisma_client_queries_active", 1.0); - let rollback_stmt = "ROLLBACK"; + let mut depth_guard = self.depth.lock().await; + let rollback_stmt = self.rollback_statement(*depth_guard).await; if self.options().use_phantom_query { - let rollback_stmt = JsBaseQueryable::phantom_query_message(rollback_stmt); + let rollback_stmt = JsBaseQueryable::phantom_query_message(&rollback_stmt); self.raw_phantom_cmd(rollback_stmt.as_str()).await?; } else { - self.inner.raw_cmd(rollback_stmt).await?; + self.inner.raw_cmd(&rollback_stmt).await?; } - UnsafeFuture(self.tx_proxy.rollback()).await + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + let _ = UnsafeFuture(self.tx_proxy.rollback()).await; + + Ok(*depth_guard) } fn as_queryable(&self) -> &dyn Queryable { @@ -191,6 +231,18 @@ impl Queryable for JsTransaction { fn requires_isolation_first(&self) -> bool { self.inner.requires_isolation_first() } + + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await + } } #[cfg(target_arch = "wasm32")] diff --git a/query-engine/query-engine/src/server/mod.rs b/query-engine/query-engine/src/server/mod.rs index 01b61a07b6b4..1fc9789f865b 100644 --- a/query-engine/query-engine/src/server/mod.rs +++ b/query-engine/query-engine/src/server/mod.rs @@ -282,7 +282,11 @@ async fn transaction_start_handler(cx: Arc, req: Request) - let body_start = req.into_body(); let full_body = hyper::body::to_bytes(body_start).await?; let mut tx_opts: TransactionOptions = serde_json::from_slice(full_body.as_ref()).unwrap(); - let tx_id = tx_opts.with_new_transaction_id(); + let tx_id = if tx_opts.new_tx_id.is_none() { + tx_opts.with_new_transaction_id() + } else { + tx_opts.new_tx_id.clone().unwrap() + }; // This is the span we use to instrument the execution of a transaction. This span will be open // during the tx execution, and held in the ITXServer for that transaction (see ITXServer])