Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
amessing-bdai committed Dec 6, 2023
1 parent 160438d commit 69a4fde
Show file tree
Hide file tree
Showing 30 changed files with 152 additions and 499 deletions.
22 changes: 5 additions & 17 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/action_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
class ActionClientWrapper(rclpy.action.ActionClient):
"""A wrapper for ros2's ActionClient for extra functionality"""

def __init__(
self, action_type: Type[Action], action_name: str, node: Optional[Node] = None
) -> None:
def __init__(self, action_type: Type[Action], action_name: str, node: Optional[Node] = None) -> None:
"""Constructor
Args:
Expand All @@ -24,9 +22,7 @@ def __init__(
"""
node = node or scope.node()
if node is None:
raise ValueError(
"no ROS 2 node available (did you use bdai_ros2_wrapper.process.main?)"
)
raise ValueError("no ROS 2 node available (did you use bdai_ros2_wrapper.process.main?)")
self._node = node
super().__init__(self._node, action_type, action_name)
self._node.get_logger().info(f"Waiting for action server for {action_name}")
Expand Down Expand Up @@ -64,9 +60,7 @@ def _on_failure() -> None:
nonlocal failed
failed = True

handle = self.send_goal_async_handle(
action_name=action_name, goal=goal, on_failure_callback=_on_failure
)
handle = self.send_goal_async_handle(action_name=action_name, goal=goal, on_failure_callback=_on_failure)
handle.set_on_cancel_success_callback(_on_cancel_succeeded)
if not handle.wait_for_result(timeout_sec=timeout_sec):
# If the action didn't fail and wasn't canceled then it timed out and should be canceled
Expand Down Expand Up @@ -102,11 +96,7 @@ def send_goal_async_handle(
Returns:
ActionHandle: An object to manage the asynchronous lifecycle of the action request
"""
handle = ActionHandle(
action_name=action_name,
logger=self._node.get_logger(),
context=self._node.context,
)
handle = ActionHandle(action_name=action_name, logger=self._node.get_logger(), context=self._node.context)
if result_callback is not None:
handle.set_result_callback(result_callback)

Expand All @@ -117,9 +107,7 @@ def send_goal_async_handle(
send_goal_future = self.send_goal_async(goal)
else:
handle.set_feedback_callback(feedback_callback)
send_goal_future = self.send_goal_async(
goal, feedback_callback=handle.get_feedback_callback
)
send_goal_future = self.send_goal_async(goal, feedback_callback=handle.get_feedback_callback)
handle.set_send_goal_future(send_goal_future)

return handle
26 changes: 6 additions & 20 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/action_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ class ActionHandle(object):
as holding the various callbacks for sending an ActionGoal (cancel, failure, feedback, result)
"""

def __init__(
self,
action_name: str,
logger: Optional[RcutilsLogger] = None,
context: Optional[Context] = None,
):
def __init__(self, action_name: str, logger: Optional[RcutilsLogger] = None, context: Optional[Context] = None):
"""Constructor
Args:
Expand Down Expand Up @@ -73,8 +68,7 @@ def wait_for_result(self, timeout_sec: Optional[float] = None) -> bool:
True if successful; False if the timeout was triggered or the action was rejected, cancelled, or abort
"""
return self._wait_for_result_event.wait(timeout=timeout_sec) and (
self._result is not None
and self._result.status == GoalStatus.STATUS_SUCCEEDED
self._result is not None and self._result.status == GoalStatus.STATUS_SUCCEEDED
)

def wait_for_acceptance(self, timeout_sec: Optional[float] = None) -> bool:
Expand All @@ -97,31 +91,23 @@ def set_send_goal_future(self, send_goal_future: Future) -> None:
self._send_goal_future = send_goal_future
self._send_goal_future.add_done_callback(self._goal_response_callback)

def set_feedback_callback(
self, feedback_callback: Callable[[Action.Feedback], None]
) -> None:
def set_feedback_callback(self, feedback_callback: Callable[[Action.Feedback], None]) -> None:
"""Sets the callback used to process feedback received while an Action is being executed"""
self._feedback_callback = feedback_callback

def set_result_callback(
self, result_callback: Callable[[Action.Result], None]
) -> None:
def set_result_callback(self, result_callback: Callable[[Action.Result], None]) -> None:
"""Sets the callback for processing the result from executing an Action"""
self._result_callback = result_callback

def set_on_failure_callback(self, on_failure_callback: Callable) -> None:
"""Set the callback to execute upon failure"""
self._on_failure_callback = on_failure_callback

def set_on_cancel_success_callback(
self, on_cancel_success_callback: Callable
) -> None:
def set_on_cancel_success_callback(self, on_cancel_success_callback: Callable) -> None:
"""Set the callback to execute upon successfully canceling the action"""
self._on_cancel_success_callback = on_cancel_success_callback

def set_on_cancel_failure_callback(
self, on_cancel_failure_callback: Callable
) -> None:
def set_on_cancel_failure_callback(self, on_cancel_failure_callback: Callable) -> None:
"""Set the callback to execute upon failing to cancel the action"""
self._on_cancel_failure_callback = on_cancel_failure_callback

Expand Down
4 changes: 1 addition & 3 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from rclpy.utilities import get_default_context


def wait_for_shutdown(
*, timeout_sec: Optional[float] = None, context: Optional[Context] = None
) -> bool:
def wait_for_shutdown(*, timeout_sec: Optional[float] = None, context: Optional[Context] = None) -> bool:
"""
Wait for context shutdown.
Expand Down
122 changes: 31 additions & 91 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ def __str__(self) -> str:
class Worker(threading.Thread):
"""A worker in its own daemonized OS thread."""

def __init__(
self, executor_weakref: weakref.ref, stop_on_timeout: bool = True
) -> None:
def __init__(self, executor_weakref: weakref.ref, stop_on_timeout: bool = True) -> None:
"""
Initializes the worker.
Expand Down Expand Up @@ -139,24 +137,18 @@ def run(self) -> None:
work: typing.Optional[AutoScalingThreadPool.Work] = None
while True:
if AutoScalingThreadPool._interpreter_shutdown:
self._logger.debug(
"Interpreter is shutting down! Terminating worker..."
)
self._logger.debug("Interpreter is shutting down! Terminating worker...")
self._runqueue.put(None)
break

executor: typing.Optional[
AutoScalingThreadPool
] = self._executor_weakref()
executor: typing.Optional[AutoScalingThreadPool] = self._executor_weakref()
if executor is None:
self._logger.debug("Executor is gone! Terminating worker...")
self._runqueue.put(None)
break

if executor._shutdown:
self._logger.debug(
"Executor is shutting down! Terminating worker..."
)
self._logger.debug("Executor is shutting down! Terminating worker...")
self._runqueue.put(None)
break

Expand Down Expand Up @@ -249,9 +241,7 @@ def __init__(
if max_workers <= 0:
raise ValueError("Maximum number of workers must be a positive number")
if max_workers < min_workers:
raise ValueError(
"Maximum number of workers must be larger than or equal to the minimum number of workers"
)
raise ValueError("Maximum number of workers must be larger than or equal to the minimum number of workers")
self._max_workers = max_workers

if max_idle_time is None:
Expand Down Expand Up @@ -282,45 +272,33 @@ def __init__(
self._shutdown_lock = threading.Lock()
self._scaling_event = threading.Condition()

self._waitqueues: typing.Dict[
typing.Callable[..., typing.Any], collections.deque
] = collections.defaultdict(collections.deque)
self._runlists: typing.Dict[
typing.Callable[..., typing.Any], typing.List[AutoScalingThreadPool.Work]
] = collections.defaultdict(list)
self._waitqueues: typing.Dict[typing.Callable[..., typing.Any], collections.deque] = collections.defaultdict(
collections.deque
)
self._runlists: typing.Dict[typing.Callable[..., typing.Any], typing.List[AutoScalingThreadPool.Work]] = (
collections.defaultdict(list)
)
self._runslots = threading.Semaphore(0)

runqueue: queue.SimpleQueue = queue.SimpleQueue()
self._weak_self = weakref.ref(self, lambda ref: runqueue.put(None))

with AutoScalingThreadPool._lock:
if AutoScalingThreadPool._interpreter_shutdown:
raise RuntimeError(
"cannot create thread pool while interpreter is shutting down"
)
raise RuntimeError("cannot create thread pool while interpreter is shutting down")
self._runqueue = runqueue
self._logger.debug(
"Registering runqueue for external wake up on interpreter shutdown..."
)
self._logger.debug("Registering runqueue for external wake up on interpreter shutdown...")
AutoScalingThreadPool._all_runqueues.add(runqueue)
self._logger.debug("Done registering runqueue")

self._workers: weakref.WeakSet[
AutoScalingThreadPool.Worker
] = weakref.WeakSet()
self._workers: weakref.WeakSet[AutoScalingThreadPool.Worker] = weakref.WeakSet()
if self._min_workers > 0:
with self._scaling_event:
self._logger.debug(
f"Pre-populating pool with {self._min_workers} workers"
)
self._logger.debug(f"Pre-populating pool with {self._min_workers} workers")
for _ in range(self._min_workers): # fire up stable worker pool
worker = AutoScalingThreadPool.Worker(
self._weak_self, stop_on_timeout=False
)
worker = AutoScalingThreadPool.Worker(self._weak_self, stop_on_timeout=False)
# register worker for external joining on interpreter shutdown
self._logger.debug(
"Registering worker for external joining on interpreter shutdown..."
)
self._logger.debug("Registering worker for external joining on interpreter shutdown...")
AutoScalingThreadPool._all_workers.add(worker)
self._logger.debug("Done registering worker")
self._logger.debug("Adding worker to the pool...")
Expand All @@ -345,25 +323,15 @@ def scaling_event(self) -> threading.Condition:
def working(self) -> bool:
"""Whether work is ongoing or not."""
with self._submit_lock:
return any(
work.pending()
for runlist in self._runlists.values()
for work in runlist
) or any(
work.pending()
for waitqueue in self._waitqueues.values()
for work in waitqueue
return any(work.pending() for runlist in self._runlists.values() for work in runlist) or any(
work.pending() for waitqueue in self._waitqueues.values() for work in waitqueue
)

@property
def capped(self) -> bool:
"""Whether submission quotas are in force or not."""
with self._submit_lock:
return any(
work.pending()
for waitqueue in self._waitqueues.values()
for work in waitqueue
)
return any(work.pending() for waitqueue in self._waitqueues.values() for work in waitqueue)

def wait(self, timeout: typing.Optional[float] = None) -> bool:
"""
Expand All @@ -379,14 +347,8 @@ def wait(self, timeout: typing.Optional[float] = None) -> bool:
True if all work completed, False if the wait timed out.
"""
with self._submit_lock:
futures = [
work.future for runlist in self._runlists.values() for work in runlist
]
futures += [
work.future
for waitqueue in self._waitqueues.values()
for work in waitqueue
]
futures = [work.future for runlist in self._runlists.values() for work in runlist]
futures += [work.future for waitqueue in self._waitqueues.values() for work in waitqueue]
done, not_done = concurrent.futures.wait(futures, timeout=timeout)
return len(not_done) == 0

Expand All @@ -395,9 +357,7 @@ def _cleanup_after(self, work: "AutoScalingThreadPool.Work") -> bool:
with self._submit_lock:
self._logger.debug(f"Cleaning up after work '{work}'")
self._runlists[work.fn].remove(work)
if (
work.fn in self._waitqueues and self._waitqueues[work.fn]
): # continue with pending work
if work.fn in self._waitqueues and self._waitqueues[work.fn]: # continue with pending work
self._logger.debug("Have similar work pending!")
self._logger.debug("Fetching pending work...")
work = self._waitqueues[work.fn].popleft()
Expand Down Expand Up @@ -443,11 +403,7 @@ def _do_submit(self, work: "AutoScalingThreadPool.Work") -> None:
# NOTE(hidmic): cannot recreate type signature for method override
# See https://github.com/python/typeshed/blob/main/stdlib/concurrent/futures/_base.pyi.
def submit( # type: ignore
self,
fn: typing.Callable[..., typing.Any],
/,
*args: typing.Any,
**kwargs: typing.Any,
self, fn: typing.Callable[..., typing.Any], /, *args: typing.Any, **kwargs: typing.Any
) -> concurrent.futures.Future:
"""
Submits work to the pool.
Expand All @@ -473,9 +429,7 @@ def submit( # type: ignore
f"Submitting work '{work}'...",
)
if self._submission_quota > len(self._runlists[work.fn]):
if (
work.fn in self._waitqueues and self._waitqueues[work.fn]
): # prioritize pending work
if work.fn in self._waitqueues and self._waitqueues[work.fn]: # prioritize pending work
self._logger.debug("Have similar work pending")
self._logger.debug(f"Work '{work}' put to wait", work)
self._waitqueues[work.fn].append(work)
Expand Down Expand Up @@ -559,11 +513,7 @@ class AutoScalingMultiThreadedExecutor(rclpy.executors.Executor):
class Task:
"""A bundle of an executable task and its associated entity."""

def __init__(
self,
task: rclpy.task.Task,
entity: typing.Optional[rclpy.executors.WaitableEntityType],
) -> None:
def __init__(self, task: rclpy.task.Task, entity: typing.Optional[rclpy.executors.WaitableEntityType]) -> None:
self.task = task
self.entity = entity
self.callback_group = entity.callback_group if entity is not None else None
Expand Down Expand Up @@ -627,9 +577,7 @@ def __init__(
submission_quota=max_threads_per_callback_group,
logger=logger,
)
self._wip: typing.Dict[
AutoScalingMultiThreadedExecutor.Task, concurrent.futures.Future
] = {}
self._wip: typing.Dict[AutoScalingMultiThreadedExecutor.Task, concurrent.futures.Future] = {}

@property
def thread_pool(self) -> AutoScalingThreadPool:
Expand Down Expand Up @@ -657,9 +605,7 @@ def _do_spin_once(self, *args: typing.Any, **kwargs: typing.Any) -> None:
# submission future is done. That is, a task could be legitimately ready for
# dispatch and be missed. Fortunately, this will only delay dispatch until the
# next spin cycle.
if task not in self._wip or (
self._wip[task].done() and not task.done()
):
if task not in self._wip or (self._wip[task].done() and not task.done()):
self._wip[task] = self._thread_pool.submit(task)
for task in list(self._wip):
if task.done():
Expand Down Expand Up @@ -709,9 +655,7 @@ def shutdown(self, timeout_sec: typing.Optional[float] = None) -> bool:


@contextlib.contextmanager
def background(
executor: rclpy.executors.Executor,
) -> typing.Iterator[rclpy.executors.Executor]:
def background(executor: rclpy.executors.Executor) -> typing.Iterator[rclpy.executors.Executor]:
"""
Pushes an executor to a background thread.
Expand All @@ -727,9 +671,7 @@ def background(
background_thread = threading.Thread(target=executor.spin)
executor.spin = bind_to_thread(executor.spin, background_thread)
executor.spin_once = bind_to_thread(executor.spin_once, background_thread)
executor.spin_until_future_complete = bind_to_thread(
executor.spin_until_future_complete, background_thread
)
executor.spin_until_future_complete = bind_to_thread(executor.spin_until_future_complete, background_thread)
executor.spin_once_until_future_complete = bind_to_thread(
executor.spin_once_until_future_complete, background_thread
)
Expand All @@ -742,9 +684,7 @@ def background(


@contextlib.contextmanager
def foreground(
executor: rclpy.executors.Executor,
) -> typing.Iterator[rclpy.executors.Executor]:
def foreground(executor: rclpy.executors.Executor) -> typing.Iterator[rclpy.executors.Executor]:
"""
Manages an executor in the current thread.
Expand Down
7 changes: 1 addition & 6 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
from rclpy.utilities import get_default_context


def wait_for_future(
future: Future,
timeout_sec: Optional[float] = None,
*,
context: Optional[Context] = None
) -> bool:
def wait_for_future(future: Future, timeout_sec: Optional[float] = None, *, context: Optional[Context] = None) -> bool:
"""Blocks while waiting for a future to become done
Args:
Expand Down
Loading

0 comments on commit 69a4fde

Please sign in to comment.