Skip to content

Commit

Permalink
fixes locks and make more examples and tests
Browse files Browse the repository at this point in the history
found a fatal bug in expiration and SCRIPT usage.
  • Loading branch information
Heiss committed Sep 22, 2023
1 parent d49b42b commit 38e2f81
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 38 deletions.
7 changes: 1 addition & 6 deletions src/redis/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::ops;
pub struct Generic<T> {
pub(crate) cache: Option<T>,
pub(crate) key: String,
client: redis::Client,
pub(crate) client: redis::Client,
}

impl<T> Generic<T>
Expand Down Expand Up @@ -124,11 +124,6 @@ where
self.cache.as_ref().unwrap()
}

pub fn acquire_mut(&mut self) -> &mut T {
self.cache = self.try_get();
self.cache.as_mut().unwrap()
}

fn try_get(&self) -> Option<T> {
let mut conn = self.get_conn();
let res: RedisResult<String> = conn.get(&self.key);
Expand Down
200 changes: 168 additions & 32 deletions src/redis/lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pub enum LockError {
UnlockFailed,
#[error("No connection to Redis available")]
NoConnection,
#[error("Lock expired with id #{0}")]
LockExpired(usize),
#[error("Error by Redis")]
Redis(#[from] redis::RedisError),
}
Expand Down Expand Up @@ -40,22 +42,23 @@ impl From<i8> for LockNum {
/// 2. The timeout in seconds,
/// 3. The value to store.
const LOCK_SCRIPT: &str = r#"
local val = redis.call("get", ARGV[1])
if redis.call("exists", ARGV[1]) or val == false or val == ARGV[3] then
redis.call("setex", ARGV[1], ARGV[2], ARGV[3])
local val = redis.call("get", ARGV[1] .. "_lock")
if val == false or val == ARGV[3] then
redis.call("setex", ARGV[1] .. "_lock", ARGV[2], ARGV[3])
return 1
end
return 0"#;

/// The drop script.
/// It is used to drop a value in Redis, so that only the instance that locked it can drop it.
///
/// Takes 2 Arguments:
/// 1. The key of the value to drop,
/// 2. The value to check.
const DROP_SCRIPT: &str = r#"
local val = redis.call("get", ARGV[1])
if val == ARGV[2] then
redis.call("del", ARGV[1])
local current_lock = redis.call("get", ARGV[1] .. "_lock")
if current_lock == ARGV[2] then
redis.call("del", ARGV[1] .. "_lock")
return 1
end
return 0"#;
Expand All @@ -65,12 +68,41 @@ const DROP_SCRIPT: &str = r#"
/// It is a very simple counter that is stored in Redis and returns all numbers only once.
///
/// Takes 1 Argument:
/// 1. The key of the field to increment and return.
/// 1. The key of the value to lock.
const UUID_SCRIPT: &str = r#"
redis.call("incr", ARGV[1])
local val = redis.call("get", ARGV[1])
redis.call("incr", ARGV[1] .. "_uuids")
local val = redis.call("get", ARGV[1] .. "_uuids")
return val"#;

/// The store script.
/// It is used to store a value in Redis with a lock.
///
/// Takes 3 Arguments:
/// 1. The key of the value to store,
/// 2. The uuid of the lock object,
/// 3. The value to store.
const STORE_SCRIPT: &str = r#"
local current_lock = redis.call("get", ARGV[1] .. "_lock")
if current_lock == ARGV[2] then
redis.call("set", ARGV[1], ARGV[3])
return 1
end
return 0"#;

/// The load script.
/// It is used to load a value from Redis with a lock.
///
/// Takes 2 Arguments:
/// 1. The key of the value to load,
/// 2. The uuid of the lock.
const LOAD_SCRIPT: &str = r#"
local current_lock = redis.call("get", ARGV[1] .. "_lock")
if current_lock == ARGV[2] then
local val = redis.call("get", ARGV[1])
return val
end
return nil"#;

/// The RedisMutex struct.
/// It is used to lock a value in Redis, so that only one instance can access it at a time.
/// You have to use RedisGeneric as the data type.
Expand All @@ -79,27 +111,27 @@ return val"#;
/// The lock is released when the guard is dropped or it expires.
/// The default expiration time is 1000ms. If you need more time, use the [Guard::expand()] function.
pub struct Mutex<T> {
client: redis::Client,
conn: Option<redis::Connection>,
data: Generic<T>,
key: String,
uuid: usize,
}

impl<T> Mutex<T> {
pub fn new(client: redis::Client, data: Generic<T>) -> Self {
let mut conn = client
impl<T> Mutex<T>
where
T: Serialize + DeserializeOwned,
{
pub fn new(data: Generic<T>) -> Self {
let mut conn = data
.client
.get_connection()
.expect("Failed to get connection to Redis");

let uuid = redis::Script::new(UUID_SCRIPT)
.arg(format!("uuid_{}", data.key))
.arg(&data.key)
.invoke::<usize>(&mut conn)
.expect("Failed to get uuid");

Self {
client,
key: format!("lock_{}", data.key),
data,
conn: Some(conn),
uuid,
Expand All @@ -118,6 +150,58 @@ impl<T> Mutex<T> {
/// this function will block until the lock is released, which will be happen after the lock
/// expires (1000ms).
/// If you need to extend this time, you can use the [Guard::expand()] function.
///
/// # Example
/// ```
/// use dtypes::redis::Di32 as i32;
/// use dtypes::redis::Mutex;
/// use std::thread::scope;
///
/// let client = redis::Client::open("redis://localhost:6379").unwrap();
/// let client2 = client.clone();
///
/// scope(|s| {
/// let t1 = s.spawn(move || {
/// let mut i32 = i32::new("test_add_example1", client2);
/// let mut lock = Mutex::new(i32);
/// let mut guard = lock.lock().unwrap();
/// guard.store(2).expect("TODO: panic message");
/// assert_eq!(*guard, 2);
/// });
/// {
/// let mut i32 = i32::new("test_add_example1", client);
/// let mut lock = Mutex::new(i32);
/// let mut guard = lock.lock().unwrap();
/// guard.store(1).expect("Failed to store value");
/// assert_eq!(*guard, 1);
/// }
/// t1.join().expect("Failed to join thread1");
/// });
/// ```
///
/// It does not allow any deadlocks, because the lock will automatically release after some time.
/// So you have to check for errors, if you want to handle them.
///
/// Beware: Your CPU can anytime switch to another thread, so you have to check for errors!
/// But if you are brave enough, you can drop the result and hope for the best.
///
/// # Example
/// ```
/// use std::thread::sleep;
/// use dtypes::redis::Di32 as i32;
/// use dtypes::redis::Mutex;
///
/// let client = redis::Client::open("redis://localhost:6379").unwrap();
/// let mut i32 = i32::new("test_add_example2", client.clone());
/// i32.store(1);
/// assert_eq!(i32.acquire(), &1);
/// let mut lock = Mutex::new(i32);
///
/// let mut guard = lock.lock().unwrap();
/// sleep(std::time::Duration::from_millis(1000));
/// let res = guard.store(3);
/// assert!(res.is_err(), "{:?}", res);
/// ```
pub fn lock(&mut self) -> Result<Guard<T>, LockError> {
let mut conn = match self.conn.take() {
Some(conn) => conn,
Expand All @@ -131,14 +215,13 @@ impl<T> Mutex<T> {

while LockNum::from(
lock_cmd
.arg(&self.key)
.arg(1000)
.arg(&self.data.key)
.arg(1)
.arg(&self.uuid.to_string())
.invoke::<i8>(&mut conn)
.expect("Failed to lock. You should not see this!"),
) == LockNum::Fail
{
println!("waiting for lock");
std::hint::spin_loop();
}

Expand Down Expand Up @@ -171,7 +254,10 @@ pub struct Guard<'a, T> {
expanded: bool,
}

impl<'a, T> Guard<'a, T> {
impl<'a, T> Guard<'a, T>
where
T: Serialize + DeserializeOwned,
{
fn new(lock: &'a mut Mutex<T>) -> Result<Self, LockError> {
Ok(Self {
lock,
Expand All @@ -190,10 +276,61 @@ impl<'a, T> Guard<'a, T> {
}

let conn = self.lock.conn.as_mut().expect("Connection should be there");
let expand = redis::Cmd::expire(&self.lock.key, 2000);
let expand = redis::Cmd::expire(format!("{}_lock", &self.lock.data.key), 2);
expand.execute(conn);
self.expanded = true;
}

/// Stores the value in Redis.
/// This function blocks until the value is stored.
/// Disables the store operation of the guarded value.
pub fn store(&mut self, value: T) -> Result<(), LockError>
where
T: Serialize,
{
let conn = self.lock.conn.as_mut().ok_or(LockError::NoConnection)?;
let script = redis::Script::new(STORE_SCRIPT);
let result: i8 = script
.arg(&self.lock.data.key)
.arg(self.lock.uuid)
.arg(serde_json::to_string(&value).expect("Failed to serialize value"))
.invoke(conn)
.expect("Failed to store value. You should not see this!");
if result == 0 {
return Err(LockError::LockExpired(self.lock.uuid));
}
self.lock.data.cache = Some(value);
Ok(())
}

/// Loads the value from Redis.
/// This function blocks until the value is loaded.
/// Shadows the load operation of the guarded value.
pub fn acquire(&mut self) -> &T {
self.lock.data.cache = self.try_get();
self.lock.data.cache.as_ref().unwrap()
}

fn try_get(&mut self) -> Option<T> {
let conn = self
.lock
.conn
.as_mut()
.ok_or(LockError::NoConnection)
.expect("Connection should be there");
let script = redis::Script::new(LOAD_SCRIPT);
let result: Option<String> = script
.arg(&self.lock.data.key)
.arg(self.lock.uuid)
.invoke(conn)
.expect("Failed to load value. You should not see this!");
let result = result?;

if result == "nil" {
return None;
}
Some(serde_json::from_str(&result).expect("Failed to deserialize value"))
}
}

impl<T> Deref for Guard<'_, T>
Expand Down Expand Up @@ -222,11 +359,9 @@ impl<T> Drop for Guard<'_, T> {
fn drop(&mut self) {
let conn = self.lock.conn.as_mut().expect("Connection should be there");
let script = redis::Script::new(DROP_SCRIPT);
let key = &self.lock.key;
let uuid = &self.lock.uuid;
script
.arg(key)
.arg(uuid.to_string())
.arg(&self.lock.data.key)
.arg(self.lock.uuid)
.invoke::<()>(conn)
.expect("Failed to drop lock. You should not see this!");
}
Expand All @@ -240,20 +375,21 @@ mod tests {
#[test]
fn test_create_lock() {
let client = redis::Client::open("redis://localhost:6379").unwrap();
let i32 = Di32::new("test_add_locking", client.clone());
let i32_2 = Di32::new("test_add_locking", client.clone());
let mut lock: Mutex<i32> = Mutex::new(client.clone(), i32);
let mut lock2: Mutex<i32> = Mutex::new(client, i32_2);
let client2 = client.clone();

thread::scope(|s| {
let t1 = s.spawn(move || {
let i32_2 = Di32::new("test_add_locking", client2.clone());
let mut lock2: Mutex<i32> = Mutex::new(i32_2);
let mut guard = lock2.lock().unwrap();
guard.store(2);
guard.store(2).expect("TODO: panic message");
assert_eq!(*guard, 2);
});
{
let i32 = Di32::new("test_add_locking", client.clone());
let mut lock: Mutex<i32> = Mutex::new(i32);
let mut guard = lock.lock().unwrap();
guard.store(1);
guard.store(1).expect("TODO: panic message");
assert_eq!(*guard, 1);
}
t1.join().expect("Failed to join thread1");
Expand Down

0 comments on commit 38e2f81

Please sign in to comment.