Skip to content

Commit

Permalink
Run un-readiness check in separate task (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 authored Nov 10, 2024
1 parent e5e2906 commit 15008af
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 56 deletions.
6 changes: 6 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changes

## [4.4.0] - 2024-11-10

* Check service readiness once per decoded item

* Run un-readiness check in separate task

## [4.3.1] - 2024-11-05

* Do not rely on not_ready(), always check service readiness
Expand Down
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ntex-mqtt"
version = "4.3.1"
version = "4.4.0"
authors = ["ntex contributors <[email protected]>"]
description = "Client and Server framework for MQTT v5 and v3.1.1 protocols"
documentation = "https://docs.rs/ntex-mqtt"
Expand All @@ -18,10 +18,11 @@ features = ["ntex/tokio"]
ntex-io = "2"
ntex-net = "2"
ntex-util = "2.5"
ntex-service = "3.3"
ntex-service = "3.3.3"
ntex-bytes = "0.1"
ntex-codec = "0.6"
ntex-router = "0.5"
ntex-rt = "0.4"
bitflags = "2"
log = "0.4"
pin-project-lite = "0.2"
Expand Down
153 changes: 99 additions & 54 deletions src/io.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
//! Framed transport dispatcher
use std::future::{poll_fn, Future};
use std::task::{ready, Context, Poll};
use std::{cell::RefCell, collections::VecDeque, future::Future, pin::Pin, rc::Rc};
use std::{cell::Cell, cell::RefCell, collections::VecDeque, pin::Pin, rc::Rc};

use ntex_codec::{Decoder, Encoder};
use ntex_io::{
Decoded, DispatchItem, DispatcherConfig, IoBoxed, IoRef, IoStatusUpdate, RecvError,
};
use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
use ntex_util::time::Seconds;
use ntex_util::{task::LocalWaker, time::Seconds};

type Response<U> = <U as Encoder>::Item;

Expand All @@ -28,12 +29,13 @@ pin_project_lite::pin_project! {
bitflags::bitflags! {
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
struct Flags: u8 {
const READY_ERR = 0b000001;
const IO_ERR = 0b000010;
const KA_ENABLED = 0b000100;
const KA_TIMEOUT = 0b001000;
const READ_TIMEOUT = 0b010000;
const READY = 0b100000;
const READY_ERR = 0b0000001;
const IO_ERR = 0b0000010;
const KA_ENABLED = 0b0000100;
const KA_TIMEOUT = 0b0001000;
const READ_TIMEOUT = 0b0010000;
const READY = 0b0100000;
const READY_TASK = 0b1000000;
}
}

Expand All @@ -43,7 +45,7 @@ struct DispatcherInner<S: Service<DispatchItem<U>>, U: Encoder + Decoder + 'stat
codec: U,
service: PipelineBinding<S, DispatchItem<U>>,
st: IoDispatcherState,
state: Rc<RefCell<DispatcherState<S, U>>>,
state: Rc<DispatcherState<S, U>>,
config: DispatcherConfig,
read_remains: u32,
read_remains_prev: u32,
Expand All @@ -55,9 +57,11 @@ struct DispatcherInner<S: Service<DispatchItem<U>>, U: Encoder + Decoder + 'stat
}

struct DispatcherState<S: Service<DispatchItem<U>>, U: Encoder + Decoder> {
error: Option<IoDispatcherError<S::Error, <U as Encoder>::Error>>,
base: usize,
queue: VecDeque<ServiceResult<Result<S::Response, S::Error>>>,
error: Cell<Option<IoDispatcherError<S::Error, <U as Encoder>::Error>>>,
base: Cell<usize>,
ready: Cell<bool>,
queue: RefCell<VecDeque<ServiceResult<Result<S::Response, S::Error>>>>,
waker: LocalWaker,
}

enum ServiceResult<T> {
Expand Down Expand Up @@ -116,11 +120,13 @@ where
// register keepalive timer
io.set_disconnect_timeout(config.disconnect_timeout());

let state = Rc::new(RefCell::new(DispatcherState {
error: None,
base: 0,
queue: VecDeque::new(),
}));
let state = Rc::new(DispatcherState {
error: Cell::new(None),
base: Cell::new(0),
ready: Cell::new(false),
queue: RefCell::new(VecDeque::new()),
waker: LocalWaker::default(),
});
let keepalive_timeout = config.keepalive_timeout();

Dispatcher {
Expand Down Expand Up @@ -169,53 +175,54 @@ where
<U as Encoder>::Item: 'static,
{
fn handle_result(
&mut self,
&self,
item: Result<S::Response, S::Error>,
response_idx: usize,
io: &IoRef,
codec: &U,
wake: bool,
) {
let idx = response_idx.wrapping_sub(self.base);
let mut queue = self.queue.borrow_mut();
let idx = response_idx.wrapping_sub(self.base.get());

// handle first response
if idx == 0 {
let _ = self.queue.pop_front();
self.base = self.base.wrapping_add(1);
let _ = queue.pop_front();
self.base.set(self.base.get().wrapping_add(1));
match item {
Err(err) => {
self.error = Some(err.into());
self.error.set(Some(err.into()));
}
Ok(Some(item)) => {
if let Err(err) = io.encode(item, codec) {
self.error = Some(IoDispatcherError::Encoder(err));
self.error.set(Some(IoDispatcherError::Encoder(err)));
}
}
Ok(None) => (),
}

// check remaining response
while let Some(item) = self.queue.front_mut().and_then(|v| v.take()) {
let _ = self.queue.pop_front();
self.base = self.base.wrapping_add(1);
while let Some(item) = queue.front_mut().and_then(|v| v.take()) {
let _ = queue.pop_front();
self.base.set(self.base.get().wrapping_add(1));
match item {
Err(err) => {
self.error = Some(err.into());
self.error.set(Some(err.into()));
}
Ok(Some(item)) => {
if let Err(err) = io.encode(item, codec) {
self.error = Some(IoDispatcherError::Encoder(err));
self.error.set(Some(IoDispatcherError::Encoder(err)));
}
}
Ok(None) => (),
}
}

if wake && self.queue.is_empty() {
if wake && queue.is_empty() {
io.wake()
}
} else {
self.queue[idx] = ServiceResult::Ready(item);
queue[idx] = ServiceResult::Ready(item);
}
}
}
Expand All @@ -232,10 +239,12 @@ where
let mut this = self.as_mut().project();
let inner = &mut this.inner;

inner.state.waker.register(cx.waker());

// handle service response future
if let Some(fut) = inner.response.as_mut() {
if let Poll::Ready(item) = Pin::new(fut).poll(cx) {
inner.state.borrow_mut().handle_result(
inner.state.handle_result(
item,
inner.response_idx,
inner.io.as_ref(),
Expand All @@ -246,6 +255,12 @@ where
}
}

// start ready task
if inner.flags.contains(Flags::READY_TASK) {
inner.flags.insert(Flags::READY_TASK);
ntex_rt::spawn(not_ready(inner.state.clone(), inner.service.clone()));
}

loop {
match inner.st {
IoDispatcherState::Processing => {
Expand Down Expand Up @@ -295,6 +310,7 @@ where
PollService::Continue => continue,
};

inner.state.ready.set(false);
inner.call_service(cx, item);
}
// handle write back-pressure
Expand Down Expand Up @@ -328,7 +344,7 @@ where
}
}

if inner.state.borrow().queue.is_empty() {
if inner.state.queue.borrow().is_empty() {
if inner.io.poll_shutdown(cx).is_ready() {
log::trace!("{}: io shutdown completed", inner.io.tag());
inner.st = IoDispatcherState::Shutdown;
Expand Down Expand Up @@ -361,7 +377,7 @@ where

Poll::Ready(
if let Some(IoDispatcherError::Service(err)) =
inner.state.borrow_mut().error.take()
inner.state.error.take()
{
Err(err)
} else {
Expand All @@ -384,61 +400,60 @@ where
<U as Encoder>::Item: 'static,
{
fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>) {
let mut state = self.state.borrow_mut();
let mut fut = self.service.call_nowait(item);
let mut queue = self.state.queue.borrow_mut();

// optimize first call
if self.response.is_none() {
if let Poll::Ready(res) = Pin::new(&mut fut).poll(cx) {
// check if current result is only response
if state.queue.is_empty() {
if queue.is_empty() {
match res {
Err(err) => {
state.error = Some(err.into());
self.state.error.set(Some(err.into()));
}
Ok(Some(item)) => {
if let Err(err) = self.io.encode(item, &self.codec) {
state.error = Some(IoDispatcherError::Encoder(err));
self.state.error.set(Some(IoDispatcherError::Encoder(err)));
}
}
Ok(None) => (),
}
} else {
self.response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Ready(res));
queue.push_back(ServiceResult::Ready(res));
self.response_idx = self.state.base.get().wrapping_add(queue.len());
}
} else {
self.response = Some(fut);
self.response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Pending);
self.response_idx = self.state.base.get().wrapping_add(queue.len());
queue.push_back(ServiceResult::Pending);
}
} else {
let response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Pending);
let response_idx = self.state.base.get().wrapping_add(queue.len());
queue.push_back(ServiceResult::Pending);

let st = self.io.get_ref();
let codec = self.codec.clone();
let state = self.state.clone();

ntex_util::spawn(async move {
let item = fut.await;
state.borrow_mut().handle_result(item, response_idx, &st, &codec, true);
state.handle_result(item, response_idx, &st, &codec, true);
});
}
}

fn check_error(&mut self) -> PollService<U> {
// check for errors
let mut state = self.state.borrow_mut();
if let Some(err) = state.error.take() {
if let Some(err) = self.state.error.take() {
log::trace!("{}: Error occured, stopping dispatcher", self.io.tag());
self.st = IoDispatcherState::Stop;
match err {
IoDispatcherError::Encoder(err) => {
PollService::Item(DispatchItem::EncoderError(err))
}
IoDispatcherError::Service(err) => {
state.error = Some(IoDispatcherError::Service(err));
self.state.error.set(Some(IoDispatcherError::Service(err)));
PollService::Continue
}
}
Expand All @@ -448,9 +463,13 @@ where
}

fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
if self.state.ready.get() {
return Poll::Ready(self.check_error());
}

match self.service.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
let _ = self.service.poll_not_ready(cx);
self.state.ready.set(true);
Poll::Ready(self.check_error())
}
// pause io read task
Expand Down Expand Up @@ -498,7 +517,7 @@ where
log::error!("{}: Service readiness check failed, stopping", self.io.tag());
self.st = IoDispatcherState::Stop;
self.flags.insert(Flags::READY_ERR);
self.state.borrow_mut().error = Some(IoDispatcherError::Service(err));
self.state.error.set(Some(IoDispatcherError::Service(err)));
Poll::Ready(PollService::Item(DispatchItem::Disconnect(None)))
}
}
Expand Down Expand Up @@ -576,6 +595,30 @@ where
}
}

async fn not_ready<S, U>(
slf: Rc<DispatcherState<S, U>>,
pl: PipelineBinding<S, DispatchItem<U>>,
) where
S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Encoder + Decoder + 'static,
{
loop {
if !pl.is_shutdown() {
if let Err(err) = poll_fn(|cx| pl.poll_ready(cx)).await {
slf.error.set(Some(IoDispatcherError::Service(err)));
break;
}
if !pl.is_shutdown() {
poll_fn(|cx| pl.poll_not_ready(cx)).await;
slf.ready.set(false);
slf.waker.wake();
continue;
}
}
break;
}
}

#[cfg(test)]
mod tests {
use std::{cell::Cell, io, sync::Arc, sync::Mutex};
Expand Down Expand Up @@ -616,11 +659,13 @@ mod tests {
let keepalive_timeout = config.keepalive_timeout();
let rio = io.get_ref();

let state = Rc::new(RefCell::new(DispatcherState {
error: None,
base: 0,
queue: VecDeque::new(),
}));
let state = Rc::new(DispatcherState {
error: Cell::new(None),
base: Cell::new(0),
ready: Cell::new(false),
waker: LocalWaker::default(),
queue: RefCell::new(VecDeque::new()),
});

(
Dispatcher {
Expand Down

0 comments on commit 15008af

Please sign in to comment.