Skip to content

Commit

Permalink
Wait tasks with waitgroup
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty authored and pbzweihander committed Nov 23, 2021
1 parent 54dbb93 commit 595e47d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 28 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ route-recognizer = "0.2.0"
serde = "1.0.117"
serde_json = "1.0.59"
stopper = "0.2.0"
waitgroup = "0.1.2"

[dev-dependencies]
async-std = { version = "1.6.5", features = ["unstable", "attributes"] }
Expand Down
24 changes: 10 additions & 14 deletions src/listener/tcp_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use async_std::prelude::*;
use async_std::{io, task};

use futures_util::future::Either;
use futures_util::stream::FuturesUnordered;

use waitgroup::{WaitGroup, Worker};

/// This represents a tide [Listener](crate::listener::Listener) that
/// wraps an [async_std::net::TcpListener]. It is implemented as an
Expand All @@ -25,7 +26,6 @@ pub struct TcpListener<State> {
listener: Option<net::TcpListener>,
server: Option<Server<State>>,
info: Option<ListenInfo>,
join_handles: Vec<task::JoinHandle<()>>,
}

impl<State> TcpListener<State> {
Expand All @@ -35,7 +35,6 @@ impl<State> TcpListener<State> {
listener: None,
server: None,
info: None,
join_handles: Vec::new(),
}
}

Expand All @@ -45,16 +44,18 @@ impl<State> TcpListener<State> {
listener: Some(tcp_listener.into()),
server: None,
info: None,
join_handles: Vec::new(),
}
}
}

fn handle_tcp<State: Clone + Send + Sync + 'static>(
app: Server<State>,
stream: TcpStream,
) -> task::JoinHandle<()> {
wait_group_worker: Worker,
) {
task::spawn(async move {
let _wait_group_worker = wait_group_worker;

let local_addr = stream.local_addr().ok();
let peer_addr = stream.peer_addr().ok();

Expand All @@ -75,7 +76,7 @@ fn handle_tcp<State: Clone + Send + Sync + 'static>(
if let Err(error) = fut.await {
log::error!("async-h1 error", { error: error.to_string() });
}
})
});
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -121,6 +122,7 @@ where
} else {
Either::Right(incoming)
};
let wait_group = WaitGroup::new();

while let Some(stream) = incoming.next().await {
match stream {
Expand All @@ -133,18 +135,12 @@ where
}

Ok(stream) => {
let handle = handle_tcp(server.clone(), stream);
self.join_handles.push(handle);
handle_tcp(server.clone(), stream, wait_group.worker());
}
};
}

let join_handles = std::mem::take(&mut self.join_handles);
join_handles
.into_iter()
.collect::<FuturesUnordered<task::JoinHandle<()>>>()
.collect::<()>()
.await;
wait_group.wait().await;

Ok(())
}
Expand Down
24 changes: 10 additions & 14 deletions src/listener/unix_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use async_std::prelude::*;
use async_std::{io, task};

use futures_util::future::Either;
use futures_util::stream::FuturesUnordered;

use waitgroup::{WaitGroup, Worker};

/// This represents a tide [Listener](crate::listener::Listener) that
/// wraps an [async_std::os::unix::net::UnixListener]. It is implemented as an
Expand All @@ -26,7 +27,6 @@ pub struct UnixListener<State> {
listener: Option<net::UnixListener>,
server: Option<Server<State>>,
info: Option<ListenInfo>,
join_handles: Vec<task::JoinHandle<()>>,
}

impl<State> UnixListener<State> {
Expand All @@ -36,7 +36,6 @@ impl<State> UnixListener<State> {
listener: None,
server: None,
info: None,
join_handles: Vec::new(),
}
}

Expand All @@ -46,16 +45,18 @@ impl<State> UnixListener<State> {
listener: Some(unix_listener.into()),
server: None,
info: None,
join_handles: Vec::new(),
}
}
}

fn handle_unix<State: Clone + Send + Sync + 'static>(
app: Server<State>,
stream: UnixStream,
) -> task::JoinHandle<()> {
wait_group_worker: Worker,
) {
task::spawn(async move {
let _wait_group_worker = wait_group_worker;

let local_addr = unix_socket_addr_to_string(stream.local_addr());
let peer_addr = unix_socket_addr_to_string(stream.peer_addr());

Expand All @@ -76,7 +77,7 @@ fn handle_unix<State: Clone + Send + Sync + 'static>(
if let Err(error) = fut.await {
log::error!("async-h1 error", { error: error.to_string() });
}
})
});
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -119,6 +120,7 @@ where
} else {
Either::Right(incoming)
};
let wait_group = WaitGroup::new();

while let Some(stream) = incoming.next().await {
match stream {
Expand All @@ -131,18 +133,12 @@ where
}

Ok(stream) => {
let handle = handle_unix(server.clone(), stream);
self.join_handles.push(handle);
handle_unix(server.clone(), stream, wait_group.worker());
}
};
}

let join_handles = std::mem::take(&mut self.join_handles);
join_handles
.into_iter()
.collect::<FuturesUnordered<task::JoinHandle<()>>>()
.collect::<()>()
.await;
wait_group.wait().await;

Ok(())
}
Expand Down

0 comments on commit 595e47d

Please sign in to comment.