diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/callables.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/callables.py index 8ff8181..1c2cd5a 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/callables.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/callables.py @@ -2,10 +2,12 @@ import abc import functools -from typing import Any, Callable, Iterable, Optional, Tuple +import inspect +from typing import Any, Callable, Iterable, Literal, Optional, Tuple, Type, Union, overload from rclpy.task import Future +from bdai_ros2_wrappers.executors import assign_coroutine from bdai_ros2_wrappers.futures import AnyFuture, FutureLike, as_proper_future from bdai_ros2_wrappers.utilities import take_kwargs @@ -163,3 +165,207 @@ def asynchronous(self, *args: Any, **kwargs: Any) -> Any: if self.starred: return self.wrapped_callable.asynchronous(*self.composed_callable(*args, **inner_kwargs), **outer_kwargs) return self.wrapped_callable.asynchronous(self.composed_callable(*args, **inner_kwargs), **outer_kwargs) + + +class GeneralizedFunction(GeneralizedCallable): + """A generalized callable defined by parts.""" + + def __init__( + self, + synchronous_callable: Optional[Callable] = None, + asynchronous_callable: Optional[Callable] = None, + ) -> None: + """Initialize generalized function. + + Args: + synchronous_callable: optional synchronous body. + asynchronous_callable: optional asynchronous body. + """ + self._synchronous_callable = synchronous_callable + self._asynchronous_callable = asynchronous_callable + + def synchronous(self, *args: Any, **kwargs: Any) -> Any: + """Invoke function synchronously (ie. potentially blocking).""" + if self._synchronous_callable is None: + raise NotImplementedError("synchronous invocation is not supported") + return self._synchronous_callable(*args, **kwargs) + + def asynchronous(self, *args: Any, **kwargs: Any) -> Any: + """Invoke function asynchronously, returning a future-like object.""" + if self._asynchronous_callable is None: + raise NotImplementedError("asynchronous invocation is not supported") + return self._asynchronous_callable(*args, **kwargs) + + +class GeneralizedMethod: + """A data descriptor for generalized callables bound to class instances.""" + + class Unbound: + """An unbound generalized method descriptor.""" + + def __init__(self, method: "GeneralizedMethod") -> None: + """Initialize unbound descriptor. + + Args: + method: associated generalized method. + """ + self.synchronous_callable: Optional[Callable] = None + self.asynchronous_callable: Optional[Callable] = None + if not method.transitional: + if inspect.iscoroutinefunction(method.prototype): + self.asynchronous_callable = method.prototype + else: + self.synchronous_callable = method.prototype + if method.synchronous_overload is not None: + self.synchronous_callable = method.synchronous_overload + if method.asynchronous_overload is not None: + self.asynchronous_callable = method.asynchronous_overload + + self.default_callable: Optional[Callable] = None + if not method.transitional: + if self.synchronous_callable is not None: + self.default_callable = self.synchronous_callable + else: + self.default_callable = self.asynchronous_callable + else: + self.default_callable = method.prototype + + def __get__( + self, + instance: Optional[Any], + owner: Optional[Type] = None, + ) -> Union["GeneralizedMethod.Unbound", "GeneralizedMethod.Bound"]: + if instance is None: + return self + synchronous_callable = self.synchronous_callable + if synchronous_callable is not None: + synchronous_callable = synchronous_callable.__get__(instance, owner) + assert synchronous_callable is not None + asynchronous_callable = self.asynchronous_callable + if asynchronous_callable is not None: + asynchronous_callable = asynchronous_callable.__get__(instance, owner) + assert asynchronous_callable is not None + if inspect.iscoroutinefunction(self.asynchronous_callable): + asynchronous_callable = assign_coroutine(asynchronous_callable, instance.executor) + default_callable = self.default_callable + if default_callable is not None: + default_callable = default_callable.__get__(instance, owner) + assert default_callable is not None + if inspect.iscoroutinefunction(self.default_callable): + default_callable = assign_coroutine(default_callable, instance.executor) + implementation = GeneralizedFunction(synchronous_callable, asynchronous_callable) + return GeneralizedMethod.Bound(implementation, default_callable) + + class Bound(VectorizingCallable, ComposableCallable): + """A bound generalized method callable.""" + + def __init__(self, body: GeneralizedCallable, default_callable: Optional[Callable] = None) -> None: + """Initialize bound method callable. + + Args: + body: method body as a generalized callable + default_callable: optionally override default plain calls, defaults to synchronous calls. + """ + self.body = body + if default_callable is None: + default_callable = body.synchronous + self._default_callable = default_callable + + def __getattr__(self, name: str) -> Any: + return getattr(self.body, name) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Invoke method (optionally pre-existing).""" + return self._default_callable(*args, **kwargs) + + def synchronous(self, *args: Any, **kwargs: Any) -> Any: + """Invoke method synchronously.""" + return self.body.synchronous(*args, **kwargs) + + def asynchronous(self, *args: Any, **kwargs: Any) -> Any: + """Invoke method asynchronously.""" + return self.body.asynchronous(*args, **kwargs) + + def __init__(self, prototype: Callable, transitional: bool) -> None: + """Initializes the generalized method. + + Args: + prototype: method prototype, usually just a signature but + may also be used as an overload for convenience (iff the + function type matches the missing overload). + transitional: a transitional method will stick to its + prototype for default invocations, simplifying the + adoption of generalized methods in existing codebases. + """ + self.prototype = prototype + self.transitional = transitional + self.synchronous_overload: Optional[Callable] = None + self.asynchronous_overload: Optional[Callable] = None + + def sync_overload(self, func: Callable) -> Callable: + """Register `func` as this method synchronous overload.""" + if self.synchronous_overload is not None: + raise RuntimeError("cannot redefine synchronous overload") + self.synchronous_overload = func + return func + + def async_overload(self, func: Callable) -> Callable: + """Register `func` as this method asynchronous overload.""" + if self.asynchronous_overload is not None: + raise RuntimeError("cannot redefine asynchronous overload") + self.asynchronous_overload = func + return func + + def __set_name__(self, owner: Type, name: str) -> None: + self.__attribute_name = f"__{name}_method" + setattr(owner, self.__attribute_name, GeneralizedMethod.Unbound(self)) + + def rebind(self, instance: Any, body: GeneralizedCallable) -> None: + """Change this method's `body` for the given `instance`.""" + default_callable: Optional[Callable] = None + if self.transitional: + default_callable = self.prototype.__get__(instance) + bound_method = GeneralizedMethod.Bound(body, default_callable) + setattr(instance, self.__attribute_name, bound_method) + + @overload + def __get__(self, instance: Literal[None], owner: Optional[Type] = ...) -> "GeneralizedMethod": + ... + + @overload + def __get__(self, instance: Any, owner: Optional[Type] = ...) -> "GeneralizedMethod.Bound": + ... + + def __get__( + self, + instance: Optional[Any], + owner: Optional[Type] = None, + ) -> Union["GeneralizedMethod", "GeneralizedMethod.Bound"]: + if instance is None: + return self + return getattr(instance, self.__attribute_name) + + +@overload +def generalized_method(func: Callable, *, transitional: bool = ...) -> GeneralizedMethod: + ... + + +@overload +def generalized_method(*, transitional: bool = ...) -> Callable: + ... + + +def generalized_method( + func: Optional[Callable] = None, + *, + transitional: bool = False, +) -> Union[Callable, GeneralizedMethod]: + """Define a generalized method by decoration.""" + + def _decorator(func: Callable) -> GeneralizedMethod: + return GeneralizedMethod(func, transitional) + + if func is None: + return _decorator + return _decorator(func) diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py index 43809fb..2d2ad7c 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py @@ -4,6 +4,7 @@ import concurrent.futures import contextlib import dataclasses +import functools import inspect import logging import os @@ -14,6 +15,7 @@ import rclpy.executors +from bdai_ros2_wrappers.futures import FutureLike from bdai_ros2_wrappers.utilities import bind_to_thread, fqn @@ -728,3 +730,20 @@ def foreground(executor: rclpy.executors.Executor) -> typing.Iterator[rclpy.exec yield executor finally: executor.shutdown() + + +def assign_coroutine( + coroutine: typing.Callable[..., typing.Awaitable], + executor: rclpy.executors.Executor, +) -> typing.Callable[..., FutureLike]: + """Assign a `coroutine` to a given `executor`. + + An assigned coroutine will return a future-like object + that will be serviced by the associated executor. + """ + + @functools.wraps(coroutine) + def __wrapper(*args: typing.Any, **kwargs: typing.Any) -> FutureLike: + return executor.create_task(coroutine, *args, **kwargs) + + return __wrapper diff --git a/bdai_ros2_wrappers/test/test_callables.py b/bdai_ros2_wrappers/test/test_callables.py new file mode 100644 index 0000000..077c94a --- /dev/null +++ b/bdai_ros2_wrappers/test/test_callables.py @@ -0,0 +1,184 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved. + +from collections.abc import Mapping +from typing import Any, Optional + +import pytest + +from bdai_ros2_wrappers.callables import generalized_method +from bdai_ros2_wrappers.futures import wait_for_future +from bdai_ros2_wrappers.scope import ROSAwareScope + + +class Bucket: + def __init__( + self, + ros: ROSAwareScope, + storage: Optional[Mapping] = None, + ) -> None: + self.executor = ros.executor + if storage is None: + storage = {} + self._storage = dict(storage) + + @generalized_method(transitional=True) + def create(self, content: Any) -> str: + name = str(hash(content)) + if name in self._storage: + raise RuntimeError() + self._storage[name] = content + return name + + @create.sync_overload + def _create_sync(self, name: str, content: Any) -> bool: + if name in self._storage: + return False + self._storage[name] = content + return True + + @create.async_overload + async def _create_async(self, name: str, content: Any) -> bool: + return self._create_sync(name, content) + + @generalized_method + def read(self, name: str) -> Optional[Any]: + return self._storage.get(name) + + @read.async_overload + async def _read_async(self, name: str) -> Optional[Any]: + return self.read(name) + + @generalized_method + async def update(self, name: str, content: Any) -> bool: + if name not in self._storage: + return False + self._storage[name] = content + return True + + @generalized_method + def delete(self, name: str) -> bool: + if name not in self._storage: + return False + del self._storage[name] + return True + + +def test_transitional_method(ros: ROSAwareScope) -> None: + bucket = Bucket(ros) + name = bucket.create("some data") + assert name in bucket._storage + assert bucket._storage[name] == "some data" + + with pytest.raises(RuntimeError): + bucket.create("some data") + + assert not bucket.create.synchronously(name, "some other data") + assert bucket.create.synchronously("my-data", "some other data") + assert "my-data" in bucket._storage + assert bucket._storage["my-data"] == "some other data" + + future = bucket.create.asynchronously("my-data", "more data") + assert wait_for_future(future, timeout_sec=5.0) + assert future.result() is False + + future = bucket.create.asynchronously("extras", "more data") + assert wait_for_future(future, timeout_sec=5.0) + assert future.result() is True + assert "extras" in bucket._storage + assert bucket._storage["extras"] == "more data" + + +def test_nominal_method(ros: ROSAwareScope) -> None: + bucket = Bucket(ros, {"my-data": "some data"}) + + assert bucket.read("my-data") == "some data" + assert bucket.read.synchronously("my-data") == "some data" + future = bucket.read.asynchronously("my-data") + assert wait_for_future(future, timeout_sec=5.0) + assert future.result() == "some data" + + assert not bucket.read("other-data") + assert not bucket.read.synchronously("other-data") + future = bucket.read.asynchronously("other-data") + assert wait_for_future(future, timeout_sec=5.0) + assert future.result() is None + + +def test_sync_only_method(ros: ROSAwareScope) -> None: + bucket = Bucket( + ros, + { + "my-data": "some data", + "extras": "more data", + "old": "old data", + }, + ) + assert bucket.delete("my-data") + assert "my-data" not in bucket._storage + assert not bucket.delete("my-data") + assert bucket.delete.synchronously("extras") + assert "extras" not in bucket._storage + assert not bucket.delete.synchronously("extras") + with pytest.raises(NotImplementedError): + bucket.delete.asynchronously("old") + assert "old" in bucket._storage + + +def test_async_only_method(ros: ROSAwareScope) -> None: + bucket = Bucket(ros, {"my-data": "some data"}) + future = bucket.update("my-data", "new data") + assert wait_for_future(future, timeout_sec=5.0) + assert future.result() is True + assert bucket._storage["my-data"] == "new data" + + future = bucket.update.asynchronously("my-data", "newer data") + assert wait_for_future(future, timeout_sec=5.0) + assert future.result() is True + assert bucket._storage["my-data"] == "newer data" + + with pytest.raises(NotImplementedError): + bucket.update.synchronously("my-data", "") + assert bucket._storage["my-data"] == "newer data" + + +def test_vectorized_method(ros: ROSAwareScope) -> None: + bucket = Bucket( + ros, + { + "my-data": "some data", + "extras": "more data", + }, + ) + data = bucket.read.vectorized(["my-data", "extras"]) + assert data == ["some data", "more data"] + + data = bucket.read.vectorized.synchronously(["my-data", "extras"]) + assert data == ["some data", "more data"] + + future = bucket.read.vectorized.asynchronously(["my-data", "extras"]) + assert wait_for_future(future, timeout_sec=5.0) + assert future.result() == ["some data", "more data"] + + +def test_composed_method(ros: ROSAwareScope) -> None: + bucket = Bucket(ros) + Bucket.create.rebind( + bucket, + bucket.create.compose( + (lambda name, *data: (name, data)), + starred=True, + ), + ) + name = bucket.create("some data") + assert name in bucket._storage + assert bucket._storage[name] == "some data" + + assert bucket.create.synchronously("my-data", "some other data", 1, True) + assert "my-data" in bucket._storage + assert bucket._storage["my-data"] == ("some other data", 1, True) + + future = bucket.create.asynchronously("extras", 0, "more data", False) + assert wait_for_future(future, timeout_sec=5.0) + assert future.result() is True + assert "extras" in bucket._storage + assert bucket._storage["extras"] == (0, "more data", False)