diff --git a/CHANGES.md b/CHANGES.md index 086e2a5..9ad7f79 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [4.3.0] - 2024-11-04 + +* Use updated Service trait + ## [4.2.1] - 2024-11-01 * Better rediness error handling diff --git a/Cargo.toml b/Cargo.toml index 94b9901..8e7bb0b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-mqtt" -version = "4.2.1" +version = "4.3.0" authors = ["ntex contributors "] description = "Client and Server framework for MQTT v5 and v3.1.1 protocols" documentation = "https://docs.rs/ntex-mqtt" @@ -17,8 +17,8 @@ features = ["ntex/tokio"] [dependencies] ntex-io = "2" ntex-net = "2" -ntex-util = "2" -ntex-service = "3.2.1" +ntex-util = "2.5" +ntex-service = "3.3" ntex-bytes = "0.1" ntex-codec = "0.6" ntex-router = "0.5" diff --git a/src/inflight.rs b/src/inflight.rs index 404daee..207cddf 100644 --- a/src/inflight.rs +++ b/src/inflight.rs @@ -2,7 +2,7 @@ use std::{cell::Cell, future::poll_fn, rc::Rc, task::Context, task::Poll}; use ntex_service::{Middleware, Service, ServiceCtx}; -use ntex_util::task::LocalWaker; +use ntex_util::{future::join, future::select, task::LocalWaker}; /// Trait for types that could be sized pub trait SizedRequest { @@ -73,15 +73,21 @@ where type Response = S::Response; type Error = S::Error; - ntex_service::forward_shutdown!(service); - #[inline] async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), S::Error> { - ctx.ready(&self.service).await?; + if !self.count.is_available() { + let (_, res) = join(self.count.available(), ctx.ready(&self.service)).await; + res + } else { + ctx.ready(&self.service).await + } + } - // check if we have capacity - self.count.available().await; - Ok(()) + #[inline] + async fn not_ready(&self) { + if self.count.is_available() { + select(self.count.unavailable(), self.service.not_ready()).await; + } } #[inline] @@ -92,6 +98,8 @@ where drop(task_guard); result } + + ntex_service::forward_shutdown!(service); } struct Counter(Rc); @@ -119,9 +127,18 @@ impl Counter { CounterGuard::new(size, self.0.clone()) } + fn is_available(&self) -> bool { + (self.0.max_cap == 0 || self.0.cur_cap.get() < self.0.max_cap) + && (self.0.max_size == 0 || self.0.cur_size.get() <= self.0.max_size) + } + async fn available(&self) { poll_fn(|cx| if self.0.available(cx) { Poll::Ready(()) } else { Poll::Pending }).await } + + async fn unavailable(&self) { + poll_fn(|cx| if self.0.available(cx) { Poll::Pending } else { Poll::Ready(()) }).await + } } struct CounterGuard(u32, Rc); @@ -143,8 +160,14 @@ impl Drop for CounterGuard { impl CounterInner { fn inc(&self, size: u32) { - self.cur_cap.set(self.cur_cap.get() + 1); - self.cur_size.set(self.cur_size.get() + size as usize); + let cur_cap = self.cur_cap.get() + 1; + self.cur_cap.set(cur_cap); + let cur_size = self.cur_size.get() + size as usize; + self.cur_size.set(cur_size); + + if cur_cap == self.max_cap || cur_size >= self.max_size { + self.task.wake(); + } } fn dec(&self, size: u32) { @@ -161,12 +184,12 @@ impl CounterInner { } fn available(&self, cx: &Context<'_>) -> bool { + self.task.register(cx.waker()); if (self.max_cap == 0 || self.cur_cap.get() < self.max_cap) && (self.max_size == 0 || self.cur_size.get() <= self.max_size) { true } else { - self.task.register(cx.waker()); false } } @@ -261,13 +284,14 @@ mod tests { .await } - async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> { + async fn call(&self, _: (), ctx: ServiceCtx<'_, Self>) -> Result<(), ()> { let fut = sleep(self.dur); self.cnt.set(true); self.waker.wake(); let _ = fut.await; self.cnt.set(false); + self.waker.wake(); Ok::<_, ()>(()) } } @@ -286,6 +310,7 @@ mod tests { })) .bind(); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); + assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending); let srv2 = srv.clone(); ntex_util::spawn(async move { @@ -300,6 +325,7 @@ mod tests { let _ = poll_fn(|cx| srv2.poll_ready(cx)).await; let _ = tx.send(()); }); + assert_eq!(poll_fn(|cx| srv.poll_ready(cx)).await, Ok(())); let _ = rx.await; } diff --git a/src/io.rs b/src/io.rs index 490f067..ef544e4 100644 --- a/src/io.rs +++ b/src/io.rs @@ -28,11 +28,12 @@ pin_project_lite::pin_project! { bitflags::bitflags! { #[derive(Copy, Clone, Eq, PartialEq, Debug)] struct Flags: u8 { - const READY_ERR = 0b00001; - const IO_ERR = 0b00010; - const KA_ENABLED = 0b00100; - const KA_TIMEOUT = 0b01000; - const READ_TIMEOUT = 0b10000; + const READY_ERR = 0b000001; + const IO_ERR = 0b000010; + const KA_ENABLED = 0b000100; + const KA_TIMEOUT = 0b001000; + const READ_TIMEOUT = 0b010000; + const READY = 0b100000; } } @@ -426,26 +427,39 @@ where } } + fn check_error(&mut self) -> PollService { + // check for errors + let mut state = self.state.borrow_mut(); + if let Some(err) = state.error.take() { + log::trace!("{}: Error occured, stopping dispatcher", self.io.tag()); + self.st = IoDispatcherState::Stop; + match err { + IoDispatcherError::Encoder(err) => { + PollService::Item(DispatchItem::EncoderError(err)) + } + IoDispatcherError::Service(err) => { + state.error = Some(IoDispatcherError::Service(err)); + PollService::Continue + } + } + } else { + PollService::Ready + } + } + fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll> { + // check service readiness + if self.flags.contains(Flags::READY) { + if self.service.poll_not_ready(cx).is_pending() { + return Poll::Ready(self.check_error()); + } + self.flags.remove(Flags::READY); + } + match self.service.poll_ready(cx) { Poll::Ready(Ok(_)) => { - // check for errors - let mut state = self.state.borrow_mut(); - Poll::Ready(if let Some(err) = state.error.take() { - log::trace!("{}: Error occured, stopping dispatcher", self.io.tag()); - self.st = IoDispatcherState::Stop; - match err { - IoDispatcherError::Encoder(err) => { - PollService::Item(DispatchItem::EncoderError(err)) - } - IoDispatcherError::Service(err) => { - state.error = Some(IoDispatcherError::Service(err)); - PollService::Continue - } - } - } else { - PollService::Ready - }) + self.flags.insert(Flags::READY); + Poll::Ready(self.check_error()) } // pause io read task Poll::Pending => { diff --git a/src/server.rs b/src/server.rs index 4c8d9a3..6f27ca3 100644 --- a/src/server.rs +++ b/src/server.rs @@ -233,6 +233,11 @@ where ready2 } + #[inline] + async fn not_ready(&self) { + select(self.handlers.0.not_ready(), self.handlers.1.not_ready()).await; + } + #[inline] async fn shutdown(&self) { self.handlers.0.shutdown().await; @@ -296,6 +301,11 @@ where Service::::ready(self, ctx).await } + #[inline] + async fn not_ready(&self) { + Service::::not_ready(self).await + } + #[inline] async fn shutdown(&self) { Service::::shutdown(self).await diff --git a/src/v3/client/dispatcher.rs b/src/v3/client/dispatcher.rs index b5701a7..2e9ee6b 100644 --- a/src/v3/client/dispatcher.rs +++ b/src/v3/client/dispatcher.rs @@ -2,7 +2,7 @@ use std::{cell::RefCell, marker::PhantomData, num::NonZeroU16, rc::Rc}; use ntex_io::DispatchItem; use ntex_service::{Pipeline, Service, ServiceCtx}; -use ntex_util::future::{join, Either}; +use ntex_util::future::{join, select, Either}; use ntex_util::{services::inflight::InFlightService, HashSet}; use crate::error::{HandshakeError, MqttError, ProtocolError}; @@ -86,6 +86,11 @@ where } } + #[inline] + async fn not_ready(&self) { + select(self.publish.not_ready(), self.inner.control.not_ready()).await; + } + async fn shutdown(&self) { self.inner.sink.close(); let _ = Pipeline::new(&self.inner.control).call(Control::closed()).await; diff --git a/src/v3/dispatcher.rs b/src/v3/dispatcher.rs index 7e41673..a69c392 100644 --- a/src/v3/dispatcher.rs +++ b/src/v3/dispatcher.rs @@ -4,7 +4,7 @@ use ntex_io::DispatchItem; use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory}; use ntex_util::services::buffer::{BufferService, BufferServiceError}; use ntex_util::services::inflight::InFlightService; -use ntex_util::{future::join, HashSet}; +use ntex_util::{future::join, future::select, HashSet}; use crate::error::{HandshakeError, MqttError, ProtocolError}; use crate::types::QoS; @@ -145,6 +145,11 @@ where } } + #[inline] + async fn not_ready(&self) { + select(self.publish.not_ready(), self.inner.control.not_ready()).await; + } + async fn shutdown(&self) { self.inner.sink.close(); let _ = Pipeline::new(&self.inner.control).call(Control::closed()).await; diff --git a/src/v3/router.rs b/src/v3/router.rs index 40a504e..7200240 100644 --- a/src/v3/router.rs +++ b/src/v3/router.rs @@ -1,4 +1,4 @@ -use std::rc::Rc; +use std::{future::poll_fn, future::Future, pin::Pin, rc::Rc, task::Poll}; use ntex_router::{IntoPattern, RouterBuilder}; use ntex_service::boxed::{self, BoxService, BoxServiceFactory}; @@ -116,6 +116,25 @@ impl Service for RouterService { ctx.ready(&self.default).await } + #[inline] + async fn not_ready(&self) { + let mut futs = Vec::with_capacity(self.handlers.len() + 1); + for hnd in &self.handlers { + futs.push(Box::pin(hnd.not_ready())); + } + futs.push(Box::pin(self.default.not_ready())); + + poll_fn(|cx| { + for hnd in &mut futs { + if Pin::new(hnd).poll(cx).is_ready() { + return Poll::Ready(()); + } + } + Poll::Pending + }) + .await; + } + async fn call( &self, mut req: Publish, diff --git a/src/v5/client/dispatcher.rs b/src/v5/client/dispatcher.rs index dc7498c..3104015 100644 --- a/src/v5/client/dispatcher.rs +++ b/src/v5/client/dispatcher.rs @@ -3,7 +3,7 @@ use std::{cell::RefCell, marker::PhantomData, num::NonZeroU16, rc::Rc}; use ntex_bytes::ByteString; use ntex_io::DispatchItem; use ntex_service::{Pipeline, Service, ServiceCtx}; -use ntex_util::{future::join, future::Either, HashMap, HashSet}; +use ntex_util::{future::join, future::select, future::Either, HashMap, HashSet}; use crate::error::{HandshakeError, MqttError, ProtocolError}; use crate::types::packet_type; @@ -114,6 +114,11 @@ where } } + #[inline] + async fn not_ready(&self) { + select(self.publish.not_ready(), self.inner.control.not_ready()).await; + } + async fn shutdown(&self) { self.inner.sink.drop_sink(); let _ = Pipeline::new(&self.inner.control).call(Control::closed()).await; diff --git a/src/v5/dispatcher.rs b/src/v5/dispatcher.rs index e6cd4a0..bd65945 100644 --- a/src/v5/dispatcher.rs +++ b/src/v5/dispatcher.rs @@ -5,7 +5,7 @@ use ntex_io::DispatchItem; use ntex_service::{self as service, Pipeline, Service, ServiceCtx, ServiceFactory}; use ntex_util::services::inflight::InFlightService; use ntex_util::services::{buffer::BufferService, buffer::BufferServiceError}; -use ntex_util::{future::join, HashMap, HashSet}; +use ntex_util::{future::join, future::select, HashMap, HashSet}; use crate::error::{HandshakeError, MqttError, ProtocolError}; use crate::types::QoS; @@ -155,6 +155,11 @@ where } } + #[inline] + async fn not_ready(&self) { + select(self.publish.not_ready(), self.inner.control.not_ready()).await; + } + async fn shutdown(&self) { self.inner.sink.drop_sink(); let _ = Pipeline::new(&self.inner.control).call(Control::closed()).await; diff --git a/src/v5/router.rs b/src/v5/router.rs index 7aab22d..3dd4e30 100644 --- a/src/v5/router.rs +++ b/src/v5/router.rs @@ -1,4 +1,5 @@ -use std::{cell::RefCell, num::NonZeroU16, rc::Rc}; +use std::future::{poll_fn, Future}; +use std::{cell::RefCell, num::NonZeroU16, pin::Pin, rc::Rc, task::Poll}; use ntex_bytes::ByteString; use ntex_router::{IntoPattern, Path, RouterBuilder}; @@ -130,6 +131,25 @@ impl Service for RouterService { ctx.ready(&self.default).await } + #[inline] + async fn not_ready(&self) { + let mut futs = Vec::with_capacity(self.handlers.len() + 1); + for hnd in &self.handlers { + futs.push(Box::pin(hnd.not_ready())); + } + futs.push(Box::pin(self.default.not_ready())); + + poll_fn(|cx| { + for hnd in &mut futs { + if Pin::new(hnd).poll(cx).is_ready() { + return Poll::Ready(()); + } + } + Poll::Pending + }) + .await; + } + #[allow(clippy::await_holding_refcell_ref)] async fn call( &self,