Skip to content

Commit

Permalink
Add GeneralizedGuard
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <[email protected]>
  • Loading branch information
mhidalgo-bdai committed Jul 18, 2024
1 parent 41b3177 commit 95619ab
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
38 changes: 37 additions & 1 deletion bdai_ros2_wrappers/bdai_ros2_wrappers/callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from bdai_ros2_wrappers.executors import assign_coroutine
from bdai_ros2_wrappers.futures import AnyFuture, FutureLike, as_proper_future
from bdai_ros2_wrappers.utilities import take_kwargs
from bdai_ros2_wrappers.utilities import fqn, take_kwargs


def starmap_async(func: Callable[..., AnyFuture], iterable: Iterable[Tuple[Any, ...]]) -> Future:
Expand Down Expand Up @@ -167,6 +167,42 @@ def asynchronous(self, *args: Any, **kwargs: Any) -> Any:
return self.wrapped_callable.asynchronous(self.composed_callable(*args, **inner_kwargs), **outer_kwargs)


class GeneralizedGuard(GeneralizedDecorator):
"""A decorator that guards generalized callable invocations."""

def __init__(
self,
condition: Callable[[], bool],
wrapped_callable: GeneralizedCallable,
message: Optional[str] = None
) -> None:
"""Initializes generalized guard.
Args:
condition: boolean predicate to guard invocations.
wrapped_callable: the guarded callable.
message: optional human-readable message to raise whenever
the guard condition does not hold.
"""
super().__init__(wrapped_callable)
self.condition = condition
if message is None:
message = fqn(condition)
self.message = message

def synchronous(self, *args: Any, **kwargs: Any) -> Any:
"""Invokes callable synchronously if the guarded condition holds true, raises otherwise."""
if not self.condition():
raise RuntimeError(self.message)
return self.wrapped_callable.synchronous(*args, **kwargs)

def asynchronous(self, *args: Any, **kwargs: Any) -> Any:
"""Invokes callable asynchronously if the guarded condition holds true, raises otherwise."""
if not self.condition():
raise RuntimeError(self.message)
return self.wrapped_callable.asynchronous(*args, **kwargs)


class GeneralizedFunction(GeneralizedCallable):
"""A generalized callable defined by parts."""

Expand Down
22 changes: 21 additions & 1 deletion bdai_ros2_wrappers/test/test_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from bdai_ros2_wrappers.callables import generalized_method
from bdai_ros2_wrappers.callables import GeneralizedGuard, generalized_method
from bdai_ros2_wrappers.futures import wait_for_future
from bdai_ros2_wrappers.scope import ROSAwareScope

Expand Down Expand Up @@ -182,3 +182,23 @@ def test_composed_method(ros: ROSAwareScope) -> None:
assert future.result() is True
assert "extras" in bucket._storage
assert bucket._storage["extras"] == (0, "more data", False)


def test_guarded_method(ros: ROSAwareScope) -> None:
bucket = Bucket(ros, {"my-data": "some data"})

read_permission = False

def allowed() -> bool:
nonlocal read_permission
return read_permission

guarded_read = GeneralizedGuard(allowed, bucket.read)
Bucket.read.rebind(bucket, guarded_read)

with pytest.raises(RuntimeError):
bucket.read("my-data")

read_permission = True

assert bucket.read("my-data") == "some data"

0 comments on commit 95619ab

Please sign in to comment.