Skip to content

Commit

Permalink
use separate methods for handling savepoints
Browse files Browse the repository at this point in the history
  • Loading branch information
LucianBuzzo committed Nov 12, 2024
1 parent f1e3846 commit b8ca794
Show file tree
Hide file tree
Showing 18 changed files with 367 additions and 214 deletions.
36 changes: 14 additions & 22 deletions quaint/src/connector/mssql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,32 +244,24 @@ impl Queryable for Mssql {
}

/// Statement to begin a transaction
fn begin_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("SAVE TRANSACTION savepoint{depth}"))
} else {
Cow::Borrowed("BEGIN TRAN")
}
fn begin_statement(&self) -> &'static str {
"BEGIN TRAN"
}

/// Statement to commit a transaction
fn commit_statement(&self, depth: u32) -> Cow<'static, str> {
// MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested
// transaction we just continue onwards
if depth > 1 {
Cow::Owned("".to_string())
} else {
Cow::Borrowed("COMMIT")
}
/// Statement to create a savepoint
fn create_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("SAVE TRANSACTION savepoint{depth}"))
}

/// Statement to rollback a transaction
fn rollback_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("ROLLBACK TRANSACTION savepoint{depth}"))
} else {
Cow::Borrowed("ROLLBACK")
}
// MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested
// transaction we just continue onwards
fn release_savepoint_statement(&self, _depth: u32) -> Cow<'static, str> {
Cow::Owned("".to_string())
}

/// Statement to rollback to a savepoint
fn rollback_to_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("ROLLBACK TRANSACTION savepoint{depth}"))
}

fn requires_isolation_first(&self) -> bool {
Expand Down
33 changes: 13 additions & 20 deletions quaint/src/connector/mysql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,29 +350,22 @@ impl Queryable for Mysql {
}

/// Statement to begin a transaction
fn begin_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("BEGIN")
}
fn begin_statement(&self) -> &'static str {
"BEGIN"
}

/// Statement to commit a transaction
fn commit_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("COMMIT")
}
/// Statement to create a savepoint
fn create_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
}

/// Statement to rollback a transaction
fn rollback_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("ROLLBACK TO savepoint{depth}"))
} else {
Cow::Borrowed("ROLLBACK")
}
/// Statement to release a savepoint
fn release_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
}

/// Statement to rollback to a savepoint
fn rollback_to_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("ROLLBACK TO savepoint{depth}"))
}
}
33 changes: 13 additions & 20 deletions quaint/src/connector/postgres/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -809,30 +809,23 @@ impl Queryable for PostgreSql {
}

/// Statement to begin a transaction
fn begin_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("BEGIN")
}
fn begin_statement(&self) -> &'static str {
"BEGIN"
}

/// Statement to commit a transaction
fn commit_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("COMMIT")
}
/// Statement to create a savepoint
fn create_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
}

/// Statement to rollback a transaction
fn rollback_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("ROLLBACK")
}
/// Statement to release a savepoint
fn release_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
}

/// Statement to rollback to a savepoint
fn rollback_to_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}"))
}
}

Expand Down
33 changes: 13 additions & 20 deletions quaint/src/connector/queryable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,30 +92,23 @@ pub trait Queryable: Send + Sync {
}

/// Statement to begin a transaction
fn begin_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("BEGIN")
}
fn begin_statement(&self) -> &'static str {
"BEGIN"
}

/// Statement to commit a transaction
fn commit_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("COMMIT")
}
/// Statement to create a savepoint in a transaction
fn create_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
}

/// Statement to rollback a transaction
fn rollback_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("ROLLBACK")
}
/// Statement to release a savepoint in a transaction
fn release_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
}

/// Statement to rollback to a savepoint in a transaction
fn rollback_to_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}"))
}

/// Sets the transaction isolation level to given value.
Expand Down
33 changes: 13 additions & 20 deletions quaint/src/connector/sqlite/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,34 +185,27 @@ impl Queryable for Sqlite {
}

/// Statement to begin a transaction
fn begin_statement(&self, depth: u32) -> Cow<'static, str> {
fn begin_statement(&self) -> &'static str {
// From https://sqlite.org/isolation.html:
// `BEGIN IMMEDIATE` avoids possible `SQLITE_BUSY_SNAPSHOT` that arise when another connection jumps ahead in line.
// The BEGIN IMMEDIATE command goes ahead and starts a write transaction, and thus blocks all other writers.
// If the BEGIN IMMEDIATE operation succeeds, then no subsequent operations in that transaction will ever fail with an SQLITE_BUSY error.
if depth > 1 {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("BEGIN IMMEDIATE")
}
"BEGIN IMMEDIATE"
}

/// Statement to commit a transaction
fn commit_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("COMMIT")
}
/// Statement to create a savepoint
fn create_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
}

/// Statement to rollback a transaction
fn rollback_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("ROLLBACK TO savepoint{depth}"))
} else {
Cow::Borrowed("ROLLBACK")
}
/// Statement to release a savepoint
fn release_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
}

/// Statement to rollback to a savepoint
fn rollback_to_savepoint_statement(&self, depth: u32) -> Cow<'static, str> {
Cow::Owned(format!("ROLLBACK TO savepoint{depth}"))
}
}

Expand Down
120 changes: 85 additions & 35 deletions quaint/src/connector/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,25 @@ use std::{

#[async_trait]
pub trait Transaction: Queryable {
fn depth(&self) -> u32;

/// Start a new transaction or nested transaction via savepoint.
async fn begin(&mut self) -> crate::Result<()>;

/// Commit the changes to the database and consume the transaction.
async fn commit(&mut self) -> crate::Result<u32>;
async fn commit(&mut self) -> crate::Result<()>;

/// Rolls back the changes to the database.
async fn rollback(&mut self) -> crate::Result<u32>;
async fn rollback(&mut self) -> crate::Result<()>;

/// Creates a savepoint in the transaction.
async fn create_savepoint(&mut self) -> crate::Result<()>;

/// Releases a savepoint in the transaction.
async fn release_savepoint(&mut self) -> crate::Result<()>;

/// Rolls back to a savepoint in the transaction.
async fn rollback_to_savepoint(&mut self) -> crate::Result<()>;

/// workaround for lack of upcasting between traits https://github.com/rust-lang/rust/issues/65991
fn as_queryable(&self) -> &dyn Queryable;
Expand Down Expand Up @@ -79,71 +90,110 @@ impl<'a> DefaultTransaction<'a> {

#[async_trait]
impl<'a> Transaction for DefaultTransaction<'a> {
fn depth(&self) -> u32 {
*self.depth.lock().unwrap()
}

async fn begin(&mut self) -> crate::Result<()> {
let current_depth = {
// Lock the mutex in it's own scope to ensure its dropped before the await
{
let mut depth = self.depth.lock().unwrap();
*depth += 1;
*depth
};
}

let begin_statement = self.inner.begin_statement(current_depth);
let begin_statement = self.inner.begin_statement();

self.inner.raw_cmd(&begin_statement).await?;
self.inner.raw_cmd(begin_statement).await?;

Ok(())
}

/// Commit the changes to the database and consume the transaction.
async fn commit(&mut self) -> crate::Result<u32> {
// Lock the mutex and get the depth value
let depth_val = {
let depth = self.depth.lock().unwrap();
*depth
};

async fn commit(&mut self) -> crate::Result<()> {
// Perform the asynchronous operation without holding the lock
let commit_statement = self.inner.commit_statement(depth_val);
self.inner.raw_cmd(&commit_statement).await?;
self.inner.raw_cmd("COMMIT").await?;

// Lock the mutex again to modify the depth
let new_depth = {
// Lock the mutex to modify the depth
let mut depth = self.depth.lock().unwrap();
*depth -= 1;

self.gauge.decrement();

Ok(())
}

/// Rolls back the changes to the database.
async fn rollback(&mut self) -> crate::Result<()> {
self.inner.raw_cmd("ROLLBACK").await?;

// Lock the mutex to modify the depth
let mut depth = self.depth.lock().unwrap();
*depth -= 1;

self.gauge.decrement();

Ok(())
}

/// Creates a savepoint in the transaction
async fn create_savepoint(&mut self) -> crate::Result<()> {
let current_depth = {
let mut depth = self.depth.lock().unwrap();
*depth -= 1;
*depth += 1;
*depth
};

if new_depth == 0 {
self.gauge.decrement();
}

Ok(new_depth)
let create_savepoint_statement = self.inner.create_savepoint_statement(current_depth);
self.inner.raw_cmd(&create_savepoint_statement).await?;
Ok(())
}

/// Rolls back the changes to the database.
async fn rollback(&mut self) -> crate::Result<u32> {
/// Releases a savepoint in the transaction
async fn release_savepoint(&mut self) -> crate::Result<()> {
// Lock the mutex and get the depth value
let depth_val = {
let depth = self.depth.lock().unwrap();
*depth
};

// Perform the asynchronous operation without holding the lock
let rollback_statement = self.inner.rollback_statement(depth_val);
if depth_val == 0 {
panic!(
"No savepoint to release in transaction, make sure to call create_savepoint before release_savepoint"
);
}

self.inner.raw_cmd(&rollback_statement).await?;
// Perform the asynchronous operation without holding the lock
let release_savepoint_statement = self.inner.release_savepoint_statement(depth_val);
self.inner.raw_cmd(&release_savepoint_statement).await?;

// Lock the mutex again to modify the depth
let new_depth = {
let mut depth = self.depth.lock().unwrap();
*depth -= 1;
let mut depth = self.depth.lock().unwrap();
*depth -= 1;

Ok(())
}

/// Rollback to savepoint in the transaction
async fn rollback_to_savepoint(&mut self) -> crate::Result<()> {
let depth_val = {
let depth = self.depth.lock().unwrap();
*depth
};

if new_depth == 0 {
self.gauge.decrement();
if depth_val == 0 {
panic!(
"No savepoint to rollback to in transaction, make sure to call create_savepoint before rollback_to_savepoint"
);
}

Ok(new_depth)
let rollback_to_savepoint_statement = self.inner.rollback_to_savepoint_statement(depth_val);
self.inner.raw_cmd(&rollback_to_savepoint_statement).await?;

// Lock the mutex again to modify the depth
let mut depth = self.depth.lock().unwrap();
*depth -= 1;

Ok(())
}

fn as_queryable(&self) -> &dyn Queryable {
Expand Down
Loading

0 comments on commit b8ca794

Please sign in to comment.