Skip to content

Commit

Permalink
found bug with interpret and fixed it!
Browse files Browse the repository at this point in the history
  • Loading branch information
NishanthJKumar committed Nov 21, 2024
1 parent aef0db0 commit 360594a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
11 changes: 8 additions & 3 deletions predicators/datasets/generate_atom_trajs_with_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,6 @@ def _generate_ground_atoms_with_vlm_oo_code_gen(
ground_atoms = utils.abstract(state, candidates | known_predicates)
ground_atoms_traj.append(ground_atoms)
ground_atoms_trajs.append(ground_atoms_traj)
import ipdb; ipdb.set_trace()
return ground_atoms_trajs


Expand Down Expand Up @@ -870,12 +869,18 @@ def _create_prompt_from_image_option_traj(
for i, a in enumerate(image_option_traj.actions):
state = image_option_traj.states[i]
demo_str.append(f"state {i}:")
demo_str.append(state.dict_str(indent=2, object_features=True))
# NOTE: it's important to set the round_feat_vals argument to False
# here. If we set it to True, then the VLM might mistakenly propose
# predicates that work given rounding, but fail otherwise.
# So for instance, a predicate classifier that does `== 0` would
# work for a value 0.00123 rounded to a single decimal place,
# but wouldn't actually work when deployed on the number 0.00123!
demo_str.append(state.dict_str(indent=2, object_features=True, round_feat_vals=False))
demo_str.append(f"action {i}: {a.name}")
num_states = len(image_option_traj.states)
state = image_option_traj.states[-1]
demo_str.append(f"state {num_states}:")
demo_str.append(state.dict_str(indent=2, object_features=True))
demo_str.append(state.dict_str(indent=2, object_features=True, round_feat_vals=False))
demo_str_ = '\n'.join(demo_str)
template = template.replace("[DEMO_TRAJECTORY]", demo_str_)

Expand Down
4 changes: 2 additions & 2 deletions predicators/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,14 @@ def pretty_str(self) -> str:
suffix = "\n" + "#" * ll + "\n"
return prefix + "\n\n".join(table_strs) + suffix

def dict_str(self, indent: int = 0, object_features: bool = True) -> str:
def dict_str(self, indent: int = 0, object_features: bool = True, round_feat_vals: bool = True) -> str:
"""Return a dictionary representation of the state."""
state_dict = {}
for obj in self:
obj_dict = {}
if obj.type.name == "robot" or object_features:
for attribute, value in zip(obj.type.feature_names, self[obj]):
if isinstance(value, (float, int, np.float32)):
if isinstance(value, (float, int, np.float32)) and round_feat_vals:
value = round(float(value), 1)
obj_dict[attribute] = value
obj_name = obj.name
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"ImageHash",
"google-generativeai",
"tenacity",
"opencv-python",
"torchvision"
],
include_package_data=True,
extras_require={
Expand Down

0 comments on commit 360594a

Please sign in to comment.