From b0392c255f98e852a0a375e20e3018373210bdef Mon Sep 17 00:00:00 2001 From: Nishanth Kumar Date: Mon, 9 Oct 2023 10:19:18 -0400 Subject: [PATCH] fixes --- .../grammar_search_invention_approach.py | 31 +++++++------------ .../approaches/test_nsrt_learning_approach.py | 15 ++++----- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/predicators/approaches/grammar_search_invention_approach.py b/predicators/approaches/grammar_search_invention_approach.py index b8fe9d073a..7a2393410e 100644 --- a/predicators/approaches/grammar_search_invention_approach.py +++ b/predicators/approaches/grammar_search_invention_approach.py @@ -1057,9 +1057,7 @@ def _select_predicates_by_clustering( each_dim_uniform = False break if each_dim_uniform: - clusters[option][types][max_num_objs] = [ - segments - ] + clusters[option][types][max_num_objs] = [segments] logging.info( f"STEP 4: generated no further sample-based" f" clusters (uniformly distributed " @@ -1074,33 +1072,29 @@ def _select_predicates_by_clustering( # complexity of the model (chosen by BIC here) # are hyperparameters. max_components = min( - len(samples), len(np.unique(samples)), CFG. - grammar_search_clustering_gmm_num_components - ) + len(samples), len(np.unique(samples)), + CFG.grammar_search_clustering_gmm_num_components) n_components = np.arange(1, max_components + 1) models = [ - GMM(n, - covariance_type="full", + GMM(n, covariance_type="full", random_state=0).fit(samples) for n in n_components ] bic = [m.bic(samples) for m in models] best = models[np.argmin(bic)] assignments = best.predict(samples) - label_to_segments: Dict[int, - List[Segment]] = {} + label_to_segments: Dict[int, List[Segment]] = {} for l, assignment in enumerate(assignments): label_to_segments.setdefault( assignment, []).append(segments[l]) clusters[option][types][max_num_objs] = list( label_to_segments.values()) - logging.info( - f"STEP 4: generated " - f"{len(label_to_segments.keys())}" - f"sample-based clusters for cluster " - f"{i+j+k+1} from STEP 3 involving option " - f"{option}, type {types}, and max num " - f"objects {max_num_objs}.") + logging.info(f"STEP 4: generated " + f"{len(label_to_segments.keys())}" + f"sample-based clusters for cluster " + f"{i+j+k+1} from STEP 3 involving option " + f"{option}, type {types}, and max num " + f"objects {max_num_objs}.") # We could avoid these loops by creating the final set of clusters # as part of STEP 4, but this is not prohibitively slow and serves @@ -1160,8 +1154,7 @@ def _select_predicates_by_clustering( if i == j: continue if consistent_shared_add_effects_per_cluster[ - i] == consistent_shared_add_effects_per_cluster[ - j]: + i] == consistent_shared_add_effects_per_cluster[j]: logging.info( f"Final clusters {i} and {j} cannot be " f"disambiguated after removing the inconsistent" diff --git a/tests/approaches/test_nsrt_learning_approach.py b/tests/approaches/test_nsrt_learning_approach.py index 1ab1b757b8..485479ff16 100644 --- a/tests/approaches/test_nsrt_learning_approach.py +++ b/tests/approaches/test_nsrt_learning_approach.py @@ -434,6 +434,7 @@ def test_predicate_invention_with_oracle_clustering(): solve_exceptions=ApproachFailure, additional_settings=additional_settings) + def test_predicate_invention_with_custom_clustering(): """Test for predicate invention with a custom clustering algorithm.""" additional_settings = { @@ -455,10 +456,10 @@ def test_predicate_invention_with_custom_clustering(): offline_data_method="demo", solve_exceptions=ApproachFailure, additional_settings=additional_settings) - # _test_approach(env_name="repeated_nextto", - # num_train_tasks=10, - # approach_name="grammar_search_invention", - # strips_learner="cluster_and_intersect", - # offline_data_method="demo", - # solve_exceptions=ApproachFailure, - # additional_settings=additional_settings) \ No newline at end of file + _test_approach(env_name="repeated_nextto", + num_train_tasks=10, + approach_name="grammar_search_invention", + strips_learner="pnad_search", + offline_data_method="demo", + solve_exceptions=ApproachFailure, + additional_settings=additional_settings)