diff --git a/Cargo.toml b/Cargo.toml index fab5e45b..6512284c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ route-recognizer = "0.2.0" serde = "1.0.117" serde_json = "1.0.59" stopper = "0.2.0" +waitgroup = "0.1.2" [dev-dependencies] async-std = { version = "1.6.5", features = ["unstable", "attributes"] } diff --git a/src/listener/tcp_listener.rs b/src/listener/tcp_listener.rs index 9ae9dc8d..5b444ff3 100644 --- a/src/listener/tcp_listener.rs +++ b/src/listener/tcp_listener.rs @@ -10,7 +10,8 @@ use async_std::prelude::*; use async_std::{io, task}; use futures_util::future::Either; -use futures_util::stream::FuturesUnordered; + +use waitgroup::{WaitGroup, Worker}; /// This represents a tide [Listener](crate::listener::Listener) that /// wraps an [async_std::net::TcpListener]. It is implemented as an @@ -25,7 +26,6 @@ pub struct TcpListener { listener: Option, server: Option>, info: Option, - join_handles: Vec>, } impl TcpListener { @@ -35,7 +35,6 @@ impl TcpListener { listener: None, server: None, info: None, - join_handles: Vec::new(), } } @@ -45,7 +44,6 @@ impl TcpListener { listener: Some(tcp_listener.into()), server: None, info: None, - join_handles: Vec::new(), } } } @@ -53,8 +51,11 @@ impl TcpListener { fn handle_tcp( app: Server, stream: TcpStream, -) -> task::JoinHandle<()> { + wait_group_worker: Worker, +) { task::spawn(async move { + let _wait_group_worker = wait_group_worker; + let local_addr = stream.local_addr().ok(); let peer_addr = stream.peer_addr().ok(); @@ -75,7 +76,7 @@ fn handle_tcp( if let Err(error) = fut.await { log::error!("async-h1 error", { error: error.to_string() }); } - }) + }); } #[async_trait::async_trait] @@ -121,6 +122,7 @@ where } else { Either::Right(incoming) }; + let wait_group = WaitGroup::new(); while let Some(stream) = incoming.next().await { match stream { @@ -133,18 +135,12 @@ where } Ok(stream) => { - let handle = handle_tcp(server.clone(), stream); - self.join_handles.push(handle); + handle_tcp(server.clone(), stream, wait_group.worker()); } }; } - let join_handles = std::mem::take(&mut self.join_handles); - join_handles - .into_iter() - .collect::>>() - .collect::<()>() - .await; + wait_group.wait().await; Ok(()) } diff --git a/src/listener/unix_listener.rs b/src/listener/unix_listener.rs index 50233ca8..9b6c6e4d 100644 --- a/src/listener/unix_listener.rs +++ b/src/listener/unix_listener.rs @@ -11,7 +11,8 @@ use async_std::prelude::*; use async_std::{io, task}; use futures_util::future::Either; -use futures_util::stream::FuturesUnordered; + +use waitgroup::{WaitGroup, Worker}; /// This represents a tide [Listener](crate::listener::Listener) that /// wraps an [async_std::os::unix::net::UnixListener]. It is implemented as an @@ -26,7 +27,6 @@ pub struct UnixListener { listener: Option, server: Option>, info: Option, - join_handles: Vec>, } impl UnixListener { @@ -36,7 +36,6 @@ impl UnixListener { listener: None, server: None, info: None, - join_handles: Vec::new(), } } @@ -46,7 +45,6 @@ impl UnixListener { listener: Some(unix_listener.into()), server: None, info: None, - join_handles: Vec::new(), } } } @@ -54,8 +52,11 @@ impl UnixListener { fn handle_unix( app: Server, stream: UnixStream, -) -> task::JoinHandle<()> { + wait_group_worker: Worker, +) { task::spawn(async move { + let _wait_group_worker = wait_group_worker; + let local_addr = unix_socket_addr_to_string(stream.local_addr()); let peer_addr = unix_socket_addr_to_string(stream.peer_addr()); @@ -76,7 +77,7 @@ fn handle_unix( if let Err(error) = fut.await { log::error!("async-h1 error", { error: error.to_string() }); } - }) + }); } #[async_trait::async_trait] @@ -119,6 +120,7 @@ where } else { Either::Right(incoming) }; + let wait_group = WaitGroup::new(); while let Some(stream) = incoming.next().await { match stream { @@ -131,18 +133,12 @@ where } Ok(stream) => { - let handle = handle_unix(server.clone(), stream); - self.join_handles.push(handle); + handle_unix(server.clone(), stream, wait_group.worker()); } }; } - let join_handles = std::mem::take(&mut self.join_handles); - join_handles - .into_iter() - .collect::>>() - .collect::<()>() - .await; + wait_group.wait().await; Ok(()) }