Skip to content

Commit

Permalink
Add tests for generalized callables
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <[email protected]>
  • Loading branch information
mhidalgo-bdai committed May 16, 2024
1 parent bb32b10 commit f360eae
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 20 deletions.
58 changes: 38 additions & 20 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc
import functools
import inspect
from typing import Any, Callable, Iterable, Optional, Tuple, Type, Union, overload
from typing import Any, Callable, Iterable, Literal, Optional, Tuple, Type, Union, overload

from rclpy.task import Future

Expand Down Expand Up @@ -211,19 +211,25 @@ def __init__(self, method: "GeneralizedMethod") -> None:
"""
self.synchronous_callable: Optional[Callable] = None
self.asynchronous_callable: Optional[Callable] = None
self.transitional_callable: Optional[Callable] = None
if not method.transitional:
if inspect.iscoroutinefunction(method.prototype):
self.asynchronous_callable = method.prototype
else:
self.synchronous_callable = method.prototype
else:
self.transitional_callable = method.prototype
if method.synchronous_overload is not None:
self.synchronous_callable = method.synchronous_overload
if method.asynchronous_overload is not None:
self.asynchronous_callable = method.asynchronous_overload

self.default_callable: Optional[Callable] = None
if not method.transitional:
if self.synchronous_callable is not None:
self.default_callable = self.synchronous_callable
else:
self.default_callable = self.asynchronous_callable
else:
self.default_callable = method.prototype

def __get__(
self,
instance: Optional[Any],
Expand All @@ -233,32 +239,34 @@ def __get__(
return self
synchronous_callable = self.synchronous_callable
if synchronous_callable is not None:
asynchronous_callable = synchronous_callable.__get__(instance, owner)
synchronous_callable = synchronous_callable.__get__(instance, owner)
assert synchronous_callable is not None
asynchronous_callable = self.asynchronous_callable
if asynchronous_callable is not None:
asynchronous_callable = asynchronous_callable.__get__(instance, owner)
assert asynchronous_callable is not None
if inspect.iscoroutinefunction(self.asynchronous_callable):
asynchronous_callable = assign_coroutine(asynchronous_callable, instance.executor)
default_callable = self.default_callable
if default_callable is not None:
default_callable = default_callable.__get__(instance, owner)
assert default_callable is not None
if inspect.iscoroutinefunction(self.default_callable):
default_callable = assign_coroutine(default_callable, instance.executor)
implementation = GeneralizedFunction(synchronous_callable, asynchronous_callable)
transitional_callable = self.transitional_callable
if transitional_callable is not None:
transitional_callable = transitional_callable.__get__(instance, owner)
return GeneralizedMethod.Bound(implementation, migrating_from=transitional_callable)
return GeneralizedMethod.Bound(implementation, default_callable)

class Bound(VectorizingCallable, ComposableCallable):
"""A bound generalized method callable."""

def __init__(self, body: GeneralizedCallable, migrating_from: Optional[Callable] = None) -> None:
def __init__(self, body: GeneralizedCallable, default_callable: Optional[Callable] = None) -> None:
"""Initialize bound method callable.
Args:
body: method body as a generalized callable
migrating_from: when migrating to generalized methods,
the prior definition may be fed here so as to keep plain
method invocations the same.
default_callable: optionally override default plain calls, defaults to synchronous calls.
"""
self.body = body
default_callable = migrating_from
if default_callable is None:
default_callable = body.synchronous
self._default_callable = default_callable
Expand Down Expand Up @@ -294,30 +302,40 @@ def __init__(self, prototype: Callable, transitional: bool) -> None:
self.synchronous_overload: Optional[Callable] = None
self.asynchronous_overload: Optional[Callable] = None

def sync_overload(self, func: Callable) -> None:
def sync_overload(self, func: Callable) -> Callable:
"""Register `func` as this method synchronous overload."""
if self.synchronous_overload is not None:
raise RuntimeError("cannot redefine synchronous overload")
self.synchronous_overload = func
return func

def async_overload(self, func: Callable) -> None:
def async_overload(self, func: Callable) -> Callable:
"""Register `func` as this method asynchronous overload."""
if self.asynchronous_overload is not None:
raise RuntimeError("cannot redefine asynchronous overload")
self.asynchronous_overload = func
return func

def __set_name__(self, owner: Type, name: str) -> None:
self.__attribute_name = f"__{name}_method"
setattr(owner, self.__attribute_name, GeneralizedMethod.Unbound(self))

def rebind(self, instance: Any, body: GeneralizedCallable) -> None:
"""Change this method's `body` for the given `instance`."""
bound_method = GeneralizedMethod.Bound(
body,
migrating_from=(self.prototype.__get__(instance) if self.transitional else None),
)
default_callable: Optional[Callable] = None
if self.transitional:
default_callable = self.prototype.__get__(instance)
bound_method = GeneralizedMethod.Bound(body, default_callable)
setattr(instance, self.__attribute_name, bound_method)

@overload
def __get__(self, instance: Literal[None], owner: Optional[Type] = ...) -> "GeneralizedMethod":
...

@overload
def __get__(self, instance: Any, owner: Optional[Type] = ...) -> "GeneralizedMethod.Bound":
...

def __get__(
self,
instance: Optional[Any],
Expand Down
185 changes: 185 additions & 0 deletions bdai_ros2_wrappers/test/test_callables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved.

from collections.abc import Mapping
from typing import Any, Optional

import pytest

from bdai_ros2_wrappers.callables import generalizedmethod
from bdai_ros2_wrappers.futures import wait_for_future
from bdai_ros2_wrappers.scope import ROSAwareScope


class Bucket:
def __init__(
self,
ros: ROSAwareScope,
storage: Optional[Mapping] = None,
) -> None:
self.executor = ros.executor
if storage is None:
storage = {}
self._storage = dict(storage)

@generalizedmethod(transitional=True)
def create(self, content: Any) -> str:
name = str(hash(content))
if name in self._storage:
raise RuntimeError()
self._storage[name] = content
return name

@create.sync_overload
def _create_sync(self, name: str, content: Any) -> bool:
if name in self._storage:
return False
self._storage[name] = content
return True

@create.async_overload
async def _create_async(self, name: str, content: Any) -> bool:
return self._create_sync(name, content)

@generalizedmethod
def read(self, name: str) -> Optional[Any]:
return self._storage.get(name)

@read.async_overload
async def _read_async(self, name: str) -> Optional[Any]:
# mypy does not handle data descriptors well
return self.read(name)

@generalizedmethod
async def update(self, name: str, content: Any) -> bool:
if name not in self._storage:
return False
self._storage[name] = content
return True

@generalizedmethod
def delete(self, name: str) -> bool:
if name not in self._storage:
return False
del self._storage[name]
return True


def test_transitional_method(ros: ROSAwareScope) -> None:
bucket = Bucket(ros)
name = bucket.create("some data")
assert name in bucket._storage
assert bucket._storage[name] == "some data"

with pytest.raises(RuntimeError):
bucket.create("some data")

assert not bucket.create.synchronously(name, "some other data")
assert bucket.create.synchronously("my-data", "some other data")
assert "my-data" in bucket._storage
assert bucket._storage["my-data"] == "some other data"

future = bucket.create.asynchronously("my-data", "more data")
assert wait_for_future(future, timeout_sec=5.0)
assert future.result() is False

future = bucket.create.asynchronously("extras", "more data")
assert wait_for_future(future, timeout_sec=5.0)
assert future.result() is True
assert "extras" in bucket._storage
assert bucket._storage["extras"] == "more data"


def test_nominal_method(ros: ROSAwareScope) -> None:
bucket = Bucket(ros, {"my-data": "some data"})

assert bucket.read("my-data") == "some data"
assert bucket.read.synchronously("my-data") == "some data"
future = bucket.read.asynchronously("my-data")
assert wait_for_future(future, timeout_sec=5.0)
assert future.result() == "some data"

assert not bucket.read("other-data")
assert not bucket.read.synchronously("other-data")
future = bucket.read.asynchronously("other-data")
assert wait_for_future(future, timeout_sec=5.0)
assert future.result() is None


def test_sync_only_method(ros: ROSAwareScope) -> None:
bucket = Bucket(
ros,
{
"my-data": "some data",
"extras": "more data",
"old": "old data",
},
)
assert bucket.delete("my-data")
assert "my-data" not in bucket._storage
assert not bucket.delete("my-data")
assert bucket.delete.synchronously("extras")
assert "extras" not in bucket._storage
assert not bucket.delete.synchronously("extras")
with pytest.raises(NotImplementedError):
bucket.delete.asynchronously("old")
assert "old" in bucket._storage


def test_async_only_method(ros: ROSAwareScope) -> None:
bucket = Bucket(ros, {"my-data": "some data"})
future = bucket.update("my-data", "new data")
assert wait_for_future(future, timeout_sec=5.0)
assert future.result() is True
assert bucket._storage["my-data"] == "new data"

future = bucket.update.asynchronously("my-data", "newer data")
assert wait_for_future(future, timeout_sec=5.0)
assert future.result() is True
assert bucket._storage["my-data"] == "newer data"

with pytest.raises(NotImplementedError):
bucket.update.synchronously("my-data", "")
assert bucket._storage["my-data"] == "newer data"


def test_vectorized_method(ros: ROSAwareScope) -> None:
bucket = Bucket(
ros,
{
"my-data": "some data",
"extras": "more data",
},
)
data = bucket.read.vectorized(["my-data", "extras"])
assert data == ["some data", "more data"]

data = bucket.read.vectorized.synchronously(["my-data", "extras"])
assert data == ["some data", "more data"]

future = bucket.read.vectorized.asynchronously(["my-data", "extras"])
assert wait_for_future(future, timeout_sec=5.0)
assert future.result() == ["some data", "more data"]


def test_composed_method(ros: ROSAwareScope) -> None:
bucket = Bucket(ros)
Bucket.create.rebind(
bucket,
bucket.create.compose(
(lambda name, *data: (name, data)),
starred=True,
),
)
name = bucket.create("some data")
assert name in bucket._storage
assert bucket._storage[name] == "some data"

assert bucket.create.synchronously("my-data", "some other data", 1, True)
assert "my-data" in bucket._storage
assert bucket._storage["my-data"] == ("some other data", 1, True)

future = bucket.create.asynchronously("extras", 0, "more data", False)
assert wait_for_future(future, timeout_sec=5.0)
assert future.result() is True
assert "extras" in bucket._storage
assert bucket._storage["extras"] == (0, "more data", False)

0 comments on commit f360eae

Please sign in to comment.