Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for greedy (or batched) message streaming #121

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved.

from typing import Any, Callable, Generic, Iterable, Iterator, List, Optional, TypeVar
from typing import Any, Callable, Generator, Generic, Iterable, List, Literal, Optional, TypeVar, Union, overload

import tf2_ros
from rclpy.node import Node
Expand Down Expand Up @@ -47,7 +47,7 @@ def link(self) -> Filter:
@property
def history(self) -> List[MessageT]:
"""Gets the entire history of messages received so far."""
return list(self._tape.content())
return self._tape.content(greedy=True)

@property
def latest(self) -> Optional[MessageT]:
Expand Down Expand Up @@ -80,29 +80,71 @@ def recall(self, callback: Callable[[MessageT], None]) -> Tunnel:
tunnel.registerCallback(callback)
return tunnel

@overload
def stream(
self,
*,
forward_only: bool = False,
expunge: bool = False,
buffer_size: Optional[int] = None,
timeout_sec: Optional[float] = None,
) -> Iterator[MessageT]:
) -> Generator[MessageT, None, None]:
"""Overload for plain streaming."""

@overload
def stream(
self,
*,
greedy: Literal[True],
forward_only: bool = False,
expunge: bool = False,
buffer_size: Optional[int] = None,
timeout_sec: Optional[float] = None,
) -> Generator[List[MessageT], None, None]:
"""Overload for greedy, batched streaming."""

def stream(
self,
*,
greedy: bool = False,
forward_only: bool = False,
expunge: bool = False,
buffer_size: Optional[int] = None,
timeout_sec: Optional[float] = None,
) -> Generator[Union[MessageT, List[MessageT]], None, None]:
"""Iterates over messages as they come.

Iteration stops when the given timeout expires or when the associated context
is shutdown. Note that iterating over the message stream is a blocking operation.

Args:
greedy: if true, greedily batch messages as it becomes available.
forward_only: whether to ignore previosuly received messages.
expunge: if true, wipe out the message history after reading
if it applies (i.e. non-forward only streams).
buffer_size: optional maximum size for the incoming messages buffer.
If none is provided, the buffer will be grow unbounded.
timeout_sec: optional timeout, in seconds, for a new message to be received.

Returns:
a lazy iterator over messages.
a lazy iterator over messages, one message at a time or in batches if greedy.

Raises:
TimeoutError: if streaming times out waiting for a new message.
"""
if greedy:
# use boolean literals to help mypy
return self._tape.content(
follow=True,
greedy=True,
expunge=expunge,
forward_only=forward_only,
buffer_size=buffer_size,
timeout_sec=timeout_sec,
)
return self._tape.content(
follow=True,
expunge=expunge,
forward_only=forward_only,
buffer_size=buffer_size,
timeout_sec=timeout_sec,
Expand Down
127 changes: 107 additions & 20 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
import weakref
from collections.abc import Mapping, MutableSet
from typing import Any, Callable, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Generator, Generic, List, Literal, Optional, Tuple, TypeVar, Union, overload

import rclpy.clock
import rclpy.duration
Expand Down Expand Up @@ -109,22 +109,42 @@ def write(self, data: U) -> bool:
return False
return True

def read(self, timeout_sec: Optional[float] = None) -> Optional[U]:
def try_read(self) -> Optional[U]:
"""Try to read data from the stream.

Returns:
data if the read is successful and ``None``
if there is nothing to be read or the stream
is interrupted.
"""
try:
data = self._queue.get_nowait()
self._queue.task_done()
except queue.Empty:
return None
return data

def read(self, timeout_sec: Optional[float] = None) -> U:
"""Read data from the stream.

Args:
timeout_sec: optional read timeout, in seconds.

Returns:
data if the read is successful and ``None``
data read

Raises:
InterruptedError if the read if the read is successful and ``None``
if the read times out or is interrupted.
"""
try:
data = self._queue.get(timeout=timeout_sec)
except queue.Empty:
return None
self._queue.task_done()
return data
self._queue.task_done()
if data is None:
raise InterruptedError()
return data
except queue.Empty as e:
raise TimeoutError() from e

def interrupt(self) -> None:
"""Interrupt the stream and wake up the reader."""
Expand Down Expand Up @@ -216,61 +236,128 @@ def head(self) -> Optional[T]:
return None
return self._content[0]

@overload
def content(
self,
*,
follow: bool = ...,
forward_only: bool = ...,
expunge: bool = ...,
buffer_size: Optional[int] = ...,
timeout_sec: Optional[float] = ...,
label: Optional[str] = ...,
) -> Generator[T, None, None]:
"""Overload for non-greedy iteration."""

@overload
def content(
self,
*,
greedy: Literal[True],
follow: Literal[True],
forward_only: bool = ...,
expunge: bool = ...,
buffer_size: Optional[int] = ...,
timeout_sec: Optional[float] = ...,
label: Optional[str] = ...,
) -> Generator[List[T], None, None]:
"""Overload for greedy batched iteration."""

@overload
def content(
self,
*,
greedy: Literal[True],
expunge: bool = ...,
buffer_size: Optional[int] = ...,
timeout_sec: Optional[float] = ...,
label: Optional[str] = ...,
) -> List[T]:
"""Overload for greedy full reads."""

def content(
self,
*,
greedy: bool = False,
follow: bool = False,
forward_only: bool = False,
expunge: bool = False,
buffer_size: Optional[int] = None,
timeout_sec: Optional[float] = None,
label: Optional[str] = None,
) -> Iterator[T]:
) -> Union[Generator[Union[T, List[T]], None, None], List[T]]:
"""Iterate over the data tape.

When following the data tape, iteration stops when the given timeout
expires and when the data tape is closed.

Args:
greedy: if true, greedily batch content as it becomes available.
follow: whether to follow the data tape as it gets written or not.
forward_only: if true, ignore existing content and only look ahead
when following the data tape.
expunge: if true, wipe out existing content in the data tape after
reading if it applies (i.e. non-forward only iterations).
buffer_size: optional buffer size when following the data tape.
If none is provided, the buffer will grow as necessary.
timeout_sec: optional timeout, in seconds, when following the data tape.
label: optional label to qualify logs and warnings.

Returns:
a lazy iterator over the data tape.
a lazy iterator over the data tape, one item at a time or in batches if greedy.

Raises:
TimeoutError: if iteration times out waiting for new data.
"""
# Here we split the generator in two, so that setup code is executed eagerly.
with self._lock:
content: Optional[collections.deque] = None
if not forward_only and self._content is not None:
content = self._content.copy()
if self._content is not None and expunge:
self._content.clear()
stream: Optional[Tape.Stream] = None
if follow and not self._closed:
stream = Tape.Stream(buffer_size, label)
self._streams.add(stream)

def _generator() -> Iterator:
if content is not None and stream is None and greedy:
return list(content)

def _generator() -> Generator[Union[T, List[T]], None, None]:
nonlocal content, stream
try:
if content is not None:
yield from content
if greedy:
yield list(content)
else:
yield from content

if stream is not None:
while not self._closed:
feedback = stream.read(timeout_sec)
if feedback is None:
break
yield feedback
while not stream.consumed:
# This is safe as long as there is
# a single reader for the stream,
# which is currently the case.
feedback = stream.read(timeout_sec)
if feedback is None:
continue
yield feedback
if greedy:
batch = [feedback]
while True:
feedback = stream.try_read()
if feedback is None:
break
batch.append(feedback)
yield batch
else:
yield feedback

last_batch: List[T] = []
while not stream.consumed:
feedback = stream.try_read()
if feedback is not None:
last_batch.append(feedback)
if not greedy:
yield from last_batch
else:
yield last_batch
finally:
if stream is not None:
with self._lock:
Expand Down
50 changes: 50 additions & 0 deletions bdai_ros2_wrappers/test/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,62 @@
# Copyright (c) 2023 Boston Dynamics AI Institute Inc. All rights reserved.

import argparse
import contextlib
import itertools

import pytest

from bdai_ros2_wrappers.utilities import Tape, either_or, ensure, namespace_with


def test_tape_content_iteration() -> None:
tape: Tape[int] = Tape()
expected_sequence = list(range(10))
for i in expected_sequence:
tape.write(i)
assert list(tape.content()) == expected_sequence


def test_tape_content_destructive_iteration() -> None:
tape: Tape[int] = Tape()
expected_sequence = list(range(10))
for i in expected_sequence:
tape.write(i)
assert list(tape.content(expunge=True)) == expected_sequence
assert len(list(tape.content())) == 0


def test_tape_content_greedy_iteration() -> None:
tape: Tape[int] = Tape()
expected_sequence = list(range(10))
for i in expected_sequence:
tape.write(i)
assert tape.content(greedy=True) == expected_sequence


def test_tape_content_following() -> None:
tape: Tape[int] = Tape()
expected_sequence = list(range(10))
for i in expected_sequence:
tape.write(i)
with contextlib.closing(tape.content(follow=True)) as stream:
assert list(itertools.islice(stream, 10)) == expected_sequence
tape.write(10)
assert next(stream) == 10


def test_tape_content_greedy_following() -> None:
tape: Tape[int] = Tape()
expected_sequence = list(range(10))
for i in expected_sequence:
tape.write(i)
with contextlib.closing(tape.content(greedy=True, follow=True)) as stream:
assert next(stream) == expected_sequence
tape.write(10)
tape.write(20)
assert next(stream) == [10, 20]


def test_tape_drops_unused_streams() -> None:
tape: Tape[int] = Tape(max_length=0)

Expand Down
Loading