diff --git a/commons/zenoh-protocol/src/core/wire_expr.rs b/commons/zenoh-protocol/src/core/wire_expr.rs index 3681863ca..7a70f1006 100644 --- a/commons/zenoh-protocol/src/core/wire_expr.rs +++ b/commons/zenoh-protocol/src/core/wire_expr.rs @@ -17,7 +17,7 @@ use alloc::{ borrow::Cow, string::{String, ToString}, }; -use core::{convert::TryInto, fmt}; +use core::{convert::TryInto, fmt, sync::atomic::AtomicU16}; use zenoh_keyexpr::{keyexpr, OwnedKeyExpr}; use zenoh_result::{bail, ZResult}; @@ -28,6 +28,7 @@ use crate::network::Mapping; pub type ExprId = u16; pub type ExprLen = u16; +pub type AtomicExprId = AtomicU16; pub const EMPTY_EXPR_ID: ExprId = 0; /// A zenoh **resource** is represented by a pair composed by a **key** and a diff --git a/commons/zenoh-protocol/src/network/mod.rs b/commons/zenoh-protocol/src/network/mod.rs index 4e38ceda2..ed23b0337 100644 --- a/commons/zenoh-protocol/src/network/mod.rs +++ b/commons/zenoh-protocol/src/network/mod.rs @@ -27,7 +27,7 @@ pub use declare::{ pub use interest::Interest; pub use oam::Oam; pub use push::Push; -pub use request::{Request, RequestId}; +pub use request::{AtomicRequestId, Request, RequestId}; pub use response::{Response, ResponseFinal}; use crate::core::{CongestionControl, Priority, Reliability}; diff --git a/commons/zenoh-protocol/src/network/request.rs b/commons/zenoh-protocol/src/network/request.rs index 019e68095..3fd9eb221 100644 --- a/commons/zenoh-protocol/src/network/request.rs +++ b/commons/zenoh-protocol/src/network/request.rs @@ -11,10 +11,13 @@ // Contributors: // ZettaScale Zenoh Team, // +use core::sync::atomic::AtomicU32; + use crate::{core::WireExpr, zenoh::RequestBody}; /// The resolution of a RequestId pub type RequestId = u32; +pub type AtomicRequestId = AtomicU32; pub mod flag { pub const N: u8 = 1 << 5; // 0x20 Named if N==1 then the key expr has name/suffix diff --git a/zenoh/src/api/key_expr.rs b/zenoh/src/api/key_expr.rs index 86e188022..2a3c775bf 100644 --- a/zenoh/src/api/key_expr.rs +++ b/zenoh/src/api/key_expr.rs @@ -618,8 +618,7 @@ impl Wait for KeyExprUndeclaration<'_> { }; tracing::trace!("undeclare_keyexpr({:?})", expr_id); let mut state = zwrite!(session.0.state); - assert_ne!(expr_id, 0, "0 is not a valid keyexpr id"); - state.local_resources.remove(expr_id as usize); + state.local_resources.remove(&expr_id); let primitives = state.primitives()?; drop(state); diff --git a/zenoh/src/api/session.rs b/zenoh/src/api/session.rs index 99af898b4..0c01bffdb 100644 --- a/zenoh/src/api/session.rs +++ b/zenoh/src/api/session.rs @@ -25,7 +25,6 @@ use std::{ time::{Duration, SystemTime, UNIX_EPOCH}, }; -use slab::Slab; use tracing::{error, info, trace, warn}; use uhlc::Timestamp; #[cfg(feature = "internal")] @@ -43,7 +42,8 @@ use zenoh_protocol::network::{ use zenoh_protocol::{ core::{ key_expr::{keyexpr, OwnedKeyExpr}, - CongestionControl, EntityId, ExprId, Parameters, Reliability, WireExpr, EMPTY_EXPR_ID, + AtomicExprId, CongestionControl, EntityId, ExprId, Parameters, Reliability, WireExpr, + EMPTY_EXPR_ID, }, network::{ self, @@ -53,8 +53,8 @@ use zenoh_protocol::{ UndeclareSubscriber, }, interest::{InterestMode, InterestOptions}, - push, request, DeclareFinal, Interest, Mapping, Push, Request, RequestId, Response, - ResponseFinal, + push, request, AtomicRequestId, DeclareFinal, Interest, Mapping, Push, Request, RequestId, + Response, ResponseFinal, }, zenoh::{ query::{self, ext::QueryBodyType}, @@ -121,7 +121,11 @@ zconfigurable! { pub(crate) struct SessionState { pub(crate) primitives: Option>, // @TODO replace with MaybeUninit ?? - pub(crate) local_resources: Slab, + pub(crate) expr_id_counter: AtomicExprId, // @TODO: manage rollover and uniqueness + pub(crate) qid_counter: AtomicRequestId, + #[cfg(feature = "unstable")] + pub(crate) liveliness_qid_counter: AtomicRequestId, + pub(crate) local_resources: HashMap, pub(crate) remote_resources: HashMap, #[cfg(feature = "unstable")] pub(crate) remote_subscribers: HashMap>, @@ -136,9 +140,9 @@ pub(crate) struct SessionState { pub(crate) tokens: HashMap>, #[cfg(feature = "unstable")] pub(crate) matching_listeners: HashMap>, - pub(crate) queries: Slab, + pub(crate) queries: HashMap, #[cfg(feature = "unstable")] - pub(crate) liveliness_queries: Slab, + pub(crate) liveliness_queries: HashMap, pub(crate) aggregated_subscribers: Vec, pub(crate) aggregated_publishers: Vec, } @@ -148,12 +152,13 @@ impl SessionState { aggregated_subscribers: Vec, aggregated_publishers: Vec, ) -> SessionState { - // Note: local_resources start at 1 because 0 is reserved for NO_RESOURCE - let mut local_resources = Slab::new(); - local_resources.insert(Resource::Prefix { prefix: "".into() }); SessionState { primitives: None, - local_resources, + expr_id_counter: AtomicExprId::new(1), // Note: start at 1 because 0 is reserved for NO_RESOURCE + qid_counter: AtomicRequestId::new(0), + #[cfg(feature = "unstable")] + liveliness_qid_counter: AtomicRequestId::new(0), + local_resources: HashMap::new(), remote_resources: HashMap::new(), #[cfg(feature = "unstable")] remote_subscribers: HashMap::new(), @@ -168,9 +173,9 @@ impl SessionState { tokens: HashMap::new(), #[cfg(feature = "unstable")] matching_listeners: HashMap::new(), - queries: Slab::new(), + queries: HashMap::new(), #[cfg(feature = "unstable")] - liveliness_queries: Slab::new(), + liveliness_queries: HashMap::new(), aggregated_subscribers, aggregated_publishers, } @@ -188,16 +193,13 @@ impl SessionState { #[inline] fn get_local_res(&self, id: &ExprId) -> Option<&Resource> { - if *id == 0 { - return None; - } - self.local_resources.get(*id as usize) + self.local_resources.get(id) } #[inline] fn get_remote_res(&self, id: &ExprId, mapping: Mapping) -> Option<&Resource> { match mapping { - Mapping::Receiver => self.get_local_res(id), + Mapping::Receiver => self.local_resources.get(id), Mapping::Sender => self.remote_resources.get(id), } } @@ -1127,11 +1129,11 @@ impl SessionInner { match state .local_resources .iter() - .skip(1) // skip NO_RESOURCE - .find(|(_, res)| res.name() == prefix) + .find(|(_expr_id, res)| res.name() == prefix) { - Some((expr_id, _res)) => Ok(expr_id as ExprId), + Some((expr_id, _res)) => Ok(*expr_id), None => { + let expr_id = state.expr_id_counter.fetch_add(1, Ordering::SeqCst); let mut res = Resource::new(Box::from(prefix)); if let Resource::Node(res_node) = &mut res { for kind in [ @@ -1145,10 +1147,7 @@ impl SessionInner { } } } - if state.local_resources.vacant_key() > ExprId::MAX as usize { - bail!("too many keyexprs declared"); - } - let expr_id = state.local_resources.insert(res) as ExprId; + state.local_resources.insert(expr_id, res); drop(state); primitives.send_declare(Declare { interest_id: None, @@ -1333,9 +1332,8 @@ impl SessionInner { .insert(sub_state.id, sub_state.clone()); for res in state .local_resources - .iter_mut() - .skip(1) // skip NO_RESOURCE - .filter_map(|(_, res)|res.as_node_mut()) + .values_mut() + .filter_map(Resource::as_node_mut) { if key_expr.intersects(&res.key_expr) { res.subscribers_mut(SubscriberKind::Subscriber) @@ -1413,9 +1411,8 @@ impl SessionInner { trace!("undeclare_subscriber({:?})", sub_state); for res in state .local_resources - .iter_mut() - .skip(1) // skip NO_RESOURCE - .filter_map(|(_, res)|res.as_node_mut()) + .values_mut() + .filter_map(Resource::as_node_mut) { res.subscribers_mut(kind) .retain(|sub| sub.id != sub_state.id); @@ -1607,9 +1604,8 @@ impl SessionInner { for res in state .local_resources - .iter_mut() - .skip(1) // skip NO_RESOURCE - .filter_map(|(_, res)|res.as_node_mut()) + .values_mut() + .filter_map(Resource::as_node_mut) { if key_expr.intersects(&res.key_expr) { res.subscribers_mut(SubscriberKind::LivelinessSubscriber) @@ -2069,22 +2065,12 @@ impl SessionInner { ConsolidationMode::Auto => ConsolidationMode::Latest, mode => mode, }; + let qid = state.qid_counter.fetch_add(1, Ordering::SeqCst); let nb_final = match destination { Locality::Any => 2, _ => 1, }; - let wexpr = key_expr.to_wire(self).to_owned(); - let qid = state.queries.insert(QueryState { - nb_final, - key_expr: key_expr.clone().into_owned(), - parameters: parameters.clone().into_owned(), - reception_mode: consolidation, - replies: (consolidation != ConsolidationMode::None).then(HashMap::new), - callback, - }) as RequestId; - tracing::trace!("Register query {} (nb_final = {})", qid, nb_final); - let token = self.task_controller.get_cancellation_token(); self.task_controller .spawn_with_rt(zenoh_runtime::ZRuntime::Net, { @@ -2095,7 +2081,7 @@ impl SessionInner { tokio::select! { _ = tokio::time::sleep(timeout) => { let mut state = zwrite!(session.state); - if let Some(query) = state.queries.try_remove(qid as usize) { + if let Some(query) = state.queries.remove(&qid) { std::mem::drop(state); tracing::debug!("Timeout on query {}! Send error and close.", qid); if query.reception_mode == ConsolidationMode::Latest { @@ -2115,6 +2101,20 @@ impl SessionInner { } }); + tracing::trace!("Register query {} (nb_final = {})", qid, nb_final); + let wexpr = key_expr.to_wire(self).to_owned(); + state.queries.insert( + qid, + QueryState { + nb_final, + key_expr: key_expr.clone().into_owned(), + parameters: parameters.clone().into_owned(), + reception_mode: consolidation, + replies: (consolidation != ConsolidationMode::None).then(HashMap::new), + callback, + }, + ); + let primitives = state.primitives()?; drop(state); @@ -2176,13 +2176,7 @@ impl SessionInner { ) -> ZResult<()> { tracing::trace!("liveliness.get({}, {:?})", key_expr, timeout); let mut state = zwrite!(self.state); - - let wexpr = key_expr.to_wire(self).to_owned(); - let id = state - .liveliness_queries - .insert(LivelinessQueryState { callback }) as InterestId; - tracing::trace!("Register liveliness query {}", id); - + let id = state.liveliness_qid_counter.fetch_add(1, Ordering::SeqCst); let token = self.task_controller.get_cancellation_token(); self.task_controller .spawn_with_rt(zenoh_runtime::ZRuntime::Net, { @@ -2192,7 +2186,7 @@ impl SessionInner { tokio::select! { _ = tokio::time::sleep(timeout) => { let mut state = zwrite!(session.state); - if let Some(query) = state.liveliness_queries.try_remove(id as usize) { + if let Some(query) = state.liveliness_queries.remove(&id) { std::mem::drop(state); tracing::debug!("Timeout on liveliness query {}! Send error and close.", id); query.callback.call(Reply { @@ -2207,6 +2201,12 @@ impl SessionInner { } }); + tracing::trace!("Register liveliness query {}", id); + let wexpr = key_expr.to_wire(self).to_owned(); + state + .liveliness_queries + .insert(id, LivelinessQueryState { callback }); + let primitives = state.primitives()?; drop(state); @@ -2402,8 +2402,7 @@ impl Primitives for WeakSession { { Ok(key_expr) => { if let Some(interest_id) = msg.interest_id { - if let Some(query) = state.liveliness_queries.get(interest_id as usize) - { + if let Some(query) = state.liveliness_queries.get(&interest_id) { let reply = Reply { result: Ok(Sample { key_expr, @@ -2517,7 +2516,7 @@ impl Primitives for WeakSession { #[cfg(feature = "unstable")] if let Some(interest_id) = msg.interest_id { let mut state = zwrite!(self.state); - let _ = state.liveliness_queries.try_remove(interest_id as usize); + let _ = state.liveliness_queries.remove(&interest_id); } } } @@ -2593,7 +2592,7 @@ impl Primitives for WeakSession { if state.primitives.is_none() { return; // Session closing or closed } - match state.queries.get_mut(msg.rid as usize) { + match state.queries.get_mut(&msg.rid) { Some(query) => { let callback = query.callback.clone(); std::mem::drop(state); @@ -2624,7 +2623,7 @@ impl Primitives for WeakSession { return; } }; - match state.queries.get_mut(msg.rid as usize) { + match state.queries.get_mut(&msg.rid) { Some(query) => { let c = zcondfeat!("unstable", !query.parameters.reply_key_expr_any(), true); @@ -2795,11 +2794,11 @@ impl Primitives for WeakSession { if state.primitives.is_none() { return; // Session closing or closed } - match state.queries.get_mut(msg.rid as usize) { + match state.queries.get_mut(&msg.rid) { Some(query) => { query.nb_final -= 1; if query.nb_final == 0 { - let query = state.queries.try_remove(msg.rid as usize).unwrap(); + let query = state.queries.remove(&msg.rid).unwrap(); std::mem::drop(state); if query.reception_mode == ConsolidationMode::Latest { for (_, reply) in query.replies.unwrap().into_iter() {