diff --git a/examples/spawn.rs b/examples/spawn.rs new file mode 100644 index 0000000..2a6d49d --- /dev/null +++ b/examples/spawn.rs @@ -0,0 +1,58 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! This example shows how to spawn tasks with `await_tree::spawn` that are automatically registered +//! to the current registry of the scope. + +use std::time::Duration; + +use await_tree::{Config, InstrumentAwait, Registry}; +use futures::future::pending; +use tokio::time::sleep; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct Actor(usize); + +async fn actor(i: usize) { + // Since we're already inside the scope of a registered/instrumented task, we can directly spawn + // new tasks with `await_tree::spawn` to also register them in the same registry. + await_tree::spawn_anonymous(format!("background task {i}"), async { + pending::<()>().await; + }) + .instrument_await("waiting for background task") + .await + .unwrap(); +} + +#[tokio::main] +async fn main() { + let registry = Registry::new(Config::default()); + + for i in 0..3 { + let root = registry.register(Actor(i), format!("actor {i}")); + tokio::spawn(root.instrument(actor(i))); + } + + sleep(Duration::from_secs(1)).await; + + for (_actor, tree) in registry.collect::() { + // actor 0 [1.004s] + // waiting for background task [1.004s] + println!("{tree}"); + } + for tree in registry.collect_anonymous() { + // background task 0 [1.004s] + println!("{tree}"); + } +} diff --git a/src/context.rs b/src/context.rs index d1c8efb..695143b 100644 --- a/src/context.rs +++ b/src/context.rs @@ -14,12 +14,12 @@ use std::fmt::{Debug, Write}; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; use indextree::{Arena, NodeId}; use itertools::Itertools; use parking_lot::{Mutex, MutexGuard}; +use crate::root::current_context; use crate::Span; /// Node in the span tree. @@ -42,11 +42,13 @@ impl SpanNode { } } -/// The id of an await-tree context. We will check the id recorded in the instrumented future -/// against the current task-local context before trying to update the tree. - -// Also used as the key for anonymous trees in the registry. -// Intentionally made private to prevent users from reusing the same id when registering a new tree. +/// The id of an await-tree context. +/// +/// We will check the id recorded in the instrumented future against the current task-local context +/// before trying to update the tree. +/// +/// Also used as the key for anonymous trees in the registry. Intentionally made private to prevent +/// users from reusing the same id when registering a new tree. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub(crate) struct ContextId(u64); @@ -253,17 +255,9 @@ impl TreeContext { } } -tokio::task_local! { - pub(crate) static CONTEXT: Arc -} - -pub(crate) fn context() -> Option> { - CONTEXT.try_with(Arc::clone).ok() -} - /// Get the await-tree of current task. Returns `None` if we're not instrumented. /// /// This is useful if you want to check which component or runtime task is calling this function. pub fn current_tree() -> Option { - context().map(|c| c.tree().clone()) + current_context().map(|c| c.tree().clone()) } diff --git a/src/future.rs b/src/future.rs index 940b063..f165843 100644 --- a/src/future.rs +++ b/src/future.rs @@ -19,7 +19,8 @@ use std::task::Poll; use indextree::NodeId; use pin_project::{pin_project, pinned_drop}; -use crate::context::{context, ContextId}; +use crate::context::ContextId; +use crate::root::current_context; use crate::Span; enum State { @@ -57,7 +58,7 @@ impl Future for Instrumented { fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { let this = self.project(); - let context = context(); + let context = current_context(); let (context, this_node) = match this.state { State::Initial(span) => { @@ -140,7 +141,7 @@ impl PinnedDrop for Instrumented { State::Polled { this_node, this_context_id, - } => match context() { + } => match current_context() { // Context correct Some(c) if c.id() == *this_context_id => { c.tree().remove_and_detach(*this_node); diff --git a/src/lib.rs b/src/lib.rs index 7ea9ef0..81362a4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,11 +22,15 @@ mod context; mod future; mod obj_utils; mod registry; +mod root; +mod spawn; pub use context::current_tree; use flexstr::SharedStr; pub use future::Instrumented; -pub use registry::{AnyKey, Config, ConfigBuilder, ConfigBuilderError, Key, Registry, TreeRoot}; +pub use registry::{AnyKey, Config, ConfigBuilder, ConfigBuilderError, Key, Registry}; +pub use root::TreeRoot; +pub use spawn::{spawn, spawn_anonymous}; /// A cheaply cloneable span in the await-tree. #[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord)] diff --git a/src/registry.rs b/src/registry.rs index 9c0b02f..aa48f78 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -14,7 +14,6 @@ use std::any::Any; use std::fmt::Debug; -use std::future::Future; use std::hash::Hash; use std::sync::{Arc, Weak}; @@ -22,9 +21,9 @@ use derive_builder::Builder; use parking_lot::RwLock; use weak_table::WeakValueHashMap; -use crate::context::{ContextId, Tree, TreeContext, CONTEXT}; +use crate::context::{ContextId, Tree, TreeContext}; use crate::obj_utils::{DynEq, DynHash}; -use crate::Span; +use crate::{Span, TreeRoot}; /// Configuration for an await-tree registry, which affects the behavior of all await-trees in the /// registry. @@ -42,20 +41,6 @@ impl Default for Config { } } -/// The root of an await-tree. -pub struct TreeRoot { - context: Arc, - #[allow(dead_code)] - registry: Weak, -} - -impl TreeRoot { - /// Instrument the given future with the context of this tree root. - pub async fn instrument(self, future: F) -> F::Output { - CONTEXT.scope(self.context, future).await - } -} - /// A key that can be used to identify a task and its await-tree in the [`Registry`]. /// /// All thread-safe types that can be used as a key of a hash map are automatically implemented with @@ -103,7 +88,6 @@ impl AnyKey { type Contexts = RwLock>>; -#[derive(Debug)] struct RegistryCore { contexts: Contexts, config: Config, @@ -112,9 +96,16 @@ struct RegistryCore { /// The registry of multiple await-trees. /// /// Can be cheaply cloned to share the same registry. -#[derive(Debug)] pub struct Registry(Arc); +impl Debug for Registry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Registry") + .field("config", self.config()) + .finish_non_exhaustive() + } +} + impl Clone for Registry { fn clone(&self) -> Self { Self(Arc::clone(&self.0)) @@ -150,7 +141,7 @@ impl Registry { TreeRoot { context, - registry: Arc::downgrade(&self.0), + registry: WeakRegistry(Arc::downgrade(&self.0)), } } @@ -227,6 +218,14 @@ impl Registry { } } +pub(crate) struct WeakRegistry(Weak); + +impl WeakRegistry { + pub fn upgrade(&self) -> Option { + self.0.upgrade().map(Registry) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/root.rs b/src/root.rs new file mode 100644 index 0000000..0d12aed --- /dev/null +++ b/src/root.rs @@ -0,0 +1,45 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::future::Future; +use std::sync::Arc; + +use crate::context::TreeContext; +use crate::registry::WeakRegistry; +use crate::Registry; + +/// The root of an await-tree. +pub struct TreeRoot { + pub(crate) context: Arc, + pub(crate) registry: WeakRegistry, +} + +tokio::task_local! { + pub(crate) static ROOT: TreeRoot +} + +pub(crate) fn current_context() -> Option> { + ROOT.try_with(|r| r.context.clone()).ok() +} + +pub(crate) fn current_registry() -> Option { + ROOT.try_with(|r| r.registry.upgrade()).ok().flatten() +} + +impl TreeRoot { + /// Instrument the given future with the context of this tree root. + pub async fn instrument(self, future: F) -> F::Output { + ROOT.scope(self, future).await + } +} diff --git a/src/spawn.rs b/src/spawn.rs new file mode 100644 index 0000000..4e41f26 --- /dev/null +++ b/src/spawn.rs @@ -0,0 +1,59 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TODO: should we consider exposing `current_registry` +// so that users can not only spawn tasks but also get and collect trees? + +// TODO: should we support "global registry" for users to quick start? + +use std::future::Future; + +use tokio::task::JoinHandle; + +use crate::root::current_registry; +use crate::{Key, Span}; + +/// Spawns a new asynchronous task instrumented with the given root [`Span`], returning a +/// [`JoinHandle`] for it. +/// +/// The spawned task will be registered in the current [`Registry`](crate::Registry) with the given +/// [`Key`], if it exists. Otherwise, this is equivalent to [`tokio::spawn`]. +pub fn spawn(key: impl Key, root_span: impl Into, future: T) -> JoinHandle +where + T: Future + Send + 'static, + T::Output: Send + 'static, +{ + if let Some(registry) = current_registry() { + tokio::spawn(registry.register(key, root_span).instrument(future)) + } else { + tokio::spawn(future) + } +} + +/// Spawns a new asynchronous task instrumented with the given root [`Span`], returning a +/// [`JoinHandle`] for it. +/// +/// The spawned task will be registered in the current [`Registry`](crate::Registry) anonymously, if +/// it exists. Otherwise, this is equivalent to [`tokio::spawn`]. +pub fn spawn_anonymous(root_span: impl Into, future: T) -> JoinHandle +where + T: Future + Send + 'static, + T::Output: Send + 'static, +{ + if let Some(registry) = current_registry() { + tokio::spawn(registry.register_anonymous(root_span).instrument(future)) + } else { + tokio::spawn(future) + } +} diff --git a/src/tests.rs b/src/tests.rs index cc99260..c33fd17 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -12,181 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use futures::future::{join_all, poll_fn, select_all}; -use futures::{pin_mut, FutureExt, Stream, StreamExt}; -use itertools::Itertools; +#![cfg(test)] -use crate::context::context; -use crate::{Config, InstrumentAwait, Registry}; - -async fn sleep(time: u64) { - tokio::time::sleep(std::time::Duration::from_millis(time)).await; - println!("slept {time}ms"); -} - -async fn sleep_nested() { - join_all([ - sleep(1500).instrument_await("sleep nested 1500"), - sleep(2500).instrument_await("sleep nested 2500"), - ]) - .await; -} - -async fn multi_sleep() { - sleep(400).await; - - sleep(800) - .instrument_await("sleep another in multi sleep") - .await; -} - -fn stream1() -> impl Stream { - use futures::stream::{iter, once}; - - iter(std::iter::repeat_with(|| { - once(async { - sleep(150).await; - }) - })) - .flatten() -} - -fn stream2() -> impl Stream { - use futures::stream::{iter, once}; - - iter([ - once(async { - sleep(444).await; - }) - .boxed(), - once(async { - join_all([ - sleep(400).instrument_await("sleep nested 400"), - sleep(600).instrument_await("sleep nested 600"), - ]) - .await; - }) - .boxed(), - ]) - .flatten() -} - -async fn hello() { - async move { - // Join - join_all([ - sleep(1000) - .boxed() - .instrument_await(format!("sleep {}", 1000)), - sleep(2000).boxed().instrument_await("sleep 2000"), - sleep_nested().boxed().instrument_await("sleep nested"), - multi_sleep().boxed().instrument_await("multi sleep"), - ]) - .await; - - // Join another - join_all([ - sleep(1200).instrument_await("sleep 1200"), - sleep(2200).instrument_await("sleep 2200"), - ]) - .await; - - // Cancel - select_all([ - sleep(666).boxed().instrument_await("sleep 666"), - sleep_nested() - .boxed() - .instrument_await("sleep nested (should be cancelled)"), - ]) - .await; - - // Check whether cleaned up - sleep(233).instrument_await("sleep 233").await; - - // Check stream next drop - { - let mut stream1 = stream1().fuse().boxed(); - let mut stream2 = stream2().fuse().boxed(); - let mut count = 0; - - 'outer: loop { - tokio::select! { - _ = stream1.next().instrument_await(format!("stream1 next {count}")) => {}, - r = stream2.next().instrument_await(format!("stream2 next {count}")) => { - if r.is_none() { break 'outer } - }, - } - sleep(50) - .instrument_await(format!("sleep before next stream poll: {count}")) - .await; - count += 1; - } - } - - // Check whether cleaned up - sleep(233).instrument_await("sleep 233").await; - - // TODO: add tests on sending the future to another task or context. - } - .instrument_await("hello") - .await; - - // Aborted futures have been cleaned up. There should only be a single active node of root. - assert_eq!(context().unwrap().tree().active_node_count(), 1); -} - -#[tokio::test] -async fn test_await_tree() { - let registry = Registry::new(Config::default()); - let root = registry.register((), "actor 233"); - - let fut = root.instrument(hello()); - pin_mut!(fut); - - let expected_counts = vec![ - (1, 0), - (8, 0), - (9, 0), - (8, 0), - (6, 0), - (5, 0), - (4, 0), - (4, 0), - (3, 0), - (6, 0), - (3, 0), - (4, 0), - (3, 0), - (4, 0), - (3, 0), - (4, 0), - (3, 0), - (6, 0), - (5, 2), - (6, 0), - (5, 2), - (6, 0), - (5, 0), - (4, 1), - (5, 0), - (3, 0), - (3, 0), - ]; - let mut actual_counts = vec![]; - - poll_fn(|cx| { - let tree = registry - .collect::<()>() - .into_iter() - .exactly_one() - .ok() - .unwrap() - .1; - println!("{tree}"); - actual_counts.push((tree.active_node_count(), tree.detached_node_count())); - fut.poll_unpin(cx) - }) - .await; - - assert_eq!(actual_counts, expected_counts); -} +mod functionality; +mod spawn; diff --git a/src/tests/functionality.rs b/src/tests/functionality.rs new file mode 100644 index 0000000..5e271fd --- /dev/null +++ b/src/tests/functionality.rs @@ -0,0 +1,192 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use futures::future::{join_all, poll_fn, select_all}; +use futures::{pin_mut, FutureExt, Stream, StreamExt}; +use itertools::Itertools; + +use crate::root::current_context; +use crate::{Config, InstrumentAwait, Registry}; + +async fn sleep(time: u64) { + tokio::time::sleep(std::time::Duration::from_millis(time)).await; + println!("slept {time}ms"); +} + +async fn sleep_nested() { + join_all([ + sleep(1500).instrument_await("sleep nested 1500"), + sleep(2500).instrument_await("sleep nested 2500"), + ]) + .await; +} + +async fn multi_sleep() { + sleep(400).await; + + sleep(800) + .instrument_await("sleep another in multi sleep") + .await; +} + +fn stream1() -> impl Stream { + use futures::stream::{iter, once}; + + iter(std::iter::repeat_with(|| { + once(async { + sleep(150).await; + }) + })) + .flatten() +} + +fn stream2() -> impl Stream { + use futures::stream::{iter, once}; + + iter([ + once(async { + sleep(444).await; + }) + .boxed(), + once(async { + join_all([ + sleep(400).instrument_await("sleep nested 400"), + sleep(600).instrument_await("sleep nested 600"), + ]) + .await; + }) + .boxed(), + ]) + .flatten() +} + +async fn hello() { + async move { + // Join + join_all([ + sleep(1000) + .boxed() + .instrument_await(format!("sleep {}", 1000)), + sleep(2000).boxed().instrument_await("sleep 2000"), + sleep_nested().boxed().instrument_await("sleep nested"), + multi_sleep().boxed().instrument_await("multi sleep"), + ]) + .await; + + // Join another + join_all([ + sleep(1200).instrument_await("sleep 1200"), + sleep(2200).instrument_await("sleep 2200"), + ]) + .await; + + // Cancel + select_all([ + sleep(666).boxed().instrument_await("sleep 666"), + sleep_nested() + .boxed() + .instrument_await("sleep nested (should be cancelled)"), + ]) + .await; + + // Check whether cleaned up + sleep(233).instrument_await("sleep 233").await; + + // Check stream next drop + { + let mut stream1 = stream1().fuse().boxed(); + let mut stream2 = stream2().fuse().boxed(); + let mut count = 0; + + 'outer: loop { + tokio::select! { + _ = stream1.next().instrument_await(format!("stream1 next {count}")) => {}, + r = stream2.next().instrument_await(format!("stream2 next {count}")) => { + if r.is_none() { break 'outer } + }, + } + sleep(50) + .instrument_await(format!("sleep before next stream poll: {count}")) + .await; + count += 1; + } + } + + // Check whether cleaned up + sleep(233).instrument_await("sleep 233").await; + + // TODO: add tests on sending the future to another task or context. + } + .instrument_await("hello") + .await; + + // Aborted futures have been cleaned up. There should only be a single active node of root. + assert_eq!(current_context().unwrap().tree().active_node_count(), 1); +} + +#[tokio::test] +async fn test_await_tree() { + let registry = Registry::new(Config::default()); + let root = registry.register((), "actor 233"); + + let fut = root.instrument(hello()); + pin_mut!(fut); + + let expected_counts = vec![ + (1, 0), + (8, 0), + (9, 0), + (8, 0), + (6, 0), + (5, 0), + (4, 0), + (4, 0), + (3, 0), + (6, 0), + (3, 0), + (4, 0), + (3, 0), + (4, 0), + (3, 0), + (4, 0), + (3, 0), + (6, 0), + (5, 2), + (6, 0), + (5, 2), + (6, 0), + (5, 0), + (4, 1), + (5, 0), + (3, 0), + (3, 0), + ]; + let mut actual_counts = vec![]; + + poll_fn(|cx| { + let tree = registry + .collect::<()>() + .into_iter() + .exactly_one() + .ok() + .unwrap() + .1; + println!("{tree}"); + actual_counts.push((tree.active_node_count(), tree.detached_node_count())); + fut.poll_unpin(cx) + }) + .await; + + assert_eq!(actual_counts, expected_counts); +} diff --git a/src/tests/spawn.rs b/src/tests/spawn.rs new file mode 100644 index 0000000..e35c54b --- /dev/null +++ b/src/tests/spawn.rs @@ -0,0 +1,45 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::time::Duration; + +use futures::future::pending; +use tokio::time::sleep; + +use crate::{Config, InstrumentAwait, Registry}; + +#[tokio::test] +async fn main() { + let registry = Registry::new(Config::default()); + + tokio::spawn(registry.register((), "root").instrument(async { + crate::spawn_anonymous("child", async { + crate::spawn_anonymous("grandson", async { + pending::<()>().await; + }) + .instrument_await("wait for grandson") + .await + .unwrap() + }) + .instrument_await("wait for child") + .await + .unwrap() + })); + + sleep(Duration::from_secs(1)).await; + + assert_eq!(registry.collect::<()>().len(), 1); + assert_eq!(registry.collect_anonymous().len(), 2); + assert_eq!(registry.collect_all().len(), 3); +}