From da795be42cb5bd30d0717a788c932f9411ae9e4b Mon Sep 17 00:00:00 2001 From: jmatejcz Date: Mon, 10 Mar 2025 09:55:56 +0100 Subject: [PATCH] moved common logic to ManipulationTask --- .../rai_bench/examples/o3de_test_benchmark.py | 30 ++++-- .../o3de_test_bench/tasks/grab_carrot_task.py | 83 ++-------------- .../tasks/group_vegetables_task.py | 59 +---------- .../tasks/manipulation_task.py | 98 ++++++++++++++++++- .../o3de_test_bench/tasks/place_cubes_task.py | 74 +------------- 5 files changed, 128 insertions(+), 216 deletions(-) diff --git a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py index 6068806d..2da09f2d 100644 --- a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py +++ b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py @@ -21,21 +21,31 @@ import rclpy from langchain.tools import BaseTool -from rai.agents.conversational_agent import create_conversational_agent -from rai.communication.ros2.connectors import ROS2ARIConnector -from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool -from rai.tools.ros2.topics import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool -from rai.utils.model_initialization import get_llm_model +from rai.agents.conversational_agent import create_conversational_agent # type: ignore +from rai.communication.ros2.connectors import ROS2ARIConnector # type: ignore +from rai.tools.ros.manipulation import ( # type: ignore + GetObjectPositionsTool, + MoveToPointTool, +) +from rai.tools.ros2.topics import ( # type: ignore + GetROS2ImageTool, + GetROS2TopicsNamesAndTypesTool, +) +from rai.utils.model_initialization import get_llm_model # type: ignore from rai_open_set_vision.tools import GetGrabbingPointTool -from rai_bench.benchmark_model import Benchmark, Task -from rai_bench.o3de_test_bench.tasks import GrabCarrotTask, PlaceCubesTask -from rai_sim.o3de.o3de_bridge import ( +from rai_bench.benchmark_model import Benchmark, Task # type: ignore +from rai_bench.o3de_test_bench.tasks import ( # type: ignore + GrabCarrotTask, + GroupVegetablesTask, + PlaceCubesTask, +) +from rai_sim.o3de.o3de_bridge import ( # type: ignore O3DEngineArmManipulationBridge, O3DExROS2SimulationConfig, Pose, ) -from rai_sim.simulation_bridge import Rotation, Translation +from rai_sim.simulation_bridge import Rotation, Translation # type: ignore if __name__ == "__main__": rclpy.init() @@ -142,6 +152,7 @@ configs_dir + "scene2.yaml", configs_dir + "scene3.yaml", configs_dir + "scene4.yaml", + configs_dir + "scene5.yaml", ] simulations_configs = [ O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) @@ -150,6 +161,7 @@ tasks: List[Task] = [ GrabCarrotTask(logger=bench_logger), PlaceCubesTask(logger=bench_logger), + GroupVegetablesTask(logger=bench_logger), ] scenarios = Benchmark.create_scenarios( tasks=tasks, diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py index 773aed05..593d37a8 100644 --- a/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py @@ -13,21 +13,20 @@ # limitations under the License. from typing import List, Tuple -from rai_bench.benchmark_model import ( - EntitiesMismatchException, +from rai_bench.o3de_test_bench.tasks.manipulation_task import ( + ManipulationTask, # type: ignore ) -from rai_bench.o3de_test_bench.tasks.manipulation_task import ManipulationTask -from rai_sim.o3de.o3de_bridge import ( - SimulationBridge, +from rai_sim.simulation_bridge import ( # type: ignore + SimulationConfig, + SpawnedEntity, ) -from rai_sim.simulation_bridge import SimulationConfig, SpawnedEntity, SimulationConfigT class GrabCarrotTask(ManipulationTask): obj_types = ["carrot"] def get_prompt(self) -> str: - return "Manipulate objects, so that all carrots to the left side of the table (positive y)" + return "Manipulate objects, so that all carrots are on the left side of the table (positive y)" def validate_config(self, simulation_config: SimulationConfig) -> bool: for ent in simulation_config.entities: @@ -37,76 +36,6 @@ def validate_config(self, simulation_config: SimulationConfig) -> bool: return False def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]: - """Calculate how many objects are positioned correct and incorrect""" correct = sum(1 for ent in entities if ent.pose.translation.y > 0.0) incorrect: int = len(entities) - correct return correct, incorrect - - def calculate_initial_placements( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> tuple[int, int]: - """ - Calculates the number of objects that are correctly and incorrectly placed initially. - """ - initial_carrots = self.filter_entities_by_prefab_type( - simulation_bridge.spawned_entities, prefab_types=self.obj_types - ) - initially_correct, initially_incorrect = self.calculate_correct( - entities=initial_carrots - ) - - self.logger.info( # type: ignore - f"Initially correctly placed carrots: {initially_correct}, Initially incorrectly placed carrots: {initially_incorrect}" - ) - return initially_correct, initially_incorrect - - def calculate_final_placements( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> tuple[int, int]: - """ - Calculates the number of objects that are correctly and incorrectly placed at the end of the simulation. - """ - scene_state = simulation_bridge.get_scene_state() - final_carrots = self.filter_entities_by_prefab_type( - scene_state.entities, prefab_types=self.obj_types - ) - final_correct, final_incorrect = self.calculate_correct(entities=final_carrots) - - self.logger.info( # type: ignore - f"Finally correctly placed carrots: {final_correct}, Finally incorrectly placed carrots: {final_incorrect}" - ) - return final_correct, final_incorrect - - def calculate_result( - self, simulation_bridge: SimulationBridge[SimulationConfig] - ) -> float: - """ - Calculates a score from 0.0 to 1.0, where 0.0 represents the initial placements or worse and 1.0 represents perfect final placements. - """ - initially_correct, initially_incorrect = self.calculate_initial_placements( - simulation_bridge - ) - final_correct, final_incorrect = self.calculate_final_placements( - simulation_bridge - ) - - total_objects = initially_correct + initially_incorrect - if total_objects == 0: - return 1.0 - elif (initially_correct + initially_incorrect) != ( - final_correct + final_incorrect - ): - raise EntitiesMismatchException( - "number of initial entities does not match final entities number." - ) - elif initially_incorrect == 0: - pass - # NOTE all objects are placed correctly - # no point in running task - raise ValueError("All objects are placed correctly at the start.") - else: - corrected = final_correct - initially_correct - score = max(0.0, corrected / initially_incorrect) - - self.logger.info(f"Calculated score: {score:.2f}") # type: ignore - return score diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/group_vegetables_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/group_vegetables_task.py index 37ac45fc..064c5b5b 100644 --- a/src/rai_bench/rai_bench/o3de_test_bench/tasks/group_vegetables_task.py +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/group_vegetables_task.py @@ -14,12 +14,10 @@ from typing import List, Tuple -from rai_bench.o3de_test_bench.tasks.manipulation_task import ManipulationTask -from rai_bench.benchmark_model import ( - EntitiesMismatchException, +from rai_bench.o3de_test_bench.tasks.manipulation_task import ( + ManipulationTask, # type: ignore ) -from rai_sim.o3de.o3de_bridge import SimulationBridge # type: ignore -from rai_sim.simulation_bridge import SimulationConfig, SpawnedEntity, SimulationConfigT # type: ignore +from rai_sim.simulation_bridge import SimulationConfig, SpawnedEntity # type: ignore class GroupVegetablesTask(ManipulationTask): @@ -68,54 +66,3 @@ def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]: misclustered.extend(veggies) return len(properly_clustered), len(misclustered) - - def calculate_initial_placements( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> Tuple[int, int]: - """Calculate the number of initially correct and incorrect placements.""" - initial_veggies = self.filter_entities_by_prefab_type( - simulation_bridge.spawned_entities, self.obj_types - ) - initially_correct, initially_incorrect = self.calculate_correct(initial_veggies) - - self.logger.info(f"Initially correct: {initially_correct}, Initially incorrect: {initially_incorrect}") # type: ignore - return initially_correct, initially_incorrect - - def calculate_final_placements( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> Tuple[int, int]: - """Calculate the number of correctly and incorrectly placed objects at the end of the simulation.""" - scene_state = simulation_bridge.get_scene_state() - final_veggies = self.filter_entities_by_prefab_type( - scene_state.entities, self.obj_types - ) - final_correct, final_incorrect = self.calculate_correct(final_veggies) - - self.logger.info(f"Final correct: {final_correct}, Final incorrect: {final_incorrect}") # type: ignore - return final_correct, final_incorrect - - def calculate_result( - self, simulation_bridge: SimulationBridge[SimulationConfig] - ) -> float: - """Calculates a score from 0.0 to 1.0 based on placement improvements.""" - initially_correct, initially_incorrect = self.calculate_initial_placements( - simulation_bridge - ) - final_correct, final_incorrect = self.calculate_final_placements( - simulation_bridge - ) - - total_objects = initially_correct + initially_incorrect - if total_objects == 0: - return 1.0 - elif total_objects != (final_correct + final_incorrect): - raise EntitiesMismatchException( - "Mismatch in initial and final entity counts." - ) - elif initially_incorrect == 0: - raise ValueError("All objects are placed correctly at the start.") - else: - corrected = final_correct - initially_correct - score = max(0.0, corrected / initially_incorrect) - self.logger.info(f"Calculated score: {score:.2f}") # type: ignore - return score diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/manipulation_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/manipulation_task.py index a31395f3..3d249981 100644 --- a/src/rai_bench/rai_bench/o3de_test_bench/tasks/manipulation_task.py +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/manipulation_task.py @@ -13,16 +13,28 @@ # limitations under the License. import logging -from typing import Union +from abc import ABC, abstractmethod +from typing import List, Tuple, Union from rclpy.impl.rcutils_logger import RcutilsLogger -from rai_bench.benchmark_model import Task # type: ignore +from rai_bench.benchmark_model import ( # type: ignore + EntitiesMismatchException, + Task, # type: ignore +) +from rai_sim.simulation_bridge import ( # type: ignore + SimulationBridge, + SimulationConfig, + SimulationConfigT, + SpawnedEntity, +) loggers_type = Union[RcutilsLogger, logging.Logger] -class ManipulationTask(Task): +class ManipulationTask(Task, ABC): + obj_types: List[str] = [] + def __init__(self, logger: loggers_type | None = None): super().__init__(logger) self.initially_misplaced_now_correct = 0 @@ -35,3 +47,83 @@ def reset_values(self): self.initially_misplaced_still_incorrect = 0 self.initially_correct_still_correct = 0 self.initially_correct_now_incorrect = 0 + + @abstractmethod + def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]: + """ + This method should implement calculation of how many objects + are positioned correctly and incorrectly + + first int of the tuple must be number of correctly placed objects + second int s- number of incorrectly placed objects + """ + pass + + def calculate_initial_placements( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> tuple[int, int]: + """ + Calculates the number of objects that are correctly and incorrectly placed initially. + """ + initial_carrots = self.filter_entities_by_prefab_type( + simulation_bridge.spawned_entities, prefab_types=self.obj_types + ) + initially_correct, initially_incorrect = self.calculate_correct( + entities=initial_carrots + ) + + self.logger.info( # type: ignore + f"Initially correctly placed carrots: {initially_correct}, Initially incorrectly placed carrots: {initially_incorrect}" + ) + return initially_correct, initially_incorrect + + def calculate_final_placements( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> tuple[int, int]: + """ + Calculates the number of objects that are correctly and incorrectly placed at the end of the simulation. + """ + scene_state = simulation_bridge.get_scene_state() + final_carrots = self.filter_entities_by_prefab_type( + scene_state.entities, prefab_types=self.obj_types + ) + final_correct, final_incorrect = self.calculate_correct(entities=final_carrots) + + self.logger.info( # type: ignore + f"Finally correctly placed carrots: {final_correct}, Finally incorrectly placed carrots: {final_incorrect}" + ) + return final_correct, final_incorrect + + def calculate_result( + self, simulation_bridge: SimulationBridge[SimulationConfig] + ) -> float: + """ + Calculates a score from 0.0 to 1.0, where 0.0 represents the initial placements or worse and 1.0 represents perfect final placements. + """ + initially_correct, initially_incorrect = self.calculate_initial_placements( + simulation_bridge + ) + final_correct, final_incorrect = self.calculate_final_placements( + simulation_bridge + ) + + total_objects = initially_correct + initially_incorrect + if total_objects == 0: + return 1.0 + elif (initially_correct + initially_incorrect) != ( + final_correct + final_incorrect + ): + raise EntitiesMismatchException( + "number of initial entities does not match final entities number." + ) + elif initially_incorrect == 0: + pass + # NOTE all objects are placed correctly + # no point in running task + raise ValueError("All objects are placed correctly at the start.") + else: + corrected = final_correct - initially_correct + score = max(0.0, corrected / initially_incorrect) + + self.logger.info(f"Calculated score: {score:.2f}") # type: ignore + return score diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py index b651f6c9..e92326b0 100644 --- a/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py @@ -13,12 +13,10 @@ # limitations under the License. from typing import List, Tuple -from rai_bench.benchmark_model import ( - EntitiesMismatchException, +from rai_bench.o3de_test_bench.tasks.manipulation_task import ( + ManipulationTask, # type: ignore ) -from rai_bench.o3de_test_bench.tasks.manipulation_task import ManipulationTask -from rai_sim.o3de.o3de_bridge import SimulationBridge -from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT, SpawnedEntity +from rai_sim.simulation_bridge import SimulationConfig, SpawnedEntity # type: ignore class PlaceCubesTask(ManipulationTask): @@ -48,69 +46,3 @@ def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]: ) incorrect: int = len(entities) - correct return correct, incorrect - - def calculate_initial_placements( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> tuple[int, int]: - """ - Calculates the number of objects that are correctly and incorrectly placed initially. - """ - initial_cubes = self.filter_entities_by_prefab_type( - simulation_bridge.spawned_entities, prefab_types=self.obj_types - ) - initially_correct, initially_incorrect = self.calculate_correct( - entities=initial_cubes - ) - - self.logger.info( # type: ignore - f"Initially correctly placed cubes: {initially_correct}, Initially incorrectly placed cubes: {initially_incorrect}" - ) - return initially_correct, initially_incorrect - - def calculate_final_placements( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> tuple[int, int]: - """ - Calculates the number of objects that are correctly and incorrectly placed at the end of the simulation. - """ - scene_state = simulation_bridge.get_scene_state() - final_cubes = self.filter_entities_by_prefab_type( - scene_state.entities, prefab_types=self.obj_types - ) - final_correct, final_incorrect = self.calculate_correct(entities=final_cubes) - - self.logger.info( # type: ignore - f"Finally correctly placed cubes: {final_correct}, Finally incorrectly placed cubes: {final_incorrect}" - ) - return final_correct, final_incorrect - - def calculate_result( - self, simulation_bridge: SimulationBridge[SimulationConfig] - ) -> float: - """ - Calculates a score from 0.0 to 1.0, where 0.0 represents the initial placements or worse and 1.0 represents perfect final placements. - """ - initially_correct, initially_incorrect = self.calculate_initial_placements( - simulation_bridge - ) - final_correct, final_incorrect = self.calculate_final_placements( - simulation_bridge - ) - - total_objects = initially_correct + initially_incorrect - if total_objects == 0: - return 1.0 - elif (initially_correct + initially_incorrect) != ( - final_correct + final_incorrect - ): - raise EntitiesMismatchException( - "number of initial entities does not match final entities number." - ) - elif initially_incorrect == 0: - raise ValueError("All objects are placed correctly at the start.") - else: - corrected = final_correct - initially_correct - score = max(0.0, corrected / initially_incorrect) - - self.logger.info(f"Calculated score: {score:.2f}") # type: ignore - return score