Skip to content

Commit

Permalink
Improve process-wide APIs type annotations (#113)
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <[email protected]>
  • Loading branch information
mhidalgo-bdai authored Aug 14, 2024
1 parent a7857a2 commit 9c42ca6
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 28 deletions.
86 changes: 80 additions & 6 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@

import bdai_ros2_wrappers.context as context
import bdai_ros2_wrappers.scope as scope
from bdai_ros2_wrappers.scope import AnyEntity, AnyEntityFactoryCallable, ROSAwareScope
from bdai_ros2_wrappers.scope import ROSAwareScope
from bdai_ros2_wrappers.tf_listener_wrapper import TFListenerWrapper
from bdai_ros2_wrappers.utilities import either_or

NodeT = typing.TypeVar("NodeT", bound=rclpy.node.Node)
NodeFactoryCallable = typing.Callable[..., NodeT]
GraphFactoryCallable = typing.Callable[..., typing.List[NodeT]]

MainCallableTakingArgs = typing.Callable[[argparse.Namespace], typing.Optional[int]]
MainCallableTakingArgv = typing.Callable[[typing.Sequence[str]], typing.Optional[int]]
MainCallableTakingNoArgs = typing.Callable[[], typing.Optional[int]]
Expand Down Expand Up @@ -239,7 +243,38 @@ def tf_listener() -> typing.Optional[TFListenerWrapper]:
return process.tf_listener


def load(factory: AnyEntityFactoryCallable, *args: typing.Any, **kwargs: typing.Any) -> AnyEntity:
@typing.overload
def load(factory: NodeFactoryCallable[NodeT], *args: typing.Any, **kwargs: typing.Any) -> NodeT:
"""Loads a ROS 2 node within the current ROS 2 aware process scope.
See `ROSAwareProcess` and `ROSAwareScope.load` documentation for further
reference on positional and keyword arguments taken by this function.
Raises:
RuntimeError: if no process is executing.
"""


@typing.overload
def load(factory: GraphFactoryCallable[NodeT], *args: typing.Any, **kwargs: typing.Any) -> typing.List[NodeT]:
"""Loads a ROS 2 graph within the current ROS 2 aware process scope.
See `ROSAwareProcess` and `ROSAwareScope.load` documentation for further
reference on positional and keyword arguments taken by this function.
Raises:
RuntimeError: if no process is executing.
"""


def load(
factory: typing.Union[
NodeFactoryCallable[NodeT],
GraphFactoryCallable[NodeT],
],
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.Union[NodeT, typing.List[NodeT]]:
"""Loads a ROS 2 node (or a collection thereof) within the current ROS 2 aware process scope.
See `ROSAwareProcess` and `ROSAwareScope.load` documentation for further
Expand All @@ -254,7 +289,7 @@ def load(factory: AnyEntityFactoryCallable, *args: typing.Any, **kwargs: typing.
return process.load(factory, *args, **kwargs)


def unload(loaded: AnyEntity) -> None:
def unload(loaded: typing.Union[rclpy.node.Node, typing.List[rclpy.node.Node]]) -> None:
"""Unloads a ROS 2 node (or a collection thereof) from the current ROS 2 aware process scope.
See `ROSAwareProcess` and `ROSAwareScope.unload` documentation for further
Expand All @@ -269,11 +304,46 @@ def unload(loaded: AnyEntity) -> None:
process.unload(loaded)


@typing.overload
def managed(
factory: AnyEntityFactoryCallable,
factory: NodeFactoryCallable[NodeT],
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.ContextManager[AnyEntity]:
) -> typing.ContextManager[NodeT]:
"""Manages a ROS 2 node within the current ROS 2 aware process scope.
See `ROSAwareProcess` and `ROSAwareScope.managed` documentation for further
reference on positional and keyword arguments taken by this function.
Raises:
RuntimeError: if no process is executing.
"""


@typing.overload
def managed(
factory: GraphFactoryCallable[NodeT],
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.ContextManager[typing.List[NodeT]]:
"""Manages a ROS 2 graph within the current ROS 2 aware process scope.
See `ROSAwareProcess` and `ROSAwareScope.managed` documentation for further
reference on positional and keyword arguments taken by this function.
Raises:
RuntimeError: if no process is executing.
"""


def managed(
factory: typing.Union[
NodeFactoryCallable[NodeT],
GraphFactoryCallable[NodeT],
],
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.Union[typing.ContextManager[NodeT], typing.ContextManager[typing.List[NodeT]]]:
"""Manages a ROS 2 node (or a collection thereof) within the current ROS 2 aware process scope.
See `ROSAwareProcess` and `ROSAwareScope.managed` documentation for further
Expand All @@ -288,7 +358,11 @@ def managed(
return process.managed(factory, *args, **kwargs)


def spin(factory: typing.Optional[AnyEntityFactoryCallable] = None, *args: typing.Any, **kwargs: typing.Any) -> None:
def spin(
factory: typing.Optional[typing.Union[NodeFactoryCallable, GraphFactoryCallable]] = None,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
"""Spins current ROS 2 aware process executor (and all ROS 2 nodes in it).
Optionally, manages a ROS 2 node (or a collection thereof) for as long as it spins.
Expand Down
115 changes: 93 additions & 22 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
from bdai_ros2_wrappers.tf_listener_wrapper import TFListenerWrapper
from bdai_ros2_wrappers.utilities import fqn, namespace_with

AnyEntity = typing.Union[rclpy.node.Node, typing.List[rclpy.node.Node]]
NodeFactoryCallable = typing.Callable[..., rclpy.node.Node]
GraphFactoryCallable = typing.Callable[..., typing.Iterable[rclpy.node.Node]]
AnyEntityFactoryCallable = typing.Union[NodeFactoryCallable, GraphFactoryCallable]
NodeT = typing.TypeVar("NodeT", bound=rclpy.node.Node)
NodeFactoryCallable = typing.Callable[..., NodeT]
GraphFactoryCallable = typing.Callable[..., typing.List[NodeT]]


class ROSAwareScope(typing.ContextManager["ROSAwareScope"]):
Expand Down Expand Up @@ -256,10 +255,10 @@ def node(self, node: rclpy.node.Node) -> None:
@typing.overload
def managed(
self,
factory: NodeFactoryCallable,
factory: NodeFactoryCallable[NodeT],
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.ContextManager[rclpy.node.Node]:
) -> typing.ContextManager[NodeT]:
"""Manages a ROS 2 node within scope.
Upon context entry, a ROS 2 node is instantiated and loaded.
Expand All @@ -275,10 +274,10 @@ def managed(
@typing.overload
def managed(
self,
factory: GraphFactoryCallable,
factory: GraphFactoryCallable[NodeT],
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.ContextManager[typing.List[rclpy.node.Node]]:
) -> typing.ContextManager[typing.List[NodeT]]:
"""Manages a collection (or graph) of ROS 2 nodes within scope.
Upon context entry, ROS 2 nodes are instantiated and loaded.
Expand All @@ -294,10 +293,10 @@ def managed(
@contextlib.contextmanager
def managed(
self,
factory: AnyEntityFactoryCallable,
factory: typing.Union[NodeFactoryCallable[NodeT], GraphFactoryCallable[NodeT]],
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.Iterator[AnyEntity]:
) -> typing.Union[typing.Iterator[NodeT], typing.Iterator[typing.List[NodeT]]]:
"""Overloaded method. See above for documentation."""
loaded = self.load(factory, *args, **kwargs)
try:
Expand All @@ -306,7 +305,7 @@ def managed(
self.unload(loaded)

@typing.overload
def load(self, factory: NodeFactoryCallable, *args: typing.Any, **kwargs: typing.Any) -> rclpy.node.Node:
def load(self, factory: NodeFactoryCallable[NodeT], *args: typing.Any, **kwargs: typing.Any) -> NodeT:
"""Instantiates and loads a ROS 2 node.
If a __post_init__ method is defined by the instantiated ROS 2 node, it will be invoked
Expand All @@ -331,10 +330,10 @@ def load(self, factory: NodeFactoryCallable, *args: typing.Any, **kwargs: typing
@typing.overload
def load(
self,
factory: GraphFactoryCallable,
factory: GraphFactoryCallable[NodeT],
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.List[rclpy.node.Node]:
) -> typing.List[NodeT]:
"""Instantiates and loads a collection (or graph) of ROS 2 nodes.
For each ROS 2 node instantiated, if a __post_init__ method is defined it will be invoked
Expand All @@ -357,11 +356,11 @@ def load(

def load(
self,
factory: AnyEntityFactoryCallable,
factory: typing.Union[NodeFactoryCallable[NodeT], GraphFactoryCallable[NodeT]],
*args: typing.Any,
namespace: typing.Optional[str] = None,
**kwargs: typing.Any,
) -> AnyEntity:
) -> typing.Union[NodeT, typing.List[NodeT]]:
"""Overloaded method. See above for documentation."""
with self._lock:
if self._stack is None:
Expand All @@ -385,7 +384,7 @@ def load(
self._graph.append(node)
return node

def unload(self, loaded: AnyEntity) -> None:
def unload(self, loaded: typing.Union[rclpy.node.Node, typing.List[rclpy.node.Node]]) -> None:
"""Unloads and destroys ROS 2 nodes.
Args:
Expand Down Expand Up @@ -459,7 +458,7 @@ def spin(self, factory: GraphFactoryCallable, *args: typing.Any, **kwargs: typin

def spin(
self,
factory: typing.Optional[AnyEntityFactoryCallable] = None,
factory: typing.Optional[typing.Union[NodeFactoryCallable, GraphFactoryCallable]] = None,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
Expand Down Expand Up @@ -553,7 +552,43 @@ def executor() -> typing.Optional[rclpy.executors.Executor]:
return scope.executor


def load(factory: AnyEntityFactoryCallable, *args: typing.Any, **kwargs: typing.Any) -> AnyEntity:
@typing.overload
def load(
factory: NodeFactoryCallable[NodeT],
*args: typing.Any,
**kwargs: typing.Any,
) -> NodeT:
"""Loads a ROS 2 node within the current ROS 2 aware scope.
See `ROSAwareScope.load` documentation for further reference on positional and keyword
arguments taken by this function.
Raises:
RuntimeError: if called outside scope.
"""


@typing.overload
def load(
factory: GraphFactoryCallable[NodeT],
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.List[NodeT]:
"""Loads a ROS 2 graph within the current ROS 2 aware scope.
See `ROSAwareScope.load` documentation for further reference on positional and keyword
arguments taken by this function.
Raises:
RuntimeError: if called outside scope.
"""


def load(
factory: typing.Union[NodeFactoryCallable[NodeT], GraphFactoryCallable[NodeT]],
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.Union[NodeT, typing.List[NodeT]]:
"""Loads a ROS 2 node (or a collection thereof) within the current ROS 2 aware scope.
See `ROSAwareScope.load` documentation for further reference on positional and keyword
Expand All @@ -568,7 +603,7 @@ def load(factory: AnyEntityFactoryCallable, *args: typing.Any, **kwargs: typing.
return scope.load(factory, *args, **kwargs)


def unload(loaded: AnyEntity) -> None:
def unload(loaded: typing.Union[rclpy.node.Node, typing.List[rclpy.node.Node]]) -> None:
"""Unloads a ROS 2 node (or a collection thereof) from the current ROS 2 aware scope.
See `ROSAwareScope.unload` documentation for further reference on positional and
Expand All @@ -583,11 +618,43 @@ def unload(loaded: AnyEntity) -> None:
scope.unload(loaded)


@typing.overload
def managed(
factory: NodeFactoryCallable[NodeT],
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.ContextManager[NodeT]:
"""Manages a ROS 2 node within the current ROS 2 aware scope.
See `ROSAwareScope.managed` documentation for further reference on positional and
keyword arguments taken by this function.
Raises:
RuntimeError: if called outside scope.
"""


@typing.overload
def managed(
factory: AnyEntityFactoryCallable,
factory: GraphFactoryCallable[NodeT],
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.ContextManager[AnyEntity]:
) -> typing.ContextManager[typing.List[NodeT]]:
"""Manages a ROS 2 graph within the current ROS 2 aware scope.
See `ROSAwareScope.managed` documentation for further reference on positional and
keyword arguments taken by this function.
Raises:
RuntimeError: if called outside scope.
"""


def managed(
factory: typing.Union[NodeFactoryCallable[NodeT], GraphFactoryCallable[NodeT]],
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.Union[typing.ContextManager[NodeT], typing.ContextManager[typing.List[NodeT]]]:
"""Manages a ROS 2 node (or a collection thereof) within the current ROS 2 aware scope.
See `ROSAwareScope.managed` documentation for further reference on positional and
Expand All @@ -602,7 +669,11 @@ def managed(
return scope.managed(factory, *args, **kwargs)


def spin(factory: typing.Optional[AnyEntityFactoryCallable] = None, *args: typing.Any, **kwargs: typing.Any) -> None:
def spin(
factory: typing.Optional[typing.Union[NodeFactoryCallable, GraphFactoryCallable]] = None,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
"""Spins current ROS 2 aware scope executor (and all the ROS 2 nodes in it).
Optionally, manages a ROS 2 node (or a collection thereof) for as long as it spins.
Expand Down

0 comments on commit 9c42ca6

Please sign in to comment.