Skip to content

Commit

Permalink
tokio-nursery: Abort tasks when NurseryStream is dropped
Browse files Browse the repository at this point in the history
  • Loading branch information
jwodder committed Jan 11, 2025
1 parent 04519bd commit 2a355bc
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 19 deletions.
1 change: 1 addition & 0 deletions crates/tokio-nursery/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ license.workspace = true

[dependencies]
futures-util = { version = "0.3.29", default-features = false, features = ["std"] }
pin-project-lite = "0.2.16"
tokio = { version = "1.29.0", features = ["rt", "sync"] }

[dev-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion crates/tokio-nursery/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ sendable, and so it can be used to spawn tasks from within other tasks.
The `NurseryStream` is a [`Stream`][] of the values returned by the tasks as
they complete; if a task panics, the panic is propagated. Once the `Nursery`
object and all of its clones have been dropped, and once all spawned futures
have completed, the stream will close.
have completed, the stream will close. If the `NurseryStream` is dropped, all
tasks in the nursery are aborted.

[tokio]: https://tokio.rs
[`async_nursery`]: https://crates.io/crates/async_nursery
Expand Down
91 changes: 73 additions & 18 deletions crates/tokio-nursery/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use futures_util::{FutureExt, Stream};
use futures_util::{stream::FuturesUnordered, Stream, StreamExt};
use pin_project_lite::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};

/// Type returned by [`FutureExt::catch_unwind()`]. If the inner task ran to
/// completion, this is `Ok`; otherwise, if the taks panicked, this is `Err`.
type UnwindResult<T> = Result<T, Box<dyn std::any::Any + Send>>;
use tokio::{
sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
task::{JoinError, JoinHandle},
};

/// A handle for spawning new tasks in a task group/nursery.
///
Expand All @@ -16,7 +16,7 @@ type UnwindResult<T> = Result<T, Box<dyn std::any::Any + Send>>;
/// corresponding [`NurseryStream`] can yield `None`.
#[derive(Debug)]
pub struct Nursery<T> {
sender: UnboundedSender<UnwindResult<T>>,
sender: UnboundedSender<FragileHandle<T>>,
}

impl<T: Send + 'static> Nursery<T> {
Expand All @@ -25,19 +25,21 @@ impl<T: Send + 'static> Nursery<T> {
/// futures that will be spawned in the nursery.
pub fn new() -> (Nursery<T>, NurseryStream<T>) {
let (sender, receiver) = unbounded_channel();
(Nursery { sender }, NurseryStream { receiver })
(
Nursery { sender },
NurseryStream {
receiver,
tasks: FuturesUnordered::new(),
},
)
}

/// Spawn a future that returns `T` in the nursery.
pub fn spawn<Fut>(&self, fut: Fut)
where
Fut: Future<Output = T> + Send + 'static,
{
let sender = self.sender.clone();
tokio::spawn(async move {
let task = std::panic::AssertUnwindSafe(fut).catch_unwind();
let _ = sender.send(task.await);
});
let _ = self.sender.send(FragileHandle::new(tokio::spawn(fut)));
}
}

Expand All @@ -55,9 +57,12 @@ impl<T> Clone for Nursery<T> {
///
/// The corresponding [`Nursery`] and all clones thereof must be dropped before
/// the stream can yield `None`.
///
/// When a `NurseryStream` is dropped, all tasks in the nursery are aborted.
#[derive(Debug)]
pub struct NurseryStream<T> {
receiver: UnboundedReceiver<UnwindResult<T>>,
receiver: UnboundedReceiver<FragileHandle<T>>,
tasks: FuturesUnordered<FragileHandle<T>>,
}

impl<T: 'static> Stream for NurseryStream<T> {
Expand All @@ -70,18 +75,68 @@ impl<T: 'static> Stream for NurseryStream<T> {
///
/// If a task panics, this method resumes unwinding the panic.
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
match ready!(self.receiver.poll_recv(cx)) {
let closed = loop {
match self.receiver.poll_recv(cx) {
Poll::Pending => break false,
Poll::Ready(Some(handle)) => self.tasks.push(handle),
Poll::Ready(None) => break true,
}
};
match ready!(self.tasks.poll_next_unpin(cx)) {
Some(Ok(r)) => Some(r).into(),
Some(Err(e)) => std::panic::resume_unwind(e),
None => None.into(),
Some(Err(e)) => match e.try_into_panic() {
Ok(barf) => std::panic::resume_unwind(barf),
Err(e) => unreachable!(
"Task in nursery should not have been aborted before dropping stream, but got {e:?}"
),
},
None => {
if closed {
// All Nursery clones dropped and all results yielded; end
// of stream
None.into()
} else {
Poll::Pending
}
}
}
}
}

pin_project! {
/// A wrapper around `tokio::task::JoinHandle` that aborts the task on drop.
#[derive(Debug)]
struct FragileHandle<T> {
#[pin]
inner: JoinHandle<T>
}

impl<T> PinnedDrop for FragileHandle<T> {
fn drop(this: Pin<&mut Self>) {
this.project().inner.abort();
}
}
}

impl<T> FragileHandle<T> {
fn new(inner: JoinHandle<T>) -> Self {
FragileHandle { inner }
}
}

impl<T> Future for FragileHandle<T> {
type Output = Result<T, JoinError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.inner.poll(cx)
}
}

#[cfg(test)]
mod tests {
use super::*;
use futures_util::StreamExt;
use futures_util::{FutureExt, StreamExt};
use std::time::Duration;
use tokio::time::timeout;

Expand Down

0 comments on commit 2a355bc

Please sign in to comment.