From 3dc3d15f28b114348d95ed76faa55e3f10dfb316 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Wed, 26 Feb 2025 09:58:57 +0100 Subject: [PATCH] feat: rai_bench (#436) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Maciej Majek Co-authored-by: Bartłomiej Boczek Co-authored-by: MagdalenaKotynia chore: pre-commit --- .gitignore | 2 + examples/manipulation-demo.launch.py | 49 +-- examples/manipulation-demo.py | 20 +- poetry.lock | 15 +- pyproject.toml | 3 +- setup_shell.sh | 3 + src/rai_bench/README.md | 60 ++++ src/rai_bench/pyproject.toml | 17 + src/rai_bench/rai_bench/benchmark_model.py | 290 ++++++++++++++++++ src/rai_bench/rai_bench/main.py | 187 +++++++++++ .../rai_bench/o3de_test_bench/__init__.py | 13 + .../o3de_test_bench/configs/scene1.yaml | 25 ++ .../o3de_test_bench/configs/scene2.yaml | 51 +++ .../o3de_test_bench/configs/scene3.yaml | 25 ++ .../o3de_test_bench/configs/scene4.yaml | 50 +++ .../o3de_test_bench/tasks/__init__.py | 17 + .../o3de_test_bench/tasks/grab_carrot_task.py | 101 ++++++ .../o3de_test_bench/tasks/place_cubes_task.py | 104 +++++++ .../rai/agents/conversational_agent.py | 2 +- src/rai_core/rai/agents/tool_runner.py | 7 +- .../rai/communication/ros2/connectors.py | 33 +- src/rai_core/rai/tools/ros/manipulation.py | 53 +--- src/rai_core/rai/tools/ros/utils.py | 4 +- .../rai_open_set_vision/examples/talker.py | 8 +- .../services/grounded_sam.py | 7 +- .../services/grounding_dino.py | 8 +- .../rai_open_set_vision/tools/gdino_tools.py | 31 +- .../tools/segmentation_tools.py | 114 +++++-- src/rai_sim/rai_sim/o3de/o3de_bridge.py | 93 +++++- src/rai_sim/rai_sim/simulation_bridge.py | 2 +- 30 files changed, 1229 insertions(+), 165 deletions(-) create mode 100644 src/rai_bench/README.md create mode 100644 src/rai_bench/pyproject.toml create mode 100644 src/rai_bench/rai_bench/benchmark_model.py create mode 100644 src/rai_bench/rai_bench/main.py create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/__init__.py create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py diff --git a/.gitignore b/.gitignore index 48d048836..9ee309471 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,5 @@ logs/ src/examples/*-demo artifact_database.pkl + +imgui.ini diff --git a/examples/manipulation-demo.launch.py b/examples/manipulation-demo.launch.py index a9210698f..35720a6af 100644 --- a/examples/manipulation-demo.launch.py +++ b/examples/manipulation-demo.launch.py @@ -12,22 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import rclpy -from launch import LaunchContext, LaunchDescription +from launch import LaunchDescription from launch.actions import ( DeclareLaunchArgument, ExecuteProcess, IncludeLaunchDescription, - OpaqueFunction, - RegisterEventHandler, ) -from launch.event_handlers import OnExecutionComplete, OnProcessStart from launch.launch_description_sources import PythonLaunchDescriptionSource from launch.substitutions import LaunchConfiguration from launch_ros.actions import Node from launch_ros.substitutions import FindPackageShare -from rclpy.qos import QoSProfile, ReliabilityPolicy -from rosgraph_msgs.msg import Clock def generate_launch_description(): @@ -46,21 +40,6 @@ def generate_launch_description(): output="screen", ) - def wait_for_clock_message(context: LaunchContext, *args, **kwargs): - rclpy.init() - node = rclpy.create_node("wait_for_game_launcher") - node.create_subscription( - Clock, - "/clock", - lambda msg: rclpy.shutdown(), - QoSProfile(depth=1, reliability=ReliabilityPolicy.BEST_EFFORT), - ) - rclpy.spin(node) - return None - - # Game launcher will start publishing the clock message after loading the simulation - wait_for_game_launcher = OpaqueFunction(function=wait_for_clock_message) - launch_moveit = IncludeLaunchDescription( PythonLaunchDescriptionSource( [ @@ -72,7 +51,7 @@ def wait_for_clock_message(context: LaunchContext, *args, **kwargs): launch_robotic_manipulation = Node( package="robotic_manipulation", executable="robotic_manipulation", - name="robotic_manipulation_node", + # name="robotic_manipulation_node", output="screen", parameters=[ {"use_sim_time": True}, @@ -90,28 +69,10 @@ def wait_for_clock_message(context: LaunchContext, *args, **kwargs): return LaunchDescription( [ - # Include the game_launcher argument game_launcher_arg, - # Launch the game launcher and wait for it to load launch_game_launcher, - RegisterEventHandler( - event_handler=OnProcessStart( - target_action=launch_game_launcher, - on_start=[ - wait_for_game_launcher, - ], - ) - ), - # Launch the MoveIt node after loading the simulation - RegisterEventHandler( - event_handler=OnExecutionComplete( - target_action=wait_for_game_launcher, - on_completion=[ - launch_openset, - launch_moveit, - launch_robotic_manipulation, - ], - ) - ), + launch_openset, + launch_moveit, + launch_robotic_manipulation, ] ) diff --git a/examples/manipulation-demo.py b/examples/manipulation-demo.py index 6f73248c2..92c820031 100644 --- a/examples/manipulation-demo.py +++ b/examples/manipulation-demo.py @@ -12,37 +12,37 @@ # See the License for the specific language goveself.rning permissions and # limitations under the License. -import threading import rclpy import rclpy.qos from langchain_core.messages import HumanMessage from rai.agents.conversational_agent import create_conversational_agent -from rai.node import RaiBaseNode +from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool -from rai.tools.ros.native import GetCameraImage, Ros2GetTopicsNamesAndTypesTool +from rai.tools.ros2.topics import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool from rai.utils.model_initialization import get_llm_model +from rai_open_set_vision.tools import GetGrabbingPointTool def create_agent(): rclpy.init() - node = RaiBaseNode(node_name="manipulation_demo") + connector = ROS2ARIConnector() + node = connector.node node.declare_parameter("conversion_ratio", 1.0) - threading.Thread(target=node.spin).start() - tools = [ GetObjectPositionsTool( - node=node, + connector=connector, target_frame="panda_link0", source_frame="RGBDCamera5", camera_topic="/color_image5", depth_topic="/depth_image5", camera_info_topic="/color_camera_info5", + get_grabbing_point_tool=GetGrabbingPointTool(connector=connector), ), - MoveToPointTool(node=node, manipulator_frame="panda_link0"), - GetCameraImage(node=node), - Ros2GetTopicsNamesAndTypesTool(node=node), + MoveToPointTool(connector=connector, manipulator_frame="panda_link0"), + GetROS2ImageTool(connector=connector), + GetROS2TopicsNamesAndTypesTool(connector=connector), ] llm = get_llm_model(model_type="complex_model", streaming=True) diff --git a/poetry.lock b/poetry.lock index 2d22b986d..9eb39f9fa 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5700,6 +5700,19 @@ torchaudio = "^2.3.1" type = "directory" url = "src/rai_asr" +[[package]] +name = "rai-bench" +version = "0.1.0" +description = "" +optional = false +python-versions = "^3.10" +files = [] +develop = true + +[package.source] +type = "directory" +url = "src/rai_bench" + [[package]] name = "rai-sim" version = "0.0.1" @@ -8318,4 +8331,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "a53906ce2c798e5e0a02c7db25cf00cf36e021186a79429ba1bd8f0836b12db2" +content-hash = "c5469635a5db79c258554ad9f4e49331515940e406fbf912822651a0e0c33dda" diff --git a/pyproject.toml b/pyproject.toml index 863309947..a3edd3730 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ rai = {path = "src/rai_core", develop = true} rai_asr = {path = "src/rai_asr", develop = true} rai_tts = {path = "src/rai_tts", develop = true} rai_sim = {path = "src/rai_sim", develop = true} +rai_bench = {path = "src/rai_bench", develop = true} langchain-core = "^0.3" langchain = "*" @@ -30,7 +31,6 @@ requests = "^2.32.2" pre-commit = "^3.7.0" openai = "^1.23.3" coloredlogs = "^15.0.1" -opencv-python = "^4.9.0.80" markdown = "^3.6" boto3 = "^1.34.98" tqdm = "^4.66.4" @@ -62,6 +62,7 @@ pytest-timeout = "^2.3.1" tomli-w = "^1.1.0" faster-whisper = "^1.1.1" pydub = "^0.25.1" +opencv-python = "^4.11.0.86" [tool.poetry.group.dev.dependencies] ipykernel = "^6.29.4" diff --git a/setup_shell.sh b/setup_shell.sh index cc67a5369..164372ed9 100755 --- a/setup_shell.sh +++ b/setup_shell.sh @@ -30,3 +30,6 @@ esac export PYTHONPATH PYTHONPATH="$(dirname "$(dirname "$(poetry run which python)")")/lib/python$(poetry run python --version | awk '{print $2}' | cut -d. -f1,2)/site-packages:$PYTHONPATH" +PYTHONPATH="src/rai_core:$PYTHONPATH" +PYTHONPATH="src/rai_asr:$PYTHONPATH" +PYTHONPATH="src/rai_tts:$PYTHONPATH" diff --git a/src/rai_bench/README.md b/src/rai_bench/README.md new file mode 100644 index 000000000..e2abd3d4f --- /dev/null +++ b/src/rai_bench/README.md @@ -0,0 +1,60 @@ +## RAI Benchmark + +## Description + +The RAI Bench is a package including benchmarks and providing frame for creating new benchmarks + +## Frame Components + +Frame components can be found in `src/rai_bench/rai_bench/benchmark_model.py` + +- `Task` - abstract class for creating specific task. It introduces helper funtions that make it easier to calculate metrics/scores. Your custom tasks must implement a prompt got agent to do, a way to calculate a result and a validation if given scene config suits the task. +- +- `Scenario` - class defined by a Scene and Task. Can be created manually like: + + ```python + + ``` + +- `Benchmark` - class responsible for running and logging scenarios. + +### O3DE TEST BENCHMARK + +O3DE Test Benchmark (src/rai_bench/rai_bench/o3de_test_bench/), contains 2 Tasks(tasks/) - GrabCarrotTask and PlaceCubesTask (these tasks implement calculating scores) and 4 scene_configs(configs/) for O3DE robotic arm simulation. + +Both tasks calculate score, taking into consideration 4 values: + +- initially_misplaced_now_correct - when the object which was in the incorrect place at the start, is in a correct place at the end +- initially_misplaced_still_incorrect - when the object which was in the incorrect place at the start, is in a incorrect place at the end +- initially_correct_still_correct - when the object which was in the correct place at the start, is in a correct place at the end +- initially_correct_now_incorrect - when the object which was in the correct place at the start, is in a incorrect place at the end + +The result is a value between 0 and 1, calculated like (initially_misplaced_now_correct + initially_correct_still_correct) / number_of_initial_objects. +This score is calculated at the beggining and at the end of each scenario. + +### Example usage + +Example of how to load scenes, define scenarios and run benchmark can be found in `src/rai_bench/rai_bench/benchmark_main.py` + +Scenarios can be loaded manually like: + +```python +one_carrot_simulation_config = O3DExROS2SimulationConfig.load_config( + base_config_path=Path("path_to_scene.yaml"), + connector_config_path=Path("path_to_o3de_config.yaml"), + ) + +Scenario(task=GrabCarrotTask(logger=some_logger), simulation_config=one_carrot_simulation_config) +``` + +or automatically like: + +```python +scenarios = Benchmark.create_scenarios( + tasks=tasks, simulation_configs=simulations_configs + ) +``` + +which will result in list of scenarios with combination of every possible task and scene(task decides if scene config is suitable for it). + +Both approaches can be found in `main.py` diff --git a/src/rai_bench/pyproject.toml b/src/rai_bench/pyproject.toml new file mode 100644 index 000000000..52255eb9a --- /dev/null +++ b/src/rai_bench/pyproject.toml @@ -0,0 +1,17 @@ +[tool.poetry] +name = "rai-bench" +version = "0.1.0" +description = "" +authors = ["jmatejcz "] +readme = "README.md" + +packages = [ + { include = "rai_bench", from = "." }, +] +[tool.poetry.dependencies] +python = "^3.10" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/src/rai_bench/rai_bench/benchmark_model.py b/src/rai_bench/rai_bench/benchmark_model.py new file mode 100644 index 000000000..bc47ef407 --- /dev/null +++ b/src/rai_bench/rai_bench/benchmark_model.py @@ -0,0 +1,290 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import logging +import time +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, List, Union + +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from rai.messages import HumanMultimodalMessage +from rclpy.impl.rcutils_logger import RcutilsLogger + +from rai_sim.simulation_bridge import ( + PoseModel, + SimulationBridge, + SimulationConfig, + SimulationConfigT, + SpawnedEntity, +) + +loggers_type = Union[RcutilsLogger, logging.Logger] + + +class EntitiesMismatchException(Exception): + def __init__(self, message: str) -> None: + super().__init__(message) + + +class Task(ABC): + """ + Task to perform. + Specyfic implementation should implement a way to calculate results. + Abstract provides utility functions for common calculations, that can be usefull when + creating metrics + """ + + def __init__( + self, + logger: loggers_type | None = None, + ) -> None: + if logger: + self.logger = logger + else: + self.logger = logging.getLogger(__name__) + + @abstractmethod + def get_prompt(self) -> str: + pass + + @abstractmethod + def validate_config(self, simulation_config: SimulationConfig) -> bool: + """Task should be able to verify if given config is suitable for specific task + + Args: + simulation_config (SimulationConfig): initial scene setup + Returns: + bool: True is suitable, False otherwise + """ + pass + + @abstractmethod + def calculate_result( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> float: + """ + Calculate result of the task + """ + pass + + def filter_entities_by_prefab_type( + self, entities: List[SpawnedEntity], prefab_types: List[str] + ) -> List[SpawnedEntity]: + """Filter and return only these entities that match provided prefab types""" + return [ent for ent in entities if ent.prefab_name in prefab_types] + + def euclidean_distance(self, pos1: PoseModel, pos2: PoseModel) -> float: + """Calculate euclidean distance between 2 positions""" + return ( + (pos1.translation.x - pos2.translation.x) ** 2 + + (pos1.translation.y - pos2.translation.y) ** 2 + + (pos1.translation.z - pos2.translation.z) ** 2 + ) ** 0.5 + + def is_adjacent(self, pos1: PoseModel, pos2: PoseModel, threshold_distance: float): + """ + Check if positions are adjacent to each other, the threshold_distance is a distance + in simulation, refering to how close they have to be to classify them as adjacent + """ + self.logger.debug( # type: ignore + f"Euclidean distance: {self.euclidean_distance(pos1, pos2)}, pos1: {pos1}, pos2: {pos2}" + ) + return self.euclidean_distance(pos1, pos2) < threshold_distance + + def is_adjacent_to_any( + self, pos1: PoseModel, positions: List[PoseModel], threshold_distance: float + ) -> bool: + """ + Check if given position is adjacent to any position in the given list. + """ + + return any( + self.is_adjacent(pos1, pos2, threshold_distance) for pos2 in positions + ) + + def count_adjacent( + self, positions: List[PoseModel], threshold_distance: float + ) -> int: + """ + Count how many adjacent positions are in the given list. + Note that position has to be adjacent to only 1 other position + to be counted, not all of them + """ + adjacent_count = 0 + + for i, p1 in enumerate(positions): + for j, p2 in enumerate(positions): + if i != j: + if self.is_adjacent(p1, p2, threshold_distance): + adjacent_count += 1 + break + + return adjacent_count + + +class Scenario(Generic[SimulationConfigT]): + """Single instances are run separatly by benchmark""" + + def __init__( + self, + task: Task, + simulation_config: SimulationConfigT, + simulation_config_path: str, + ) -> None: + if not task.validate_config(simulation_config): + raise ValueError("This scene is invalid for this task.") + self.task = task + self.simulation_config = simulation_config + # NOTE (jm) needed for logging which config was used, + # there probably is better method to do it + self.simulation_config_path = simulation_config_path + + +class Benchmark: + """ + Defined by a set of scenarios to be done + """ + + def __init__( + self, + simulation_bridge: SimulationBridge[SimulationConfigT], + scenarios: List[Scenario[SimulationConfigT]], + logger: loggers_type | None = None, + ) -> None: + self.simulation_bridge = simulation_bridge + self.num_of_scenarios = len(scenarios) + self.scenarios = enumerate(iter(scenarios)) + self.results: List[Dict[str, Any]] = [] + if logger: + self._logger = logger + else: + self._logger = logging.getLogger(__name__) + + @classmethod + def create_scenarios( + cls, + tasks: List[Task], + simulation_configs: List[SimulationConfigT], + simulation_configs_paths: List[str], + ) -> List[Scenario[SimulationConfigT]]: + # TODO (jm) hacky_fix, taking paths as args here, not the best solution, + # but more changes to code would be required + scenarios: List[Scenario[SimulationConfigT]] = [] + for task in tasks: + for sim_conf, sim_path in zip(simulation_configs, simulation_configs_paths): + try: + scenarios.append( + Scenario( + task=task, + simulation_config=sim_conf, + simulation_config_path=sim_path, + ) + ) + except ValueError as e: + print( + f"Could not create Scenario from task: {task.get_prompt()} and simulation_config: {sim_conf}, {e}" + ) + return scenarios + + def run_next(self, agent) -> None: + """ + Runs the next scenario + """ + try: + i, scenario = next(self.scenarios) # Get the next scenario + + self.simulation_bridge.setup_scene(scenario.simulation_config) + self._logger.info( # type: ignore + "======================================================================================" + ) + self._logger.info( # type: ignore + f"RUNNING SCENARIO NUMBER {i + 1} / {self.num_of_scenarios}, TASK: {scenario.task.get_prompt()}" + ) + initial_result = scenario.task.calculate_result(self.simulation_bridge) + self._logger.info(f"RESULT OF THE INITIAL SETUP: {initial_result}") # type: ignore + tool_calls_num = 0 + + ts = time.perf_counter() + for state in agent.stream( + {"messages": [HumanMessage(content=scenario.task.get_prompt())]} + ): + graph_node_name = list(state.keys())[0] + msg = state[graph_node_name]["messages"][-1] + + if isinstance(msg, HumanMultimodalMessage): + last_msg = msg.text + elif isinstance(msg, BaseMessage): + if isinstance(msg.content, list): + if len(msg.content) == 1: + if type(msg.content[0]) is dict: + last_msg = msg.content[0].get("text", "") + else: + last_msg = msg.content + self._logger.debug(f"{graph_node_name}: {last_msg}") # type: ignore + + else: + raise ValueError(f"Unexpected type of message: {type(msg)}") + + if isinstance(msg, AIMessage): + # TODO (jm) figure out more robust way of counting tool calls + tool_calls_num += len(msg.tool_calls) + + self._logger.info(f"AI Message: {msg}") # type: ignore + + te = time.perf_counter() + + result = scenario.task.calculate_result(self.simulation_bridge) + total_time = te - ts + self._logger.info( # type: ignore + f"TASK SCORE: {result}, TOTAL TIME: {total_time:.3f}, NUM_OF_TOOL_CALLS: {tool_calls_num}" + ) + + self.results.append( + { + "task": scenario.task.get_prompt(), + "simulation_config": scenario.simulation_config_path, + "initial_score": initial_result, + "final_score": result, + "total_time": f"{total_time:.3f}", + "number_of_tool_calls": tool_calls_num, + } + ) + + except StopIteration: + print("No more scenarios left to run.") + + def get_results(self) -> List[Dict[str, Any]]: + return self.results + + def dump_results_to_csv(self, filename: str) -> None: + if not self.results: + self._logger.warning("No results to save.") # type: ignore + return + + fieldnames = [ + "task", + "initial_score", + "simulation_config", + "final_score", + "total_time", + "number_of_tool_calls", + ] + + with open(filename, mode="w", newline="", encoding="utf-8") as file: + writer = csv.DictWriter(file, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(self.results) + + self._logger.info(f"Results saved to {filename}") # type: ignore diff --git a/src/rai_bench/rai_bench/main.py b/src/rai_bench/rai_bench/main.py new file mode 100644 index 000000000..7875c92c4 --- /dev/null +++ b/src/rai_bench/rai_bench/main.py @@ -0,0 +1,187 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +########### EXAMPLE USAGE ########### +import logging +import time +from pathlib import Path +from typing import List + +import rclpy +from langchain.tools import BaseTool +from rai.agents.conversational_agent import create_conversational_agent +from rai.communication.ros2.connectors import ROS2ARIConnector +from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool +from rai.tools.ros2.topics import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool +from rai.utils.model_initialization import get_llm_model +from rai_open_set_vision.tools import GetGrabbingPointTool + +from rai_bench.benchmark_model import Benchmark, Task +from rai_bench.o3de_test_bench.tasks import GrabCarrotTask, PlaceCubesTask +from rai_sim.o3de.o3de_bridge import ( + O3DEngineArmManipulationBridge, + O3DExROS2SimulationConfig, + PoseModel, +) +from rai_sim.simulation_bridge import Rotation, Translation + +if __name__ == "__main__": + rclpy.init() + connector = ROS2ARIConnector() + node = connector.node + node.declare_parameter("conversion_ratio", 1.0) + + # define model + llm = get_llm_model(model_type="complex_model", streaming=True) + + system_prompt = """ + You are a robotic arm with interfaces to detect and manipulate objects. + Here are the coordinates information: + x - front to back (positive is forward) + y - left to right (positive is right) + z - up to down (positive is up) + Before starting the task, make sure to grab the camera image to understand the environment. + """ + # define tools + tools: List[BaseTool] = [ + GetObjectPositionsTool( + connector=connector, + target_frame="panda_link0", + source_frame="RGBDCamera5", + camera_topic="/color_image5", + depth_topic="/depth_image5", + camera_info_topic="/color_camera_info5", + get_grabbing_point_tool=GetGrabbingPointTool(connector=connector), + ), + MoveToPointTool(connector=connector, manipulator_frame="panda_link0"), + GetROS2ImageTool(connector=connector), + GetROS2TopicsNamesAndTypesTool(connector=connector), + ] + # define loggers + log_file = "src/rai_bench/rai_bench/benchmark.log" + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + file_handler.setFormatter(formatter) + + bench_logger = logging.getLogger("Benchmark logger") + bench_logger.setLevel(logging.INFO) + bench_logger.addHandler(file_handler) + + agent_logger = logging.getLogger("Agent logger") + agent_logger.setLevel(logging.INFO) + agent_logger.addHandler(file_handler) + + configs_dir = "src/rai_bench/rai_bench/o3de_test_bench/configs/" + connector_path = configs_dir + "o3de_config.yaml" + #### Create scenarios manually + # load different scenes + # one_carrot_simulation_config = O3DExROS2SimulationConfig.load_config( + # base_config_path=Path(configs_dir + "scene1.yaml"), + # connector_config_path=Path(connector_path), + # ) + # multiple_carrot_simulation_config = O3DExROS2SimulationConfig.load_config( + # base_config_path=Path(configs_dir + "scene2.yaml"), + # connector_config_path=Path(connector_path), + # ) + # red_cubes_simulation_config = O3DExROS2SimulationConfig.load_config( + # base_config_path=Path(configs_dir + "scene3.yaml"), + # connector_config_path=Path(connector_path), + # ) + # multiple_cubes_simulation_config = O3DExROS2SimulationConfig.load_config( + # base_config_path=Path(configs_dir + "scene4.yaml"), + # connector_config_path=Path(connector_path), + # ) + # # combine different scene configs with the tasks to create various scenarios + # scenarios = [ + # Scenario( + # task=GrabCarrotTask(logger=bench_logger), + # simulation_config=one_carrot_simulation_config, + # simulation_config_path=configs_dir + "scene1.yaml", + # ), + # Scenario( + # task=GrabCarrotTask(logger=bench_logger), + # simulation_config=multiple_carrot_simulation_config, + # simulation_config_path=configs_dir + "scene2.yaml", + # ), + # Scenario( + # task=PlaceCubesTask(logger=bench_logger), + # simulation_config=red_cubes_simulation_config, + # simulation_config_path=configs_dir + "scene3.yaml", + # ), + # Scenario( + # task=PlaceCubesTask(logger=bench_logger), + # simulation_config=multiple_cubes_simulation_config, + # simulation_config_path=configs_dir + "scene4.yaml", + # ), + # ] + + ### Create scenarios automatically + simulation_configs_paths = [ + configs_dir + "scene1.yaml", + configs_dir + "scene2.yaml", + configs_dir + "scene3.yaml", + configs_dir + "scene4.yaml", + ] + simulations_configs = [ + O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) + for path in simulation_configs_paths + ] + tasks: List[Task] = [ + GrabCarrotTask(logger=bench_logger), + PlaceCubesTask(logger=bench_logger), + ] + scenarios = Benchmark.create_scenarios( + tasks=tasks, + simulation_configs=simulations_configs, + simulation_configs_paths=simulation_configs_paths, + ) + + # custom request to arm + base_arm_pose = PoseModel( + translation=Translation(x=0.5, y=0.1, z=0.3), + rotation=Rotation(x=1.0, y=0.0, z=0.0, w=0.0), + ) + + o3de = O3DEngineArmManipulationBridge(connector, logger=agent_logger) + # define benchamrk + benchmark = Benchmark( + simulation_bridge=o3de, + scenarios=scenarios, + logger=bench_logger, + ) + for i, s in enumerate(scenarios): + agent = create_conversational_agent( + llm, tools, system_prompt, logger=agent_logger + ) + benchmark.run_next(agent=agent) + o3de.move_arm( + pose=base_arm_pose, + initial_gripper_state=True, + final_gripper_state=False, + frame_id="panda_link0", + ) # return to case position + time.sleep(2) # admire the end position for a second ;) + + bench_logger.info("===============================================================") + bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") + bench_logger.info("===============================================================") + benchmark.dump_results_to_csv(filename="src/rai_bench/rai_bench/results.csv") + + connector.shutdown() + o3de.shutdown() + rclpy.shutdown() diff --git a/src/rai_bench/rai_bench/o3de_test_bench/__init__.py b/src/rai_bench/rai_bench/o3de_test_bench/__init__.py new file mode 100644 index 000000000..97ceef6f0 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/__init__.py @@ -0,0 +1,13 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml new file mode 100644 index 000000000..a683362a6 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml @@ -0,0 +1,25 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: corn2 + prefab_name: corn + pose: + translation: + x: 0.5 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml new file mode 100644 index 000000000..2be04e047 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml @@ -0,0 +1,51 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: carrot2 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: carrot3 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: carrot4 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml new file mode 100644 index 000000000..1eef69a48 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml @@ -0,0 +1,25 @@ +entities: + - name: cube1 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube2 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml new file mode 100644 index 000000000..00b814c93 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml @@ -0,0 +1,50 @@ +entities: + - name: cube1 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube2 + prefab_name: blue_cube + pose: + translation: + x: 0.4 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: cube3 + prefab_name: yellow_cube + pose: + translation: + x: 0.5 + y: -0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube4 + prefab_name: yellow_cube + pose: + translation: + x: 0.5 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py new file mode 100644 index 000000000..5be82bf8c --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py @@ -0,0 +1,17 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from rai_bench.o3de_test_bench.tasks.grab_carrot_task import GrabCarrotTask +from rai_bench.o3de_test_bench.tasks.place_cubes_task import PlaceCubesTask + +__all__ = ["GrabCarrotTask", "PlaceCubesTask"] diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py new file mode 100644 index 000000000..ca040fd62 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py @@ -0,0 +1,101 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rai_bench.benchmark_model import ( + EntitiesMismatchException, + Task, +) +from rai_sim.o3de.o3de_bridge import ( + SimulationBridge, +) +from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT + + +class GrabCarrotTask(Task): + def get_prompt(self) -> str: + return "Manipulate objects, so that all carrots to the left side of the table (positive y)" + + def validate_config(self, simulation_config: SimulationConfig) -> bool: + for ent in simulation_config.entities: + if ent.prefab_name == "carrot": + return True + + return False + + def calculate_result( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> float: + # TODO (jm) extract common logic to some parent manipulation task? + initially_misplaced_now_correct = 0 # when the object which was in the incorrect place at the start, is in a correct place at the end + initially_misplaced_still_incorrect = 0 # when the object which was in the incorrect place at the start, is in a incorrect place at the end + initially_correct_still_correct = 0 # when the object which was in the correct place at the start, is in a correct place at the end + initially_correct_now_incorrect = 0 # when the object which was in the correct place at the start, is in a incorrect place at the end + + scene_state = simulation_bridge.get_scene_state() + initial_carrots = self.filter_entities_by_prefab_type( + simulation_bridge.spawned_entities, prefab_types=["carrot"] + ) + final_carrots = self.filter_entities_by_prefab_type( + scene_state.entities, prefab_types=["carrot"] + ) + num_initial_carrots = len(initial_carrots) + + if num_initial_carrots != len(final_carrots): + raise EntitiesMismatchException( + "Number of initially spawned entities does not match number of entities present at the end." + ) + + else: + self.logger.debug(f"initial positions: {initial_carrots}") # type: ignore + self.logger.debug(f"current positions: {final_carrots}") # type: ignore + for ini_carrot in initial_carrots: + for final_carrot in final_carrots: + if ini_carrot.name == final_carrot.name: + initial_y = ini_carrot.pose.translation.y + final_y = final_carrot.pose.translation.y + # NOTE the specific coords that refer to for example + # middle of the table can differ across simulations, + # take that into consideration + if ( + initial_y <= 0.0 + ): # Carrot started in the incorrect place (right side) + if final_y >= 0.0: + initially_misplaced_now_correct += ( + 1 # Moved to correct side + ) + else: + initially_misplaced_still_incorrect += ( + 1 # Stayed on incorrect side + ) + else: # Carrot started in the correct place (left side) + if final_y >= 0.0: + initially_correct_still_correct += ( + 1 # Stayed on correct side + ) + else: + initially_correct_now_incorrect += ( + 1 # Moved incorrectly to the wrong side + ) + break + else: + raise EntitiesMismatchException( + f"Entity with name: {ini_carrot.name} which was present in initial scene, not found in final scene." + ) + + self.logger.info( # type: ignore + f"initially_misplaced_now_correct: {initially_misplaced_now_correct}, initially_misplaced_still_incorrect: {initially_misplaced_still_incorrect}, initially_correct_still_correct: {initially_correct_still_correct}, initially_correct_now_incorrect: {initially_correct_now_incorrect}" + ) + return ( + initially_misplaced_now_correct + initially_correct_still_correct + ) / num_initial_carrots diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py new file mode 100644 index 000000000..26bdd590e --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py @@ -0,0 +1,104 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rai_bench.benchmark_model import ( + EntitiesMismatchException, + Task, +) +from rai_sim.o3de.o3de_bridge import SimulationBridge +from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT + + +class PlaceCubesTask(Task): + def get_prompt(self) -> str: + return "Manipulate objects, so that all cubes are adjacent to at least one cube" + + def validate_config(self, simulation_config: SimulationConfig) -> bool: + cube_types = ["red_cube", "blue_cube", "yellow_cube"] + cubes_num = 0 + for ent in simulation_config.entities: + if ent.prefab_name in cube_types: + cubes_num += 1 + if cubes_num > 1: + return True + + return False + + def calculate_result( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> float: + # TODO (jm) extract common logic to some parent manipulation task? + initially_misplaced_now_correct = 0 # when the object which was in the incorrect place at the start, is in a correct place at the end + initially_misplaced_still_incorrect = 0 # when the object which was in the incorrect place at the start, is in a incorrect place at the end + initially_correct_still_correct = 0 # when the object which was in the correct place at the start, is in a correct place at the end + initially_correct_now_incorrect = 0 # when the object which was in the correct place at the start, is in a incorrect place at the end + + cube_types = ["red_cube", "blue_cube", "yellow_cube"] + scene_state = simulation_bridge.get_scene_state() + + initial_cubes = self.filter_entities_by_prefab_type( + simulation_bridge.spawned_entities, prefab_types=cube_types + ) + final_cubes = self.filter_entities_by_prefab_type( + scene_state.entities, prefab_types=cube_types + ) + num_of_objects = len(initial_cubes) + + if num_of_objects != len(final_cubes): + raise EntitiesMismatchException( + "Number of initially spawned entities does not match number of entities present at the end." + ) + + else: + ini_poses = [cube.pose for cube in initial_cubes] + final_poses = [cube.pose for cube in final_cubes] + # NOTE the specific coords that refer to for example + # middle of the table can differ across simulations, + # take that into consideration + self.logger.debug(f"initial positions: {initial_cubes}") + self.logger.debug(f"current positions: {final_cubes}") + for i, ini_cube in enumerate(initial_cubes): + for j, final_cube in enumerate(final_cubes): + if ini_cube.name == final_cube.name: + was_adjacent_initially = self.is_adjacent_to_any( + ini_cube.pose, + [p for p in ini_poses if p != ini_cube.pose], + 0.15, + ) + is_adjacent_finally = self.is_adjacent_to_any( + final_cube.pose, + [p for p in final_poses if p != final_cube.pose], + 0.15, + ) + if not was_adjacent_initially and is_adjacent_finally: + initially_misplaced_now_correct += 1 + elif not was_adjacent_initially and not is_adjacent_finally: + initially_misplaced_still_incorrect += 1 + elif was_adjacent_initially and is_adjacent_finally: + initially_correct_still_correct += 1 + elif was_adjacent_initially and not is_adjacent_finally: + initially_correct_now_incorrect += 1 + + break + else: + raise EntitiesMismatchException( + f"Entity with name: {ini_cube.name} which was present in initial scene, not found in final scene." + ) + + self.logger.info( + f"initially_misplaced_now_correct: {initially_misplaced_now_correct}, initially_misplaced_still_incorrect: {initially_misplaced_still_incorrect}, initially_correct_still_correct: {initially_correct_still_correct}, initially_correct_now_incorrect: {initially_correct_now_incorrect}" + ) + return ( + initially_misplaced_now_correct + initially_correct_still_correct + ) / num_of_objects diff --git a/src/rai_core/rai/agents/conversational_agent.py b/src/rai_core/rai/agents/conversational_agent.py index 072a1c164..739b159ae 100644 --- a/src/rai_core/rai/agents/conversational_agent.py +++ b/src/rai_core/rai/agents/conversational_agent.py @@ -56,7 +56,7 @@ def create_conversational_agent( debug=False, ): _logger = None - if isinstance(logger, RcutilsLogger): + if logger: _logger = logger else: _logger = logging.getLogger(__name__) diff --git a/src/rai_core/rai/agents/tool_runner.py b/src/rai_core/rai/agents/tool_runner.py index 12e0889d3..5c35ac9a8 100644 --- a/src/rai_core/rai/agents/tool_runner.py +++ b/src/rai_core/rai/agents/tool_runner.py @@ -69,8 +69,13 @@ def run_one(call: ToolCall): ts = time.perf_counter() output = self.tools_by_name[call["name"]].invoke(call, config) # type: ignore te = time.perf_counter() - ts + tool_output_log = ( + str(output.content)[:1000] + "..." + if len(str(output.content)) > 1000 + else "" + ) self.logger.info( - f"Tool {call['name']} completed in {te:.2f} seconds. Tool output: {str(output.content)[:100]}{'...' if len(str(output.content)) > 100 else ''}" + f"Tool {call['name']} completed in {te:.2f} seconds. Tool output: {tool_output_log}" ) self.logger.debug( f"Tool {call['name']} output: \n\n{str(output.content)}" diff --git a/src/rai_core/rai/communication/ros2/connectors.py b/src/rai_core/rai/communication/ros2/connectors.py index af7953d78..fb7db7d3d 100644 --- a/src/rai_core/rai/communication/ros2/connectors.py +++ b/src/rai_core/rai/communication/ros2/connectors.py @@ -70,6 +70,7 @@ def __init__( self._service_api = ROS2ServiceAPI(self._node) self._actions_api = ROS2ActionAPI(self._node) self._tf_buffer = Buffer(node=self._node) + self.tf_listener = TransformListener(self._tf_buffer, self._node) self._executor = MultiThreadedExecutor() self._executor.add_node(self._node) @@ -180,7 +181,6 @@ def get_transform( source_frame: str, timeout_sec: float = 5.0, ) -> TransformStamped: - tf_listener = TransformListener(self._tf_buffer, self._node) transform_available = self.wait_for_transform( self._tf_buffer, target_frame, source_frame, timeout_sec ) @@ -192,20 +192,25 @@ def get_transform( target_frame, source_frame, rclpy.time.Time(), - timeout=Duration(seconds=timeout_sec), + timeout=Duration(seconds=int(timeout_sec)), ) - tf_listener.unregister() + return transform def terminate_action(self, action_handle: str, **kwargs: Any): self._actions_api.terminate_goal(action_handle) + @property + def node(self) -> Node: + return self._node + def shutdown(self): - self._executor.shutdown() - self._thread.join() + self.tf_listener.unregister() + self._node.destroy_node() self._actions_api.shutdown() self._topic_api.shutdown() - self._node.destroy_node() + self._executor.shutdown() + self._thread.join() class ROS2HRIMessage(HRIMessage): @@ -279,15 +284,19 @@ def __init__( ] _targets = [ - target - if isinstance(target, tuple) - else (target, TopicConfig(is_subscriber=False)) + ( + target + if isinstance(target, tuple) + else (target, TopicConfig(is_subscriber=False)) + ) for target in targets ] _sources = [ - source - if isinstance(source, tuple) - else (source, TopicConfig(is_subscriber=True)) + ( + source + if isinstance(source, tuple) + else (source, TopicConfig(is_subscriber=True)) + ) for source in sources ] diff --git a/src/rai_core/rai/tools/ros/manipulation.py b/src/rai_core/rai/tools/ros/manipulation.py index 8deacd603..9436490b0 100644 --- a/src/rai_core/rai/tools/ros/manipulation.py +++ b/src/rai_core/rai/tools/ros/manipulation.py @@ -15,21 +15,15 @@ from typing import Literal, Type import numpy as np -import rclpy -import rclpy.callback_groups -import rclpy.executors -import rclpy.qos -import rclpy.subscription -import rclpy.task from geometry_msgs.msg import Point, Pose, PoseStamped, Quaternion from langchain_core.tools import BaseTool from pydantic import BaseModel, Field from rai_open_set_vision.tools import GetGrabbingPointTool -from rclpy.client import Client -from rclpy.node import Node from tf2_geometry_msgs import do_transform_pose +from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.utils import TF2TransformFetcher +from rai.utils.ros_async import get_future_result from rai_interfaces.srv import ManipulatorMoveTo @@ -52,8 +46,7 @@ class MoveToPointTool(BaseTool): "success of grabbing or releasing objects. Use additional sensors or tools for that information." ) - node: Node - client: Client + connector: ROS2ARIConnector = Field(..., exclude=True) manipulator_frame: str = Field(..., description="Manipulator frame") min_z: float = Field(default=0.135, description="Minimum z coordinate [m]") @@ -72,16 +65,6 @@ class MoveToPointTool(BaseTool): args_schema: Type[MoveToPointToolInput] = MoveToPointToolInput - def __init__(self, node: Node, **kwargs): - super().__init__( - node=node, - client=node.create_client( - ManipulatorMoveTo, - "/manipulator_move_to", - ), - **kwargs, - ) - def _run( self, x: float, @@ -89,6 +72,10 @@ def _run( z: float, task: Literal["grab", "drop"], ) -> str: + client = self.connector.node.create_client( + ManipulatorMoveTo, + "/manipulator_move_to", + ) pose_stamped = PoseStamped() pose_stamped.header.frame_id = self.manipulator_frame pose_stamped.pose = Pose( @@ -117,21 +104,18 @@ def _run( request.initial_gripper_state = False # closed request.final_gripper_state = True # open - future = self.client.call_async(request) - self.node.get_logger().debug( + future = client.call_async(request) + self.connector.node.get_logger().debug( f"Calling ManipulatorMoveTo service with request: x={request.target_pose.pose.position.x:.2f}, y={request.target_pose.pose.position.y:.2f}, z={request.target_pose.pose.position.z:.2f}" ) + response = get_future_result(future, timeout_sec=5.0) + if response is None: + return f"Service call failed for point ({x:.2f}, {y:.2f}, {z:.2f})." - rclpy.spin_until_future_complete(self.node, future, timeout_sec=5.0) - - if future.result() is not None: - response = future.result() - if response.success: - return f"End effector successfully positioned at coordinates ({x:.2f}, {y:.2f}, {z:.2f}). Note: The status of object interaction (grab/drop) is not confirmed by this movement." - else: - return f"Failed to position end effector at coordinates ({x:.2f}, {y:.2f}, {z:.2f})." + if response.success: + return f"End effector successfully positioned at coordinates ({x:.2f}, {y:.2f}, {z:.2f}). Note: The status of object interaction (grab/drop) is not confirmed by this movement." else: - return f"Service call failed for point ({x:.2f}, {y:.2f}, {z:.2f})." + return f"Failed to position end effector at coordinates ({x:.2f}, {y:.2f}, {z:.2f})." class GetObjectPositionsToolInput(BaseModel): @@ -153,14 +137,9 @@ class GetObjectPositionsTool(BaseTool): camera_topic: str # rgb camera topic depth_topic: str camera_info_topic: str # rgb camera info topic - node: Node + connector: ROS2ARIConnector = Field(..., exclude=True) get_grabbing_point_tool: GetGrabbingPointTool - def __init__(self, node: Node, **kwargs): - super(GetObjectPositionsTool, self).__init__( - node=node, get_grabbing_point_tool=GetGrabbingPointTool(node=node), **kwargs - ) - args_schema: Type[GetObjectPositionsToolInput] = GetObjectPositionsToolInput @staticmethod diff --git a/src/rai_core/rai/tools/ros/utils.py b/src/rai_core/rai/tools/ros/utils.py index 8c34e23c4..1e207e31c 100644 --- a/src/rai_core/rai/tools/ros/utils.py +++ b/src/rai_core/rai/tools/ros/utils.py @@ -151,7 +151,9 @@ def wait_for_message( if msg_info is not None: return True, msg_info[0] finally: - node.destroy_subscription(sub) + # TODO(boczekbartek): uncomment when rclpy resolves: https://github.com/ros2/rclpy/issues/1142 + # node.destroy_subscription(sub) + pass return False, None diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/examples/talker.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/examples/talker.py index 370c6e488..72fe82198 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/examples/talker.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/examples/talker.py @@ -28,7 +28,9 @@ def __init__(self): self.declare_parameter("image_path", "") self.cli = self.create_client(RAIGroundingDino, "grounding_dino_classify") while not self.cli.wait_for_service(timeout_sec=1.0): - self.get_logger().info("service not available, waiting again...") + self.get_logger().info( + "service grounding_dino_classify not available, waiting again..." + ) self.req = RAIGroundingDino.Request() self.bridge = CvBridge() @@ -56,7 +58,9 @@ def __init__(self): super().__init__(node_name="GSClientExample", parameter_overrides=[]) self.cli = self.create_client(RAIGroundedSam, "grounded_sam_segment") while not self.cli.wait_for_service(timeout_sec=1.0): - self.get_logger().info("service not available, waiting again...") + self.get_logger().info( + "service grounded_sam_segment not available, waiting again..." + ) self.req = RAIGroundedSam.Request() self.bridge = CvBridge() diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py index 45fe9eb52..85fa70488 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py @@ -34,9 +34,6 @@ class GSamService(Node): def __init__(self): super().__init__(node_name=GSAM_NODE_NAME, parameter_overrides=[]) - self.srv = self.create_service( - RAIGroundedSam, GSAM_SERVICE_NAME, self.segment_callback - ) self.declare_parameter("weights_path", "") try: @@ -49,6 +46,10 @@ def __init__(self): self.get_logger().error("Could not load model") raise Exception("Could not load model") + self.srv = self.create_service( + RAIGroundedSam, GSAM_SERVICE_NAME, self.segment_callback + ) + def _init_weight_path(self): try: found_path = get_package_share_directory("rai_open_set_vision") diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py index 7927d7cf8..eba1ee44a 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py @@ -43,9 +43,7 @@ class GDRequest(TypedDict): class GDinoService(Node): def __init__(self): super().__init__(node_name=GDINO_NODE_NAME, parameter_overrides=[]) - self.srv = self.create_service( - RAIGroundingDino, GDINO_SERVICE_NAME, self.classify_callback - ) + self.declare_parameter("weights_path", "") try: weight_path = self.get_parameter("weights_path").value @@ -57,6 +55,10 @@ def __init__(self): self.get_logger().error("Could not load model") raise Exception("Could not load model") + self.srv = self.create_service( + RAIGroundingDino, GDINO_SERVICE_NAME, self.classify_callback + ) + def _init_weight_path(self) -> Path: try: found_path = get_package_share_directory("rai_open_set_vision") diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py index 336ec4bc7..62656c5a8 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py @@ -15,14 +15,11 @@ from typing import List, NamedTuple, Type import numpy as np -import rclpy -import rclpy.qos import sensor_msgs.msg from pydantic import BaseModel, Field -from rai.node import RaiBaseNode +from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.ros import Ros2BaseInput, Ros2BaseTool from rai.tools.ros.utils import convert_ros_img_to_ndarray -from rai.tools.utils import wait_for_message from rai.utils.ros_async import get_future_result from rclpy.exceptions import ( ParameterNotDeclaredException, @@ -82,7 +79,7 @@ class DistanceMeasurement(NamedTuple): # --------------------- Tools --------------------- class GroundingDinoBaseTool(Ros2BaseTool): - node: RaiBaseNode = Field(..., exclude=True, required=True) + connector: ROS2ARIConnector = Field(..., exclude=True) box_threshold: float = Field(default=0.35, description="Box threshold for GDINO") text_threshold: float = Field(default=0.45, description="Text threshold for GDINO") @@ -90,9 +87,11 @@ class GroundingDinoBaseTool(Ros2BaseTool): def _call_gdino_node( self, camera_img_message: sensor_msgs.msg.Image, object_names: list[str] ) -> Future: - cli = self.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) + cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) while not cli.wait_for_service(timeout_sec=1.0): - self.node.get_logger().info("service not available, waiting again...") + self.node.get_logger().info( + f"service {GDINO_SERVICE_NAME} not available, waiting again..." + ) req = RAIGroundingDino.Request() req.source_img = camera_img_message req.classes = " , ".join(object_names) @@ -103,20 +102,16 @@ def _call_gdino_node( return future def get_img_from_topic(self, topic: str, timeout_sec: int = 2): - success, msg = wait_for_message( - sensor_msgs.msg.Image, - self.node, - topic, - qos_profile=rclpy.qos.qos_profile_sensor_data, - time_to_wait=timeout_sec, - ) - - if success: - self.node.get_logger().info(f"Received message of type from topic {topic}") + msg = self.connector.receive_message(topic, timeout_sec=timeout_sec).payload + + if msg is not None: + self.connector.node.get_logger().info( + f"Received message of {type(msg)} from topic {topic}" + ) return msg else: error = f"No message received in {timeout_sec} seconds from topic {topic}" - self.node.get_logger().error(error) + self.connector.node.get_logger().error(error) return error def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py index 29ff0fe18..043802d8f 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py @@ -18,9 +18,10 @@ import numpy as np import rclpy import sensor_msgs.msg +from langchain_core.tools import BaseTool from pydantic import Field -from rai.node import RaiBaseNode -from rai.tools.ros import Ros2BaseInput, Ros2BaseTool +from rai.communication.ros2.connectors import ROS2ARIConnector +from rai.tools.ros import Ros2BaseInput from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray from rai.utils.ros_async import get_future_result from rclpy import Future @@ -64,8 +65,8 @@ class GetGrabbingPointInput(Ros2BaseInput): # --------------------- Tools --------------------- -class GetSegmentationTool(Ros2BaseTool): - node: RaiBaseNode = Field(..., exclude=True) +class GetSegmentationTool: + connector: ROS2ARIConnector = Field(..., exclude=True) name: str = "" description: str = "" @@ -84,7 +85,7 @@ def _get_gsam_response(self, future: Future) -> Optional[RAIGroundedSam.Response return get_future_result(future) def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: - msg = self.node.get_raw_message_from_topic(topic) + msg = self.connector.receive_message(topic).payload if type(msg) is sensor_msgs.msg.Image: return msg else: @@ -93,9 +94,11 @@ def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: def _call_gdino_node( self, camera_img_message: sensor_msgs.msg.Image, object_name: str ) -> Future: - cli = self.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) + cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) while not cli.wait_for_service(timeout_sec=1.0): - self.node.get_logger().info("service not available, waiting again...") + self.node.get_logger().info( + f"service {GDINO_SERVICE_NAME} not available, waiting again..." + ) req = RAIGroundingDino.Request() req.source_img = camera_img_message req.classes = object_name @@ -108,9 +111,11 @@ def _call_gdino_node( def _call_gsam_node( self, camera_img_message: sensor_msgs.msg.Image, data: RAIGroundingDino.Response ): - cli = self.node.create_client(RAIGroundedSam, "grounded_sam_segment") + cli = self.connector.node.create_client(RAIGroundedSam, "grounded_sam_segment") while not cli.wait_for_service(timeout_sec=1.0): - self.node.get_logger().info("service not available, waiting again...") + self.node.get_logger().info( + "service grounded_sam_segment not available, waiting again..." + ) req = RAIGroundedSam.Request() req.detections = data.detections req.source_img = camera_img_message @@ -126,9 +131,11 @@ def _run( camera_img_msg = self._get_image_message(camera_topic) future = self._call_gdino_node(camera_img_msg, object_name) - logger = self.node.get_logger() + logger = self.connector.node.get_logger() try: - conversion_ratio = self.node.get_parameter("conversion_ratio").value + conversion_ratio = self.connector.node.get_parameter( + "conversion_ratio" + ).value if not isinstance(conversion_ratio, float): logger.error( f"Parameter conversion_ratio was set badly: {type(conversion_ratio)}: {conversion_ratio} expected float. Using default value 0.001" @@ -185,19 +192,72 @@ def depth_to_point_cloud( return points -class GetGrabbingPointTool(GetSegmentationTool): +class GetGrabbingPointTool(BaseTool): + connector: ROS2ARIConnector = Field(..., exclude=True) + name: str = "GetGrabbingPointTool" description: str = "Get the grabbing point of an object" pcd: List[Any] = [] args_schema: Type[GetGrabbingPointInput] = GetGrabbingPointInput + box_threshold: float = Field(default=0.35, description="Box threshold for GDINO") + text_threshold: float = Field(default=0.45, description="Text threshold for GDINO") + + def _get_gdino_response( + self, future: Future + ) -> Optional[RAIGroundingDino.Response]: + return get_future_result(future) + + def _get_gsam_response(self, future: Future) -> Optional[RAIGroundedSam.Response]: + return get_future_result(future) + + def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: + msg = self.connector.receive_message(topic).payload + if type(msg) is sensor_msgs.msg.Image: + return msg + else: + raise Exception("Received wrong message") + + def _call_gdino_node( + self, camera_img_message: sensor_msgs.msg.Image, object_name: str + ) -> Future: + cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) + while not cli.wait_for_service(timeout_sec=1.0): + self.connector.node.get_logger().info( + "service not available, waiting again..." + ) + req = RAIGroundingDino.Request() + req.source_img = camera_img_message + req.classes = object_name + req.box_threshold = self.box_threshold + req.text_threshold = self.text_threshold + + future = cli.call_async(req) + return future + + def _call_gsam_node( + self, camera_img_message: sensor_msgs.msg.Image, data: RAIGroundingDino.Response + ): + cli = self.connector.node.create_client(RAIGroundedSam, "grounded_sam_segment") + while not cli.wait_for_service(timeout_sec=1.0): + self.connector.node.get_logger().info( + "service not available, waiting again..." + ) + req = RAIGroundedSam.Request() + req.detections = data.detections + req.source_img = camera_img_message + future = cli.call_async(req) + + return future def _get_camera_info_message(self, topic: str) -> sensor_msgs.msg.CameraInfo: for _ in range(3): - msg = self.node.get_raw_message_from_topic(topic, timeout_sec=3.0) + msg = self.connector.receive_message(topic, timeout_sec=3.0).payload if isinstance(msg, sensor_msgs.msg.CameraInfo): return msg - self.node.get_logger().warn("Received wrong message type. Retrying...") + self.connector.node.get_logger().warn( + "Received wrong message type. Retrying..." + ) raise Exception("Failed to receive correct CameraInfo message after 3 attempts") @@ -259,16 +319,18 @@ def _run( camera_info_topic: str, object_name: str, ): - camera_img_msg = self._get_image_message(camera_topic) - depth_msg = self._get_image_message(depth_topic) + camera_img_msg = self.connector.receive_message(camera_topic).payload + depth_msg = self.connector.receive_message(depth_topic).payload camera_info = self._get_camera_info_message(camera_info_topic) intrinsic = self._get_intrinsic_from_camera_info(camera_info) future = self._call_gdino_node(camera_img_msg, object_name) - logger = self.node.get_logger() + logger = self.connector.node.get_logger() try: - conversion_ratio = self.node.get_parameter("conversion_ratio").value + conversion_ratio = self.connector.node.get_parameter( + "conversion_ratio" + ).value if not isinstance(conversion_ratio, float): logger.error( f"Parameter conversion_ratio was set badly: {type(conversion_ratio)}: {conversion_ratio} expected float. Using default value 0.001" @@ -280,21 +342,17 @@ def _run( ) conversion_ratio = 0.001 resolved = None - while rclpy.ok(): - resolved = self._get_gdino_response(future) - if resolved is not None: - break + + resolved = get_future_result(future) assert resolved is not None future = self._call_gsam_node(camera_img_msg, resolved) ret = [] - while rclpy.ok(): - resolved = self._get_gsam_response(future) - if resolved is not None: - for img_msg in resolved.masks: - ret.append(convert_ros_img_to_base64(img_msg)) - break + resolved = get_future_result(future) + if resolved is not None: + for img_msg in resolved.masks: + ret.append(convert_ros_img_to_base64(img_msg)) assert resolved is not None rets = [] for mask_msg in resolved.masks: diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py index 45e377d39..d514beb78 100644 --- a/src/rai_sim/rai_sim/o3de/o3de_bridge.py +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -18,13 +18,16 @@ import subprocess import time from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import yaml -from geometry_msgs.msg import Point, Pose, Quaternion +from geometry_msgs.msg import Point, Pose, PoseStamped, Quaternion from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage +from rai.utils.ros_async import get_future_result +from std_msgs.msg import Header from tf2_geometry_msgs import do_transform_pose +from rai_interfaces.srv import ManipulatorMoveTo from rai_sim.simulation_bridge import ( Entity, PoseModel, @@ -40,6 +43,9 @@ class O3DExROS2SimulationConfig(SimulationConfig): binary_path: Path robotic_stack_command: str + required_services: List[str] + required_topics: List[str] + required_actions: List[str] @classmethod def load_config( @@ -199,6 +205,35 @@ def get_scene_state(self) -> SceneState: ) return SceneState(entities=entities) + def _is_robotic_stack_ready( + self, simulation_config: O3DExROS2SimulationConfig, retries: int = 30 + ) -> bool: + i = 0 + while i < retries: + topics = self.connector.get_topics_names_and_types() + services = self.connector.node.get_service_names_and_types() + topics_names = [tp[0] for tp in topics] + service_names = [srv[0] for srv in services] + self.logger.info( + f"required services: {simulation_config.required_services}" + ) + self.logger.info(f"required topics: {simulation_config.required_topics}") + self.logger.info(f"required actions: {simulation_config.required_actions}") + # NOTE actions will be listed in services and topics + if ( + all(srv in service_names for srv in simulation_config.required_services) + and all(tp in topics_names for tp in simulation_config.required_topics) + and all( + ac in service_names for ac in simulation_config.required_actions + ) + ): + self.logger.info("All required services are available.") + return True + + time.sleep(5) + retries += 1 + return False + def setup_scene(self, simulation_config: O3DExROS2SimulationConfig): if self.current_binary_path != simulation_config.binary_path: if self.current_sim_process: @@ -211,6 +246,11 @@ def setup_scene(self, simulation_config: O3DExROS2SimulationConfig): while self.spawned_entities: self._despawn_entity(self.spawned_entities[0]) + if not self._is_robotic_stack_ready(simulation_config=simulation_config): + raise RuntimeError( + "Not all required services, topics and actions are available" + ) + for entity in simulation_config.entities: self._spawn_entity(entity) @@ -304,3 +344,52 @@ def from_ros2_pose(self, pose: Pose) -> PoseModel: ) return PoseModel(translation=translation, rotation=rotation) + + +class O3DEngineArmManipulationBridge(O3DExROS2Bridge): + def move_arm( + self, + pose: PoseModel, + initial_gripper_state: bool, + final_gripper_state: bool, + frame_id: str, + ): + """Moves arm to a given position + + Args: + pose (PoseModel): where to move arm + initial_gripper_state (bool): False means closed grip, True means open grip + final_gripper_state (bool): False means closed grip, True means open grip + frame_id (str): reference frame + """ + + request = ManipulatorMoveTo.Request() + request.initial_gripper_state = initial_gripper_state + request.final_gripper_state = final_gripper_state + + request.target_pose = PoseStamped() + request.target_pose.header = Header() + request.target_pose.header.frame_id = frame_id + + request.target_pose.pose.position.x = pose.translation.x + request.target_pose.pose.position.x = pose.translation.y + request.target_pose.pose.position.z = pose.translation.z + + if pose.rotation: + request.target_pose.pose.orientation.x = pose.rotation.x + request.target_pose.pose.orientation.y = pose.rotation.y + request.target_pose.pose.orientation.z = pose.rotation.z + request.target_pose.pose.orientation.w = pose.rotation.w + + client = self.connector.node.create_client( + ManipulatorMoveTo, + "/manipulator_move_to", + ) + while not client.wait_for_service(timeout_sec=5.0): + self.connector.node.get_logger().info("Service not available, waiting...") + + self.connector.node.get_logger().info("Making request to arm manipulator...") + future = client.call_async(request) + result = get_future_result(future, timeout_sec=5.0) + + self.connector.node.get_logger().debug(f"Moving arm result: {result}") diff --git a/src/rai_sim/rai_sim/simulation_bridge.py b/src/rai_sim/rai_sim/simulation_bridge.py index 93959382d..a1e1fcbd2 100644 --- a/src/rai_sim/rai_sim/simulation_bridge.py +++ b/src/rai_sim/rai_sim/simulation_bridge.py @@ -36,7 +36,7 @@ class Rotation(BaseModel): class PoseModel(BaseModel): translation: Translation - rotation: Optional[Rotation] + rotation: Optional[Rotation] = None class Entity(BaseModel):