Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jm/feat/o3de bench more tasks #452

Draft
wants to merge 15 commits into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 111 additions & 6 deletions src/rai_bench/rai_bench/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,22 @@
import logging
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Union
from typing import Any, Dict, Generic, List, Set, TypeVar, Union

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from rai.messages import HumanMultimodalMessage
from rclpy.impl.rcutils_logger import RcutilsLogger

from rai_sim.simulation_bridge import (
Entity,
Pose,
SimulationBridge,
SimulationConfig,
SimulationConfigT,
SpawnedEntity,
)

loggers_type = Union[RcutilsLogger, logging.Logger]
EntityT = TypeVar("EntityT", bound=Entity)


class EntitiesMismatchException(Exception):
Expand Down Expand Up @@ -82,11 +83,30 @@ def calculate_result(
"""
pass

def get_initial_and_current_positions(
self,
simulation_bridge: SimulationBridge[SimulationConfig],
object_types: List[str],
):
scene_state = simulation_bridge.get_scene_state()
initial_objects = self.filter_entities_by_prefab_type(
simulation_bridge.spawned_entities, object_types=object_types
)
final_objects = self.filter_entities_by_prefab_type(
scene_state.entities, object_types=object_types
)

if len(initial_objects) != len(final_objects):
raise EntitiesMismatchException(
"Number of initially spawned entities does not match number of entities present at the end."
)
return initial_objects, final_objects

def filter_entities_by_prefab_type(
self, entities: List[SpawnedEntity], prefab_types: List[str]
) -> List[SpawnedEntity]:
self, entities: List[EntityT], object_types: List[str]
) -> List[EntityT]:
"""Filter and return only these entities that match provided prefab types"""
return [ent for ent in entities if ent.prefab_name in prefab_types]
return [ent for ent in entities if ent.prefab_name in object_types]

def euclidean_distance(self, pos1: Pose, pos2: Pose) -> float:
"""Calculate euclidean distance between 2 positions"""
Expand Down Expand Up @@ -134,6 +154,91 @@ def count_adjacent(self, positions: List[Pose], threshold_distance: float) -> in

return adjacent_count

def build_neighbourhood_list(
self, entities: List[EntityT]
) -> Dict[EntityT, List[EntityT]]:
"""Assignes a list of neighbours to every object based on threshold distance"""
neighbourhood_graph: Dict[EntityT, List[EntityT]] = {
entity: [] for entity in entities
}
for entity in entities:
neighbourhood_graph[entity] = [
other
for other in entities
if entity != other and self.is_adjacent(entity.pose, other.pose, 0.15)
]
return neighbourhood_graph

def group_entities_by_type(
self, entities: List[EntityT]
) -> Dict[str, List[EntityT]]:
"""Returns dictionary of entities grouped by type"""
entities_by_type: Dict[str, List[EntityT]] = {}
for entity in entities:
entities_by_type.setdefault(entity.prefab_name, []).append(entity)
return entities_by_type

def check_neighbourhood_types(
self,
neighbourhood: List[EntityT],
allowed_types: List[str],
) -> bool:
"""Check if ALL neighbours are given types"""
return not neighbourhood or all(
adj.prefab_name in allowed_types for adj in neighbourhood
)

def find_clusters(
self, neighbourhood_list: Dict[EntityT, List[EntityT]]
) -> List[List[EntityT]]:
"""Find clusters of entities using DFS algorithm, lone entities are counted as a cluster"""
visited: Set[EntityT] = set()
clusters: List[List[EntityT]] = []

def dfs(node: EntityT, cluster: List[EntityT]):
visited.add(node)
cluster.append(node)
for neighbor in neighbourhood_list.get(node, []):
if neighbor not in visited:
dfs(neighbor, cluster)

for node in neighbourhood_list.keys():
if node not in visited:
component: List[EntityT] = []
dfs(node, component)
clusters.append(component)

return clusters

def group_entities_by_z_coordinate(
# TODO (jm) figure out how to group by other coords and orientation, without reapeting code
self,
entities: List[EntityT],
margin: float,
) -> List[List[EntityT]]:
"""
Groups entities that are aligned along a z axis within a margin (top to bottom).
Usefull for checking if objects form lines or towers
"""

entities = sorted(entities, key=lambda ent: ent.pose.translation.z)
groups: List[List[EntityT]] = []

for entity in entities:
placed = False
for group in groups:
if (
abs(group[0].pose.translation.z - entity.pose.translation.z)
<= margin
):
group.append(entity)
placed = True
break
if not placed:
groups.append([entity])

return groups


class Scenario(Generic[SimulationConfigT]):
"""
Expand Down Expand Up @@ -270,11 +375,11 @@ def run_next(self, agent) -> None:
te = time.perf_counter()

result = scenario.task.calculate_result(self.simulation_bridge)

total_time = te - ts
self._logger.info( # type: ignore
f"TASK SCORE: {result}, TOTAL TIME: {total_time:.3f}, NUM_OF_TOOL_CALLS: {tool_calls_num}"
)

scenario_result: Dict[str, Any] = {
"task": scenario.task.get_prompt(),
"simulation_config": scenario.simulation_config_path,
Expand Down
41 changes: 32 additions & 9 deletions src/rai_bench/rai_bench/examples/o3de_test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,38 @@
from pathlib import Path
from typing import List


import rclpy
from langchain.tools import BaseTool
from rai.agents.conversational_agent import create_conversational_agent
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool
from rai.tools.ros2.topics import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool
from rai.utils.model_initialization import get_llm_model
from rai.agents.conversational_agent import create_conversational_agent # type: ignore
from rai.communication.ros2.connectors import ROS2ARIConnector # type: ignore
from rai.tools.ros.manipulation import ( # type: ignore
GetObjectPositionsTool,
MoveToPointTool,
)
from rai.tools.ros2.topics import ( # type: ignore
GetROS2ImageTool,
GetROS2TopicsNamesAndTypesTool,
)
from rai.utils.model_initialization import get_llm_model # type: ignore
from rai_open_set_vision.tools import GetGrabbingPointTool

from rai_bench.benchmark_model import Benchmark, Task
from rai_bench.o3de_test_bench.tasks import GrabCarrotTask, PlaceCubesTask
from rai_sim.o3de.o3de_bridge import (
from rai_bench.benchmark_model import Benchmark, Task # type: ignore
from rai_bench.o3de_test_bench.tasks import ( # type: ignore
BuildCubeTowerTask,
GrabCarrotTask,
GroupVegetablesTask,
PlaceCubesTask,
BuildYellowCubeTowerTask,
BuildBlueCubeTowerTask,
BuildRedCubeTowerTask,
)
from rai_sim.o3de.o3de_bridge import ( # type: ignore
O3DEngineArmManipulationBridge,
O3DExROS2SimulationConfig,
Pose,
)
from rai_sim.simulation_bridge import Rotation, Translation
from rai_sim.simulation_bridge import Rotation, Translation # type: ignore

if __name__ == "__main__":
rclpy.init()
Expand Down Expand Up @@ -142,6 +157,9 @@
configs_dir + "scene2.yaml",
configs_dir + "scene3.yaml",
configs_dir + "scene4.yaml",
configs_dir + "scene5.yaml",
configs_dir + "scene6.yaml",
configs_dir + "scene7.yaml",
]
simulations_configs = [
O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path))
Expand All @@ -150,6 +168,11 @@
tasks: List[Task] = [
GrabCarrotTask(logger=bench_logger),
PlaceCubesTask(logger=bench_logger),
GroupVegetablesTask(logger=bench_logger),
BuildCubeTowerTask(logger=bench_logger),
BuildRedCubeTowerTask(logger=bench_logger),
BuildYellowCubeTowerTask(logger=bench_logger),
BuildBlueCubeTowerTask(logger=bench_logger),
]
scenarios = Benchmark.create_scenarios(
tasks=tasks,
Expand Down
50 changes: 50 additions & 0 deletions src/rai_bench/rai_bench/o3de_test_bench/configs/scene5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
entities:
- name: carrot1
prefab_name: carrot
pose:
translation:
x: 0.5
y: -0.3
z: 0.05
rotation:
x: 0.0
y: 0.0
z: 0.0
w: 1.0
- name: corn1
prefab_name: apple
pose:
translation:
x: 0.5
y: 0.4
z: 0.05
rotation:
x: 0.0
y: 0.0
z: 0.0
w: 1.0

- name: carrot2
prefab_name: carrot
pose:
translation:
x: 0.1
y: -0.3
z: 0.05
rotation:
x: 0.0
y: 0.0
z: 0.0
w: 1.0
- name: corn2
prefab_name: apple
pose:
translation:
x: 0.1
y: 0.4
z: 0.05
rotation:
x: 0.0
y: 0.0
z: 0.0
w: 1.0
17 changes: 16 additions & 1 deletion src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from rai_bench.o3de_test_bench.tasks.build_tower_task import (
BuildBlueCubeTowerTask,
BuildCubeTowerTask,
BuildRedCubeTowerTask,
BuildYellowCubeTowerTask,
)
from rai_bench.o3de_test_bench.tasks.grab_carrot_task import GrabCarrotTask
from rai_bench.o3de_test_bench.tasks.group_vegetables_task import GroupVegetablesTask
from rai_bench.o3de_test_bench.tasks.place_cubes_task import PlaceCubesTask

__all__ = ["GrabCarrotTask", "PlaceCubesTask"]
__all__ = [
"BuildBlueCubeTowerTask",
"BuildCubeTowerTask",
"BuildRedCubeTowerTask",
"BuildYellowCubeTowerTask",
"GrabCarrotTask",
"GroupVegetablesTask",
"PlaceCubesTask",
]
Loading
Loading