diff --git a/predicators/approaches/grammar_search_invention_approach.py b/predicators/approaches/grammar_search_invention_approach.py index 93986d1ef0..2bca2dbf08 100644 --- a/predicators/approaches/grammar_search_invention_approach.py +++ b/predicators/approaches/grammar_search_invention_approach.py @@ -9,10 +9,13 @@ from dataclasses import dataclass, field from functools import cached_property from operator import le -from typing import Callable, Dict, FrozenSet, Iterator, List, Sequence, Set, \ - Tuple +from typing import Any, Callable, Dict, FrozenSet, Iterator, List, Sequence, \ + Set, Tuple +import numpy as np from gym.spaces import Box +from scipy.stats import kstest +from sklearn.mixture import GaussianMixture as GMM from predicators import utils from predicators.approaches.nsrt_learning_approach import NSRTLearningApproach @@ -887,6 +890,50 @@ def _get_successors( return set(kept_predicates) + @staticmethod + def _get_consistent_predicates( + predicates: Set[Predicate], clusters: List[List[Segment]] + ) -> Tuple[Set[Predicate], Set[Predicate]]: + """Returns all predicates that are consistent with respect to a set of + segment clusters. + + A consistent predicate is is either an add effect, a delete + effect, or doesn't change, within each cluster, for all + clusters. + """ + + consistent: Set[Predicate] = set() + inconsistent: Set[Predicate] = set() + for pred in predicates: + keep_pred = True + for seg_list in clusters: + segment_0 = seg_list[0] + pred_in_add_effs_0 = pred in [ + atom.predicate for atom in segment_0.add_effects + ] + pred_in_del_effs_0 = pred in [ + atom.predicate for atom in segment_0.delete_effects + ] + for seg in seg_list[1:]: + pred_in_curr_add_effs = pred in [ + atom.predicate for atom in seg.add_effects + ] + pred_in_curr_del_effs = pred in [ + atom.predicate for atom in seg.delete_effects + ] + not_consis_add = pred_in_add_effs_0 != pred_in_curr_add_effs + not_consis_del = pred_in_del_effs_0 != pred_in_curr_del_effs + if not_consis_add or not_consis_del: + keep_pred = False + inconsistent.add(pred) + logging.info(f"Inconsistent predicate: {pred.name}") + break + if not keep_pred: + break + else: + consistent.add(pred) + return consistent, inconsistent + def _select_predicates_by_clustering( self, candidates: Dict[Predicate, float], initial_predicates: Set[Predicate], dataset: Dataset, @@ -894,6 +941,236 @@ def _select_predicates_by_clustering( """Cluster segments from the atom_dataset into clusters corresponding to operators and use this to select predicates.""" + if CFG.grammar_search_pred_clusterer == "option-type-number-sample": + # This procedure tries to reverse engineer the clusters of segments + # that correspond to the oracle operators and selects predicates + # that are add effects in those clusters, letting pnad_search + # downstream handle chainability. + + assert CFG.segmenter == "option_changes" + segments = [ + seg for ll_traj, atom_seq in atom_dataset for seg in + segment_trajectory(ll_traj, initial_predicates, atom_seq) + ] + + # Step 1: + # Cluster segments by the option that generated them. We know that + # at the very least, operators are 1 to 1 with options. + option_to_segments: Dict[Any, Any] = {} # Dict[str, List[Segment]] + for seg in segments: + name = seg.get_option().name + option_to_segments.setdefault(name, []).append(seg) + logging.info(f"STEP 1: generated {len(option_to_segments.keys())} " + f"option-based clusters.") + clusters = option_to_segments.copy() # Tree-like structure. + + # Step 2: + # Further cluster by the types that appear in a segment's add + # effects. Operators have a fixed number of typed arguments. + for i, pair in enumerate(option_to_segments.items()): + option, segments = pair + types_to_segments: Dict[Tuple[Type, ...], List[Segment]] = {} + for seg in segments: + types_in_effects = [ + set(a.predicate.types) for a in seg.add_effects + ] + # To cluster on type, there must be types. That is, there + # must be add effects in the segment and the object + # arguments for at least one add effect must be nonempty. + assert len(types_in_effects) > 0 and len( + set.union(*types_in_effects)) > 0 + types = tuple(sorted(list(set.union(*types_in_effects)))) + types_to_segments.setdefault(types, []).append(seg) + logging.info( + f"STEP 2: generated {len(types_to_segments.keys())} " + f"type-based clusters for cluster {i+1} from STEP 1 " + f"involving option {option}.") + clusters[option] = types_to_segments + + # Step 3: + # Further cluster by the maximum number of objects that appear in a + # segment's add effects. Note that the use of maximum here is + # somewhat arbitrary. Alternatively, you could cluster for every + # possible number of objects and not the max among what you see in + # the add effects of a particular segment. + for i, (option, types_to_segments) in enumerate(clusters.items()): + for j, (types, + segments) in enumerate(types_to_segments.items()): + num_to_segments: Dict[int, List[Segment]] = {} + for seg in segments: + max_num_objs = max( + len(a.objects) for a in seg.add_effects) + num_to_segments.setdefault(max_num_objs, + []).append(seg) + logging.info( + f"STEP 3: generated {len(num_to_segments.keys())} " + f"num-object-based clusters for cluster {i+j+1} from " + f"STEP 2 involving option {option} and type {types}.") + clusters[option][types] = num_to_segments + + # Step 4: + # Further cluster by sample, if a sample is present. The idea here + # is to separate things like PickFromTop and PickFromSide. + for i, (option, types_to_num) in enumerate(clusters.items()): + for j, (types, + num_to_segments) in enumerate(types_to_num.items()): + for k, (max_num_objs, + segments) in enumerate(num_to_segments.items()): + # If the segments in this cluster have no sample, then + # don't cluster further. + if len(segments[0].get_option().params) == 0: + clusters[option][types][max_num_objs] = [segments] + logging.info( + f"STEP 4: generated no further sample-based " + f"clusters (no parameter) for cluster {i+j+k+1}" + f" from STEP 3 involving option {option}, type " + f"{types}, and max num objects {max_num_objs}." + ) + continue + # pylint: disable=line-too-long + # If the parameters are described by a uniform + # distribution, then don't cluster further. This + # helps prevent overfitting. A proper implementation + # would do a multi-dimensional test + # (https://ieeexplore.ieee.org/document/4767477, + # https://ui.adsabs.harvard.edu/abs/1987MNRAS.225..155F/abstract, + # https://stats.stackexchange.com/questions/30982/how-to-test-uniformity-in-several-dimensions) + # but for now we will only check each dimension + # individually to keep the implementation simple. + # pylint: enable=line-too-long + samples = np.array( + [seg.get_option().params for seg in segments]) + each_dim_uniform = True + for d in range(samples.shape[1]): + col = samples[:, d] + minimum = col.min() + maximum = col.max() + null_hypothesis = np.random.uniform( + minimum, maximum, len(col)) + p_value = kstest(col, null_hypothesis).pvalue + + # We use a significance value of 0.05. + if p_value < 0.05: + each_dim_uniform = False + break + if each_dim_uniform: + clusters[option][types][max_num_objs] = [segments] + logging.info( + f"STEP 4: generated no further sample-based" + f" clusters (uniformly distributed " + f"parameter) for cluster {i+j+k+1} " + f"from STEP 3 involving option {option}, " + f" type {types}, and max num objects " + f"{max_num_objs}.") + continue + # Determine clusters by assignment from a + # Gaussian Mixture Model. The number of + # components and the negative weighting on the + # complexity of the model (chosen by BIC here) + # are hyperparameters. + max_components = min( + len(samples), len(np.unique(samples)), + CFG.grammar_search_clustering_gmm_num_components) + n_components = np.arange(1, max_components + 1) + models = [ + GMM(n, covariance_type="full", + random_state=0).fit(samples) + for n in n_components + ] + bic = [m.bic(samples) for m in models] + best = models[np.argmin(bic)] + assignments = best.predict(samples) + label_to_segments: Dict[int, List[Segment]] = {} + for l, assignment in enumerate(assignments): + label_to_segments.setdefault( + assignment, []).append(segments[l]) + clusters[option][types][max_num_objs] = list( + label_to_segments.values()) + logging.info(f"STEP 4: generated " + f"{len(label_to_segments.keys())}" + f"sample-based clusters for cluster " + f"{i+j+k+1} from STEP 3 involving option " + f"{option}, type {types}, and max num " + f"objects {max_num_objs}.") + + # We could avoid these loops by creating the final set of clusters + # as part of STEP 4, but this is not prohibitively slow and serves + # to clarify the nested dictionary structure, which we may make use + # of in follow-up work that modifies the clusters more. + final_clusters = [] + for option in clusters.keys(): + for types in clusters[option].keys(): + for max_num_objs in clusters[option][types].keys(): + for cluster in clusters[option][types][max_num_objs]: + final_clusters.append(cluster) + logging.info(f"Total {len(final_clusters)} final clusters.") + + # Step 5: + # Extract predicates from the pure intersection of the add effects + # in each cluster. + extracted_preds = set() + shared_add_effects_per_cluster = [] + for c in final_clusters: + grounded_add_effects_per_segment = [ + seg.add_effects for seg in c + ] + ungrounded_add_effects_per_segment = [] + for effs in grounded_add_effects_per_segment: + ungrounded_add_effects_per_segment.append( + set(a.predicate for a in effs)) + shared_add_effects_in_cluster = set.intersection( + *ungrounded_add_effects_per_segment) + shared_add_effects_per_cluster.append( + shared_add_effects_in_cluster) + extracted_preds |= shared_add_effects_in_cluster + + # Step 6: + # Remove inconsistent predicates except if removing them prevents us + # from disambiguating two or more clusters (i.e. their add effect + # sets are the same after removing the inconsistent predicates). The + # idea here is that HoldingTop and HoldingSide are inconsistent + # within the PlaceOnTable cluster in painting, but we don't want to + # remove them, since we had generated them specifically to + # disambiguate segments in the cluster with the Pick option. + # A consistent predicate is either an add effect, a delete + # effect, or doesn't change, within each cluster, for all clusters. + # Note that it is possible that when 2 inconsistent predicates are + # removed, then two clusters cannot be disambiguated, but if you + # keep either of the two, then you can disambiguate the clusters. + # For now, we just add both back, which is not ideal. + consistent, inconsistent = self._get_consistent_predicates( + extracted_preds, list(final_clusters)) + predicates_to_keep: Set[Predicate] = consistent + consistent_shared_add_effects_per_cluster = [ + add_effs - inconsistent + for add_effs in shared_add_effects_per_cluster + ] + num_clusters = len(final_clusters) + for i in range(num_clusters): + for j in range(num_clusters): + if i == j: + continue + if consistent_shared_add_effects_per_cluster[ + i] == consistent_shared_add_effects_per_cluster[j]: + logging.info( + f"Final clusters {i} and {j} cannot be " + f"disambiguated after removing the inconsistent" + f" predicates.") + predicates_to_keep |= \ + shared_add_effects_per_cluster[i] + predicates_to_keep |= \ + shared_add_effects_per_cluster[j] + + # Remove the initial predicates. + predicates_to_keep -= initial_predicates + + logging.info( + f"\nSelected {len(predicates_to_keep)} predicates out of " + f"{len(candidates)} candidates:") + for pred in sorted(predicates_to_keep): + logging.info(f"{pred}") + return predicates_to_keep + if CFG.grammar_search_pred_clusterer == "oracle": assert CFG.offline_data_method == "demo+gt_operators" assert dataset.annotations is not None and len( @@ -932,35 +1209,8 @@ def _select_predicates_by_clustering( # Finally, select predicates that are consistent (either, it is # an add effect, or a delete effect, or doesn't change) # within all demos. - predicates_to_keep: Set[Predicate] = set() - # for pred in consistent_add_effs_preds: - for pred in non_static_predicates: - keep_pred = True - for seg_list in gt_op_to_segments.values(): - segment_0 = seg_list[0] - pred_in_add_effs_0 = pred in [ - atom.predicate for atom in segment_0.add_effects - ] - pred_in_del_effs_0 = pred in [ - atom.predicate for atom in segment_0.delete_effects - ] - for seg in seg_list[1:]: - pred_in_curr_add_effs = pred in [ - atom.predicate for atom in seg.add_effects - ] - pred_in_curr_del_effs = pred in [ - atom.predicate for atom in seg.delete_effects - ] - if not ((pred_in_add_effs_0 == pred_in_curr_add_effs) - and - (pred_in_del_effs_0 == pred_in_curr_del_effs)): - keep_pred = False - break - if not keep_pred: - break - - else: - predicates_to_keep.add(pred) + predicates_to_keep, _ = self._get_consistent_predicates( + non_static_predicates, list(gt_op_to_segments.values())) # Before returning, remove all the initial predicates. predicates_to_keep -= initial_predicates diff --git a/predicators/settings.py b/predicators/settings.py index f8aa365902..a856581fe3 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -632,6 +632,9 @@ class GlobalSettings: grammar_search_expected_nodes_allow_noops = True grammar_search_classifier_pretty_str_names = ["?x", "?y", "?z"] + # grammar search clustering algorithm parameters + grammar_search_clustering_gmm_num_components = 10 + @classmethod def get_arg_specific_settings(cls, args: Dict[str, Any]) -> Dict[str, Any]: """A workaround for global settings that are derived from the diff --git a/tests/approaches/test_nsrt_learning_approach.py b/tests/approaches/test_nsrt_learning_approach.py index ee972dece4..485479ff16 100644 --- a/tests/approaches/test_nsrt_learning_approach.py +++ b/tests/approaches/test_nsrt_learning_approach.py @@ -433,3 +433,33 @@ def test_predicate_invention_with_oracle_clustering(): offline_data_method="demo+gt_operators", solve_exceptions=ApproachFailure, additional_settings=additional_settings) + + +def test_predicate_invention_with_custom_clustering(): + """Test for predicate invention with a custom clustering algorithm.""" + additional_settings = { + "grammar_search_pred_selection_approach": "clustering", + "grammar_search_pred_clusterer": "option-type-number-sample", + "segmenter": "option_changes", + } + _test_approach(env_name="blocks", + num_train_tasks=10, + approach_name="grammar_search_invention", + strips_learner="cluster_and_intersect", + offline_data_method="demo", + solve_exceptions=ApproachFailure, + additional_settings=additional_settings) + _test_approach(env_name="painting", + num_train_tasks=10, + approach_name="grammar_search_invention", + strips_learner="cluster_and_intersect", + offline_data_method="demo", + solve_exceptions=ApproachFailure, + additional_settings=additional_settings) + _test_approach(env_name="repeated_nextto", + num_train_tasks=10, + approach_name="grammar_search_invention", + strips_learner="pnad_search", + offline_data_method="demo", + solve_exceptions=ApproachFailure, + additional_settings=additional_settings)