Skip to content

Commit

Permalink
Add sync::Barrier (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
Heiss authored May 3, 2024
1 parent 0997a86 commit 9fd9317
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

## 0.2.4 - 2023-10-29

- add Barrier type, which implements a barrier for sync based on Redis

## 0.2.3 - 2023-10-29

- add RwLock Type, which implements a reader-writer lock based on Redis
Expand Down
314 changes: 314 additions & 0 deletions src/redis/barrier.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
/// The waiting script.
/// Is is used to indicate, if there is a thread waiting for the barrier.
/// Returns 1 if #num thread waiting >= #num threads that should wait. Otherwise 0.
/// If the thread is the leader, it returns 2.
/// Needs to be used n a loop to update expiration time to signal your wait.
///
/// Takes 4 arguments:
/// 1. The key of the barrier.
/// 2. The id of the barrier itself.
/// 3. The number of threads that should wait for the barrier.
/// 4. The timeout in seconds.
const WAITING_SCRIPT: &str = r#"
redis.call("set", ARGV[1] .. ":waiting:" .. ARGV[2], 1, "EX", ARGV[4])
local leader_id = redis.call("get", ARGV[1] .. ":leader")
if leader_id then
if leader_id == ARGV[2] then
return 2
end
return 1
end
local count = 0
local cursor = "0"
repeat
local res = redis.call("scan", cursor, "MATCH", ARGV[1] .. ":waiting:*", "COUNT", ARGV[3] + 1)
if next(res[2]) ~= nil then
count = count + #res[2]
end
cursor = res[1]
until cursor == "0"
if count < tonumber(ARGV[3]) then
return 0
end
if not leader_id then
if redis.call("set", ARGV[1] .. ":leader" , ARGV[2], "EX", ARGV[4], "NX") then
return 2
end
end
return 1
"#;

/// The reset script.
/// It is used to reset the barrier, so you can reuse it.
/// Essentially it deletes all keys that are used by the barrier.
///
/// Takes 1 Argument:
/// 1. The key of the value to lock.
/// 2. The uuid of the barrier.
/// 3. The number of threads that should wait for the barrier.
const RESET_SCRIPT: &str = r#"
redis.call("del", ARGV[1] .. ":waiting:" .. ARGV[2])
local count = 0
local cursor = "0"
repeat
local res = redis.call("scan", cursor, "MATCH", ARGV[1] .. ":waiting:*", "COUNT", ARGV[3] + 1)
if next(res[2]) ~= nil then
count = count + #res[2]
end
cursor = res[1]
until cursor == "0"
-- if it is the last barrier, delete the leader and uuids key
if count == 0 then
redis.call("del", ARGV[1] .. ":leader")
redis.call("del", ARGV[1] .. ":uuids")
end
"#;

/// The uuid script.
/// It is used to generate a uuid for the barrier.
/// It is a very simple counter that is stored in Redis and returns all numbers only once.
///
/// Takes 1 Argument:
/// 1. The key of the value to lock.
const UUID_SCRIPT: &str = r#"
redis.call("incr", ARGV[1] .. ":uuids")
local val = redis.call("get", ARGV[1] .. ":uuids")
return val
"#;

pub struct Barrier {
uuid: usize,
num: usize,
key: String,
_client: redis::Client,
conn: Option<redis::Connection>,
}

#[derive(PartialEq)]
enum RedisBarrierStatus {
Waiting,
Leader,
Done,
}

impl From<u8> for RedisBarrierStatus {
fn from(val: u8) -> Self {
match val {
0 => RedisBarrierStatus::Waiting,
1 => RedisBarrierStatus::Done,
2 => RedisBarrierStatus::Leader,
_ => panic!("Invalid RedisBarrierStatus"),
}
}
}

/// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all systems
/// in the [`Barrier`] have rendezvoused.
///
/// # Examples
///
/// ```
/// use dtypes::redis::Barrier;
///
/// let client = redis::Client::open("redis://localhost:6379").unwrap();
/// let mut barrier = Barrier::new(1, "barrier_doc_test", client);
/// let barrier_wait_result = barrier.wait();
/// ```
pub struct BarrierWaitResult(bool);

impl BarrierWaitResult {
/// Returns `true` if this thread is the "leader thread" for the call to
/// [`Barrier::wait()`].
///
/// Only one thread will have `true` returned from their result, all other
/// threads will have `false` returned.
///
/// # Examples
///
/// ```
/// use dtypes::redis::Barrier;
///
/// let client = redis::Client::open("redis://localhost:6379").unwrap();
/// let mut barrier = Barrier::new(1, "barrier_doc_test", client);
/// let barrier_wait_result = barrier.wait();
/// println!("{:?}", barrier_wait_result.is_leader());
/// ```
pub fn is_leader(&self) -> bool {
self.0
}
}

enum BarrierError {

Check warning on line 150 in src/redis/barrier.rs

View workflow job for this annotation

GitHub Actions / clippy

enum `BarrierError` is never used

warning: enum `BarrierError` is never used --> src/redis/barrier.rs:150:6 | 150 | enum BarrierError { | ^^^^^^^^^^^^ | = note: `#[warn(dead_code)]` on by default
RedisError(redis::RedisError),
}

impl Barrier {
pub fn new(num: usize, key: &str, client: redis::Client) -> Self {
let mut conn = client.get_connection().unwrap();

let uuid = redis::Script::new(UUID_SCRIPT)
.arg(&key)

Check warning on line 159 in src/redis/barrier.rs

View workflow job for this annotation

GitHub Actions / clippy

the borrowed expression implements the required traits

warning: the borrowed expression implements the required traits --> src/redis/barrier.rs:159:18 | 159 | .arg(&key) | ^^^^ help: change this to: `key` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_borrows_for_generic_args = note: `#[warn(clippy::needless_borrows_for_generic_args)]` on by default
.arg(&num)

Check warning on line 160 in src/redis/barrier.rs

View workflow job for this annotation

GitHub Actions / clippy

the borrowed expression implements the required traits

warning: the borrowed expression implements the required traits --> src/redis/barrier.rs:160:18 | 160 | .arg(&num) | ^^^^ help: change this to: `num` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_borrows_for_generic_args
.invoke::<usize>(&mut conn)
.expect("Failed to create barrier");

Barrier {
uuid: uuid,

Check warning on line 165 in src/redis/barrier.rs

View workflow job for this annotation

GitHub Actions / clippy

redundant field names in struct initialization

warning: redundant field names in struct initialization --> src/redis/barrier.rs:165:13 | 165 | uuid: uuid, | ^^^^^^^^^^ help: replace it with: `uuid` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#redundant_field_names = note: `#[warn(clippy::redundant_field_names)]` on by default
num,
key: key.to_string(),
_client: client,
conn: Some(conn),
}
}

/// Blocks the current thread until all threads have rendezvoused here.
///
/// Barriers are re-usable after all threads have rendezvoused once, and can
/// be used continuously.
///
/// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
/// returns `true` from [`BarrierWaitResult::is_leader()`] when returning
/// from this function, and all other threads will receive a result that
/// will return `false` from [`BarrierWaitResult::is_leader()`].
///
/// The barrier needs to be mutable, because it guarantees that the barrier is only used once in thread.
/// If you want to synchronize threads, you need to create a new barrier for each thread, so it has its own uuid.
///
/// # Examples
///
/// ```
/// use dtypes::redis::Barrier;
/// use std::thread;
///
/// let n = 10;
/// let mut handles = Vec::with_capacity(n);
/// let client = redis::Client::open("redis://localhost:6379").unwrap();
/// for _ in 0..n {
/// // The same messages will be printed together.
/// // You will NOT see any interleaving.
/// let mut barrier = Barrier::new(n, "barrier_doc_test2", client.clone());
/// handles.push(thread::spawn(move|| {
/// println!("before wait");
/// barrier.wait();
/// println!("after wait");
/// }));
/// }
/// // Wait for other threads to finish.
/// for handle in handles {
/// handle.join().unwrap();
/// }
/// ```
pub fn wait(&mut self) -> BarrierWaitResult {
let mut conn = self.conn.take().unwrap();
let timeout = 2;

let mut status = RedisBarrierStatus::Waiting;
while status == RedisBarrierStatus::Waiting {
status = redis::Script::new(WAITING_SCRIPT)
.arg(&self.key)
.arg(self.uuid)
.arg(self.num)
.arg(timeout)
.invoke::<u8>(&mut conn)
.expect("Failed to wait for barrier")
.into();
}
self.conn = Some(conn);

if status == RedisBarrierStatus::Leader {
BarrierWaitResult(true)
} else {
BarrierWaitResult(false)
}
}
}

impl Drop for Barrier {
fn drop(&mut self) {
let mut conn = self.conn.take().unwrap();
redis::Script::new(RESET_SCRIPT)
.arg(&self.key)
.arg(self.uuid)
.arg(self.num)
.invoke::<()>(&mut conn)
.expect("Failed to reset barrier");
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::thread::sleep;

#[test]
fn test_barrier_leader() {
let client = redis::Client::open("redis://localhost:6379").unwrap();
let mut barrier = Barrier::new(1, "barrier_test_leader", client);
let barrier_wait_result = barrier.wait();
assert!(barrier_wait_result.is_leader());
}

#[test]
fn test_barrier_not_leader() {
let client = redis::Client::open("redis://localhost:6379").unwrap();

let mut barrier = Barrier::new(2, "barrier_test_notleader", client.clone());

let h1 = thread::spawn(move || {
let mut barrier = Barrier::new(2, "barrier_test_notleader", client);
let barrier_wait_result = barrier.wait();
assert!(!barrier_wait_result.is_leader());
});

let h2 = thread::spawn(move || {
sleep(std::time::Duration::from_millis(1000));
let barrier_wait_result = barrier.wait();
assert!(barrier_wait_result.is_leader());
});

h1.join().unwrap();
h2.join().unwrap();
}

#[test]
fn test_barrier_slow_check() {
let n = 10;
let mut handles = Vec::with_capacity(n);
let client = redis::Client::open("redis://localhost:6379").unwrap();
for _ in 0..n {
// The same messages will be printed together.
// You will NOT see any interleaving.
let mut barrier = Barrier::new(n, "barrier_doc_test2", client.clone());
handles.push(thread::spawn(move || barrier.wait().is_leader()));
}
// Wait for other threads to finish.
assert_eq!(
handles
.into_iter()
.map(|h| h.join().unwrap())
.map(|x| if x { 1 } else { 0 })
.sum::<i32>(),
1
);
}

#[test]
fn test_barrier_reuse() {
let client = redis::Client::open("redis://localhost:6379").unwrap();

let mut barrier = Barrier::new(1, "barrier_test_reuse", client.clone());
barrier.wait();
let mut barrier = Barrier::new(1, "barrier_test_reuse", client.clone());
barrier.wait();
}
}
2 changes: 2 additions & 0 deletions src/redis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
//!
//! It is possible to implement your own complex types by implementing the [BackedType](crate::BackedType) trait.
//! But it should not be needed as long as your type implements some or all of the various [Ops](https://doc.rust-lang.org/std/ops/index.html) traits.
mod barrier;
mod bool_type;
mod clock;
mod generic;
Expand All @@ -53,6 +54,7 @@ mod string;

pub(crate) use helper::apply_operator;

pub use barrier::{Barrier, BarrierWaitResult};
pub use bool_type::TBool as Dbool;
pub use clock::ClockOrdered;
pub use generic::Generic;
Expand Down

0 comments on commit 9fd9317

Please sign in to comment.