Skip to content

Commit

Permalink
Clustering via reverse engineering (#1556)
Browse files Browse the repository at this point in the history
* Initial commit.

* Fix a minor bug.

* Small changes to satisfy mypi.

* Fix linting.

* Add tests.

* fixes

* fix minor grammatical issue

* Change check for non-zero types.

---------

Co-authored-by: Nishanth Kumar <[email protected]>
Co-authored-by: Nishanth Kumar <[email protected]>
  • Loading branch information
3 people authored Oct 31, 2023
1 parent 00c7860 commit 0a60aa7
Show file tree
Hide file tree
Showing 3 changed files with 314 additions and 31 deletions.
312 changes: 281 additions & 31 deletions predicators/approaches/grammar_search_invention_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
from dataclasses import dataclass, field
from functools import cached_property
from operator import le
from typing import Callable, Dict, FrozenSet, Iterator, List, Sequence, Set, \
Tuple
from typing import Any, Callable, Dict, FrozenSet, Iterator, List, Sequence, \
Set, Tuple

import numpy as np
from gym.spaces import Box
from scipy.stats import kstest
from sklearn.mixture import GaussianMixture as GMM

from predicators import utils
from predicators.approaches.nsrt_learning_approach import NSRTLearningApproach
Expand Down Expand Up @@ -887,13 +890,287 @@ def _get_successors(

return set(kept_predicates)

@staticmethod
def _get_consistent_predicates(
predicates: Set[Predicate], clusters: List[List[Segment]]
) -> Tuple[Set[Predicate], Set[Predicate]]:
"""Returns all predicates that are consistent with respect to a set of
segment clusters.
A consistent predicate is is either an add effect, a delete
effect, or doesn't change, within each cluster, for all
clusters.
"""

consistent: Set[Predicate] = set()
inconsistent: Set[Predicate] = set()
for pred in predicates:
keep_pred = True
for seg_list in clusters:
segment_0 = seg_list[0]
pred_in_add_effs_0 = pred in [
atom.predicate for atom in segment_0.add_effects
]
pred_in_del_effs_0 = pred in [
atom.predicate for atom in segment_0.delete_effects
]
for seg in seg_list[1:]:
pred_in_curr_add_effs = pred in [
atom.predicate for atom in seg.add_effects
]
pred_in_curr_del_effs = pred in [
atom.predicate for atom in seg.delete_effects
]
not_consis_add = pred_in_add_effs_0 != pred_in_curr_add_effs
not_consis_del = pred_in_del_effs_0 != pred_in_curr_del_effs
if not_consis_add or not_consis_del:
keep_pred = False
inconsistent.add(pred)
logging.info(f"Inconsistent predicate: {pred.name}")
break
if not keep_pred:
break
else:
consistent.add(pred)
return consistent, inconsistent

def _select_predicates_by_clustering(
self, candidates: Dict[Predicate, float],
initial_predicates: Set[Predicate], dataset: Dataset,
atom_dataset: List[GroundAtomTrajectory]) -> Set[Predicate]:
"""Cluster segments from the atom_dataset into clusters corresponding
to operators and use this to select predicates."""

if CFG.grammar_search_pred_clusterer == "option-type-number-sample":
# This procedure tries to reverse engineer the clusters of segments
# that correspond to the oracle operators and selects predicates
# that are add effects in those clusters, letting pnad_search
# downstream handle chainability.

assert CFG.segmenter == "option_changes"
segments = [
seg for ll_traj, atom_seq in atom_dataset for seg in
segment_trajectory(ll_traj, initial_predicates, atom_seq)
]

# Step 1:
# Cluster segments by the option that generated them. We know that
# at the very least, operators are 1 to 1 with options.
option_to_segments: Dict[Any, Any] = {} # Dict[str, List[Segment]]
for seg in segments:
name = seg.get_option().name
option_to_segments.setdefault(name, []).append(seg)
logging.info(f"STEP 1: generated {len(option_to_segments.keys())} "
f"option-based clusters.")
clusters = option_to_segments.copy() # Tree-like structure.

# Step 2:
# Further cluster by the types that appear in a segment's add
# effects. Operators have a fixed number of typed arguments.
for i, pair in enumerate(option_to_segments.items()):
option, segments = pair
types_to_segments: Dict[Tuple[Type, ...], List[Segment]] = {}
for seg in segments:
types_in_effects = [
set(a.predicate.types) for a in seg.add_effects
]
# To cluster on type, there must be types. That is, there
# must be add effects in the segment and the object
# arguments for at least one add effect must be nonempty.
assert len(types_in_effects) > 0 and len(
set.union(*types_in_effects)) > 0
types = tuple(sorted(list(set.union(*types_in_effects))))
types_to_segments.setdefault(types, []).append(seg)
logging.info(
f"STEP 2: generated {len(types_to_segments.keys())} "
f"type-based clusters for cluster {i+1} from STEP 1 "
f"involving option {option}.")
clusters[option] = types_to_segments

# Step 3:
# Further cluster by the maximum number of objects that appear in a
# segment's add effects. Note that the use of maximum here is
# somewhat arbitrary. Alternatively, you could cluster for every
# possible number of objects and not the max among what you see in
# the add effects of a particular segment.
for i, (option, types_to_segments) in enumerate(clusters.items()):
for j, (types,
segments) in enumerate(types_to_segments.items()):
num_to_segments: Dict[int, List[Segment]] = {}
for seg in segments:
max_num_objs = max(
len(a.objects) for a in seg.add_effects)
num_to_segments.setdefault(max_num_objs,
[]).append(seg)
logging.info(
f"STEP 3: generated {len(num_to_segments.keys())} "
f"num-object-based clusters for cluster {i+j+1} from "
f"STEP 2 involving option {option} and type {types}.")
clusters[option][types] = num_to_segments

# Step 4:
# Further cluster by sample, if a sample is present. The idea here
# is to separate things like PickFromTop and PickFromSide.
for i, (option, types_to_num) in enumerate(clusters.items()):
for j, (types,
num_to_segments) in enumerate(types_to_num.items()):
for k, (max_num_objs,
segments) in enumerate(num_to_segments.items()):
# If the segments in this cluster have no sample, then
# don't cluster further.
if len(segments[0].get_option().params) == 0:
clusters[option][types][max_num_objs] = [segments]
logging.info(
f"STEP 4: generated no further sample-based "
f"clusters (no parameter) for cluster {i+j+k+1}"
f" from STEP 3 involving option {option}, type "
f"{types}, and max num objects {max_num_objs}."
)
continue
# pylint: disable=line-too-long
# If the parameters are described by a uniform
# distribution, then don't cluster further. This
# helps prevent overfitting. A proper implementation
# would do a multi-dimensional test
# (https://ieeexplore.ieee.org/document/4767477,
# https://ui.adsabs.harvard.edu/abs/1987MNRAS.225..155F/abstract,
# https://stats.stackexchange.com/questions/30982/how-to-test-uniformity-in-several-dimensions)
# but for now we will only check each dimension
# individually to keep the implementation simple.
# pylint: enable=line-too-long
samples = np.array(
[seg.get_option().params for seg in segments])
each_dim_uniform = True
for d in range(samples.shape[1]):
col = samples[:, d]
minimum = col.min()
maximum = col.max()
null_hypothesis = np.random.uniform(
minimum, maximum, len(col))
p_value = kstest(col, null_hypothesis).pvalue

# We use a significance value of 0.05.
if p_value < 0.05:
each_dim_uniform = False
break
if each_dim_uniform:
clusters[option][types][max_num_objs] = [segments]
logging.info(
f"STEP 4: generated no further sample-based"
f" clusters (uniformly distributed "
f"parameter) for cluster {i+j+k+1} "
f"from STEP 3 involving option {option}, "
f" type {types}, and max num objects "
f"{max_num_objs}.")
continue
# Determine clusters by assignment from a
# Gaussian Mixture Model. The number of
# components and the negative weighting on the
# complexity of the model (chosen by BIC here)
# are hyperparameters.
max_components = min(
len(samples), len(np.unique(samples)),
CFG.grammar_search_clustering_gmm_num_components)
n_components = np.arange(1, max_components + 1)
models = [
GMM(n, covariance_type="full",
random_state=0).fit(samples)
for n in n_components
]
bic = [m.bic(samples) for m in models]
best = models[np.argmin(bic)]
assignments = best.predict(samples)
label_to_segments: Dict[int, List[Segment]] = {}
for l, assignment in enumerate(assignments):
label_to_segments.setdefault(
assignment, []).append(segments[l])
clusters[option][types][max_num_objs] = list(
label_to_segments.values())
logging.info(f"STEP 4: generated "
f"{len(label_to_segments.keys())}"
f"sample-based clusters for cluster "
f"{i+j+k+1} from STEP 3 involving option "
f"{option}, type {types}, and max num "
f"objects {max_num_objs}.")

# We could avoid these loops by creating the final set of clusters
# as part of STEP 4, but this is not prohibitively slow and serves
# to clarify the nested dictionary structure, which we may make use
# of in follow-up work that modifies the clusters more.
final_clusters = []
for option in clusters.keys():
for types in clusters[option].keys():
for max_num_objs in clusters[option][types].keys():
for cluster in clusters[option][types][max_num_objs]:
final_clusters.append(cluster)
logging.info(f"Total {len(final_clusters)} final clusters.")

# Step 5:
# Extract predicates from the pure intersection of the add effects
# in each cluster.
extracted_preds = set()
shared_add_effects_per_cluster = []
for c in final_clusters:
grounded_add_effects_per_segment = [
seg.add_effects for seg in c
]
ungrounded_add_effects_per_segment = []
for effs in grounded_add_effects_per_segment:
ungrounded_add_effects_per_segment.append(
set(a.predicate for a in effs))
shared_add_effects_in_cluster = set.intersection(
*ungrounded_add_effects_per_segment)
shared_add_effects_per_cluster.append(
shared_add_effects_in_cluster)
extracted_preds |= shared_add_effects_in_cluster

# Step 6:
# Remove inconsistent predicates except if removing them prevents us
# from disambiguating two or more clusters (i.e. their add effect
# sets are the same after removing the inconsistent predicates). The
# idea here is that HoldingTop and HoldingSide are inconsistent
# within the PlaceOnTable cluster in painting, but we don't want to
# remove them, since we had generated them specifically to
# disambiguate segments in the cluster with the Pick option.
# A consistent predicate is either an add effect, a delete
# effect, or doesn't change, within each cluster, for all clusters.
# Note that it is possible that when 2 inconsistent predicates are
# removed, then two clusters cannot be disambiguated, but if you
# keep either of the two, then you can disambiguate the clusters.
# For now, we just add both back, which is not ideal.
consistent, inconsistent = self._get_consistent_predicates(
extracted_preds, list(final_clusters))
predicates_to_keep: Set[Predicate] = consistent
consistent_shared_add_effects_per_cluster = [
add_effs - inconsistent
for add_effs in shared_add_effects_per_cluster
]
num_clusters = len(final_clusters)
for i in range(num_clusters):
for j in range(num_clusters):
if i == j:
continue
if consistent_shared_add_effects_per_cluster[
i] == consistent_shared_add_effects_per_cluster[j]:
logging.info(
f"Final clusters {i} and {j} cannot be "
f"disambiguated after removing the inconsistent"
f" predicates.")
predicates_to_keep |= \
shared_add_effects_per_cluster[i]
predicates_to_keep |= \
shared_add_effects_per_cluster[j]

# Remove the initial predicates.
predicates_to_keep -= initial_predicates

logging.info(
f"\nSelected {len(predicates_to_keep)} predicates out of "
f"{len(candidates)} candidates:")
for pred in sorted(predicates_to_keep):
logging.info(f"{pred}")
return predicates_to_keep

if CFG.grammar_search_pred_clusterer == "oracle":
assert CFG.offline_data_method == "demo+gt_operators"
assert dataset.annotations is not None and len(
Expand Down Expand Up @@ -932,35 +1209,8 @@ def _select_predicates_by_clustering(
# Finally, select predicates that are consistent (either, it is
# an add effect, or a delete effect, or doesn't change)
# within all demos.
predicates_to_keep: Set[Predicate] = set()
# for pred in consistent_add_effs_preds:
for pred in non_static_predicates:
keep_pred = True
for seg_list in gt_op_to_segments.values():
segment_0 = seg_list[0]
pred_in_add_effs_0 = pred in [
atom.predicate for atom in segment_0.add_effects
]
pred_in_del_effs_0 = pred in [
atom.predicate for atom in segment_0.delete_effects
]
for seg in seg_list[1:]:
pred_in_curr_add_effs = pred in [
atom.predicate for atom in seg.add_effects
]
pred_in_curr_del_effs = pred in [
atom.predicate for atom in seg.delete_effects
]
if not ((pred_in_add_effs_0 == pred_in_curr_add_effs)
and
(pred_in_del_effs_0 == pred_in_curr_del_effs)):
keep_pred = False
break
if not keep_pred:
break

else:
predicates_to_keep.add(pred)
predicates_to_keep, _ = self._get_consistent_predicates(
non_static_predicates, list(gt_op_to_segments.values()))

# Before returning, remove all the initial predicates.
predicates_to_keep -= initial_predicates
Expand Down
3 changes: 3 additions & 0 deletions predicators/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,9 @@ class GlobalSettings:
grammar_search_expected_nodes_allow_noops = True
grammar_search_classifier_pretty_str_names = ["?x", "?y", "?z"]

# grammar search clustering algorithm parameters
grammar_search_clustering_gmm_num_components = 10

@classmethod
def get_arg_specific_settings(cls, args: Dict[str, Any]) -> Dict[str, Any]:
"""A workaround for global settings that are derived from the
Expand Down
Loading

0 comments on commit 0a60aa7

Please sign in to comment.