Skip to content

Commit

Permalink
Format docstrings.
Browse files Browse the repository at this point in the history
  • Loading branch information
vaxenburg committed Mar 20, 2024
1 parent be87a5c commit 4d7352f
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 82 deletions.
2 changes: 1 addition & 1 deletion flybody/fruitfly/fruitfly.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def initialize_episode(self, physics: 'mjcf.Physics',
body_mass = physics.named.model.body_subtreemass[
'walker/thorax'] # gram.
self._weight = np.linalg.norm(physics.model.opt.gravity) * body_mass
# Fold wings if not used.
# Retract wings if not used.
if not self._use_wings and self.name == 'walker':
for s in ['left', 'right']:
for dof in ['yaw', 'roll', 'pitch']:
Expand Down
8 changes: 4 additions & 4 deletions flybody/tasks/template_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ class TemplateTask(Walking):
def __init__(self, claw_friction: Optional[float] = 1.0, **kwargs):
"""Template class for walking fly tasks.
Args:
claw_friction: Friction of claw geoms with floor.
**kwargs: Arguments passed to the superclass constructor.
"""
Args:
claw_friction: Friction of claw geoms with floor.
**kwargs: Arguments passed to the superclass constructor.
"""

super().__init__(add_ghost=False, ghost_visible_legs=False, **kwargs)

Expand Down
82 changes: 41 additions & 41 deletions flybody/tasks/trajectory_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def __init__(self,
random_state: Optional[np.random.RandomState] = None):
"""Initializes the base trajectory loader.
Args:
path: Path to hdf5 dataset file with reference rajectories.
traj_indices: List of trajectory indices to use, e.g. for train/test
splitting etc. If None, use all available trajectories.
random_state: Random state for reproducibility.
"""
Args:
path: Path to hdf5 dataset file with reference rajectories.
traj_indices: List of trajectory indices to use, e.g. for train/test
splitting etc. If None, use all available trajectories.
random_state: Random state for reproducibility.
"""

if random_state is None:
self._random_state = np.random.RandomState(None)
Expand Down Expand Up @@ -72,12 +72,12 @@ def __init__(
):
"""Initializes the flight trajectory loader.
Args:
path: Path to hdf5 dataset file with reference rajectories.
traj_indices: List of trajectory indices to use, e.g. for train/test
splitting etc. If None, use all available trajectories.
random_state: Random state for reproducibility.
"""
Args:
path: Path to hdf5 dataset file with reference rajectories.
traj_indices: List of trajectory indices to use, e.g. for train/test
splitting etc. If None, use all available trajectories.
random_state: Random state for reproducibility.
"""
super().__init__(path, traj_indices, random_state=random_state)

self._com_qpos = []
Expand All @@ -103,17 +103,17 @@ def get_trajectory(
end_step: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]:
"""Returns a flight trajectory from the dataset.
Args:
traj_idx: Index of the desired trajectory. If None, a random trajectory
is selected.
start_step: Start index for the trajectory slice. If None, defaults to
the beginning.
end_step: End index for the trajectory slice. If None, defaults to the
end.
Returns:
tuple: Two numpy arrays for com_qpos and com_qvel respectively.
"""
Args:
traj_idx: Index of the desired trajectory. If None, a random
trajectory is selected.
start_step: Start index for the trajectory slice. If None, defaults
to the beginning.
end_step: End index for the trajectory slice. If None, defaults to
the end.
Returns:
tuple: Two numpy arrays for com_qpos and com_qvel respectively.
"""
if traj_idx is None:
traj_idx = self._random_state.choice(self._traj_indices)

Expand All @@ -139,12 +139,12 @@ def __init__(
):
"""Initializes the walking trajectory loader.
Args:
path: Path to hdf5 dataset file with reference rajectories.
traj_indices: List of trajectory indices to use, e.g. for train/test
splitting etc. If None, use all available trajectories.
random_state: Random state for reproducibility.
"""
Args:
path: Path to hdf5 dataset file with reference rajectories.
traj_indices: List of trajectory indices to use, e.g. for train/test
splitting etc. If None, use all available trajectories.
random_state: Random state for reproducibility.
"""

super().__init__(path, traj_indices, random_state=random_state)

Expand All @@ -163,18 +163,18 @@ def get_trajectory(
end_step: Optional[int] = None) -> Dict[str, np.ndarray]:
"""Returns a walking trajectory from the dataset.
Args:
traj_idx: Index of the desired trajectory. If None, a random trajectory
is selected.
start_step: Start index for the trajectory slice. If None, defaults to
the beginning.
end_step: End index for the trajectory slice. If None, defaults to the
end.
Returns:
dict: Dictionary containing qpos, qvel, root2site, and joint_quat of the
trajectory.
"""
Args:
traj_idx: Index of the desired trajectory. If None, a random
trajectory is selected.
start_step: Start index for the trajectory slice. If None, defaults
to the beginning.
end_step: End index for the trajectory slice. If None, defaults to
the end.
Returns:
dict: Dictionary containing qpos, qvel, root2site, and joint_quat
of the trajectory.
"""
if traj_idx is None:
traj_idx = self._random_state.choice(self._traj_indices)

Expand Down
38 changes: 18 additions & 20 deletions flybody/tasks/vision_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,21 @@ def __init__(self,
init_pos_x_range: Optional[tuple] = (-5, -5),
init_pos_y_range: Optional[tuple] = (0, 0),
**kwargs):
"""Task of learning a policy for flying and maneuvering while using a wing
beat pattern generator with controllable wing beat frequency.
Args:
wpg: Wing beat generator.
floor_contacts_fatal: Whether to terminate the episode when the fly
contacts the floor.
eye_camera_fovy: Field of view of the eye camera.
eye_camera_size: Size of the eye camera.
target_height_range: Range of target height.
target_speed_range: Range of target speed.
init_pos_x_range: Range of initial x position.
init_pos_y_range: Range of initial y position.
**kwargs: Arguments passed to the superclass constructor.
"""
"""Task of learning a policy for flying and maneuvering while using a
wing beat pattern generator with controllable wing beat frequency.
Args:
wpg: Wing beat generator.
floor_contacts_fatal: Whether to terminate the episode when the fly
contacts the floor.
eye_camera_fovy: Field of view of the eye camera.
eye_camera_size: Size of the eye camera.
target_height_range: Range of target height.
target_speed_range: Range of target speed.
init_pos_x_range: Range of initial x position.
init_pos_y_range: Range of initial y position.
**kwargs: Arguments passed to the superclass constructor.
"""

super().__init__(add_ghost=False,
num_user_actions=1,
Expand Down Expand Up @@ -110,9 +110,9 @@ def initialize_episode(self, physics: 'mjcf.Physics',
random_state: np.random.RandomState):
"""Randomly selects a starting point and set the walker.
Environment call sequence:
check_termination, get_reward_factors, get_discount
"""
Environment call sequence:
check_termination, get_reward_factors, get_discount
"""
super().initialize_episode(physics, random_state)

init_x = random_state.uniform(*self._init_pos_x_range)
Expand Down Expand Up @@ -252,9 +252,7 @@ def target_speed(self):
@composer.observable
def task_input(self):
"""Task-specific input, framed as an observable."""

def get_task_input(physics: 'mjcf.Physics'):
del physics
return np.hstack([self._target_height, self._target_speed])

return observable.Generic(get_task_input)
21 changes: 11 additions & 10 deletions flybody/tasks/walk_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@ def __init__(self,
**kwargs):
"""This task is a combination of imitation walking and ghost tracking.
Args:
traj_generator: Trajectory generator for generating walking trajectories.
mocap_joint_names: Names of mocap joints.
mocap_site_names: Names of mocap sites.
terminal_com_dist: Episode will be terminated when CoM distance from model
to ghost exceeds terminal_com_dist.
claw_friction: Friction of claw.
trajectory_sites: Whether to render trajectory sites.
**kwargs: Arguments passed to the superclass constructor.
"""
Args:
traj_generator: Trajectory generator for generating walking
trajectories.
mocap_joint_names: Names of mocap joints.
mocap_site_names: Names of mocap sites.
terminal_com_dist: Episode will be terminated when CoM distance
from model to ghost exceeds terminal_com_dist.
claw_friction: Friction of claw.
trajectory_sites: Whether to render trajectory sites.
**kwargs: Arguments passed to the superclass constructor.
"""

super().__init__(add_ghost=True, ghost_visible_legs=False, **kwargs)

Expand Down
10 changes: 4 additions & 6 deletions flybody/tasks/walk_on_ball.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ class WalkOnBall(Walking):
def __init__(self, claw_friction: Optional[float] = 1.0, **kwargs):
"""Task of tethered fly walking on floating ball.
Args:
claw_friction: Friction of claw geoms with floor.
**kwargs: Arguments passed to the superclass constructor.
"""
Args:
claw_friction: Friction of claw geoms with floor.
**kwargs: Arguments passed to the superclass constructor.
"""

super().__init__(add_ghost=False, ghost_visible_legs=False, **kwargs)

Expand Down Expand Up @@ -84,8 +84,6 @@ def check_termination(self, physics: 'mjcf.Physics') -> bool:
@composer.observable
def ball_qvel(self):
"""Simple observable of ball rotational velocity."""

def get_ball_qvel(physics: 'mjcf.Physics'):
return physics.named.data.qvel['ball']

return observable.Generic(get_ball_qvel)

0 comments on commit 4d7352f

Please sign in to comment.