Skip to content

Commit

Permalink
Add generalized methods (#93)
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <[email protected]>
  • Loading branch information
mhidalgo-bdai authored May 16, 2024
1 parent 75ea457 commit 8de80b5
Show file tree
Hide file tree
Showing 3 changed files with 410 additions and 1 deletion.
208 changes: 207 additions & 1 deletion bdai_ros2_wrappers/bdai_ros2_wrappers/callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import abc
import functools
from typing import Any, Callable, Iterable, Optional, Tuple
import inspect
from typing import Any, Callable, Iterable, Literal, Optional, Tuple, Type, Union, overload

from rclpy.task import Future

from bdai_ros2_wrappers.executors import assign_coroutine
from bdai_ros2_wrappers.futures import AnyFuture, FutureLike, as_proper_future
from bdai_ros2_wrappers.utilities import take_kwargs

Expand Down Expand Up @@ -163,3 +165,207 @@ def asynchronous(self, *args: Any, **kwargs: Any) -> Any:
if self.starred:
return self.wrapped_callable.asynchronous(*self.composed_callable(*args, **inner_kwargs), **outer_kwargs)
return self.wrapped_callable.asynchronous(self.composed_callable(*args, **inner_kwargs), **outer_kwargs)


class GeneralizedFunction(GeneralizedCallable):
"""A generalized callable defined by parts."""

def __init__(
self,
synchronous_callable: Optional[Callable] = None,
asynchronous_callable: Optional[Callable] = None,
) -> None:
"""Initialize generalized function.
Args:
synchronous_callable: optional synchronous body.
asynchronous_callable: optional asynchronous body.
"""
self._synchronous_callable = synchronous_callable
self._asynchronous_callable = asynchronous_callable

def synchronous(self, *args: Any, **kwargs: Any) -> Any:
"""Invoke function synchronously (ie. potentially blocking)."""
if self._synchronous_callable is None:
raise NotImplementedError("synchronous invocation is not supported")
return self._synchronous_callable(*args, **kwargs)

def asynchronous(self, *args: Any, **kwargs: Any) -> Any:
"""Invoke function asynchronously, returning a future-like object."""
if self._asynchronous_callable is None:
raise NotImplementedError("asynchronous invocation is not supported")
return self._asynchronous_callable(*args, **kwargs)


class GeneralizedMethod:
"""A data descriptor for generalized callables bound to class instances."""

class Unbound:
"""An unbound generalized method descriptor."""

def __init__(self, method: "GeneralizedMethod") -> None:
"""Initialize unbound descriptor.
Args:
method: associated generalized method.
"""
self.synchronous_callable: Optional[Callable] = None
self.asynchronous_callable: Optional[Callable] = None
if not method.transitional:
if inspect.iscoroutinefunction(method.prototype):
self.asynchronous_callable = method.prototype
else:
self.synchronous_callable = method.prototype
if method.synchronous_overload is not None:
self.synchronous_callable = method.synchronous_overload
if method.asynchronous_overload is not None:
self.asynchronous_callable = method.asynchronous_overload

self.default_callable: Optional[Callable] = None
if not method.transitional:
if self.synchronous_callable is not None:
self.default_callable = self.synchronous_callable
else:
self.default_callable = self.asynchronous_callable
else:
self.default_callable = method.prototype

def __get__(
self,
instance: Optional[Any],
owner: Optional[Type] = None,
) -> Union["GeneralizedMethod.Unbound", "GeneralizedMethod.Bound"]:
if instance is None:
return self
synchronous_callable = self.synchronous_callable
if synchronous_callable is not None:
synchronous_callable = synchronous_callable.__get__(instance, owner)
assert synchronous_callable is not None
asynchronous_callable = self.asynchronous_callable
if asynchronous_callable is not None:
asynchronous_callable = asynchronous_callable.__get__(instance, owner)
assert asynchronous_callable is not None
if inspect.iscoroutinefunction(self.asynchronous_callable):
asynchronous_callable = assign_coroutine(asynchronous_callable, instance.executor)
default_callable = self.default_callable
if default_callable is not None:
default_callable = default_callable.__get__(instance, owner)
assert default_callable is not None
if inspect.iscoroutinefunction(self.default_callable):
default_callable = assign_coroutine(default_callable, instance.executor)
implementation = GeneralizedFunction(synchronous_callable, asynchronous_callable)
return GeneralizedMethod.Bound(implementation, default_callable)

class Bound(VectorizingCallable, ComposableCallable):
"""A bound generalized method callable."""

def __init__(self, body: GeneralizedCallable, default_callable: Optional[Callable] = None) -> None:
"""Initialize bound method callable.
Args:
body: method body as a generalized callable
default_callable: optionally override default plain calls, defaults to synchronous calls.
"""
self.body = body
if default_callable is None:
default_callable = body.synchronous
self._default_callable = default_callable

def __getattr__(self, name: str) -> Any:
return getattr(self.body, name)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Invoke method (optionally pre-existing)."""
return self._default_callable(*args, **kwargs)

def synchronous(self, *args: Any, **kwargs: Any) -> Any:
"""Invoke method synchronously."""
return self.body.synchronous(*args, **kwargs)

def asynchronous(self, *args: Any, **kwargs: Any) -> Any:
"""Invoke method asynchronously."""
return self.body.asynchronous(*args, **kwargs)

def __init__(self, prototype: Callable, transitional: bool) -> None:
"""Initializes the generalized method.
Args:
prototype: method prototype, usually just a signature but
may also be used as an overload for convenience (iff the
function type matches the missing overload).
transitional: a transitional method will stick to its
prototype for default invocations, simplifying the
adoption of generalized methods in existing codebases.
"""
self.prototype = prototype
self.transitional = transitional
self.synchronous_overload: Optional[Callable] = None
self.asynchronous_overload: Optional[Callable] = None

def sync_overload(self, func: Callable) -> Callable:
"""Register `func` as this method synchronous overload."""
if self.synchronous_overload is not None:
raise RuntimeError("cannot redefine synchronous overload")
self.synchronous_overload = func
return func

def async_overload(self, func: Callable) -> Callable:
"""Register `func` as this method asynchronous overload."""
if self.asynchronous_overload is not None:
raise RuntimeError("cannot redefine asynchronous overload")
self.asynchronous_overload = func
return func

def __set_name__(self, owner: Type, name: str) -> None:
self.__attribute_name = f"__{name}_method"
setattr(owner, self.__attribute_name, GeneralizedMethod.Unbound(self))

def rebind(self, instance: Any, body: GeneralizedCallable) -> None:
"""Change this method's `body` for the given `instance`."""
default_callable: Optional[Callable] = None
if self.transitional:
default_callable = self.prototype.__get__(instance)
bound_method = GeneralizedMethod.Bound(body, default_callable)
setattr(instance, self.__attribute_name, bound_method)

@overload
def __get__(self, instance: Literal[None], owner: Optional[Type] = ...) -> "GeneralizedMethod":
...

@overload
def __get__(self, instance: Any, owner: Optional[Type] = ...) -> "GeneralizedMethod.Bound":
...

def __get__(
self,
instance: Optional[Any],
owner: Optional[Type] = None,
) -> Union["GeneralizedMethod", "GeneralizedMethod.Bound"]:
if instance is None:
return self
return getattr(instance, self.__attribute_name)


@overload
def generalized_method(func: Callable, *, transitional: bool = ...) -> GeneralizedMethod:
...


@overload
def generalized_method(*, transitional: bool = ...) -> Callable:
...


def generalized_method(
func: Optional[Callable] = None,
*,
transitional: bool = False,
) -> Union[Callable, GeneralizedMethod]:
"""Define a generalized method by decoration."""

def _decorator(func: Callable) -> GeneralizedMethod:
return GeneralizedMethod(func, transitional)

if func is None:
return _decorator
return _decorator(func)
19 changes: 19 additions & 0 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import concurrent.futures
import contextlib
import dataclasses
import functools
import inspect
import logging
import os
Expand All @@ -14,6 +15,7 @@

import rclpy.executors

from bdai_ros2_wrappers.futures import FutureLike
from bdai_ros2_wrappers.utilities import bind_to_thread, fqn


Expand Down Expand Up @@ -728,3 +730,20 @@ def foreground(executor: rclpy.executors.Executor) -> typing.Iterator[rclpy.exec
yield executor
finally:
executor.shutdown()


def assign_coroutine(
coroutine: typing.Callable[..., typing.Awaitable],
executor: rclpy.executors.Executor,
) -> typing.Callable[..., FutureLike]:
"""Assign a `coroutine` to a given `executor`.
An assigned coroutine will return a future-like object
that will be serviced by the associated executor.
"""

@functools.wraps(coroutine)
def __wrapper(*args: typing.Any, **kwargs: typing.Any) -> FutureLike:
return executor.create_task(coroutine, *args, **kwargs)

return __wrapper
Loading

0 comments on commit 8de80b5

Please sign in to comment.