Skip to content

Commit

Permalink
moved common logic to ManipulationTask
Browse files Browse the repository at this point in the history
  • Loading branch information
jmatejcz committed Mar 10, 2025
1 parent 37f2dee commit da795be
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 216 deletions.
30 changes: 21 additions & 9 deletions src/rai_bench/rai_bench/examples/o3de_test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,31 @@

import rclpy
from langchain.tools import BaseTool
from rai.agents.conversational_agent import create_conversational_agent
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool
from rai.tools.ros2.topics import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool
from rai.utils.model_initialization import get_llm_model
from rai.agents.conversational_agent import create_conversational_agent # type: ignore
from rai.communication.ros2.connectors import ROS2ARIConnector # type: ignore
from rai.tools.ros.manipulation import ( # type: ignore
GetObjectPositionsTool,
MoveToPointTool,
)
from rai.tools.ros2.topics import ( # type: ignore
GetROS2ImageTool,
GetROS2TopicsNamesAndTypesTool,
)
from rai.utils.model_initialization import get_llm_model # type: ignore
from rai_open_set_vision.tools import GetGrabbingPointTool

from rai_bench.benchmark_model import Benchmark, Task
from rai_bench.o3de_test_bench.tasks import GrabCarrotTask, PlaceCubesTask
from rai_sim.o3de.o3de_bridge import (
from rai_bench.benchmark_model import Benchmark, Task # type: ignore
from rai_bench.o3de_test_bench.tasks import ( # type: ignore
GrabCarrotTask,
GroupVegetablesTask,
PlaceCubesTask,
)
from rai_sim.o3de.o3de_bridge import ( # type: ignore
O3DEngineArmManipulationBridge,
O3DExROS2SimulationConfig,
Pose,
)
from rai_sim.simulation_bridge import Rotation, Translation
from rai_sim.simulation_bridge import Rotation, Translation # type: ignore

if __name__ == "__main__":
rclpy.init()
Expand Down Expand Up @@ -142,6 +152,7 @@
configs_dir + "scene2.yaml",
configs_dir + "scene3.yaml",
configs_dir + "scene4.yaml",
configs_dir + "scene5.yaml",
]
simulations_configs = [
O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path))
Expand All @@ -150,6 +161,7 @@
tasks: List[Task] = [
GrabCarrotTask(logger=bench_logger),
PlaceCubesTask(logger=bench_logger),
GroupVegetablesTask(logger=bench_logger),
]
scenarios = Benchmark.create_scenarios(
tasks=tasks,
Expand Down
83 changes: 6 additions & 77 deletions src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,20 @@
# limitations under the License.
from typing import List, Tuple

from rai_bench.benchmark_model import (
EntitiesMismatchException,
from rai_bench.o3de_test_bench.tasks.manipulation_task import (
ManipulationTask, # type: ignore
)
from rai_bench.o3de_test_bench.tasks.manipulation_task import ManipulationTask
from rai_sim.o3de.o3de_bridge import (
SimulationBridge,
from rai_sim.simulation_bridge import ( # type: ignore
SimulationConfig,
SpawnedEntity,
)
from rai_sim.simulation_bridge import SimulationConfig, SpawnedEntity, SimulationConfigT


class GrabCarrotTask(ManipulationTask):
obj_types = ["carrot"]

def get_prompt(self) -> str:
return "Manipulate objects, so that all carrots to the left side of the table (positive y)"
return "Manipulate objects, so that all carrots are on the left side of the table (positive y)"

def validate_config(self, simulation_config: SimulationConfig) -> bool:
for ent in simulation_config.entities:
Expand All @@ -37,76 +36,6 @@ def validate_config(self, simulation_config: SimulationConfig) -> bool:
return False

def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]:
"""Calculate how many objects are positioned correct and incorrect"""
correct = sum(1 for ent in entities if ent.pose.translation.y > 0.0)
incorrect: int = len(entities) - correct
return correct, incorrect

def calculate_initial_placements(
self, simulation_bridge: SimulationBridge[SimulationConfigT]
) -> tuple[int, int]:
"""
Calculates the number of objects that are correctly and incorrectly placed initially.
"""
initial_carrots = self.filter_entities_by_prefab_type(
simulation_bridge.spawned_entities, prefab_types=self.obj_types
)
initially_correct, initially_incorrect = self.calculate_correct(
entities=initial_carrots
)

self.logger.info( # type: ignore
f"Initially correctly placed carrots: {initially_correct}, Initially incorrectly placed carrots: {initially_incorrect}"
)
return initially_correct, initially_incorrect

def calculate_final_placements(
self, simulation_bridge: SimulationBridge[SimulationConfigT]
) -> tuple[int, int]:
"""
Calculates the number of objects that are correctly and incorrectly placed at the end of the simulation.
"""
scene_state = simulation_bridge.get_scene_state()
final_carrots = self.filter_entities_by_prefab_type(
scene_state.entities, prefab_types=self.obj_types
)
final_correct, final_incorrect = self.calculate_correct(entities=final_carrots)

self.logger.info( # type: ignore
f"Finally correctly placed carrots: {final_correct}, Finally incorrectly placed carrots: {final_incorrect}"
)
return final_correct, final_incorrect

def calculate_result(
self, simulation_bridge: SimulationBridge[SimulationConfig]
) -> float:
"""
Calculates a score from 0.0 to 1.0, where 0.0 represents the initial placements or worse and 1.0 represents perfect final placements.
"""
initially_correct, initially_incorrect = self.calculate_initial_placements(
simulation_bridge
)
final_correct, final_incorrect = self.calculate_final_placements(
simulation_bridge
)

total_objects = initially_correct + initially_incorrect
if total_objects == 0:
return 1.0
elif (initially_correct + initially_incorrect) != (
final_correct + final_incorrect
):
raise EntitiesMismatchException(
"number of initial entities does not match final entities number."
)
elif initially_incorrect == 0:
pass
# NOTE all objects are placed correctly
# no point in running task
raise ValueError("All objects are placed correctly at the start.")
else:
corrected = final_correct - initially_correct
score = max(0.0, corrected / initially_incorrect)

self.logger.info(f"Calculated score: {score:.2f}") # type: ignore
return score
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@

from typing import List, Tuple

from rai_bench.o3de_test_bench.tasks.manipulation_task import ManipulationTask
from rai_bench.benchmark_model import (
EntitiesMismatchException,
from rai_bench.o3de_test_bench.tasks.manipulation_task import (
ManipulationTask, # type: ignore
)
from rai_sim.o3de.o3de_bridge import SimulationBridge # type: ignore
from rai_sim.simulation_bridge import SimulationConfig, SpawnedEntity, SimulationConfigT # type: ignore
from rai_sim.simulation_bridge import SimulationConfig, SpawnedEntity # type: ignore


class GroupVegetablesTask(ManipulationTask):
Expand Down Expand Up @@ -68,54 +66,3 @@ def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]:
misclustered.extend(veggies)

return len(properly_clustered), len(misclustered)

def calculate_initial_placements(
self, simulation_bridge: SimulationBridge[SimulationConfigT]
) -> Tuple[int, int]:
"""Calculate the number of initially correct and incorrect placements."""
initial_veggies = self.filter_entities_by_prefab_type(
simulation_bridge.spawned_entities, self.obj_types
)
initially_correct, initially_incorrect = self.calculate_correct(initial_veggies)

self.logger.info(f"Initially correct: {initially_correct}, Initially incorrect: {initially_incorrect}") # type: ignore
return initially_correct, initially_incorrect

def calculate_final_placements(
self, simulation_bridge: SimulationBridge[SimulationConfigT]
) -> Tuple[int, int]:
"""Calculate the number of correctly and incorrectly placed objects at the end of the simulation."""
scene_state = simulation_bridge.get_scene_state()
final_veggies = self.filter_entities_by_prefab_type(
scene_state.entities, self.obj_types
)
final_correct, final_incorrect = self.calculate_correct(final_veggies)

self.logger.info(f"Final correct: {final_correct}, Final incorrect: {final_incorrect}") # type: ignore
return final_correct, final_incorrect

def calculate_result(
self, simulation_bridge: SimulationBridge[SimulationConfig]
) -> float:
"""Calculates a score from 0.0 to 1.0 based on placement improvements."""
initially_correct, initially_incorrect = self.calculate_initial_placements(
simulation_bridge
)
final_correct, final_incorrect = self.calculate_final_placements(
simulation_bridge
)

total_objects = initially_correct + initially_incorrect
if total_objects == 0:
return 1.0
elif total_objects != (final_correct + final_incorrect):
raise EntitiesMismatchException(
"Mismatch in initial and final entity counts."
)
elif initially_incorrect == 0:
raise ValueError("All objects are placed correctly at the start.")
else:
corrected = final_correct - initially_correct
score = max(0.0, corrected / initially_incorrect)
self.logger.info(f"Calculated score: {score:.2f}") # type: ignore
return score
98 changes: 95 additions & 3 deletions src/rai_bench/rai_bench/o3de_test_bench/tasks/manipulation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,28 @@
# limitations under the License.

import logging
from typing import Union
from abc import ABC, abstractmethod
from typing import List, Tuple, Union

from rclpy.impl.rcutils_logger import RcutilsLogger

from rai_bench.benchmark_model import Task # type: ignore
from rai_bench.benchmark_model import ( # type: ignore
EntitiesMismatchException,
Task, # type: ignore
)
from rai_sim.simulation_bridge import ( # type: ignore
SimulationBridge,
SimulationConfig,
SimulationConfigT,
SpawnedEntity,
)

loggers_type = Union[RcutilsLogger, logging.Logger]


class ManipulationTask(Task):
class ManipulationTask(Task, ABC):
obj_types: List[str] = []

def __init__(self, logger: loggers_type | None = None):
super().__init__(logger)
self.initially_misplaced_now_correct = 0
Expand All @@ -35,3 +47,83 @@ def reset_values(self):
self.initially_misplaced_still_incorrect = 0
self.initially_correct_still_correct = 0
self.initially_correct_now_incorrect = 0

@abstractmethod
def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]:
"""
This method should implement calculation of how many objects
are positioned correctly and incorrectly
first int of the tuple must be number of correctly placed objects
second int s- number of incorrectly placed objects
"""
pass

def calculate_initial_placements(
self, simulation_bridge: SimulationBridge[SimulationConfigT]
) -> tuple[int, int]:
"""
Calculates the number of objects that are correctly and incorrectly placed initially.
"""
initial_carrots = self.filter_entities_by_prefab_type(
simulation_bridge.spawned_entities, prefab_types=self.obj_types
)
initially_correct, initially_incorrect = self.calculate_correct(
entities=initial_carrots
)

self.logger.info( # type: ignore
f"Initially correctly placed carrots: {initially_correct}, Initially incorrectly placed carrots: {initially_incorrect}"
)
return initially_correct, initially_incorrect

def calculate_final_placements(
self, simulation_bridge: SimulationBridge[SimulationConfigT]
) -> tuple[int, int]:
"""
Calculates the number of objects that are correctly and incorrectly placed at the end of the simulation.
"""
scene_state = simulation_bridge.get_scene_state()
final_carrots = self.filter_entities_by_prefab_type(
scene_state.entities, prefab_types=self.obj_types
)
final_correct, final_incorrect = self.calculate_correct(entities=final_carrots)

self.logger.info( # type: ignore
f"Finally correctly placed carrots: {final_correct}, Finally incorrectly placed carrots: {final_incorrect}"
)
return final_correct, final_incorrect

def calculate_result(
self, simulation_bridge: SimulationBridge[SimulationConfig]
) -> float:
"""
Calculates a score from 0.0 to 1.0, where 0.0 represents the initial placements or worse and 1.0 represents perfect final placements.
"""
initially_correct, initially_incorrect = self.calculate_initial_placements(
simulation_bridge
)
final_correct, final_incorrect = self.calculate_final_placements(
simulation_bridge
)

total_objects = initially_correct + initially_incorrect
if total_objects == 0:
return 1.0
elif (initially_correct + initially_incorrect) != (
final_correct + final_incorrect
):
raise EntitiesMismatchException(
"number of initial entities does not match final entities number."
)
elif initially_incorrect == 0:
pass
# NOTE all objects are placed correctly
# no point in running task
raise ValueError("All objects are placed correctly at the start.")
else:
corrected = final_correct - initially_correct
score = max(0.0, corrected / initially_incorrect)

self.logger.info(f"Calculated score: {score:.2f}") # type: ignore
return score
Loading

0 comments on commit da795be

Please sign in to comment.