Skip to content

Commit

Permalink
refactor: adjusted GroupVegetablesTask to rebased code
Browse files Browse the repository at this point in the history
  • Loading branch information
jmatejcz committed Mar 10, 2025
1 parent 1c7ccd2 commit 37f2dee
Showing 1 changed file with 68 additions and 57 deletions.
125 changes: 68 additions & 57 deletions src/rai_bench/rai_bench/o3de_test_bench/tasks/group_vegetables_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,99 +12,110 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import List, Tuple

from rai_bench.o3de_test_bench.tasks.manipulation_task import ManipulationTask
from rai_bench.benchmark_model import (
EntitiesMismatchException,
)
from rai_sim.o3de.o3de_bridge import SimulationBridge # type: ignore
from rai_sim.simulation_bridge import SimulationConfig, SpawnedEntity # type: ignore
from rai_sim.simulation_bridge import SimulationConfig, SpawnedEntity, SimulationConfigT # type: ignore


class GroupVegetablesTask(ManipulationTask):
obj_types = ["tomato", "apple", "corn", "carrot"]

def get_prompt(self) -> str:
return (
"Manipulate objects, so that vegetables will be in separate clusters based on their types. "
"Manipulate objects so that vegetables form separate clusters based on their types. "
"Each cluster must: "
"1. Contain ALL vegetables of the same type "
"2. Contain ONLY vegetables of the same type "
"3. Form a single connected group (all vegetables of the same type must be adjacent) "
"4. Be completely separated from other clusters (objects of different types cannot be adjacent) "
"3. Form a single connected group "
"4. Be completely separated from other clusters "
)

def validate_config(self, simulation_config: SimulationConfig) -> bool:
"""Ensure that there are at least 2 types of vegetables"""
"""Ensure that at least two types of vegetables are present."""
veg_types = {
ent.prefab_name
for ent in simulation_config.entities
if ent.prefab_name in self.obj_types
}
return len(veg_types) > 1

def calculate_result(
self, simulation_bridge: SimulationBridge[SimulationConfig]
) -> float:
# NOTE for now if all veggies from same type don't form single cluster,
# every veggie will be counted as misplaced
# only when they form a signle cluster and none of them is adjacent to other types
# they will be counted as correctly placed
# it can be modified in the future to count only part of veggies as correct
self.reset_values()
initial_veggies, current_veggies = self.get_initial_and_current_positions(
simulation_bridge=simulation_bridge, object_types=self.obj_types
)

initial_veggies_by_type = self.group_entities_by_type(initial_veggies)
current_veggies_by_type = self.group_entities_by_type(current_veggies)
def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]:
"""Count correctly and incorrectly placed objects based on clustering rules."""
properly_clustered: List[SpawnedEntity] = []
misclustered: List[SpawnedEntity] = []

initially_properly_clustered: List[SpawnedEntity] = []
currently_properly_clustered: List[SpawnedEntity] = []
entities_by_type = self.group_entities_by_type(entities)

for veg_type, veggies in initial_veggies_by_type.items():
for veg_type, veggies in entities_by_type.items():
neighbourhood_list = self.build_neighbourhood_list(veggies)
clusters = self.find_clusters(neighbourhood_list)
if len(clusters) == 1:
# there is only 1 cluster so the 1st condition is matched
# now check if every veggie from the cluster is ajacent only to
# other veggies of same type
if all(
self.check_neighbourhood_types(
neighbourhood=neighbourhood_list[veggie],
allowed_types=[veg_type],
neighbourhood=neighbourhood_list[v], allowed_types=[veg_type]
)
for veggie in clusters[0]
for v in clusters[0]
):
initially_properly_clustered.extend(clusters[0])
properly_clustered.extend(clusters[0])
else:
misclustered.extend(clusters[0])
else:
misclustered.extend(veggies)

for veg_type, veggies in current_veggies_by_type.items():
neighbourhood_list = self.build_neighbourhood_list(veggies)
clusters = self.find_clusters(neighbourhood_list)
if len(clusters) == 1:
# there is only 1 cluster so the 1st condition is matched
# now check if every veggie from the cluster is ajacent only to
# other veggies of same type
if all(
self.check_neighbourhood_types(
neighbourhood=neighbourhood_list[veggie],
allowed_types=[veg_type],
)
for veggie in clusters[0]
):
currently_properly_clustered.extend(clusters[0])
return len(properly_clustered), len(misclustered)

self.initially_misplaced_now_correct = len(
set(currently_properly_clustered) - set(initially_properly_clustered)
def calculate_initial_placements(
self, simulation_bridge: SimulationBridge[SimulationConfigT]
) -> Tuple[int, int]:
"""Calculate the number of initially correct and incorrect placements."""
initial_veggies = self.filter_entities_by_prefab_type(
simulation_bridge.spawned_entities, self.obj_types
)
self.initially_misplaced_still_incorrect = len(
set(initial_veggies) - set(initially_properly_clustered)
initially_correct, initially_incorrect = self.calculate_correct(initial_veggies)

self.logger.info(f"Initially correct: {initially_correct}, Initially incorrect: {initially_incorrect}") # type: ignore
return initially_correct, initially_incorrect

def calculate_final_placements(
self, simulation_bridge: SimulationBridge[SimulationConfigT]
) -> Tuple[int, int]:
"""Calculate the number of correctly and incorrectly placed objects at the end of the simulation."""
scene_state = simulation_bridge.get_scene_state()
final_veggies = self.filter_entities_by_prefab_type(
scene_state.entities, self.obj_types
)
self.initially_correct_still_correct = len(
set(initially_properly_clustered) & set(currently_properly_clustered)
final_correct, final_incorrect = self.calculate_correct(final_veggies)

self.logger.info(f"Final correct: {final_correct}, Final incorrect: {final_incorrect}") # type: ignore
return final_correct, final_incorrect

def calculate_result(
self, simulation_bridge: SimulationBridge[SimulationConfig]
) -> float:
"""Calculates a score from 0.0 to 1.0 based on placement improvements."""
initially_correct, initially_incorrect = self.calculate_initial_placements(
simulation_bridge
)
self.initially_correct_now_incorrect = len(
set(initially_properly_clustered) - set(currently_properly_clustered)
final_correct, final_incorrect = self.calculate_final_placements(
simulation_bridge
)

return (
self.initially_misplaced_now_correct + self.initially_correct_still_correct
) / len(initial_veggies)
total_objects = initially_correct + initially_incorrect
if total_objects == 0:
return 1.0
elif total_objects != (final_correct + final_incorrect):
raise EntitiesMismatchException(
"Mismatch in initial and final entity counts."
)
elif initially_incorrect == 0:
raise ValueError("All objects are placed correctly at the start.")
else:
corrected = final_correct - initially_correct
score = max(0.0, corrected / initially_incorrect)
self.logger.info(f"Calculated score: {score:.2f}") # type: ignore
return score

0 comments on commit 37f2dee

Please sign in to comment.