From dd8a09f2113eb371ed93d087d7705a6a53c1719c Mon Sep 17 00:00:00 2001 From: Michel Hidalgo Date: Fri, 27 Sep 2024 17:09:52 -0300 Subject: [PATCH 1/2] Add support for greedy (or batched) message streaming Signed-off-by: Michel Hidalgo --- .../bdai_ros2_wrappers/feeds.py | 50 ++++++- .../bdai_ros2_wrappers/utilities.py | 127 +++++++++++++++--- bdai_ros2_wrappers/test/test_utilities.py | 50 +++++++ 3 files changed, 203 insertions(+), 24 deletions(-) diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py index a2f213c..9cdab12 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py @@ -1,6 +1,6 @@ # Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved. -from typing import Any, Callable, Generic, Iterable, Iterator, List, Optional, TypeVar +from typing import Any, Callable, Generator, Generic, Iterable, List, Literal, Optional, TypeVar, Union, overload import tf2_ros from rclpy.node import Node @@ -47,7 +47,7 @@ def link(self) -> Filter: @property def history(self) -> List[MessageT]: """Gets the entire history of messages received so far.""" - return list(self._tape.content()) + return self._tape.content(greedy=True) @property def latest(self) -> Optional[MessageT]: @@ -80,29 +80,71 @@ def recall(self, callback: Callable[[MessageT], None]) -> Tunnel: tunnel.registerCallback(callback) return tunnel + @overload def stream( self, *, forward_only: bool = False, + expunge: bool = False, buffer_size: Optional[int] = None, timeout_sec: Optional[float] = None, - ) -> Iterator[MessageT]: + ) -> Generator[MessageT, None, None]: + """Overload for plain streaming.""" + + @overload + def stream( + self, + *, + greedy: Literal[True], + forward_only: bool = False, + expunge: bool = False, + buffer_size: Optional[int] = None, + timeout_sec: Optional[float] = None, + ) -> Generator[List[MessageT], None, None]: + """Overload for greedy, batched streaming.""" + + def stream( + self, + *, + greedy: bool = False, + forward_only: bool = False, + expunge: bool = False, + buffer_size: Optional[int] = None, + timeout_sec: Optional[float] = None, + ) -> Generator[Union[MessageT, List[MessageT]], None, None]: """Iterates over messages as they come. Iteration stops when the given timeout expires or when the associated context is shutdown. Note that iterating over the message stream is a blocking operation. Args: + greedy: if true, greedily batch messages as it becomes available. forward_only: whether to ignore previosuly received messages. + expunge: if true, wipe out the message history after reading + if it applies (i.e. non-forward only streams). buffer_size: optional maximum size for the incoming messages buffer. If none is provided, the buffer will be grow unbounded. timeout_sec: optional timeout, in seconds, for a new message to be received. Returns: - a lazy iterator over messages. + a lazy iterator over messages, one message at a time or in batches if greedy. + + Raises: + TimeoutError: if streaming times out waiting for a new message. """ + if greedy: + # use boolean literals to help mypy + return self._tape.content( + follow=True, + greedy=True, + expunge=expunge, + forward_only=forward_only, + buffer_size=buffer_size, + timeout_sec=timeout_sec, + ) return self._tape.content( follow=True, + expunge=expunge, forward_only=forward_only, buffer_size=buffer_size, timeout_sec=timeout_sec, diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py index 7233c25..5f06f42 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py @@ -9,7 +9,7 @@ import warnings import weakref from collections.abc import Mapping, MutableSet -from typing import Any, Callable, Generic, Iterator, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Generator, Generic, List, Literal, Optional, Tuple, TypeVar, Union, overload import rclpy.clock import rclpy.duration @@ -109,22 +109,42 @@ def write(self, data: U) -> bool: return False return True - def read(self, timeout_sec: Optional[float] = None) -> Optional[U]: + def try_read(self) -> Optional[U]: + """Try to read data from the stream. + + Returns: + data if the read is successful and ``None`` + if there is nothing to be read or the stream + is interrupted. + """ + try: + data = self._queue.get_nowait() + self._queue.task_done() + except queue.Empty: + return None + return data + + def read(self, timeout_sec: Optional[float] = None) -> U: """Read data from the stream. Args: timeout_sec: optional read timeout, in seconds. Returns: - data if the read is successful and ``None`` + data read + + Raises: + InterruptedError if the read if the read is successful and ``None`` if the read times out or is interrupted. """ try: data = self._queue.get(timeout=timeout_sec) - except queue.Empty: - return None - self._queue.task_done() - return data + self._queue.task_done() + if data is None: + raise InterruptedError() + return data + except queue.Empty as e: + raise TimeoutError() from e def interrupt(self) -> None: """Interrupt the stream and wake up the reader.""" @@ -216,61 +236,128 @@ def head(self) -> Optional[T]: return None return self._content[0] + @overload def content( self, *, + follow: bool = ..., + forward_only: bool = ..., + expunge: bool = ..., + buffer_size: Optional[int] = ..., + timeout_sec: Optional[float] = ..., + label: Optional[str] = ..., + ) -> Generator[T, None, None]: + """Overload for non-greedy iteration.""" + + @overload + def content( + self, + *, + greedy: Literal[True], + follow: Literal[True], + forward_only: bool = ..., + expunge: bool = ..., + buffer_size: Optional[int] = ..., + timeout_sec: Optional[float] = ..., + label: Optional[str] = ..., + ) -> Generator[List[T], None, None]: + """Overload for greedy batched iteration.""" + + @overload + def content( + self, + *, + greedy: Literal[True], + expunge: bool = ..., + buffer_size: Optional[int] = ..., + timeout_sec: Optional[float] = ..., + label: Optional[str] = ..., + ) -> List[T]: + """Overload for greedy full reads.""" + + def content( + self, + *, + greedy: bool = False, follow: bool = False, forward_only: bool = False, + expunge: bool = False, buffer_size: Optional[int] = None, timeout_sec: Optional[float] = None, label: Optional[str] = None, - ) -> Iterator[T]: + ) -> Union[Generator[Union[T, List[T]], None, None], List[T]]: """Iterate over the data tape. When following the data tape, iteration stops when the given timeout expires and when the data tape is closed. Args: + greedy: if true, greedily batch content as it becomes available. follow: whether to follow the data tape as it gets written or not. forward_only: if true, ignore existing content and only look ahead when following the data tape. + expunge: if true, wipe out existing content in the data tape after + reading if it applies (i.e. non-forward only iterations). buffer_size: optional buffer size when following the data tape. If none is provided, the buffer will grow as necessary. timeout_sec: optional timeout, in seconds, when following the data tape. label: optional label to qualify logs and warnings. Returns: - a lazy iterator over the data tape. + a lazy iterator over the data tape, one item at a time or in batches if greedy. + + Raises: + TimeoutError: if iteration times out waiting for new data. """ # Here we split the generator in two, so that setup code is executed eagerly. with self._lock: content: Optional[collections.deque] = None if not forward_only and self._content is not None: content = self._content.copy() + if self._content is not None and expunge: + self._content.clear() stream: Optional[Tape.Stream] = None if follow and not self._closed: stream = Tape.Stream(buffer_size, label) self._streams.add(stream) - def _generator() -> Iterator: + if content is not None and stream is None and greedy: + return list(content) + + def _generator() -> Generator[Union[T, List[T]], None, None]: nonlocal content, stream try: if content is not None: - yield from content + if greedy: + yield list(content) + else: + yield from content + if stream is not None: while not self._closed: - feedback = stream.read(timeout_sec) - if feedback is None: - break - yield feedback - while not stream.consumed: - # This is safe as long as there is - # a single reader for the stream, - # which is currently the case. feedback = stream.read(timeout_sec) if feedback is None: continue - yield feedback + if greedy: + batch = [feedback] + while True: + feedback = stream.try_read() + if feedback is None: + break + batch.append(feedback) + yield batch + else: + yield feedback + + last_batch: List[T] = [] + while not stream.consumed: + feedback = stream.try_read() + if feedback is not None: + last_batch.append(feedback) + if not greedy: + yield from last_batch + else: + yield last_batch finally: if stream is not None: with self._lock: diff --git a/bdai_ros2_wrappers/test/test_utilities.py b/bdai_ros2_wrappers/test/test_utilities.py index faa2f57..671d3c5 100644 --- a/bdai_ros2_wrappers/test/test_utilities.py +++ b/bdai_ros2_wrappers/test/test_utilities.py @@ -1,12 +1,62 @@ # Copyright (c) 2023 Boston Dynamics AI Institute Inc. All rights reserved. import argparse +import contextlib +import itertools import pytest from bdai_ros2_wrappers.utilities import Tape, either_or, ensure, namespace_with +def test_tape_content_iteration() -> None: + tape: Tape[int] = Tape() + expected_sequence = list(range(10)) + for i in expected_sequence: + tape.write(i) + assert list(tape.content()) == expected_sequence + + +def test_tape_content_destructive_iteration() -> None: + tape: Tape[int] = Tape() + expected_sequence = list(range(10)) + for i in expected_sequence: + tape.write(i) + assert list(tape.content(expunge=True)) == expected_sequence + assert len(list(tape.content())) == 0 + + +def test_tape_content_greedy_iteration() -> None: + tape: Tape[int] = Tape() + expected_sequence = list(range(10)) + for i in expected_sequence: + tape.write(i) + assert tape.content(greedy=True) == expected_sequence + + +def test_tape_content_following() -> None: + tape: Tape[int] = Tape() + expected_sequence = list(range(10)) + for i in expected_sequence: + tape.write(i) + with contextlib.closing(tape.content(follow=True)) as stream: + assert list(itertools.islice(stream, 10)) == expected_sequence + tape.write(10) + assert next(stream) == 10 + + +def test_tape_content_greedy_following() -> None: + tape: Tape[int] = Tape() + expected_sequence = list(range(10)) + for i in expected_sequence: + tape.write(i) + with contextlib.closing(tape.content(greedy=True, follow=True)) as stream: + assert next(stream) == expected_sequence + tape.write(10) + tape.write(20) + assert next(stream) == [10, 20] + + def test_tape_drops_unused_streams() -> None: tape: Tape[int] = Tape(max_length=0) From fdbee4c0f7e5e18a86a2e1e29adc80139a401996 Mon Sep 17 00:00:00 2001 From: Michel Hidalgo Date: Fri, 27 Sep 2024 17:36:08 -0300 Subject: [PATCH 2/2] Remove bad InterruptedError logic Signed-off-by: Michel Hidalgo --- .../bdai_ros2_wrappers/utilities.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py index 5f06f42..5810714 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py @@ -113,7 +113,7 @@ def try_read(self) -> Optional[U]: """Try to read data from the stream. Returns: - data if the read is successful and ``None`` + data if the read is successful, and ``None`` if there is nothing to be read or the stream is interrupted. """ @@ -124,27 +124,25 @@ def try_read(self) -> Optional[U]: return None return data - def read(self, timeout_sec: Optional[float] = None) -> U: + def read(self, timeout_sec: Optional[float] = None) -> Optional[U]: """Read data from the stream. Args: timeout_sec: optional read timeout, in seconds. Returns: - data read + data if the read is successful, and ``None`` + if the stream is interrupted. Raises: - InterruptedError if the read if the read is successful and ``None`` - if the read times out or is interrupted. + TImeoutError if the read times out. """ try: data = self._queue.get(timeout=timeout_sec) - self._queue.task_done() - if data is None: - raise InterruptedError() - return data except queue.Empty as e: raise TimeoutError() from e + self._queue.task_done() + return data def interrupt(self) -> None: """Interrupt the stream and wake up the reader."""