Skip to content

Commit

Permalink
Efficiently remove MaybeDone from tuple::join
Browse files Browse the repository at this point in the history
  • Loading branch information
matheus-consoli committed Nov 22, 2022
1 parent 8a672db commit 03c26b2
Showing 1 changed file with 35 additions and 48 deletions.
83 changes: 35 additions & 48 deletions src/future/join/tuple.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::Join as JoinTrait;
use crate::utils::{PollArray, RandomGenerator, WakerArray};
use crate::utils::PollArray;

use core::fmt::{self, Debug};
use core::future::{Future, IntoFuture};
Expand All @@ -9,17 +9,30 @@ use core::task::{Context, Poll};

use pin_project::pin_project;

macro_rules! poll_future {
($fut_idx:tt, $iteration:ident, $this:ident, $outputs:ident, $futures:ident . $fut_member:ident, $cx:ident) => {
/// Generates the `poll` call for every `Future` inside `$futures`.
// This is implemented as a tt-muncher of the future name `$($F:ident)`
// and the future index `$($rest)`, taking advantage that we only support
// tuples up to 12 elements
//
// # References
// TT Muncher: https://veykril.github.io/tlborm/decl-macros/patterns/tt-muncher.html
macro_rules! poll {
(@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => {
if $fut_idx == $iteration {
if let Poll::Ready(value) =
unsafe { Pin::new_unchecked(&mut $futures.$fut_member) }.poll(&mut $cx)
{
$this.outputs.$fut_member.write(value);
if let Poll::Ready(value) = $futures.$fut_name.as_mut().poll($cx) {
$this.outputs.$fut_idx.write(value);
*$this.completed += 1;
$this.state[$fut_idx].set_consumed();
}
}
poll!(@inner $iteration, $this, $futures, $cx, $($F)* | $($rest)*);
};

// base condition, no more futures to poll
(@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, | $($rest:tt)*) => {};

($iteration:ident, $this:ident, $futures:ident, $cx:ident, $LEN:ident, $($F:ident,)+) => {
poll!(@inner $iteration, $this, $futures, $cx, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11);
};
}

Expand Down Expand Up @@ -62,14 +75,10 @@ macro_rules! impl_join_tuple {
};
($mod_name:ident $StructName:ident $($F:ident)+) => {
mod $mod_name {
use core::mem::MaybeUninit;
use core::future::Future;

#[pin_project::pin_project]
pub(super) struct Futures<$($F,)+> { $(#[pin] pub(super) $F: $F,)+ }

pub(super) struct Outputs<$($F: Future,)+> { $(pub(super) $F: MaybeUninit<$F::Output>,)+ }

#[repr(u8)]
pub(super) enum Indexes { $($F,)+ }

Expand All @@ -88,11 +97,9 @@ macro_rules! impl_join_tuple {
#[allow(non_snake_case)]
pub struct $StructName<$($F: Future),+> {
#[pin] futures: $mod_name::Futures<$($F,)+>,
outputs: $mod_name::Outputs<$($F,)+>,
rng: RandomGenerator,
wakers: WakerArray<{$mod_name::LEN}>,
outputs: ($(MaybeUninit<$F::Output>,)+),
state: PollArray<{$mod_name::LEN}>,
completed: u8,
completed: usize,
}

impl<$($F),+> Debug for $StructName<$($F),+>
Expand All @@ -116,48 +123,30 @@ macro_rules! impl_join_tuple {
fn poll(
self: Pin<&mut Self>, cx: &mut Context<'_>
) -> Poll<Self::Output> {
let this = self.project();
let mut this = self.project();

let mut readiness = this.wakers.readiness().lock().unwrap();
readiness.set_waker(cx.waker());

const LEN: u8 = $mod_name::LEN as u8;
let r = this.rng.generate(LEN as u32) as u8;
const LEN: usize = $mod_name::LEN;

let mut futures = this.futures.project();

for index in (0..LEN).map(|n| (r + n).wrapping_rem(LEN) as usize) {
if !readiness.any_ready() {
return Poll::Pending;
} else if !readiness.clear_ready(index) || this.state[index].is_consumed() {
for index in 0..LEN {
if this.state[index].is_consumed() {
continue;
}

drop(readiness);

let mut cx = Context::from_waker(this.wakers.get(index).unwrap());

$(
let fut_index = $mod_name::Indexes::$F as usize;
poll_future!(
fut_index,
index,
this,
outputs,
futures . $F,
cx
);
)+
// generate the needed code to poll `futures.{index}`
poll!(index, this, futures, cx, LEN, $($F,)+);

if *this.completed == LEN {
let out = {
let mut output = $mod_name::Outputs { $($F: MaybeUninit::uninit(),)+ };
core::mem::swap(this.outputs, &mut output);
unsafe { ( $(output.$F.assume_init(),)+ ) }
let mut out = ($(MaybeUninit::<$F::Output>::uninit(),)+);
core::mem::swap(&mut out, this.outputs);
let ($($F,)+) = out;
unsafe { ($($F.assume_init(),)+) }
};

return Poll::Ready(out);
}
readiness = this.wakers.readiness().lock().unwrap();
}

Poll::Pending
Expand All @@ -175,11 +164,9 @@ macro_rules! impl_join_tuple {
fn join(self) -> Self::Future {
let ($($F,)+): ($($F,)+) = self;
$StructName {
futures: $mod_name::Futures { $($F: $F.into_future(),)+ },
rng: RandomGenerator::new(),
wakers: WakerArray::new(),
futures: $mod_name::Futures {$($F: $F.into_future(),)+},
state: PollArray::new(),
outputs: $mod_name::Outputs { $($F: MaybeUninit::uninit(),)+ },
outputs: ($(MaybeUninit::<$F::Output>::uninit(),)+),
completed: 0,
}
}
Expand Down

0 comments on commit 03c26b2

Please sign in to comment.