Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
NishanthJKumar committed Oct 9, 2023
1 parent 6003c8c commit b0392c2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 26 deletions.
31 changes: 12 additions & 19 deletions predicators/approaches/grammar_search_invention_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,9 +1057,7 @@ def _select_predicates_by_clustering(
each_dim_uniform = False
break
if each_dim_uniform:
clusters[option][types][max_num_objs] = [
segments
]
clusters[option][types][max_num_objs] = [segments]
logging.info(
f"STEP 4: generated no further sample-based"
f" clusters (uniformly distributed "
Expand All @@ -1074,33 +1072,29 @@ def _select_predicates_by_clustering(
# complexity of the model (chosen by BIC here)
# are hyperparameters.
max_components = min(
len(samples), len(np.unique(samples)), CFG.
grammar_search_clustering_gmm_num_components
)
len(samples), len(np.unique(samples)),
CFG.grammar_search_clustering_gmm_num_components)
n_components = np.arange(1, max_components + 1)
models = [
GMM(n,
covariance_type="full",
GMM(n, covariance_type="full",
random_state=0).fit(samples)
for n in n_components
]
bic = [m.bic(samples) for m in models]
best = models[np.argmin(bic)]
assignments = best.predict(samples)
label_to_segments: Dict[int,
List[Segment]] = {}
label_to_segments: Dict[int, List[Segment]] = {}
for l, assignment in enumerate(assignments):
label_to_segments.setdefault(
assignment, []).append(segments[l])
clusters[option][types][max_num_objs] = list(
label_to_segments.values())
logging.info(
f"STEP 4: generated "
f"{len(label_to_segments.keys())}"
f"sample-based clusters for cluster "
f"{i+j+k+1} from STEP 3 involving option "
f"{option}, type {types}, and max num "
f"objects {max_num_objs}.")
logging.info(f"STEP 4: generated "
f"{len(label_to_segments.keys())}"
f"sample-based clusters for cluster "
f"{i+j+k+1} from STEP 3 involving option "
f"{option}, type {types}, and max num "
f"objects {max_num_objs}.")

# We could avoid these loops by creating the final set of clusters
# as part of STEP 4, but this is not prohibitively slow and serves
Expand Down Expand Up @@ -1160,8 +1154,7 @@ def _select_predicates_by_clustering(
if i == j:
continue
if consistent_shared_add_effects_per_cluster[
i] == consistent_shared_add_effects_per_cluster[
j]:
i] == consistent_shared_add_effects_per_cluster[j]:
logging.info(
f"Final clusters {i} and {j} cannot be "
f"disambiguated after removing the inconsistent"
Expand Down
15 changes: 8 additions & 7 deletions tests/approaches/test_nsrt_learning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ def test_predicate_invention_with_oracle_clustering():
solve_exceptions=ApproachFailure,
additional_settings=additional_settings)


def test_predicate_invention_with_custom_clustering():
"""Test for predicate invention with a custom clustering algorithm."""
additional_settings = {
Expand All @@ -455,10 +456,10 @@ def test_predicate_invention_with_custom_clustering():
offline_data_method="demo",
solve_exceptions=ApproachFailure,
additional_settings=additional_settings)
# _test_approach(env_name="repeated_nextto",
# num_train_tasks=10,
# approach_name="grammar_search_invention",
# strips_learner="cluster_and_intersect",
# offline_data_method="demo",
# solve_exceptions=ApproachFailure,
# additional_settings=additional_settings)
_test_approach(env_name="repeated_nextto",
num_train_tasks=10,
approach_name="grammar_search_invention",
strips_learner="pnad_search",
offline_data_method="demo",
solve_exceptions=ApproachFailure,
additional_settings=additional_settings)

0 comments on commit b0392c2

Please sign in to comment.