diff --git a/crates/tokio-nursery/Cargo.toml b/crates/tokio-nursery/Cargo.toml index d2984d2..1962629 100644 --- a/crates/tokio-nursery/Cargo.toml +++ b/crates/tokio-nursery/Cargo.toml @@ -8,6 +8,7 @@ license.workspace = true [dependencies] futures-util = { version = "0.3.29", default-features = false, features = ["std"] } +pin-project-lite = "0.2.16" tokio = { version = "1.29.0", features = ["rt", "sync"] } [dev-dependencies] diff --git a/crates/tokio-nursery/README.md b/crates/tokio-nursery/README.md index af06168..c826234 100644 --- a/crates/tokio-nursery/README.md +++ b/crates/tokio-nursery/README.md @@ -14,7 +14,8 @@ sendable, and so it can be used to spawn tasks from within other tasks. The `NurseryStream` is a [`Stream`][] of the values returned by the tasks as they complete; if a task panics, the panic is propagated. Once the `Nursery` object and all of its clones have been dropped, and once all spawned futures -have completed, the stream will close. +have completed, the stream will close. If the `NurseryStream` is dropped, all +tasks in the nursery are aborted. [tokio]: https://tokio.rs [`async_nursery`]: https://crates.io/crates/async_nursery diff --git a/crates/tokio-nursery/src/lib.rs b/crates/tokio-nursery/src/lib.rs index 295affc..fbacb2d 100644 --- a/crates/tokio-nursery/src/lib.rs +++ b/crates/tokio-nursery/src/lib.rs @@ -1,12 +1,12 @@ -use futures_util::{FutureExt, Stream}; +use futures_util::{stream::FuturesUnordered, Stream, StreamExt}; +use pin_project_lite::pin_project; use std::future::Future; use std::pin::Pin; use std::task::{ready, Context, Poll}; -use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; - -/// Type returned by [`FutureExt::catch_unwind()`]. If the inner task ran to -/// completion, this is `Ok`; otherwise, if the taks panicked, this is `Err`. -type UnwindResult = Result>; +use tokio::{ + sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + task::{JoinError, JoinHandle}, +}; /// A handle for spawning new tasks in a task group/nursery. /// @@ -16,7 +16,7 @@ type UnwindResult = Result>; /// corresponding [`NurseryStream`] can yield `None`. #[derive(Debug)] pub struct Nursery { - sender: UnboundedSender>, + sender: UnboundedSender>, } impl Nursery { @@ -25,7 +25,13 @@ impl Nursery { /// futures that will be spawned in the nursery. pub fn new() -> (Nursery, NurseryStream) { let (sender, receiver) = unbounded_channel(); - (Nursery { sender }, NurseryStream { receiver }) + ( + Nursery { sender }, + NurseryStream { + receiver, + tasks: FuturesUnordered::new(), + }, + ) } /// Spawn a future that returns `T` in the nursery. @@ -33,11 +39,7 @@ impl Nursery { where Fut: Future + Send + 'static, { - let sender = self.sender.clone(); - tokio::spawn(async move { - let task = std::panic::AssertUnwindSafe(fut).catch_unwind(); - let _ = sender.send(task.await); - }); + let _ = self.sender.send(FragileHandle::new(tokio::spawn(fut))); } } @@ -55,9 +57,12 @@ impl Clone for Nursery { /// /// The corresponding [`Nursery`] and all clones thereof must be dropped before /// the stream can yield `None`. +/// +/// When a `NurseryStream` is dropped, all tasks in the nursery are aborted. #[derive(Debug)] pub struct NurseryStream { - receiver: UnboundedReceiver>, + receiver: UnboundedReceiver>, + tasks: FuturesUnordered>, } impl Stream for NurseryStream { @@ -70,18 +75,68 @@ impl Stream for NurseryStream { /// /// If a task panics, this method resumes unwinding the panic. fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(self.receiver.poll_recv(cx)) { + let closed = loop { + match self.receiver.poll_recv(cx) { + Poll::Pending => break false, + Poll::Ready(Some(handle)) => self.tasks.push(handle), + Poll::Ready(None) => break true, + } + }; + match ready!(self.tasks.poll_next_unpin(cx)) { Some(Ok(r)) => Some(r).into(), - Some(Err(e)) => std::panic::resume_unwind(e), - None => None.into(), + Some(Err(e)) => match e.try_into_panic() { + Ok(barf) => std::panic::resume_unwind(barf), + Err(e) => unreachable!( + "Task in nursery should not have been aborted before dropping stream, but got {e:?}" + ), + }, + None => { + if closed { + // All Nursery clones dropped and all results yielded; end + // of stream + None.into() + } else { + Poll::Pending + } + } + } + } +} + +pin_project! { + /// A wrapper around `tokio::task::JoinHandle` that aborts the task on drop. + #[derive(Debug)] + struct FragileHandle { + #[pin] + inner: JoinHandle + } + + impl PinnedDrop for FragileHandle { + fn drop(this: Pin<&mut Self>) { + this.project().inner.abort(); } } } +impl FragileHandle { + fn new(inner: JoinHandle) -> Self { + FragileHandle { inner } + } +} + +impl Future for FragileHandle { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + this.inner.poll(cx) + } +} + #[cfg(test)] mod tests { use super::*; - use futures_util::StreamExt; + use futures_util::{FutureExt, StreamExt}; use std::time::Duration; use tokio::time::timeout;