Skip to content

Commit

Permalink
Add option to run walking env in inference/test mode.
Browse files Browse the repository at this point in the history
  • Loading branch information
vaxenburg committed Mar 26, 2024
1 parent 7a95687 commit 4f84c00
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions flybody/fly_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
from flybody.tasks.arenas.ball import BallFloor
from flybody.tasks.arenas.hills import SineBumps, SineTrench
from flybody.tasks.pattern_generators import WingBeatPatternGenerator
from flybody.tasks.trajectory_loaders import (HDF5FlightTrajectoryLoader,
HDF5WalkingTrajectoryLoader)
from flybody.tasks.trajectory_loaders import (
HDF5FlightTrajectoryLoader,
HDF5WalkingTrajectoryLoader,
InferenceWalkingTrajectoryLoader,
)


def flight_imitation(wpg_pattern_path: str,
Expand Down Expand Up @@ -60,13 +63,15 @@ def flight_imitation(wpg_pattern_path: str,
strip_singleton_obs_buffer_dim=True)


def walk_imitation(ref_path: str,
def walk_imitation(ref_path: str | None = None,
random_state: np.random.RandomState | None = None,
terminal_com_dist: float = 0.3):
"""Requires a fruitfly to track a reference walking fly.
Args:
ref_path: Path to reference trajectory dataset.
ref_path: Path to reference trajectory dataset. If not provided, task
will run in inference mode with InferenceWalkingTrajectoryLoader,
without loading actual walking dataset.
random_state: Random state for reproducibility.
terminal_com_dist: Episode will be terminated when distance from model
CoM to ghost CoM exceeds terminal_com_dist. Can be float('inf').
Expand All @@ -77,8 +82,13 @@ def walk_imitation(ref_path: str,
walker = fruitfly.FruitFly
arena = floors.Floor()
# Initialize a walking trajectory loader.
traj_generator = HDF5WalkingTrajectoryLoader(
path=ref_path, random_state=random_state)
if ref_path is not None:
inference_mode = False
traj_generator = HDF5WalkingTrajectoryLoader(
path=ref_path, random_state=random_state)
else:
inference_mode = True
traj_generator = InferenceWalkingTrajectoryLoader()
# Build a task that rewards the agent for tracking a walking ghost.
time_limit = 10.0
task = WalkImitation(walker=walker,
Expand All @@ -87,6 +97,7 @@ def walk_imitation(ref_path: str,
terminal_com_dist=terminal_com_dist,
mocap_joint_names=traj_generator.get_joint_names(),
mocap_site_names=traj_generator.get_site_names(),
inference_mode=inference_mode,
joint_filter=0.01,
future_steps=64,
time_limit=time_limit)
Expand Down

0 comments on commit 4f84c00

Please sign in to comment.