Skip to content

Commit

Permalink
Merge pull request #290 from Learning-and-Intelligent-Systems/upstrea…
Browse files Browse the repository at this point in the history
…m-vlm-preds-merge

Update current master to be in sync with upstream predicators
  • Loading branch information
NishanthJKumar authored Aug 22, 2024
2 parents d1bc95f + 8b9420e commit ccaa689
Show file tree
Hide file tree
Showing 74 changed files with 6,285 additions and 687 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Default owners for all files.
* @NishanthJKumar @wmcclinton @tsilver-bdai
@NishanthJKumar
10 changes: 8 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ pylint_recursive.py
.coverage
results
eval_trajectories
images
videos
logs
logs*
saved_approaches
saved_datasets
scripts/results
pretrained_model_cache
pretrained_model_cache*
tests/datasets/mock_vlm_datasets/cache/
machines.txt
*_vision_data
tests/_fake_trajs
Expand All @@ -29,6 +31,10 @@ predicators/envs/assets/task_jsons/spot/last.json
spot_perception_outputs
spot_perception_debug_dir/
sas_plan
downward/
Gymnasium-Robotics/
predicators/datasets/vlm_input_data_prompts/vision_api/prompt.txt
predicators/datasets/vlm_input_data_prompts/vision_api/response.txt
side-state-view.png
top-down-state-view.png

Expand Down
249 changes: 239 additions & 10 deletions predicators/approaches/bridge_policy_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,24 @@

import logging
import time
from typing import Callable, List, Optional, Sequence, Set
from typing import Callable, Dict, List, Optional, Sequence, Set

import numpy as np
from gym.spaces import Box

from predicators import utils
from predicators.approaches import ApproachFailure, ApproachTimeout
from predicators.approaches.maple_q_approach import MapleQApproach
from predicators.approaches.oracle_approach import OracleApproach
from predicators.bridge_policies import BridgePolicyDone, create_bridge_policy
from predicators.nsrt_learning.segmentation import segment_trajectory
from predicators.option_model import _OptionModelBase
from predicators.settings import CFG
from predicators.structs import NSRT, Action, BridgeDataset, DefaultState, \
DemonstrationQuery, DemonstrationResponse, InteractionRequest, \
InteractionResult, ParameterizedOption, Predicate, Query, State, Task, \
Type, _Option
from predicators.utils import OptionExecutionFailure
from predicators.structs import NSRT, Action, Array, BridgeDataset, \
DefaultState, DemonstrationQuery, DemonstrationResponse, \
InteractionRequest, InteractionResult, LiftedAtom, LowLevelTrajectory, \
Object, ParameterizedOption, Predicate, Query, State, Task, Type, \
Variable, _GroundNSRT, _Option


class BridgePolicyApproach(OracleApproach):
Expand Down Expand Up @@ -128,7 +130,7 @@ def _policy(s: State) -> Action:
except BridgePolicyDone:
assert current_control == "bridge"
failed_option = None # not used, but satisfy linting
except OptionExecutionFailure as e:
except utils.OptionExecutionFailure as e:
failed_option = e.info["last_failed_option"]
if failed_option is not None:
all_failed_options.append(failed_option)
Expand Down Expand Up @@ -179,7 +181,7 @@ def _policy(s: State) -> Action:
)
try:
return current_policy(s)
except OptionExecutionFailure as e:
except utils.OptionExecutionFailure as e:
all_failed_options.append(e.info["last_failed_option"])
raise ApproachFailure(
e.args[0], info={"all_failed_options": all_failed_options})
Expand Down Expand Up @@ -238,7 +240,7 @@ def _act_policy(s: State) -> Action:
reached_stuck_state = True
all_failed_options = e.info["all_failed_options"]
# Approach failures not caught in interaction loop.
raise OptionExecutionFailure(e.args[0], e.info)
raise utils.OptionExecutionFailure(e.args[0], e.info)

def _termination_fn(s: State) -> bool:
return reached_stuck_state or task.goal_holds(s)
Expand Down Expand Up @@ -287,7 +289,7 @@ def learn_from_interaction_results(
states = [traj.states[0]]
atoms = atom_traj
else:
states = utils.segment_trajectory_to_state_sequence(
states = utils.segment_trajectory_to_start_end_state_sequence(
segmented_traj)
atoms = utils.segment_trajectory_to_atoms_sequence(
segmented_traj)
Expand Down Expand Up @@ -346,3 +348,230 @@ def learn_from_interaction_results(
))

return self._bridge_policy.learn_from_demos(self._bridge_dataset)


class RLBridgePolicyApproach(BridgePolicyApproach):
"""A simulator-free bilevel planning approach that uses a Deep RL (Maple Q)
bridge policy."""

def __init__(self,
initial_predicates: Set[Predicate],
initial_options: Set[ParameterizedOption],
types: Set[Type],
action_space: Box,
train_tasks: List[Task],
task_planning_heuristic: str = "default",
max_skeletons_optimized: int = -1) -> None:
super().__init__(initial_predicates, initial_options, types,
action_space, train_tasks, task_planning_heuristic,
max_skeletons_optimized)
self._maple_initialized = False
if task_planning_heuristic == "default":
task_planning_heuristic = CFG.sesame_task_planning_heuristic
self._task_planning_heuristic = task_planning_heuristic
self._trajs: List[LowLevelTrajectory] = []
self.CanPlan = Predicate("CanPlan", [], self._Can_plan)
self.CallPlanner = utils.SingletonParameterizedOption(
"CallPlanner",
types=None,
policy=self.call_planner_policy,
params_space=Box(low=np.array([]), high=np.array([]), shape=(0, )),
)
initial_options.add(self.CallPlanner)
self._initial_options = initial_options
self.mapleq=MapleQApproach(self._get_current_predicates(), \
self._initial_options, self._types, \
self._action_space, self._train_tasks)
self._current_control: Optional[str] = None
option_policy = self._get_option_policy_by_planning(
self._train_tasks[0], CFG.timeout)
self._current_policy = utils.option_policy_to_policy(
option_policy,
max_option_steps=CFG.max_num_steps_option_rollout,
raise_error_on_repeated_state=True,
)
self._bridge_called_state = State(data={})
self._policy_logs: List[Optional[str]] = []

def _Can_plan(self, state: State, _: Sequence[Object]) -> bool:
if (self.mapleq._q_function._vectorize_state(state) != # pylint: disable=protected-access
self.mapleq._q_function._vectorize_state( # pylint: disable=protected-access
self._bridge_called_state)).any(): # pylint: disable=protected-access
return True
return False

def call_planner_policy(self, state: State, _: Dict, __: Sequence[Object],
___: Array) -> Action:
"""policy for CallPlanner option."""
self._current_control = "planner"
# create a new task where the init state is our current state
current_task = Task(state, self._train_tasks[0].goal)
option_policy = self._get_option_policy_by_planning(
current_task, CFG.timeout)
self._current_policy = utils.option_policy_to_policy(
option_policy,
max_option_steps=CFG.max_num_steps_option_rollout,
raise_error_on_repeated_state=True,
)

return Action(np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32))

def call_planner_nsrt(self) -> NSRT:
"""CallPlanner NSRT."""
parameters: Sequence[Variable] = []
option_vars = parameters
option = self.CallPlanner
preconditions = {LiftedAtom(self.CanPlan, [])}
add_effects: Set[LiftedAtom] = set()
delete_effects: Set[LiftedAtom] = set()

ignore_effects: Set[Predicate] = set()
call_planner_nsrt = NSRT("CallPlanner", parameters, preconditions,
add_effects, delete_effects, ignore_effects,
option, option_vars, utils.null_sampler)
return call_planner_nsrt

@classmethod
def get_name(cls) -> str:
return "rl_bridge_policy"

@property
def is_learning_based(self) -> bool:
return True

def _init_nsrts(self) -> None:
"""Initializing nsrts for MAPLE Q."""
nsrts = self._get_current_nsrts()
callplanner_nsrt = self.call_planner_nsrt()
nsrts.add(callplanner_nsrt)
predicates = self._get_current_predicates()
all_ground_nsrts: Set[_GroundNSRT] = set()
if CFG.sesame_grounder == "naive":
for nsrt in nsrts:
all_objects = {o for t in self._train_tasks for o in t.init}
all_ground_nsrts.update(
utils.all_ground_nsrts(nsrt, all_objects))
elif CFG.sesame_grounder == "fd_translator": # pragma: no cover
all_objects = set()
for t in self.mapleq._train_tasks: # pylint: disable=protected-access
curr_task_objects = set(t.init)
curr_task_types = {o.type for o in t.init}
curr_init_atoms = utils.abstract(t.init, predicates)
all_ground_nsrts.update(
utils.all_ground_nsrts_fd_translator(
nsrts, curr_task_objects, predicates, curr_task_types,
curr_init_atoms, t.goal))
all_objects.update(curr_task_objects)
else: # pragma: no cover
raise ValueError(
f"Unrecognized sesame_grounder: {CFG.sesame_grounder}")
goals = [t.goal for t in self.mapleq._train_tasks] # pylint: disable=protected-access
self.mapleq._q_function.set_grounding( # pylint: disable=protected-access
all_objects, goals, all_ground_nsrts)

def _solve(self,
task: Task,
timeout: int,
train_or_test: str = "test") -> Callable[[State], Action]:
# Start by planning. Note that we cannot start with the bridge policy
# because the bridge policy takes as input the last failed NSRT.
self._current_control = "planner"
option_policy = self._get_option_policy_by_planning(task, timeout)
self._current_policy = utils.option_policy_to_policy(
option_policy,
max_option_steps=CFG.max_num_steps_option_rollout,
raise_error_on_repeated_state=True,
)
if not self._maple_initialized:
self.mapleq = MapleQApproach(self._get_current_predicates(),
self._initial_options, self._types,
self._action_space, self._train_tasks)
self._maple_initialized = True
self._init_nsrts()

def _policy(s: State) -> Action:
# Normal execution. Either keep executing the current option, or
# switch to the next option if it has terminated.
try:
action = self._current_policy(s)
if train_or_test == "train":
self._policy_logs.append(self._current_control)
return action
except utils.OptionExecutionFailure:
logging.debug(f"Failed control: {self._current_control}")
# Switch control from planner to bridge.
assert self._current_control == "planner"
self._current_control = "bridge"
if train_or_test == "train":
self._policy_logs.append(self._current_control)
self._bridge_called_state = s
self._current_policy = self.mapleq._solve( # pylint: disable=protected-access
task, timeout, train_or_test)
action = self._current_policy(s)

return action

return _policy

def _create_interaction_request(self,
train_task_idx: int) -> InteractionRequest:
task = self._train_tasks[train_task_idx]
policy = self._solve(task, timeout=CFG.timeout, train_or_test="train")
just_starting = True

def _act_policy(s: State) -> Action:
nonlocal just_starting
if just_starting:
self._current_control = "planner"
option_policy = self._get_option_policy_by_planning(
task, CFG.timeout)
self._current_policy = utils.option_policy_to_policy(
option_policy,
max_option_steps=CFG.max_num_steps_option_rollout,
raise_error_on_repeated_state=True,
)
just_starting = False
return policy(s)

def _termination_fn(s: State) -> bool:
return task.goal_holds(s)

# The request's acting policy is from mapleq
# The resulting trajectory is from maple q's sampling
request = InteractionRequest(train_task_idx, _act_policy,
lambda s: None, _termination_fn)
return request

def learn_from_interaction_results(
self, results: Sequence[InteractionResult]) -> None:
# Turn state action pairs from results into trajectories
# If we haven't collected any new results on this cycle, skip learning
# for efficiency.
if not results:
return None
policy_logs = self._policy_logs
for i in range(len(results)):
result = results[i]
policy_log = policy_logs[:len(result.states[:-1])]
# We index max(j - 1, 0) to count for the case when CallPlanner
# is used, since "planner" is added to the corresponding policy_log.
# When j = 0, planner is always in control
mapleq_states = [
state for j, state in enumerate(result.states[:-1])
if policy_log[j] == "bridge"
or policy_log[max(j - 1, 0)] == "bridge"
]
mapleq_actions = [
action for j, action in enumerate(result.actions)
if policy_log[j] == "bridge"
or policy_log[max(j - 1, 0)] == "bridge"
]
mapleq_states.append(result.states[-1])
new_traj = LowLevelTrajectory(mapleq_states, mapleq_actions)
self._trajs.append(new_traj)
policy_logs = policy_logs[len(result.states) - 1:]

self.mapleq.get_interaction_requests()
self.mapleq._learn_nsrts(self._trajs, 0, [] * len(self._trajs)) # pylint: disable=protected-access
self._policy_logs = []
return None
16 changes: 12 additions & 4 deletions predicators/approaches/gnn_option_policy_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,12 @@ def _solve_with_shooting(self, task: Task,
start_time = time.perf_counter()
memory: Dict = {} # optionally updated by predict()
# Keep trying until the timeout.
tries: int = 0
all_num_act: int = 0
total_num_act: int = 0
while time.perf_counter() - start_time < timeout:
tries += 1
all_num_act += total_num_act
total_num_act = 0
state = task.init
plan: List[_Option] = []
Expand Down Expand Up @@ -243,10 +248,13 @@ def _policy(s: State) -> Action:
break
# If num_act is zero, that means that the option is stuck in
# the state, so we should break to avoid infinite loops.
if num_act == 0:
break
total_num_act += num_act
# Break early if we have timed out.
if time.perf_counter() - start_time < timeout:
total_num_act += num_act
if time.perf_counter() - start_time > timeout:
break
if num_act == 0:
break
all_num_act += total_num_act
print(f"Shooting: {all_num_act} actions with {tries} tries in \
{time.perf_counter() - start_time} seconds")
raise ApproachTimeout("Shooting timed out!")
Loading

0 comments on commit ccaa689

Please sign in to comment.