Skip to content

Commit

Permalink
simpler fix to interpret issue
Browse files Browse the repository at this point in the history
  • Loading branch information
NishanthJKumar committed Nov 21, 2024
1 parent e97dd48 commit a3b71a1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 20 deletions.
14 changes: 2 additions & 12 deletions predicators/datasets/generate_atom_trajs_with_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,22 +869,12 @@ def _create_prompt_from_image_option_traj(
for i, a in enumerate(image_option_traj.actions):
state = image_option_traj.states[i]
demo_str.append(f"state {i}:")
# NOTE: it's important to set the round_feat_vals argument to False
# here. If we set it to True, then the VLM might mistakenly propose
# predicates that work given rounding, but fail otherwise.
# So for instance, a predicate classifier that does `== 0` would
# work for a value 0.00123 rounded to a single decimal place,
# but wouldn't actually work when deployed on the number 0.00123!
demo_str.append(
state.dict_str(indent=2,
object_features=True,
round_feat_vals=False))
demo_str.append(state.dict_str(indent=2, object_features=True))
demo_str.append(f"action {i}: {a.name}")
num_states = len(image_option_traj.states)
state = image_option_traj.states[-1]
demo_str.append(f"state {num_states}:")
demo_str.append(
state.dict_str(indent=2, object_features=True, round_feat_vals=False))
demo_str.append(state.dict_str(indent=2, object_features=True))
demo_str_ = '\n'.join(demo_str)
template = template.replace("[DEMO_TRAJECTORY]", demo_str_)

Expand Down
9 changes: 1 addition & 8 deletions predicators/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,20 +215,13 @@ def pretty_str(self) -> str:
suffix = "\n" + "#" * ll + "\n"
return prefix + "\n\n".join(table_strs) + suffix

def dict_str(self,
indent: int = 0,
object_features: bool = True,
round_feat_vals: bool = True) -> str:
def dict_str(self, indent: int = 0, object_features: bool = True) -> str:
"""Return a dictionary representation of the state."""
state_dict = {}
for obj in self:
obj_dict = {}
if obj.type.name == "robot" or object_features:
for attribute, value in zip(obj.type.feature_names, self[obj]):
if isinstance(
value,
(float, int, np.float32)) and round_feat_vals:
value = round(float(value), 1)
obj_dict[attribute] = value
obj_name = obj.name
state_dict[f"{obj_name}:{obj.type.name}"] = obj_dict
Expand Down

0 comments on commit a3b71a1

Please sign in to comment.