Skip to content

Commit

Permalink
sync: add watch::Receiver::wait_for (tokio-rs#5611)
Browse files Browse the repository at this point in the history
  • Loading branch information
debadree25 authored Apr 24, 2023
1 parent 11b8807 commit c1778ed
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 8 deletions.
54 changes: 54 additions & 0 deletions tokio/src/sync/tests/loom_watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::sync::watch;

use loom::future::block_on;
use loom::thread;
use std::sync::Arc;

#[test]
fn smoke() {
Expand Down Expand Up @@ -34,3 +35,56 @@ fn smoke() {
th.join().unwrap();
})
}

#[test]
fn wait_for_test() {
loom::model(move || {
let (tx, mut rx) = watch::channel(false);

let tx_arc = Arc::new(tx);
let tx1 = tx_arc.clone();
let tx2 = tx_arc.clone();

let th1 = thread::spawn(move || {
for _ in 0..2 {
tx1.send_modify(|_x| {});
}
});

let th2 = thread::spawn(move || {
tx2.send(true).unwrap();
});

assert_eq!(*block_on(rx.wait_for(|x| *x)).unwrap(), true);

th1.join().unwrap();
th2.join().unwrap();
});
}

#[test]
fn wait_for_returns_correct_value() {
loom::model(move || {
let (tx, mut rx) = watch::channel(0);

let jh = thread::spawn(move || {
tx.send(1).unwrap();
tx.send(2).unwrap();
tx.send(3).unwrap();
});

// Stop at the first value we are called at.
let mut stopped_at = usize::MAX;
let returned = *block_on(rx.wait_for(|x| {
stopped_at = *x;
true
}))
.unwrap();

// Check that it returned the same value as the one we returned
// `true` for.
assert_eq!(stopped_at, returned);

jh.join().unwrap();
});
}
110 changes: 102 additions & 8 deletions tokio/src/sync/watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,18 +595,93 @@ impl<T> Receiver<T> {
/// }
/// ```
pub async fn changed(&mut self) -> Result<(), error::RecvError> {
changed_impl(&self.shared, &mut self.version).await
}

/// Waits for a value that satisifes the provided condition.
///
/// This method will call the provided closure whenever something is sent on
/// the channel. Once the closure returns `true`, this method will return a
/// reference to the value that was passed to the closure.
///
/// Before `wait_for` starts waiting for changes, it will call the closure
/// on the current value. If the closure returns `true` when given the
/// current value, then `wait_for` will immediately return a reference to
/// the current value. This is the case even if the current value is already
/// considered seen.
///
/// The watch channel only keeps track of the most recent value, so if
/// several messages are sent faster than `wait_for` is able to call the
/// closure, then it may skip some updates. Whenever the closure is called,
/// it will be called with the most recent value.
///
/// When this function returns, the value that was passed to the closure
/// when it returned `true` will be considered seen.
///
/// If the channel is closed, then `wait_for` will return a `RecvError`.
/// Once this happens, no more messages can ever be sent on the channel.
/// When an error is returned, it is guaranteed that the closure has been
/// called on the last value, and that it returned `false` for that value.
/// (If the closure returned `true`, then the last value would have been
/// returned instead of the error.)
///
/// Like the `borrow` method, the returned borrow holds a read lock on the
/// inner value. This means that long-lived borrows could cause the producer
/// half to block. It is recommended to keep the borrow as short-lived as
/// possible. See the documentation of `borrow` for more information on
/// this.
///
/// [`Receiver::changed()`]: crate::sync::watch::Receiver::changed
///
/// # Examples
///
/// ```
/// use tokio::sync::watch;
///
/// #[tokio::main]
///
/// async fn main() {
/// let (tx, _rx) = watch::channel("hello");
///
/// tx.send("goodbye").unwrap();
///
/// // here we subscribe to a second receiver
/// // now in case of using `changed` we would have
/// // to first check the current value and then wait
/// // for changes or else `changed` would hang.
/// let mut rx2 = tx.subscribe();
///
/// // in place of changed we have use `wait_for`
/// // which would automatically check the current value
/// // and wait for changes until the closure returns true.
/// assert!(rx2.wait_for(|val| *val == "goodbye").await.is_ok());
/// assert_eq!(*rx2.borrow(), "goodbye");
/// }
/// ```
pub async fn wait_for(
&mut self,
mut f: impl FnMut(&T) -> bool,
) -> Result<Ref<'_, T>, error::RecvError> {
let mut closed = false;
loop {
// In order to avoid a race condition, we first request a notification,
// **then** check the current value's version. If a new version exists,
// the notification request is dropped.
let notified = self.shared.notify_rx.notified();
{
let inner = self.shared.value.read().unwrap();

if let Some(ret) = maybe_changed(&self.shared, &mut self.version) {
return ret;
let new_version = self.shared.state.load().version();
let has_changed = self.version != new_version;
self.version = new_version;

if (!closed || has_changed) && f(&inner) {
return Ok(Ref { inner, has_changed });
}
}

notified.await;
// loop around again in case the wake-up was spurious
if closed {
return Err(error::RecvError(()));
}

// Wait for the value to change.
closed = changed_impl(&self.shared, &mut self.version).await.is_err();
}
}

Expand Down Expand Up @@ -655,6 +730,25 @@ fn maybe_changed<T>(
None
}

async fn changed_impl<T>(
shared: &Shared<T>,
version: &mut Version,
) -> Result<(), error::RecvError> {
loop {
// In order to avoid a race condition, we first request a notification,
// **then** check the current value's version. If a new version exists,
// the notification request is dropped.
let notified = shared.notify_rx.notified();

if let Some(ret) = maybe_changed(shared, version) {
return ret;
}

notified.await;
// loop around again in case the wake-up was spurious
}
}

impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
let version = self.version;
Expand Down

0 comments on commit c1778ed

Please sign in to comment.