Skip to content

Commit

Permalink
rapid learn fixed, coffe fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
“matrixbalto” committed Jan 11, 2025
1 parent a3c4d11 commit 6519fa0
Show file tree
Hide file tree
Showing 8 changed files with 891 additions and 280 deletions.
178 changes: 155 additions & 23 deletions predicators/approaches/bridge_policy_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,8 @@ def _termination_fn(s: State) -> bool:
lambda s: None, _termination_fn)
return request



def learn_from_interaction_results(
self, results: Sequence[InteractionResult]) -> None:
# Turn state action pairs from results into trajectories
Expand Down Expand Up @@ -670,11 +672,22 @@ def learn_from_interaction_results(

# Find the last planner state before this point
last_planner_idx = -1
last_planner_state = None
for i in range(j, -1, -1):
if policy_log[i] == "planner":
last_planner_state = result.states[i]
break
last_planner_atoms = utils.abstract(last_planner_state, self._get_current_predicates())
# last_planner_idx = next(j for j, atoms in enumerate(plan_log) if last_planner_atoms.issubset(atoms))
for i in range(len(plan_log[j])):
if plan_log[j][i].issubset(last_planner_atoms):
last_planner_idx = i
break

# import ipdb;ipdb.set_trace()
# try:
# last_planner_idx = plan_log[j].index(last_planner_atoms)
# except:
# import ipdb;ipdb.set_trace()
if last_planner_idx != -1:
# Get all states from last planner state onwards in the plan
safe_states = plan_log[j][last_planner_idx+1:]
Expand All @@ -687,16 +700,16 @@ def learn_from_interaction_results(
print(future_state)
print(current_atoms)
print(safe_states)
# import ipdb;ipdb.set_trace()
import ipdb;ipdb.set_trace()
rwd = 1000 # We found a match, this is a safe state
num_rewards += 1
break


reward_bonus.append(rwd)
print("num rewards", num_rewards)
# if num_rewards > 1:
# import ipdb;ipdb.set_trace()
if num_rewards >= 1:
import ipdb;ipdb.set_trace()
reward_bonuses.append(reward_bonus)
mapleq_states.append(result.states[-1])
new_traj = LowLevelTrajectory(mapleq_states, mapleq_actions)
Expand Down Expand Up @@ -994,9 +1007,21 @@ def _policy(s: State) -> Action:
self._current_control = "planner"
self._current_policy = self.planning_policy
# import ipdb;ipdb.set_trace()
print("bridge done")
print(self._current_plan)
# print("bridge done")
# print(self._current_plan)
break
if self.bridge_done:
current_task = Task(s, self._current_task.goal)
task_list, option_policy = self._get_option_policy_by_planning(
current_task, CFG.timeout)
self._current_plan = task_list
self._current_policy = utils.option_policy_to_policy(
option_policy,
max_option_steps=CFG.max_num_steps_option_rollout,
raise_error_on_repeated_state=True,
)

self.bridge_done = False

# # Normal execution. Either keep executing the current option, or
# # switch to the next option if it has terminated.
Expand Down Expand Up @@ -1034,8 +1059,8 @@ def _policy(s: State) -> Action:
print("current state", s)
print(f"Option execution failed: {e}")
print(f"Current control: {self._current_control}")
print(f"Bridge done: {self.bridge_done}")
if self.bridge_done or self._current_control == "planner":
print(f"Failed option: {e.info['last_failed_option']}")
if self._current_control == "planner":
# Get the next option from the option policy
try:
action = self._current_policy(s)
Expand All @@ -1049,26 +1074,26 @@ def _policy(s: State) -> Action:
memory = e.info.get("memory", {})
action = current_option.policy(s)

else:
action = self.last_action
print("planner execution error", action)
# else:
# action = self.last_action
# print("planner execution error", action)



# Switch control from planner to bridge.
# assert self._current_control == "planner"
if not self.bridge_done:
print("switching to bridge", self.bridge_done)
self._current_control = "bridge"
self.num_bridge_steps=0
self.mapleq._q_function._last_planner_state = s

self._bridge_called_state = s
self._current_policy = self.mapleq._solve( # pylint: disable=protected-access
task, timeout, train_or_test)
action = self._current_policy(s)
if self._current_control == "bridge":
self.num_bridge_steps += 1

print("switching to bridge")
self._current_control = "bridge"
self.num_bridge_steps=0
self.mapleq._q_function._last_planner_state = s

self._bridge_called_state = s
self._current_policy = self.mapleq._solve( # pylint: disable=protected-access
task, timeout, train_or_test)
action = self._current_policy(s)
if self._current_control == "bridge":
self.num_bridge_steps += 1


if train_or_test == "train":
Expand All @@ -1078,3 +1103,110 @@ def _policy(s: State) -> Action:

return _policy



def learn_from_interaction_results(
self, results: Sequence[InteractionResult]) -> None:
# Turn state action pairs from results into trajectories
# If we haven't collected any new results on this cycle, skip learning
# for efficiency.
if not results:
return None
all_states = []
all_actions = []
policy_logs = self._policy_logs
plan_logs = self._plan_logs
reward_bonuses = []
count=-1
mapleq_states = []
mapleq_actions = []
reward_bonuses = []
for i in range(len(results)):
reward_bonus = []
result = results[i]
policy_log = policy_logs[:len(result.states[:-1])]
plan_log = plan_logs[:len(result.states[:-1])]
# mapleq_states = [
# state for j, state in enumerate(result.states[:-1])
# if policy_log[j] == "bridge"
# or policy_log[max(j - 1, 0)] == "bridge"
# ]
# mapleq_actions = [
# action for j, action in enumerate(result.actions)
# if policy_log[j] == "bridge"
# or policy_log[max(j - 1, 0)] == "bridge"
# ]



add_to_traj = False

num_rewards = 0
for j in range(len(result.states)):
if j<len(policy_log) and policy_log[j] == "bridge":
if not add_to_traj:
#begin new trajectory!
add_to_traj = True
mapleq_states.append([])
mapleq_actions.append([])
reward_bonuses.append([])
count+=1
#add to the current trajectory
mapleq_states[count].append(result.states[j])
mapleq_actions[count].append(result.actions[j])
rwd = -1 # Default reward is -1
current_atoms = utils.abstract(result.states[j+1], self._get_current_predicates())
# Find the last planner state before this point
last_planner_idx = -1
last_planner_state = None
for i in range(j, -1, -1):
if policy_log[i] == "planner":
last_planner_state = result.states[i]
break
last_planner_atoms = utils.abstract(last_planner_state, self._get_current_predicates())
for i in range(len(plan_log[j])):
if plan_log[j][i].issubset(last_planner_atoms):
last_planner_idx = i
break

if last_planner_idx != -1:
# Get all states from last planner state onwards in the plan
safe_states = plan_log[j][last_planner_idx+1:]

# Check if current state satisfies postconditions of all remaining operators
# by checking if it contains atoms from any future state in the plan
for future_state in safe_states:
if future_state.issubset(current_atoms):
print("FOUND MATCH")
print(future_state)
print(current_atoms)
print(safe_states)
# import ipdb;ipdb.set_trace()
rwd = 1000 # We found a match, this is a safe state
num_rewards += 1
break
reward_bonuses[count].append(rwd)

else:

if add_to_traj:
#end the current trajectory and add the last state
#also add the trajectory to the list of trajectories
add_to_traj = False
mapleq_states[count].append(result.states[j])
new_traj = LowLevelTrajectory(mapleq_states[count], mapleq_actions[count])
self._trajs.append(new_traj)


print("num rewards", num_rewards)
# if num_rewards >= 1:
# import ipdb;ipdb.set_trace()
all_states.extend(mapleq_states)
all_actions.extend(mapleq_actions)
policy_logs = policy_logs[len(result.states) - 1:]
plan_logs = plan_logs[len(result.states) - 1:]
self.mapleq.get_interaction_requests()
self.mapleq._learn_nsrts(self._trajs, 0, [] * len(self._trajs), reward_bonuses) # pylint: disable=protected-access
self._policy_logs = []
self._plan_logs = []
return None
85 changes: 59 additions & 26 deletions predicators/approaches/maple_q_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,40 +253,73 @@ def _update_maple_data(self, reward_bonuses) -> None:
start_idx = self._last_seen_segment_traj_idx + 1
new_trajs = self._segmented_trajs[start_idx:]

goal_offset = CFG.max_initial_demos
assert len(self._segmented_trajs) == goal_offset + \
len(self._interaction_goals)
new_traj_goals = self._interaction_goals[goal_offset + start_idx:]

if CFG.rl_rwd_shape:

for traj_i, segmented_traj in enumerate(new_trajs):
self._last_seen_segment_traj_idx += 1
reward_bonus = reward_bonuses[traj_i]
# if sum(reward_bonus) > 0:
# import ipdb;ipdb.set_trace()
for seg_i, segment in enumerate(segmented_traj):
s = segment.states[0]
goal = None
o = segment.get_option()
ns = segment.states[-1]
# reward = 1.0 if goal.issubset(segment.final_atoms) else 0.0
reward = 0
# if CFG.use_callplanner and o.parent == self.CallPlanner and reward == 1:
# reward += 0.5

for traj_i, segmented_traj in enumerate(new_trajs):
self._last_seen_segment_traj_idx += 1
reward_bonus = reward_bonuses[traj_i]
# if sum(reward_bonus) > 0:
# import ipdb;ipdb.set_trace()
for seg_i, segment in enumerate(segmented_traj):
s = segment.states[0]
goal = new_traj_goals[traj_i]
o = segment.get_option()
ns = segment.states[-1]
reward = 1.0 if goal.issubset(segment.final_atoms) else 0.0
if CFG.use_callplanner and o.parent == self.CallPlanner and reward == 1:
reward += 0.5

if CFG.rl_rwd_shape:
reward += reward_bonus[0]
reward_bonus.pop(0)




# if reward > 0:
# print(s, o, reward)

terminal = reward > 0 or seg_i == len(segmented_traj) - 1
# if reward > 0:
# print(s, o, reward)

terminal = reward > 0 or seg_i == len(segmented_traj) - 1


# if reward >= 1:
# import ipdb;ipdb.set_trace()

self._q_function.add_datum_to_replay_buffer(
(s, goal, o, ns, reward, terminal))


else:
goal_offset = CFG.max_initial_demos
assert len(self._segmented_trajs) == goal_offset + \
len(self._interaction_goals)
new_traj_goals = self._interaction_goals[goal_offset + start_idx:]

# if reward >= 1:


for traj_i, segmented_traj in enumerate(new_trajs):
self._last_seen_segment_traj_idx += 1
reward_bonus = reward_bonuses[traj_i]
# if sum(reward_bonus) > 0:
# import ipdb;ipdb.set_trace()
for seg_i, segment in enumerate(segmented_traj):
s = segment.states[0]
goal = new_traj_goals[traj_i]
o = segment.get_option()
ns = segment.states[-1]
reward = 1.0 if goal.issubset(segment.final_atoms) else 0.0
if CFG.use_callplanner and o.parent == self.CallPlanner and reward == 1:
reward += 0.5

# if reward > 0:
# print(s, o, reward)

terminal = reward > 0 or seg_i == len(segmented_traj) - 1

self._q_function.add_datum_to_replay_buffer(
(s, goal, o, ns, reward, terminal))

# if reward >= 1:
# import ipdb;ipdb.set_trace()

self._q_function.add_datum_to_replay_buffer(
(s, goal, o, ns, reward, terminal))
23 changes: 13 additions & 10 deletions predicators/envs/coffee.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,10 +709,10 @@ class CoffeeLidEnv(CoffeeEnv):
def __init__(self, use_gui: bool = True) -> None:
super().__init__(use_gui)

#to remove lid you need a low level action that gets the stickiness exactly right
self._lid_type = Type(
"lid",
["x", "y", "stickiness", "on_cup"])
# #to remove lid you need a low level action that gets the stickiness exactly right
# self._lid_type = Type(
# "lid",
# ["x", "y", "stickiness", "on_cup"])

@classmethod
def get_name(cls) -> str:
Expand All @@ -722,7 +722,8 @@ def get_name(cls) -> str:
def types(self) -> Set[Type]:
return {
self._cup_type, self._jug_type, self._machine_type,
self._robot_type, self._lid_type
self._robot_type,
# self._lid_type
}

def simulate(self, state: State, action: Action) -> State:
Expand Down Expand Up @@ -826,10 +827,10 @@ def simulate(self, state: State, action: Action) -> State:
elif abs(fingers - self.closed_fingers) < self.grasp_finger_tol and \
sq_dist_to_handle < self.grasp_position_tol and \
abs(jug_rot) < self.pick_jug_rot_tol and norm_dwrist!=-1:
print("ig we r grasping")
print("sq_dist_to_handle < self.grasp_position_tol", sq_dist_to_handle < self.grasp_position_tol)
print("abs(jug_rot) < self.pick_jug_rot_tol", abs(jug_rot) < self.pick_jug_rot_tol)
print("abs(fingers - self.closed_fingers) < self.grasp_finger_tol")
# print("ig we r grasping")
# print("sq_dist_to_handle < self.grasp_position_tol", sq_dist_to_handle < self.grasp_position_tol)
# print("abs(jug_rot) < self.pick_jug_rot_tol", abs(jug_rot) < self.pick_jug_rot_tol)
# print("abs(fingers - self.closed_fingers) < self.grasp_finger_tol")
# Snap to the handle.
handle_x, handle_y, handle_z = handle_pos
next_state.set(self._robot, "x", handle_x)
Expand All @@ -841,12 +842,14 @@ def simulate(self, state: State, action: Action) -> State:
next_state.set(self._jug, "is_held", 1.0)
# Check if the jug should be rotated.
elif self._Twisting_holds(state, [self._robot, self._jug]):
print("WE ROTATE THE JUG with dwrist", dwrist, "and rot", state.get(self._jug, "rot"))
# print("WE ROTATE THE JUG with dwrist", dwrist, "and rot", state.get(self._jug, "rot"))
# Rotate the jug.
rot = state.get(self._jug, "rot")
next_state.set(self._jug, "rot", rot + dwrist)
# If the jug is close enough to the dispense area and the machine is
# on, the jug should get filled.
# print("so uhhhh we did not grasp the jug </3: abs(fingers - self.closed_fingers) < self.grasp_finger_tol", abs(fingers - self.closed_fingers) < self.grasp_finger_tol, " sq_dist_to_handle < self.grasp_position_tol ", sq_dist_to_handle < self.grasp_position_tol,
# " abs(jug_rot) < self.pick_jug_rot_tol ", abs(jug_rot) < self.pick_jug_rot_tol, " norm_dwrist!=-1 ", norm_dwrist!=-1)
jug_in_machine = self._JugInMachine_holds(next_state,
[self._jug, self._machine])
machine_on = self._MachineOn_holds(next_state, [self._machine])
Expand Down
Loading

0 comments on commit 6519fa0

Please sign in to comment.