Skip to content

Commit

Permalink
feat: add support for nested transaction rollbacks via savepoints in sql
Browse files Browse the repository at this point in the history
This is my first OSS contribution for a Rust project, so I'm sure I've
made some stupid mistakes, but I think it should mostly work :)

This change adds a mutable depth counter, that can track how many levels
deep a transaction is, and uses savepoints to implement correct rollback
behaviour. Previously, once a nested transaction was complete, it would
be saved with `COMMIT`, meaning that even if the outer transaction was
rolled back, the operations in the inner transaction would persist. With
this change, if the outer transaction gets rolled back, then all inner
transactions will also be rolled back.

Different flavours of SQL servers have different syntax for handling
savepoints, so I've had to add new methods to the `Queryable` trait for
getting the commit and rollback statements. These are both parameterized
by the current depth.

I've additionally had to modify the `begin_statement` method to accept a depth
parameter, as it will need to conditionally create a savepoint.

When opening a transaction via the transaction server, you can now pass
the prior transaction ID to re-use the existing transaction,
incrementing the depth.

Signed-off-by: Lucian Buzzo <[email protected]>
  • Loading branch information
LucianBuzzo committed Sep 26, 2024
1 parent 031f4d3 commit 71d3bd8
Show file tree
Hide file tree
Showing 31 changed files with 645 additions and 140 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ ensure-prisma-present:
echo "⚠️ ../prisma diverges from prisma/prisma main branch. Test results might diverge from those in CI ⚠️ "; \
fi \
else \
echo "git clone --depth=1 https://github.com/prisma/prisma.git --branch=$(DRIVER_ADAPTERS_BRANCH) ../prisma"; \
git clone --depth=1 https://github.com/prisma/prisma.git --branch=$(DRIVER_ADAPTERS_BRANCH) "../prisma" && echo "Prisma repository has been cloned to ../prisma"; \
echo "git clone --depth=1 https://github.com/LucianBuzzo/prisma.git --branch=lucianbuzzo/nested-rollbacks ../prisma"; \
git clone --depth=1 https://github.com/LucianBuzzo/prisma.git --branch=lucianbuzzo/nested-rollbacks "../prisma" && echo "Prisma repository has been cloned to ../prisma"; \
fi;

# Quick schema validation of whatever you have in the dev_datamodel.prisma file.
Expand Down
54 changes: 47 additions & 7 deletions quaint/src/connector/mssql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use futures::lock::Mutex;
use std::{
convert::TryFrom,
future::Future,
sync::atomic::{AtomicBool, Ordering},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tiberius::*;
Expand All @@ -45,11 +48,13 @@ impl TransactionCapable for Mssql {
.or(self.url.query_params.transaction_isolation_level)
.or(Some(SQL_SERVER_DEFAULT_ISOLATION));

let opts = TransactionOptions::new(isolation, self.requires_isolation_first());
let opts = TransactionOptions::new(
isolation,
self.requires_isolation_first(),
self.transaction_depth.clone(),
);

Ok(Box::new(
DefaultTransaction::new(self, self.begin_statement(), opts).await?,
))
Ok(Box::new(DefaultTransaction::new(self, opts).await?))
}
}

Expand All @@ -60,6 +65,7 @@ pub struct Mssql {
url: MssqlUrl,
socket_timeout: Option<Duration>,
is_healthy: AtomicBool,
transaction_depth: Arc<Mutex<i32>>,
}

impl Mssql {
Expand Down Expand Up @@ -91,6 +97,7 @@ impl Mssql {
url,
socket_timeout,
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(Mutex::new(0)),
};

if let Some(isolation) = this.url.transaction_isolation_level() {
Expand Down Expand Up @@ -243,8 +250,41 @@ impl Queryable for Mssql {
Ok(())
}

fn begin_statement(&self) -> &'static str {
"BEGIN TRAN"
/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVE TRANSACTION savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"BEGIN TRAN".to_string()
};

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
// MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested
// transaction we just continue onwards
let ret = if depth > 1 {
" ".to_string()
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TRANSACTION savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}

fn requires_isolation_first(&self) -> bool {
Expand Down
39 changes: 38 additions & 1 deletion quaint/src/connector/mysql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ use mysql_async::{
};
use std::{
future::Future,
sync::atomic::{AtomicBool, Ordering},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio::sync::Mutex;
Expand Down Expand Up @@ -76,6 +79,7 @@ pub struct Mysql {
socket_timeout: Option<Duration>,
is_healthy: AtomicBool,
statement_cache: Mutex<LruCache<String, my::Statement>>,
transaction_depth: Arc<futures::lock::Mutex<i32>>,
}

impl Mysql {
Expand All @@ -89,6 +93,7 @@ impl Mysql {
statement_cache: Mutex::new(url.cache()),
url,
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}

Expand Down Expand Up @@ -345,4 +350,36 @@ impl Queryable for Mysql {
fn requires_isolation_first(&self) -> bool {
true
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}
}
40 changes: 39 additions & 1 deletion quaint/src/connector/postgres/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ use std::{
fmt::{Debug, Display},
fs,
future::Future,
sync::atomic::{AtomicBool, Ordering},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio_postgres::{config::ChannelBinding, Client, Config, Statement};
Expand Down Expand Up @@ -61,6 +64,7 @@ pub struct PostgreSql {
is_healthy: AtomicBool,
is_cockroachdb: bool,
is_materialize: bool,
transaction_depth: Arc<Mutex<i32>>,
}

/// Key uniquely representing an SQL statement in the prepared statements cache.
Expand Down Expand Up @@ -289,6 +293,7 @@ impl PostgreSql {
is_healthy: AtomicBool::new(true),
is_cockroachdb,
is_materialize,
transaction_depth: Arc::new(Mutex::new(0)),
})
}

Expand Down Expand Up @@ -763,6 +768,39 @@ impl Queryable for PostgreSql {
fn requires_isolation_first(&self) -> bool {
false
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
println!("pg connector: Transaction depth: {}", depth);
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}
}

/// Sorted list of CockroachDB's reserved keywords.
Expand Down
40 changes: 36 additions & 4 deletions quaint/src/connector/queryable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,36 @@ pub trait Queryable: Send + Sync {
}

/// Statement to begin a transaction
fn begin_statement(&self) -> &'static str {
"BEGIN"
async fn begin_statement(&self, depth: i32) -> String {
println!("connector: Transaction depth: {}", depth);
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}

/// Sets the transaction isolation level to given value.
Expand Down Expand Up @@ -120,10 +148,14 @@ macro_rules! impl_default_TransactionCapable {
&'a self,
isolation: Option<IsolationLevel>,
) -> crate::Result<Box<dyn crate::connector::Transaction + 'a>> {
let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first());
let opts = crate::connector::TransactionOptions::new(
isolation,
self.requires_isolation_first(),
self.transaction_depth.clone(),
);

Ok(Box::new(
crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?,
crate::connector::DefaultTransaction::new(self, opts).await?,
))
}
}
Expand Down
45 changes: 41 additions & 4 deletions quaint/src/connector/sqlite/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
visitor::{self, Visitor},
};
use async_trait::async_trait;
use std::convert::TryFrom;
use std::{convert::TryFrom, sync::Arc};
use tokio::sync::Mutex;

/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature.
Expand All @@ -27,6 +27,7 @@ pub use rusqlite;
/// A connector interface for the SQLite database
pub struct Sqlite {
pub(crate) client: Mutex<rusqlite::Connection>,
transaction_depth: Arc<futures::lock::Mutex<i32>>,
}

impl TryFrom<&str> for Sqlite {
Expand Down Expand Up @@ -64,7 +65,10 @@ impl TryFrom<&str> for Sqlite {

let client = Mutex::new(conn);

Ok(Sqlite { client })
Ok(Sqlite {
client,
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}
}

Expand All @@ -79,6 +83,7 @@ impl Sqlite {

Ok(Sqlite {
client: Mutex::new(client),
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}

Expand Down Expand Up @@ -181,12 +186,44 @@ impl Queryable for Sqlite {
false
}

fn begin_statement(&self) -> &'static str {
/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
// From https://sqlite.org/isolation.html:
// `BEGIN IMMEDIATE` avoids possible `SQLITE_BUSY_SNAPSHOT` that arise when another connection jumps ahead in line.
// The BEGIN IMMEDIATE command goes ahead and starts a write transaction, and thus blocks all other writers.
// If the BEGIN IMMEDIATE operation succeeds, then no subsequent operations in that transaction will ever fail with an SQLITE_BUSY error.
"BEGIN IMMEDIATE"
let ret = if depth > 1 {
savepoint_stmt
} else {
"BEGIN IMMEDIATE".to_string()
};

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}
}

Expand Down
Loading

0 comments on commit 71d3bd8

Please sign in to comment.