Skip to content

Commit

Permalink
Merge pull request #29 from jwodder/nursery
Browse files Browse the repository at this point in the history
Add a simple tokio-based nursery
  • Loading branch information
jwodder authored Jan 10, 2025
2 parents ae88c3b + fb3edd0 commit 7b1256e
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 0 deletions.
17 changes: 17 additions & 0 deletions crates/tokio-nursery/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "rswodlib-tokio-nursery"
edition.workspace = true
description = "Simple tokio-based task group/nursery"
authors.workspace = true
repository.workspace = true
license.workspace = true

[dependencies]
futures-util = { version = "0.3.29", default-features = false, features = ["std"] }
tokio = { version = "1.29.0", features = ["rt", "sync"] }

[dev-dependencies]
tokio = { version = "1.29.0", features = ["macros", "rt"] }

[lints]
workspace = true
1 change: 1 addition & 0 deletions crates/tokio-nursery/LICENSE
21 changes: 21 additions & 0 deletions crates/tokio-nursery/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
This crate defines a [tokio][]-based *task group* or *nursery* for spawning
asynchronous tasks and retrieving their return values. The API is based on
[`async_nursery`][], which would have been perfect for my needs at the time,
except that it doesn't support creating a nursery inside a Tokio runtime.

Usage
=====

Call `Nursery::new()` to receive a `(Nursery<T>, NurseryStream<T>)` pair, where
`T` is the output type of the futures that you'll be spawning in the nursery.
Call `nursery.spawn(future)` to spawn a future. The nursery is clonable &
sendable, and so it can be used to spawn tasks from within other tasks.

The `NurseryStream` is a [`Stream`][] of the values returned by the tasks as
they complete; if a task panics, the panic is propagated. Once the `Nursery`
object and all of its clones have been dropped, and once all spawned futures
have completed, the stream will close.

[tokio]: https://tokio.rs
[`async_nursery`]: https://crates.io/crates/async_nursery
[`Stream`]: https://docs.rs/futures-util/latest/futures_util/stream/trait.Stream.html
120 changes: 120 additions & 0 deletions crates/tokio-nursery/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use futures_util::{FutureExt, Stream};
use std::future::Future;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};

type UnwindResult<T> = Result<T, Box<dyn std::any::Any + Send>>;

#[derive(Debug)]
pub struct Nursery<T> {
sender: UnboundedSender<UnwindResult<T>>,
}

impl<T: Send + 'static> Nursery<T> {
pub fn new() -> (Nursery<T>, NurseryStream<T>) {
let (sender, receiver) = unbounded_channel();
(Nursery { sender }, NurseryStream { receiver })
}

pub fn spawn<Fut>(&self, fut: Fut)
where
Fut: Future<Output = T> + Send + 'static,
{
let sender = self.sender.clone();
tokio::spawn(async move {
let task = std::panic::AssertUnwindSafe(fut).catch_unwind();
let _ = sender.send(task.await);
});
}
}

// Clone can't be derived, as that would erroneously add `T: Clone` bounds to
// the impl.
impl<T> Clone for Nursery<T> {
fn clone(&self) -> Nursery<T> {
Nursery {
sender: self.sender.clone(),
}
}
}

#[derive(Debug)]
pub struct NurseryStream<T> {
receiver: UnboundedReceiver<UnwindResult<T>>,
}

impl<T: 'static> Stream for NurseryStream<T> {
type Item = T;

/// Poll for one of the tasks in the nursery to complete and return its
/// return value.
///
/// # Panics
///
/// If a task panics, this method resumes unwinding the panic.
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
match ready!(self.receiver.poll_recv(cx)) {
Some(Ok(r)) => Some(r).into(),
Some(Err(e)) => std::panic::resume_unwind(e),
None => None.into(),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use futures_util::StreamExt;

#[test]
fn nursery_is_send() {
#[allow(dead_code)]
fn require_send<T: Send>(_t: T) {}

#[allow(dead_code)]
fn check_nursery_send<T: Send + 'static>() {
let (nursery, _) = Nursery::<T>::new();
require_send(nursery);
}
}

#[tokio::test]
async fn collect() {
let (nursery, nursery_stream) = Nursery::new();
nursery.spawn(std::future::ready(1));
nursery.spawn(std::future::ready(2));
nursery.spawn(std::future::ready(3));
drop(nursery);
let mut values = nursery_stream.collect::<Vec<_>>().await;
values.sort_unstable();
assert_eq!(values, vec![1, 2, 3]);
}

#[tokio::test]
async fn nested_spawn() {
let (nursery, nursery_stream) = Nursery::new();
let inner = nursery.clone();
nursery.spawn(async move {
inner.spawn(std::future::ready(0));
std::future::ready(1).await
});
nursery.spawn(std::future::ready(2));
nursery.spawn(std::future::ready(3));
drop(nursery);
let mut values = nursery_stream.collect::<Vec<_>>().await;
values.sort_unstable();
assert_eq!(values, vec![0, 1, 2, 3]);
}

#[tokio::test]
async fn reraise_panic() {
let (nursery, mut nursery_stream) = Nursery::new();
nursery.spawn(async { panic!("I can't take this anymore!") });
drop(nursery);
let r = std::panic::AssertUnwindSafe(nursery_stream.next())
.catch_unwind()
.await;
assert!(r.is_err());
}
}

0 comments on commit 7b1256e

Please sign in to comment.