Skip to content

Commit

Permalink
baby rapidlearn done
Browse files Browse the repository at this point in the history
  • Loading branch information
“matrixbalto” committed Jan 8, 2025
1 parent 0c1f131 commit a3c4d11
Show file tree
Hide file tree
Showing 7 changed files with 336 additions and 44 deletions.
284 changes: 269 additions & 15 deletions predicators/approaches/bridge_policy_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,10 @@ def __init__(self,
self._policy_logs: List[str] = []
self._plan_logs = []
self._current_task: Optional[Task] = None

self.num_bridge_steps = 0
self.bridge_done = False
self.planning_policy = None
self.num_planner_steps = 0

def _Can_plan(self, state: State, _: Sequence[Object]) -> bool:
if (MPDQNFunction._old_vectorize_state(self.mapleq._q_function, state) != # pylint: disable=protected-access
Expand Down Expand Up @@ -581,13 +584,17 @@ def _act_policy(s: State) -> Action:
if just_starting:
task = self._train_tasks[train_task_idx]
self._current_control = "planner"
self.num_bridge_steps = 0
self.bridge_done = False
self.num_planner_steps = 0
_, option_policy = self._get_option_policy_by_planning(
task, CFG.timeout)
self._current_policy = utils.option_policy_to_policy(
option_policy,
max_option_steps=CFG.max_num_steps_option_rollout,
raise_error_on_repeated_state=True,
)
self.planning_policy = self._current_policy
just_starting = False
return policy(s)

Expand All @@ -614,6 +621,7 @@ def learn_from_interaction_results(
plan_logs = self._plan_logs
reward_bonuses = []
for i in range(len(results)):
reward_bonus = []
result = results[i]
policy_log = policy_logs[:len(result.states[:-1])]
plan_log = plan_logs[:len(result.states[:-1])]
Expand All @@ -628,25 +636,68 @@ def learn_from_interaction_results(
or policy_log[max(j - 1, 0)] == "bridge"
]

# if CFG.rl_rwd_shape:
# for j, _ in enumerate(result.actions):
# reward_bonus = []
# sussy = plan_log[j] # This is the planned sequence of states at timestep j
# rwd = 0
# # Check if we're currently in bridge mode or were in bridge mode in the previous step
# if policy_log[j] == "bridge" or policy_log[max(j - 1, 0)] == "bridge":
# # Get the abstract state (set of atoms) for the next state
# current_atoms = utils.abstract(result.states[j+1], self._get_current_predicates())

# # Loop through each state in the planned sequence
# for g in range(len(sussy)):
# # If our current state contains all atoms from a planned state
# if sussy[g].issubset(current_atoms):
# # Give reward based on how far along the plan we are
# # e.g., if we match state 4 out of 10 states, reward would be 4/(10*2) = 0.2
# rwd = (g/(len(sussy)*2))
# break
# reward_bonus.append(rwd)

# reward_bonuses.extend(reward_bonus)


num_rewards = 0
if CFG.rl_rwd_shape:

for j, _ in enumerate(result.actions):
rwd = -1 # Default reward is -1

# Only apply reward shaping during bridge mode or right after
if policy_log[j] == "bridge" or policy_log[max(j - 1, 0)] == "bridge":
current_atoms = utils.abstract(result.states[j+1], self._get_current_predicates())

# Find the last planner state before this point
last_planner_idx = -1
for i in range(j, -1, -1):
if policy_log[i] == "planner":
last_planner_idx = i
break

sussy = plan_log[j]
rwd = 0
if policy_log[j] == "bridge" or policy_log[max(j - 1, 0)] == "bridge":
current_atoms = utils.abstract(result.states[j+1], self._get_current_predicates())
for g in range(len(sussy)):
if sussy[g].issubset(current_atoms):
rwd = (g/(len(sussy)*2))
if last_planner_idx != -1:
# Get all states from last planner state onwards in the plan
safe_states = plan_log[j][last_planner_idx+1:]

# Check if current state satisfies postconditions of all remaining operators
# by checking if it contains atoms from any future state in the plan
for future_state in safe_states:
if future_state.issubset(current_atoms):
print("FOUND MATCH")
print(future_state)
print(current_atoms)
print(safe_states)
# import ipdb;ipdb.set_trace()
rwd = 1000 # We found a match, this is a safe state
num_rewards += 1
break
reward_bonuses.append(rwd)


# except:
# import ipdb;ipdb.set_trace()
# reward_bonuses.extend(reward_bonus)

reward_bonus.append(rwd)
print("num rewards", num_rewards)
# if num_rewards > 1:
# import ipdb;ipdb.set_trace()
reward_bonuses.append(reward_bonus)
mapleq_states.append(result.states[-1])
new_traj = LowLevelTrajectory(mapleq_states, mapleq_actions)
self._trajs.append(new_traj)
Expand Down Expand Up @@ -823,4 +874,207 @@ def _policy(s: State) -> Action:
action = self._current_policy(s)
return action

return _policy
return _policy























class RapidLearnApproach(RLBridgePolicyApproach):
#basically same approach but stop when u hit a state thats further than the state we broke at and give a positive reward
# so u need to modify solve function to switch back to planner after u finished (bridge_done = True)
# so like check if the state we get inside of solve function has the atoms of a later state in the plan,
# if so then self.bridge_done = True

def __init__(self,
initial_predicates: Set[Predicate],
initial_options: Set[ParameterizedOption],
types: Set[Type],
action_space: Box,
train_tasks: List[Task],
task_planning_heuristic: str = "default",
max_skeletons_optimized: int = -1) -> None:
super().__init__(initial_predicates,
initial_options,
types,
action_space,
train_tasks,
task_planning_heuristic,
max_skeletons_optimized)
self.bridge_done = False
self.last_action = None
self.num_bridge_steps = 0
self.planning_policy = None

@classmethod
def get_name(cls) -> str:
return "rapid_learn"

def _solve(self,
task: Task,
timeout: int,
train_or_test: str = "test") -> Callable[[State], Action]:
# Start by planning. Note that we cannot start with the bridge policy
# because the bridge policy takes as input the last failed NSRT.

# print("\n=== NEW TASK ===")
# print("bridge_done before reset:", self.bridge_done)
self.bridge_done = False
# print("bridge_done after reset:", self.bridge_done)

self.num_planner_steps = 0
self._current_task = task
# import ipdb;ipdb.set_trace()
self._current_control = "planner"
task_list, option_policy = self._get_option_policy_by_planning(task, timeout)
self._current_policy = utils.option_policy_to_policy(
option_policy,
max_option_steps=CFG.max_num_steps_option_rollout,
raise_error_on_repeated_state=True,
)
self.planning_policy = self._current_policy
self._current_plan = task_list
if not self._maple_initialized:
self.mapleq = MPDQNApproach(self._get_current_predicates(),
self._initial_options, self._types,
self._action_space, self._train_tasks, self.CallPlanner)
self._maple_initialized = True
self._init_nsrts()

# self.mapleq._q_function.set_grounding( # pylint: disable=protected-access
# all_objects, goals, all_ground_nsrts)

# Prevent infinite loops by detecting if the bridge policy is called
# twice with the same state.
last_bridge_policy_state = DefaultState
self.num_bridge_steps=0
self.num_planner_steps = 0


def _policy(s: State) -> Action:
nonlocal last_bridge_policy_state

# if we are in bridge policy mode, check if our current state has the atoms of a later state in the plan
# if so then self.bridge_done = True


# print("num bridge steps", self.num_bridge_steps)
if self._current_control == "bridge":
print(self._current_control)
if self.num_bridge_steps > CFG.max_num_bridge_steps:
self.bridge_done = True
self._current_control = "planner"
self._current_policy = self.planning_policy
self.num_bridge_steps = 0
print("bridge done")
else:
current_atoms = utils.abstract(s, self._get_current_predicates())
print("NUM PLANNER STEPS", self.num_planner_steps)
for state in self._current_plan[self.num_planner_steps:]:
if state.issubset(current_atoms):
self.bridge_done = True
self._current_control = "planner"
self._current_policy = self.planning_policy
# import ipdb;ipdb.set_trace()
print("bridge done")
print(self._current_plan)
break

# # Normal execution. Either keep executing the current option, or
# # switch to the next option if it has terminated.
# try:
# action = self._current_policy(s)
# if train_or_test == "train":
# self._policy_logs.append(self._current_control)
# self._plan_logs.append(self._current_plan)

# if self._current_control == "bridge":
# num_bridge_steps += 1

# return action
# except utils.OptionExecutionFailure as e:
# # failed = True
# if self._current_control == "bridge":
# import ipdb;ipdb.set_trace()
# logging.debug("failed" + self._current_control)



try:
action = self._current_policy(s)
print(f"Current control: {self._current_control}, Action: {action}")
if self._current_control == "planner":
self.num_planner_steps += 1
if train_or_test == "train":
self._policy_logs.append(self._current_control)
self._plan_logs.append(self._current_plan)
if self._current_control == "bridge":
self.num_bridge_steps += 1
self.last_action = action
return action
except utils.OptionExecutionFailure as e:
print("current state", s)
print(f"Option execution failed: {e}")
print(f"Current control: {self._current_control}")
print(f"Bridge done: {self.bridge_done}")
if self.bridge_done or self._current_control == "planner":
# Get the next option from the option policy
try:
action = self._current_policy(s)
# current_option = option_policy(s)
except:
# If option_policy fails, use the one from the error info
current_option = e.info["last_failed_option"]

if current_option is not None:
# Force the option's policy to execute
memory = e.info.get("memory", {})
action = current_option.policy(s)

else:
action = self.last_action
print("planner execution error", action)



# Switch control from planner to bridge.
# assert self._current_control == "planner"
if not self.bridge_done:
print("switching to bridge", self.bridge_done)
self._current_control = "bridge"
self.num_bridge_steps=0
self.mapleq._q_function._last_planner_state = s

self._bridge_called_state = s
self._current_policy = self.mapleq._solve( # pylint: disable=protected-access
task, timeout, train_or_test)
action = self._current_policy(s)
if self._current_control == "bridge":
self.num_bridge_steps += 1


if train_or_test == "train":
self._policy_logs.append(self._current_control)
self._plan_logs.append(self._current_plan)
return action

return _policy

21 changes: 16 additions & 5 deletions predicators/approaches/maple_q_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _learn_nsrts(self, trajectories: List[LowLevelTrajectory],
# On the first cycle, we need to register the ground NSRTs, goals, and
# objects in the Q function so that it can define its inputs.

if not online_learning_cycle and CFG.approach != "rl_bridge_policy" and CFG.approach != "rl_first_bridge":
if not online_learning_cycle and CFG.approach != "rl_bridge_policy" and CFG.approach != "rl_first_bridge" and CFG.approach != "rapid_learn":
all_ground_nsrts: Set[_GroundNSRT] = set()
if CFG.sesame_grounder == "naive":
for nsrt in self._nsrts:
Expand All @@ -154,8 +154,7 @@ def _learn_nsrts(self, trajectories: List[LowLevelTrajectory],
raise ValueError(
f"Unrecognized sesame_grounder: {CFG.sesame_grounder}")
goals = [t.goal for t in self._train_tasks]
self._q_function.set_grounding(all_objects, goals,
all_ground_nsrts)
self._q_function.set_grounding(all_objects, goals, all_ground_nsrts)
# Update the data using the updated self._segmented_trajs.
if isinstance(self, MPDQNApproach):
MPDQNApproach._update_maple_data(self, reward_bonuses) # pylint: disable=protected-access
Expand Down Expand Up @@ -199,6 +198,7 @@ def _update_maple_data(self) -> None:
terminal = reward > 0 or seg_i == len(segmented_traj) - 1
if terminal:
already_terminal = terminal

self._q_function.add_datum_to_replay_buffer(
(s, goal, o, ns, reward, terminal))

Expand Down Expand Up @@ -258,8 +258,13 @@ def _update_maple_data(self, reward_bonuses) -> None:
len(self._interaction_goals)
new_traj_goals = self._interaction_goals[goal_offset + start_idx:]



for traj_i, segmented_traj in enumerate(new_trajs):
self._last_seen_segment_traj_idx += 1
reward_bonus = reward_bonuses[traj_i]
# if sum(reward_bonus) > 0:
# import ipdb;ipdb.set_trace()
for seg_i, segment in enumerate(segmented_traj):
s = segment.states[0]
goal = new_traj_goals[traj_i]
Expand All @@ -270,12 +275,18 @@ def _update_maple_data(self, reward_bonuses) -> None:
reward += 0.5

if CFG.rl_rwd_shape:
reward += reward_bonuses[0]
reward_bonuses.pop(0)
reward += reward_bonus[0]
reward_bonus.pop(0)


# if reward > 0:
# print(s, o, reward)

terminal = reward > 0 or seg_i == len(segmented_traj) - 1


# if reward >= 1:
# import ipdb;ipdb.set_trace()

self._q_function.add_datum_to_replay_buffer(
(s, goal, o, ns, reward, terminal))
Loading

0 comments on commit a3c4d11

Please sign in to comment.