From 5d44cb024606dbc3301ed7045f32af27df2bab8a Mon Sep 17 00:00:00 2001 From: Mateusz Kwapich Date: Thu, 21 Nov 2024 11:23:11 -0800 Subject: [PATCH] introduce TryStreamStats Summary: For try streams we want to also count the errors. Reviewed By: andreacampi Differential Revision: D66248375 fbshipit-source-id: e52719cc80f0f9c761b2badc8a9b0601f5c8f104 --- shed/futures_stats/src/futures03.rs | 176 ++++++++++++++++++++++++++++ shed/futures_stats/src/lib.rs | 14 +++ 2 files changed, 190 insertions(+) diff --git a/shed/futures_stats/src/futures03.rs b/shed/futures_stats/src/futures03.rs index 7f14d36b5..a84fbfc06 100644 --- a/shed/futures_stats/src/futures03.rs +++ b/shed/futures_stats/src/futures03.rs @@ -18,10 +18,12 @@ use futures::future::TryFuture; use futures::stream::Stream; use futures::task::Context; use futures::task::Poll; +use futures::TryStream; use futures_ext::future::CancelData; use super::FutureStats; use super::StreamStats; +use crate::TryStreamStats; /// A Future that gathers some basic statistics for inner Future. /// This structure's main usage is by calling [TimedFutureExt::timed]. @@ -234,6 +236,88 @@ where } } +/// A Stream that gathers some basic statistics for inner TryStream. +/// This structure's main usage is by calling [TimedTryStreamExt::try_timed]. +pub struct TimedTryStream +where + S: TryStream + Sized, + C: FnOnce(TryStreamStats), +{ + callback: Option, + inner: TimedStream ()>, + error_count: usize, + first_error_position: Option, +} +impl TimedTryStream +where + S: TryStream, + C: FnOnce(TryStreamStats), +{ + fn new(stream: S, callback: C) -> Self { + TimedTryStream { + callback: Some(callback), + inner: TimedStream::new(stream, None), + error_count: 0, + first_error_position: None, + } + } + + fn gen_stats(&self) -> TryStreamStats { + TryStreamStats { + stream_stats: self.inner.gen_stats(), + error_count: self.error_count, + first_error_position: self.first_error_position, + } + } + + fn run_callback(&mut self) { + if let Some(callback) = self.callback.take() { + let stats = self.gen_stats(); + callback(stats) + } + } +} + +impl Stream for TimedTryStream +where + S: Stream>, + C: FnOnce(TryStreamStats), +{ + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = unsafe { self.get_unchecked_mut() }; + + let poll = unsafe { Pin::new_unchecked(&mut this.inner).poll_next(cx) }; + match poll { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(item)) => { + if item.is_err() { + this.error_count += 1; + if this.first_error_position.is_none() { + this.first_error_position = Some(this.inner.count - 1) + } + } + Poll::Ready(Some(item)) + } + Poll::Ready(None) => { + this.run_callback(); + Poll::Ready(None) + } + } + } +} + +impl Drop for TimedTryStream +where + S: TryStream, + C: FnOnce(TryStreamStats), +{ + fn drop(&mut self) { + self.run_callback(); + } +} + /// A trait that provides the `timed` method to [futures::Future] for gathering stats pub trait TimedFutureExt: Future + Sized { /// Combinator that returns a future that will gather some statistics and @@ -316,6 +400,42 @@ pub trait TimedStreamExt: Stream + Sized { impl TimedStreamExt for T {} +/// A trait that provides the `try_timed` method to [futures::TryStream] for gathering stats +pub trait TimedTryStreamExt: TryStream + Sized { + /// Combinator that returns a stream that will gather some statistics and + /// pass them for inspection to the provided callback when the stream + /// completes. + /// + /// Comparered to [TimedStreamExt::timed], this method collects the stats + /// about errors encountered in the stream. + /// + /// # Examples + /// + /// ``` + /// use futures::stream::TryStreamExt; + /// use futures::stream::{self}; + /// use futures_stats::TimedTryStreamExt; + /// + /// # futures::executor::block_on(async { + /// let out = stream::iter([Ok(1), Ok(2), Err(3)]) + /// .try_timed(|stats| { + /// assert_eq!(stats.error_count, 1); + /// }) + /// .try_collect::>() + /// .await; + /// assert!(out.is_err()); + /// # }); + /// ``` + fn try_timed(self, callback: C) -> TimedTryStream + where + C: FnOnce(TryStreamStats), + { + TimedTryStream::new(self, callback) + } +} + +impl TimedTryStreamExt for T {} + #[cfg(test)] mod tests { use std::sync::atomic::AtomicBool; @@ -326,6 +446,7 @@ mod tests { use futures::stream; use futures::stream::StreamExt; + use futures::TryStreamExt; use futures_ext::FbFutureExt; use super::*; @@ -407,4 +528,59 @@ mod tests { drop(s); assert!(callback_called.load(Ordering::SeqCst)); } + + #[tokio::test] + async fn test_try_timed_stream() { + let callback_called = Arc::new(AtomicBool::new(false)); + let out = stream::iter([ + Ok(0), + Err("Integer overflow".to_owned()), + Ok(1), + Ok(2), + Err("Rounding error".to_owned()), + Err("Unit conversion error".to_owned()), + ]) + .try_timed({ + let callback_called = callback_called.clone(); + move |stats: TryStreamStats| { + assert_eq!(stats.stream_stats.count, 6); + assert_eq!(stats.error_count, 3); + assert_eq!(stats.first_error_position, Some(1)); + assert!(stats.stream_stats.completed); + callback_called.store(true, Ordering::SeqCst); + } + }) + .collect::>>() + .await; + assert_eq!(out[2], Ok(1)); + assert!(callback_called.load(Ordering::SeqCst)); + } + + #[tokio::test] + async fn test_cancel_try_timed_stream() { + let callback_called = Arc::new(AtomicBool::new(false)); + let out = stream::iter([ + Ok(0), + Err("Integer overflow".to_owned()), + Ok(1), + Ok(2), + Err("Rounding error".to_owned()), + Err("Unit conversion error".to_owned()), + ]) + .try_timed({ + let callback_called = callback_called.clone(); + move |stats: TryStreamStats| { + assert_eq!(stats.stream_stats.count, 2); + assert_eq!(stats.error_count, 1); + assert_eq!(stats.first_error_position, Some(1)); + assert!(!stats.stream_stats.completed); + callback_called.store(true, Ordering::SeqCst); + } + }) + // Try collect will drop the stream after first failure + .try_collect::>() + .await; + assert!(out.is_err()); + assert!(callback_called.load(Ordering::SeqCst)); + } } diff --git a/shed/futures_stats/src/lib.rs b/shed/futures_stats/src/lib.rs index 42ac62321..e24bf986b 100644 --- a/shed/futures_stats/src/lib.rs +++ b/shed/futures_stats/src/lib.rs @@ -23,6 +23,7 @@ pub mod futures03; pub use futures03::TimedFutureExt; pub use futures03::TimedStreamExt; pub use futures03::TimedTryFutureExt; +pub use futures03::TimedTryStreamExt; /// A structure that holds some basic statistics for Future. #[derive(Clone, Debug)] @@ -74,3 +75,16 @@ pub struct StreamStats { /// Whether the stream was polled to completion. pub completed: bool, } + +/// A structure that holds some basic statistics for Stream. +#[derive(Clone, Debug)] +pub struct TryStreamStats { + /// All the stats that are not try-stream specific. + pub stream_stats: StreamStats, + + /// Number of errors in the stream. + pub error_count: usize, + + /// Number of elements in the stream that were emitted before first error + pub first_error_position: Option, +}