Skip to content

Commit

Permalink
More progress after pair programming.
Browse files Browse the repository at this point in the history
  • Loading branch information
ashay-bdai committed Sep 10, 2024
1 parent e62e8cf commit aed5708
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 38 deletions.
3 changes: 2 additions & 1 deletion predicators/approaches/base_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from predicators.structs import Action, Dataset, InteractionRequest, \
InteractionResult, Metrics, ParameterizedOption, Predicate, State, Task, \
Type
from predicators.utils import ExceptionWithInfo
from predicators.utils import ExceptionWithInfo, create_vlm_by_name


class BaseApproach(abc.ABC):
Expand All @@ -29,6 +29,7 @@ def __init__(self, initial_predicates: Set[Predicate],
self._train_tasks = train_tasks
self._metrics: Metrics = defaultdict(float)
self._set_seed(CFG.seed)
self._vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover

@classmethod
@abc.abstractmethod
Expand Down
1 change: 1 addition & 0 deletions predicators/approaches/bilevel_planning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]:
seed = self._seed + self._num_calls
nsrts = self._get_current_nsrts()
preds = self._get_current_predicates()
import pdb; pdb.set_trace()

# Run task planning only and then greedily sample and execute in the
# policy.
Expand Down
2 changes: 1 addition & 1 deletion predicators/approaches/spot_wrapper_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]:
def _policy(state: State) -> Action:
nonlocal base_approach_policy, need_stow
# If we think that we're done, return the done action.
if task.goal_holds(state):
if task.goal_holds(state, self._vlm):
extra_info = SpotActionExtraInfo("done", [], None, tuple(),
None, tuple())
return utils.create_spot_env_action(extra_info)
Expand Down
38 changes: 26 additions & 12 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class _SpotObservation:
class _TruncatedSpotObservation:
"""An observation for a SpotEnv."""
# Camera name to image
images: Dict[str, RGBDImageWithContext]
rgbd_images: Dict[str, RGBDImageWithContext]
# Objects in the environment
objects_in_view: Set[Object]
# Objects seen only by the hand camera
Expand Down Expand Up @@ -186,7 +186,7 @@ def get_robot(


@functools.lru_cache(maxsize=None)
def get_robot_only(self) -> Tuple[Optional[Robot], Optional[LeaseClient]]:
def get_robot_only() -> Tuple[Optional[Robot], Optional[LeaseClient]]:
hostname = CFG.spot_robot_ip
sdk = create_standard_sdk("PredicatorsClient-")
robot = sdk.create_robot(hostname)
Expand Down Expand Up @@ -265,7 +265,6 @@ def __init__(self, use_gui: bool = True) -> None:
if not CFG.bilevel_plan_without_sim:
self._initialize_pybullet()
_SIMULATED_SPOT_ROBOT = self._sim_robot
import pdb; pdb.set_trace()
robot, localizer, lease_client = get_robot()
self._robot = robot
self._localizer = localizer
Expand Down Expand Up @@ -1485,9 +1484,9 @@ def _get_sweeping_surface_for_container(container: Object,
def _get_vlm_query_str(pred_name: str, objects: Sequence[Object]) -> str:
return pred_name + "(" + ", ".join(str(obj.name) for obj in objects) + ")" # pragma: no cover
_VLMOn = utils.create_vlm_predicate(
"VLMOn"
[_movable_object_type, _immovable_object_type],
_get_vlm_query_str
"VLMOn",
[_movable_object_type, _base_object_type],
lambda o: _get_vlm_query_str("VLMOn", o)
)

_ALL_PREDICATES = {
Expand Down Expand Up @@ -2428,8 +2427,17 @@ class VLMTestEnv(SpotRearrangementEnv):
@classmethod
def get_name(cls) -> str:
return "spot_vlm_test_env"

def _get_dry_task(self, train_or_test: str,
task_idx: int) -> EnvironmentTask:
raise NotImplementedError("No dry task for VLMTestEnv.")

def _create_operators() -> Iterator[STRIPSOperator]:
@property
def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]:
"""Get an object from a perception detection ID."""
raise NotImplementedError("No dry task for VLMTestEnv.")

def _create_operators(self) -> Iterator[STRIPSOperator]:
# Pick object
robot = Variable("?robot", _robot_type)
obj = Variable("?object", _movable_object_type)
Expand Down Expand Up @@ -2480,11 +2488,16 @@ def _generate_test_tasks(self) -> List[EnvironmentTask]:
goal = self._generate_goal_description() # currently just one goal
return [EnvironmentTask(None, goal) for _ in range(CFG.num_test_tasks)]

def __init__(self, use_ui: bool = True) -> None:
super().__init__(use_gui)
def _generate_train_tasks(self) -> List[EnvironmentTask]:
goal = self._generate_goal_description() # currently just one goal
return [
EnvironmentTask(None, goal) for _ in range(CFG.num_train_tasks)
]

def __init__(self, use_gui: bool = True) -> None:
robot, lease_client = get_robot_only()
self._robot = robot
self._lease_cient = lease_client
self._lease_client = lease_client
self._strips_operators: Set[STRIPSOperator] = set()
# Used to do [something] when the agent thinks the goal is reached
# but the human says it is not.
Expand All @@ -2494,12 +2507,14 @@ def __init__(self, use_ui: bool = True) -> None:
self._last_action: Optional[Action] = None
# Create constant objects.
self._spot_object = Object("robot", _robot_type)
op_to_name = {o.name for o in _create_operators()}
op_to_name = {o.name: o for o in self._create_operators()}
op_names_to_keep = {
"Pick",
"Place"
}
self._strips_operators = {op_to_name[o] for o in op_names_to_keep}
self._train_tasks = []
self._test_tasks = []

def _actively_construct_env_task(self) -> EnvironmentTask:
assert self._robot is not None
Expand Down Expand Up @@ -2564,7 +2579,6 @@ class SpotCubeEnv(SpotRearrangementEnv):
attempts to place an April Tag cube onto a particular table."""

def __init__(self, use_gui: bool = True) -> None:
import pdb; pdb.set_trace()
super().__init__(use_gui)

op_to_name = {o.name: o for o in _create_operators()}
Expand Down
2 changes: 1 addition & 1 deletion predicators/ground_truth_models/spot_env/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ class SpotEnvsGroundTruthOptionFactory(GroundTruthOptionFactory):
@classmethod
def get_env_names(cls) -> Set[str]:
return {
"spot_vlm_test_env"
"spot_vlm_test_env",
"spot_cube_env",
"spot_soda_floor_env",
"spot_soda_table_env",
Expand Down
55 changes: 37 additions & 18 deletions predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,13 @@ def render_mental_images(self, observation: Observation,
return [img]


class SpotMinimumPerceiver(BasePerceiver):
class SpotMinimalPerceiver(BasePerceiver):
"""A perceiver for spot envs with minimal functionality."""

def render_mental_images(self, observation: Observation,
env_task: EnvironmentTask) -> Video:
raise NotImplementedError()

@classmethod
def get_name(cls) -> str:
return "spot_minimal_perceiver"
Expand Down Expand Up @@ -648,40 +652,51 @@ def update_perceiver_with_action(self, action: Action) -> None:
self._prev_action = action

def reset(self, env_task: EnvironmentTask) -> Task:
init_obs = env_task.init_bos
imgs = init_obs.rgbd_images
self._robot = init_obs.robot
# import pdb; pdb.set_trace()
# init_obs = env_task.init_obs
# imgs = init_obs.rgbd_images
# self._robot = init_obs.robot
# state = self._create_state()
# state.simulator_state["images"] = [imgs]
# state.set(self._robot, "gripper_open_percentage", init_obs.gripper_open_percentage)
# self._curr_state = state
self._curr_env = get_or_create_env(CFG.env)
state = self._create_state()
state.simulator_state["images"] = [imgs]
state.set(self._robot, "gripper_open_percentage") = init_obs.gripper_open_percentage
state.simulator_state = {}
state.simulator_state["images"] = []
self._curr_state = state
goal = self._create_goal(state, env_task.goal_description)
return Task(state, goal)

def step(self, observation: Observation) -> State:
self._waiting_for_observation = False
self._robot = observation.robot
imgs = observation.rgbd_images
self._curr_state.simulator_state["images"].append([imgs])
self._curr_state.set(self._robot, "gripper_open_percentage") = observation.gripper_open_percentage
return self._curr_state.copy()
imgs = [v.rgb for _, v in imgs.items()]
self._curr_state = self._create_state()
self._curr_state.simulator_state["images"] = imgs
self._gripper_open_percentage = observation.gripper_open_percentage
ret_state = self._curr_state.copy()
return ret_state

def _create_state(self) -> State:
if self._waiting_for_observation:
return DefaultState
# Build the continuous part of the state.
assert self._robot is not None
table = Object("talbe", _immovable_object_type)
table = Object("table", _immovable_object_type)
cup = Object("cup", _movable_object_type)
pan = Object("pan", _container_type)
state_dict = {
self._robot: {
"gripper_open_percentage": self._gripper_open_percentage,
"x": self._robot_pos.x,
"y": self._robot_pos.y,
"z": self._robot_pos.z,
"qw": self._robot_pos.rot.w,
"qx": self._robot_pos.rot.x,
"qy": self._robot_pos.rot.y,
"qz": self._robot_pos.rot.z,
"x": 0,
"y": 0,
"z": 0,
"qw": 0,
"qx": 0,
"qy": 0,
"qz": 0,
},
table: {
"x": 0,
Expand Down Expand Up @@ -739,4 +754,8 @@ def _create_state(self) -> State:
"is_sweeper": 0
}
}
return State(state_dict)
state_dict = {k: list(v.values()) for k, v in state_dict.items()}
ret_state = State(state_dict)
ret_state.simulator_state = {}
ret_state.simulator_state["images"] = []
return ret_state
6 changes: 3 additions & 3 deletions predicators/pretrained_model_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def sample_completions(self,
if not os.path.exists(cache_filepath):
if CFG.llm_use_cache_only:
raise ValueError("No cached response found for prompt.")
logging.debug(f"Querying model {model_id} with new prompt.")
print(f"Querying model {model_id} with new prompt.")
# Query the model.
completions = self._sample_completions(prompt, imgs, temperature,
seed, stop_token,
Expand All @@ -118,11 +118,11 @@ def sample_completions(self,
for i, img in enumerate(imgs):
filename_suffix = str(i) + ".jpg"
img.save(os.path.join(imgs_folderpath, filename_suffix))
logging.debug(f"Saved model response to {cache_filepath}.")
print(f"Saved model response to {cache_filepath}.")
# Load the saved completion.
with open(cache_filepath, 'r', encoding='utf-8') as f:
cache_str = f.read()
logging.debug(f"Loaded model response from {cache_filepath}.")
print(f"Loaded model response from {cache_filepath}.")
assert cache_str.count(_CACHE_SEP) == num_completions
cached_prompt, completion_strs = cache_str.split(_CACHE_SEP, 1)
assert cached_prompt == prompt
Expand Down
11 changes: 9 additions & 2 deletions predicators/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,9 +489,16 @@ def __post_init__(self) -> None:
for atom in self.goal:
assert isinstance(atom, GroundAtom)

def goal_holds(self, state: State) -> bool:
def goal_holds(self, state: State, vlm: Optional[Any] = None) -> bool:
"""Return whether the goal of this task holds in the given state."""
return all(goal_atom.holds(state) for goal_atom in self.goal)
from predicators.utils import query_vlm_for_atom_vals
vlm_atoms = set(atom for atom in self.goal if isinstance(atom.predicate, VLMPredicate))
for atom in self.goal:
if atom not in vlm_atoms:
if not atom.holds(state):
return False
true_vlm_atoms = query_vlm_for_atom_vals(vlm_atoms, state, vlm)
return len(true_vlm_atoms) == len(vlm_atoms)

def replace_goal_with_alt_goal(self) -> Task:
"""Return a Task with the goal replaced with the alternative goal if it
Expand Down

0 comments on commit aed5708

Please sign in to comment.