Skip to content

Commit

Permalink
Merge pull request cbfinn#2 from cbfinn/master
Browse files Browse the repository at this point in the history
Remove hack of sampling goals in policy opt code
  • Loading branch information
cbfinn authored Jun 10, 2017
2 parents 45d0cd5 + ebee1a2 commit 5e47c95
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 18 deletions.
3 changes: 3 additions & 0 deletions examples/point_env_randgoal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def observation_space(self):
def action_space(self):
return Box(low=-0.1, high=0.1, shape=(2,))

def sample_goals(self, num_goals):
return np.random.uniform(-0.5, 0.5, size=(num_goals, 2, ))

def reset(self, reset_args=None):
goal = reset_args
if goal is not None:
Expand Down
3 changes: 3 additions & 0 deletions examples/point_env_randgoal_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def observation_space(self):
def action_space(self):
return Box(low=-0.1, high=0.1, shape=(2,))

def sample_goals(self, num_goals):
return np.random.uniform(-0.5, 0.5, size=(num_goals, 2, ))

def reset(self, reset_args=None):
goal = reset_args
if goal is not None:
Expand Down
4 changes: 4 additions & 0 deletions rllab/envs/mujoco/ant_env_direc_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def get_current_obs(self):
]).reshape(-1)
return np.r_[obs, np.array([self.goal_direction])]

def sample_goals(self, num_goals):
# for fwd/bwd env, goal direc is backwards if < 1.5, forwards if > 1.5
return np.random.uniform(0.0, 3.0, (num_goals, ))

@overrides
def reset(self, init_state=None, reset_args=None, **kwargs):
goal_vel = reset_args
Expand Down
3 changes: 3 additions & 0 deletions rllab/envs/mujoco/ant_env_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def get_current_obs(self):
]).reshape(-1)
return np.r_[obs, np.array([self._goal_vel])]

def sample_goals(self, num_goals):
return np.random.uniform(0.0, 3.0, (num_goals, ))

@overrides
def reset(self, init_state=None, reset_args=None, **kwargs):
goal_vel = reset_args
Expand Down
3 changes: 3 additions & 0 deletions rllab/envs/mujoco/ant_env_rand.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def get_current_obs(self):
self.get_body_com("torso"),
]).reshape(-1)

def sample_goals(self, num_goals):
return np.random.uniform(0.0, 3.0, (num_goals, ))

@overrides
def reset(self, init_state=None, reset_args=None, **kwargs):
goal_vel = reset_args
Expand Down
4 changes: 4 additions & 0 deletions rllab/envs/mujoco/ant_env_rand_direc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def get_current_obs(self):
self.get_body_com("torso"),
]).reshape(-1)

def sample_goals(self, num_goals):
# for fwd/bwd env, goal direc is backwards if < 1.5, forwards if > 1.5
return np.random.uniform(0.0, 3.0, (num_goals, ))

@overrides
def reset(self, init_state=None, reset_args=None, **kwargs):
goal_vel = reset_args
Expand Down
4 changes: 4 additions & 0 deletions rllab/envs/mujoco/half_cheetah_env_direc_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def __init__(self, *args, **kwargs):
super(HalfCheetahEnvDirecOracle, self).__init__(*args, **kwargs)
Serializable.__init__(self, *args, **kwargs)

def sample_goals(self, num_goals):
# for fwd/bwd env, goal direc is backwards if < 1.0, forwards if > 1.0
return np.random.uniform(0.0, 2.0, (num_goals, ))

@overrides
def reset(self, init_state=None, reset_args=None, **kwargs):
if reset_args is None:
Expand Down
3 changes: 3 additions & 0 deletions rllab/envs/mujoco/half_cheetah_env_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def __init__(self, *args, **kwargs):
self._goal_vel = None
Serializable.__init__(self, *args, **kwargs)

def sample_goals(self, num_goals):
return np.random.uniform(0.0, 2.0, (num_goals, ))

@overrides
def reset(self, init_state=None, reset_args=None, **kwargs):
goal_vel = reset_args
Expand Down
3 changes: 3 additions & 0 deletions rllab/envs/mujoco/half_cheetah_env_rand.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def __init__(self, goal=None, *args, **kwargs):
super(HalfCheetahEnvRand, self).__init__(*args, **kwargs)
Serializable.__init__(self, *args, **kwargs)

def sample_goals(self, num_goals):
return np.random.uniform(0.0, 2.0, (num_goals, ))

@overrides
def reset(self, init_state=None, reset_args=None, **kwargs):
goal_vel = reset_args
Expand Down
4 changes: 4 additions & 0 deletions rllab/envs/mujoco/half_cheetah_env_rand_direc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def __init__(self, goal_vel=None, *args, **kwargs):
self.goal_vel = goal_vel
self.reset(reset_args=goal_vel)

def sample_goals(self, num_goals):
# for fwd/bwd env, goal direc is backwards if < 1.0, forwards if > 1.0
return np.random.uniform(0.0, 2.0, (num_goals, ))

@overrides
def reset(self, init_state=None, reset_args=None, **kwargs):
goal_vel = reset_args
Expand Down
23 changes: 5 additions & 18 deletions sandbox/rocky/tf/algos/batch_sensitive_polopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,24 +150,11 @@ def train(self):
itr_start_time = time.time()
with logger.prefix('itr #%d | ' % itr):
logger.log("Sampling set of tasks/goals for this meta-batch...")
# TODO - this is a hacky way to specify tasks.
if self.env.observation_space.shape[0] <= 4: # pointmass (oracle=4, normal=2)
learner_env_goals = np.zeros((self.meta_batch_size, 2, ))
# 2d
learner_env_goals = np.random.uniform(-0.5, 0.5, size=(self.meta_batch_size, 2, ))

elif self.env.spec.action_space.shape[0] == 8: # ant
# 0.0 to 3.0 is what specifies the task -- goal vel ranges 0-3.0.
# for fwd/bwd env, goal direc is backwards if < 1.5, forwards if > 1.5
learner_env_goals = np.random.uniform(0.0, 3.0, (self.meta_batch_size, ))

elif self.env.spec.action_space.shape[0] == 6: # cheetah
# 0.0 to 2.0 is what specifies the task -- goal vel ranges 0-2.0.
# for fwd/bwd env, goal direc is backwards if < 1.0, forwards if > 1.0
learner_env_goals = np.random.uniform(0.0, 2.0, (self.meta_batch_size, ))

else:
raise NotImplementedError('unrecognized env')

env = self.env
while 'sample_goals' not in dir(env):
env = env.wrapped_env
learner_env_goals = env.sample_goals(self.meta_batch_size)

self.policy.switch_to_init_dist() # Switch to pre-update policy

Expand Down

0 comments on commit 5e47c95

Please sign in to comment.