From 96c2e7aa86039c48d6b6e3009936e95cceea9aa5 Mon Sep 17 00:00:00 2001 From: Tom Silver Date: Tue, 17 Oct 2023 10:24:07 -0400 Subject: [PATCH 1/3] fix hierarchical typing edge case --- predicators/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/predicators/utils.py b/predicators/utils.py index b1ed1776cb..45b598680a 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2806,9 +2806,10 @@ def create_pddl_domain(operators: Collection[NSRTOrSTRIPSOperator], for parent_type in sorted(parent_to_children_types): child_types = parent_to_children_types[parent_type] if not child_types: - continue - child_type_str = " ".join(t.name for t in child_types) - types_str += f"\n {child_type_str} - {parent_type.name}" + types_str += f"\n {parent_type.name}" + else: + child_type_str = " ".join(t.name for t in child_types) + types_str += f"\n {child_type_str} - {parent_type.name}" ops_lst = sorted(operators) preds_str = "\n ".join(pred.pddl_str() for pred in preds_lst) ops_strs = "\n\n ".join(op.pddl_str() for op in ops_lst) From adf256898355dbd67aedbe768a0a3483a93f3385 Mon Sep 17 00:00:00 2001 From: Tom Silver Date: Tue, 17 Oct 2023 10:51:33 -0400 Subject: [PATCH 2/3] thanks unit tests --- predicators/utils.py | 9 ++++++++- tests/test_utils.py | 6 +++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/predicators/utils.py b/predicators/utils.py index 45b598680a..13232fb332 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2806,7 +2806,14 @@ def create_pddl_domain(operators: Collection[NSRTOrSTRIPSOperator], for parent_type in sorted(parent_to_children_types): child_types = parent_to_children_types[parent_type] if not child_types: - types_str += f"\n {parent_type.name}" + # Special case: type has no children and also does not appear as a + # child of another type. + is_child_type = any( + parent_type in children + for children in parent_to_children_types.values()) + if not is_child_type: + types_str += f"\n {parent_type.name}" + # Otherwise, the type will appear as a child elsewhere. else: child_type_str = " ".join(t.name for t in child_types) types_str += f"\n {child_type_str} - {parent_type.name}" diff --git a/tests/test_utils.py b/tests/test_utils.py index 2743f7b71b..06e73cb340 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2318,12 +2318,16 @@ def test_create_pddl(): env = ProceduralTasksSpannerPDDLEnv() nsrts = get_gt_nsrts(env.get_name(), env.predicates, get_gt_options(env.get_name())) - domain_str = utils.create_pddl_domain(nsrts, env.predicates, env.types, + # Test case where there is a special type with no parents or children. + monkey_type = Type("monkey", []) + types = env.types | {monkey_type} + domain_str = utils.create_pddl_domain(nsrts, env.predicates, types, "spanner") assert domain_str == """(define (domain spanner) (:requirements :typing) (:types man nut spanner - locatable + monkey locatable location - object) (:predicates From 15b67ee86489d97904f519e589468a74d0f66572 Mon Sep 17 00:00:00 2001 From: Tom Silver Date: Tue, 17 Oct 2023 10:52:46 -0400 Subject: [PATCH 3/3] lint --- predicators/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/predicators/utils.py b/predicators/utils.py index 13232fb332..11db9e37ab 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2806,8 +2806,8 @@ def create_pddl_domain(operators: Collection[NSRTOrSTRIPSOperator], for parent_type in sorted(parent_to_children_types): child_types = parent_to_children_types[parent_type] if not child_types: - # Special case: type has no children and also does not appear as a - # child of another type. + # Special case: type has no children and also does not appear + # as a child of another type. is_child_type = any( parent_type in children for children in parent_to_children_types.values())