diff --git a/flybody/fly_envs.py b/flybody/fly_envs.py index 1ee347d..c5d5881 100755 --- a/flybody/fly_envs.py +++ b/flybody/fly_envs.py @@ -15,8 +15,11 @@ from flybody.tasks.arenas.ball import BallFloor from flybody.tasks.arenas.hills import SineBumps, SineTrench from flybody.tasks.pattern_generators import WingBeatPatternGenerator -from flybody.tasks.trajectory_loaders import (HDF5FlightTrajectoryLoader, - HDF5WalkingTrajectoryLoader) +from flybody.tasks.trajectory_loaders import ( + HDF5FlightTrajectoryLoader, + HDF5WalkingTrajectoryLoader, + InferenceWalkingTrajectoryLoader, +) def flight_imitation(wpg_pattern_path: str, @@ -60,13 +63,15 @@ def flight_imitation(wpg_pattern_path: str, strip_singleton_obs_buffer_dim=True) -def walk_imitation(ref_path: str, +def walk_imitation(ref_path: str | None = None, random_state: np.random.RandomState | None = None, terminal_com_dist: float = 0.3): """Requires a fruitfly to track a reference walking fly. Args: - ref_path: Path to reference trajectory dataset. + ref_path: Path to reference trajectory dataset. If not provided, task + will run in inference mode with InferenceWalkingTrajectoryLoader, + without loading actual walking dataset. random_state: Random state for reproducibility. terminal_com_dist: Episode will be terminated when distance from model CoM to ghost CoM exceeds terminal_com_dist. Can be float('inf'). @@ -77,8 +82,13 @@ def walk_imitation(ref_path: str, walker = fruitfly.FruitFly arena = floors.Floor() # Initialize a walking trajectory loader. - traj_generator = HDF5WalkingTrajectoryLoader( - path=ref_path, random_state=random_state) + if ref_path is not None: + inference_mode = False + traj_generator = HDF5WalkingTrajectoryLoader( + path=ref_path, random_state=random_state) + else: + inference_mode = True + traj_generator = InferenceWalkingTrajectoryLoader() # Build a task that rewards the agent for tracking a walking ghost. time_limit = 10.0 task = WalkImitation(walker=walker, @@ -87,6 +97,7 @@ def walk_imitation(ref_path: str, terminal_com_dist=terminal_com_dist, mocap_joint_names=traj_generator.get_joint_names(), mocap_site_names=traj_generator.get_site_names(), + inference_mode=inference_mode, joint_filter=0.01, future_steps=64, time_limit=time_limit)