diff --git a/src/redis/generic.rs b/src/redis/generic.rs index a4c4c84..9aebb9e 100644 --- a/src/redis/generic.rs +++ b/src/redis/generic.rs @@ -13,7 +13,7 @@ use std::ops; pub struct Generic { pub(crate) cache: Option, pub(crate) key: String, - client: redis::Client, + pub(crate) client: redis::Client, } impl Generic @@ -124,11 +124,6 @@ where self.cache.as_ref().unwrap() } - pub fn acquire_mut(&mut self) -> &mut T { - self.cache = self.try_get(); - self.cache.as_mut().unwrap() - } - fn try_get(&self) -> Option { let mut conn = self.get_conn(); let res: RedisResult = conn.get(&self.key); diff --git a/src/redis/lock.rs b/src/redis/lock.rs index 783a03a..e832b79 100644 --- a/src/redis/lock.rs +++ b/src/redis/lock.rs @@ -13,6 +13,8 @@ pub enum LockError { UnlockFailed, #[error("No connection to Redis available")] NoConnection, + #[error("Lock expired with id #{0}")] + LockExpired(usize), #[error("Error by Redis")] Redis(#[from] redis::RedisError), } @@ -40,22 +42,23 @@ impl From for LockNum { /// 2. The timeout in seconds, /// 3. The value to store. const LOCK_SCRIPT: &str = r#" - local val = redis.call("get", ARGV[1]) - if redis.call("exists", ARGV[1]) or val == false or val == ARGV[3] then - redis.call("setex", ARGV[1], ARGV[2], ARGV[3]) + local val = redis.call("get", ARGV[1] .. "_lock") + if val == false or val == ARGV[3] then + redis.call("setex", ARGV[1] .. "_lock", ARGV[2], ARGV[3]) return 1 end return 0"#; /// The drop script. /// It is used to drop a value in Redis, so that only the instance that locked it can drop it. +/// /// Takes 2 Arguments: /// 1. The key of the value to drop, /// 2. The value to check. const DROP_SCRIPT: &str = r#" - local val = redis.call("get", ARGV[1]) - if val == ARGV[2] then - redis.call("del", ARGV[1]) + local current_lock = redis.call("get", ARGV[1] .. "_lock") + if current_lock == ARGV[2] then + redis.call("del", ARGV[1] .. "_lock") return 1 end return 0"#; @@ -65,12 +68,41 @@ const DROP_SCRIPT: &str = r#" /// It is a very simple counter that is stored in Redis and returns all numbers only once. /// /// Takes 1 Argument: -/// 1. The key of the field to increment and return. +/// 1. The key of the value to lock. const UUID_SCRIPT: &str = r#" -redis.call("incr", ARGV[1]) -local val = redis.call("get", ARGV[1]) +redis.call("incr", ARGV[1] .. "_uuids") +local val = redis.call("get", ARGV[1] .. "_uuids") return val"#; +/// The store script. +/// It is used to store a value in Redis with a lock. +/// +/// Takes 3 Arguments: +/// 1. The key of the value to store, +/// 2. The uuid of the lock object, +/// 3. The value to store. +const STORE_SCRIPT: &str = r#" +local current_lock = redis.call("get", ARGV[1] .. "_lock") +if current_lock == ARGV[2] then + redis.call("set", ARGV[1], ARGV[3]) + return 1 +end +return 0"#; + +/// The load script. +/// It is used to load a value from Redis with a lock. +/// +/// Takes 2 Arguments: +/// 1. The key of the value to load, +/// 2. The uuid of the lock. +const LOAD_SCRIPT: &str = r#" +local current_lock = redis.call("get", ARGV[1] .. "_lock") +if current_lock == ARGV[2] then + local val = redis.call("get", ARGV[1]) + return val +end +return nil"#; + /// The RedisMutex struct. /// It is used to lock a value in Redis, so that only one instance can access it at a time. /// You have to use RedisGeneric as the data type. @@ -79,27 +111,27 @@ return val"#; /// The lock is released when the guard is dropped or it expires. /// The default expiration time is 1000ms. If you need more time, use the [Guard::expand()] function. pub struct Mutex { - client: redis::Client, conn: Option, data: Generic, - key: String, uuid: usize, } -impl Mutex { - pub fn new(client: redis::Client, data: Generic) -> Self { - let mut conn = client +impl Mutex +where + T: Serialize + DeserializeOwned, +{ + pub fn new(data: Generic) -> Self { + let mut conn = data + .client .get_connection() .expect("Failed to get connection to Redis"); let uuid = redis::Script::new(UUID_SCRIPT) - .arg(format!("uuid_{}", data.key)) + .arg(&data.key) .invoke::(&mut conn) .expect("Failed to get uuid"); Self { - client, - key: format!("lock_{}", data.key), data, conn: Some(conn), uuid, @@ -118,6 +150,58 @@ impl Mutex { /// this function will block until the lock is released, which will be happen after the lock /// expires (1000ms). /// If you need to extend this time, you can use the [Guard::expand()] function. + /// + /// # Example + /// ``` + /// use dtypes::redis::Di32 as i32; + /// use dtypes::redis::Mutex; + /// use std::thread::scope; + /// + /// let client = redis::Client::open("redis://localhost:6379").unwrap(); + /// let client2 = client.clone(); + /// + /// scope(|s| { + /// let t1 = s.spawn(move || { + /// let mut i32 = i32::new("test_add_example1", client2); + /// let mut lock = Mutex::new(i32); + /// let mut guard = lock.lock().unwrap(); + /// guard.store(2).expect("TODO: panic message"); + /// assert_eq!(*guard, 2); + /// }); + /// { + /// let mut i32 = i32::new("test_add_example1", client); + /// let mut lock = Mutex::new(i32); + /// let mut guard = lock.lock().unwrap(); + /// guard.store(1).expect("Failed to store value"); + /// assert_eq!(*guard, 1); + /// } + /// t1.join().expect("Failed to join thread1"); + /// }); + /// ``` + /// + /// It does not allow any deadlocks, because the lock will automatically release after some time. + /// So you have to check for errors, if you want to handle them. + /// + /// Beware: Your CPU can anytime switch to another thread, so you have to check for errors! + /// But if you are brave enough, you can drop the result and hope for the best. + /// + /// # Example + /// ``` + /// use std::thread::sleep; + /// use dtypes::redis::Di32 as i32; + /// use dtypes::redis::Mutex; + /// + /// let client = redis::Client::open("redis://localhost:6379").unwrap(); + /// let mut i32 = i32::new("test_add_example2", client.clone()); + /// i32.store(1); + /// assert_eq!(i32.acquire(), &1); + /// let mut lock = Mutex::new(i32); + /// + /// let mut guard = lock.lock().unwrap(); + /// sleep(std::time::Duration::from_millis(1000)); + /// let res = guard.store(3); + /// assert!(res.is_err(), "{:?}", res); + /// ``` pub fn lock(&mut self) -> Result, LockError> { let mut conn = match self.conn.take() { Some(conn) => conn, @@ -131,14 +215,13 @@ impl Mutex { while LockNum::from( lock_cmd - .arg(&self.key) - .arg(1000) + .arg(&self.data.key) + .arg(1) .arg(&self.uuid.to_string()) .invoke::(&mut conn) .expect("Failed to lock. You should not see this!"), ) == LockNum::Fail { - println!("waiting for lock"); std::hint::spin_loop(); } @@ -171,7 +254,10 @@ pub struct Guard<'a, T> { expanded: bool, } -impl<'a, T> Guard<'a, T> { +impl<'a, T> Guard<'a, T> +where + T: Serialize + DeserializeOwned, +{ fn new(lock: &'a mut Mutex) -> Result { Ok(Self { lock, @@ -190,10 +276,61 @@ impl<'a, T> Guard<'a, T> { } let conn = self.lock.conn.as_mut().expect("Connection should be there"); - let expand = redis::Cmd::expire(&self.lock.key, 2000); + let expand = redis::Cmd::expire(format!("{}_lock", &self.lock.data.key), 2); expand.execute(conn); self.expanded = true; } + + /// Stores the value in Redis. + /// This function blocks until the value is stored. + /// Disables the store operation of the guarded value. + pub fn store(&mut self, value: T) -> Result<(), LockError> + where + T: Serialize, + { + let conn = self.lock.conn.as_mut().ok_or(LockError::NoConnection)?; + let script = redis::Script::new(STORE_SCRIPT); + let result: i8 = script + .arg(&self.lock.data.key) + .arg(self.lock.uuid) + .arg(serde_json::to_string(&value).expect("Failed to serialize value")) + .invoke(conn) + .expect("Failed to store value. You should not see this!"); + if result == 0 { + return Err(LockError::LockExpired(self.lock.uuid)); + } + self.lock.data.cache = Some(value); + Ok(()) + } + + /// Loads the value from Redis. + /// This function blocks until the value is loaded. + /// Shadows the load operation of the guarded value. + pub fn acquire(&mut self) -> &T { + self.lock.data.cache = self.try_get(); + self.lock.data.cache.as_ref().unwrap() + } + + fn try_get(&mut self) -> Option { + let conn = self + .lock + .conn + .as_mut() + .ok_or(LockError::NoConnection) + .expect("Connection should be there"); + let script = redis::Script::new(LOAD_SCRIPT); + let result: Option = script + .arg(&self.lock.data.key) + .arg(self.lock.uuid) + .invoke(conn) + .expect("Failed to load value. You should not see this!"); + let result = result?; + + if result == "nil" { + return None; + } + Some(serde_json::from_str(&result).expect("Failed to deserialize value")) + } } impl Deref for Guard<'_, T> @@ -222,11 +359,9 @@ impl Drop for Guard<'_, T> { fn drop(&mut self) { let conn = self.lock.conn.as_mut().expect("Connection should be there"); let script = redis::Script::new(DROP_SCRIPT); - let key = &self.lock.key; - let uuid = &self.lock.uuid; script - .arg(key) - .arg(uuid.to_string()) + .arg(&self.lock.data.key) + .arg(self.lock.uuid) .invoke::<()>(conn) .expect("Failed to drop lock. You should not see this!"); } @@ -240,20 +375,21 @@ mod tests { #[test] fn test_create_lock() { let client = redis::Client::open("redis://localhost:6379").unwrap(); - let i32 = Di32::new("test_add_locking", client.clone()); - let i32_2 = Di32::new("test_add_locking", client.clone()); - let mut lock: Mutex = Mutex::new(client.clone(), i32); - let mut lock2: Mutex = Mutex::new(client, i32_2); + let client2 = client.clone(); thread::scope(|s| { let t1 = s.spawn(move || { + let i32_2 = Di32::new("test_add_locking", client2.clone()); + let mut lock2: Mutex = Mutex::new(i32_2); let mut guard = lock2.lock().unwrap(); - guard.store(2); + guard.store(2).expect("TODO: panic message"); assert_eq!(*guard, 2); }); { + let i32 = Di32::new("test_add_locking", client.clone()); + let mut lock: Mutex = Mutex::new(i32); let mut guard = lock.lock().unwrap(); - guard.store(1); + guard.store(1).expect("TODO: panic message"); assert_eq!(*guard, 1); } t1.join().expect("Failed to join thread1");