Skip to content

Commit

Permalink
Add inference-time flight trajectory loader.
Browse files Browse the repository at this point in the history
vaxenburg committed Dec 4, 2024
1 parent 781e610 commit d679422
Showing 1 changed file with 75 additions and 24 deletions.
99 changes: 75 additions & 24 deletions flybody/tasks/trajectory_loaders.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
"""Reference trajectory loaders for fruit fly imitation tasks."""

from typing import Sequence, Tuple, Optional, Dict
from typing import Sequence
from abc import ABC, abstractmethod

import h5py
import numpy as np

from flybody.tasks.synthetic_trajectories import constant_speed_trajectory
from flybody.tasks.constants import _WALK_CONTROL_TIMESTEP
from flybody.tasks.constants import _FLY_CONTROL_TIMESTEP, _WALK_CONTROL_TIMESTEP


class HDF5TrajectoryLoader(ABC):
"""Base class for loading and serving trajectories from hdf5 datasets."""

def __init__(self,
path: str,
traj_indices: Optional[Sequence[int]] = None,
random_state: Optional[np.random.RandomState] = None):
traj_indices: Sequence[int] | None = None,
random_state: np.random.RandomState | None = None):
"""Initializes the base trajectory loader.
Args:
@@ -57,9 +57,9 @@ def traj_indices(self):

@abstractmethod
def get_trajectory(self,
traj_idx: Optional[int] = None,
start_step: Optional[int] = None,
end_step: Optional[int] = None):
traj_idx: int | None = None,
start_step: int | None = None,
end_step: int | None = None):
"""Returns a trajectory."""
raise NotImplementedError("Subclasses should implement this.")

@@ -70,19 +70,24 @@ class HDF5FlightTrajectoryLoader(HDF5TrajectoryLoader):
def __init__(
self,
path: str,
traj_indices: Optional[Sequence[int]] = None,
random_state: Optional[np.random.RandomState] = None,
traj_indices: Sequence[int] | None = None,
randomize_start_step: bool = True,
random_state: np.random.RandomState | None = None,
):
"""Initializes the flight trajectory loader.
Args:
path: Path to hdf5 dataset file with reference rajectories.
traj_indices: List of trajectory indices to use, e.g. for train/test
splitting etc. If None, use all available trajectories.
randomize_start_step: Whether to select random start point in each
get_trajectory call.
random_state: Random state for reproducibility.
"""
super().__init__(path, traj_indices, random_state=random_state)

self._randomize_start_step = randomize_start_step

self._com_qpos = []
self._com_qvel = []

@@ -92,27 +97,28 @@ def __init__(
key = str(idx).zfill(n_zeros)
self._com_qpos.append(f['trajectories'][key]['com_qpos'][()])
self._com_qvel.append(f['trajectories'][key]['com_qvel'][()])
assert self._com_qpos[-1].shape[0] == self._com_qvel[-1].shape[
0]
assert self._com_qpos[-1].shape[0] == self._com_qvel[-1].shape[0]

def trajectory_len(self, traj_idx: int) -> int:
"""Returns length of trajectory with index traj_idx."""
return len(self._com_qpos[traj_idx])

def get_trajectory(
self,
traj_idx: Optional[int] = None,
start_step: Optional[int] = None,
end_step: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]:
traj_idx: int | None = None,
start_step: int | None = None,
end_step: int | None = None) -> tuple[np.ndarray, np.ndarray]:
"""Returns a flight trajectory from the dataset.
Args:
traj_idx: Index of the desired trajectory. If None, a random
trajectory is selected.
trajectory out of traj_indices is selected.
start_step: Start index for the trajectory slice. If None, defaults
to the beginning.
to the beginning. If attribute randomize_start_step is True,
start_step is ignored.
end_step: End index for the trajectory slice. If None, defaults to
the end.
the end. If attribute randomize_start_step is True,
end_step is ignored.
Returns:
tuple: Two numpy arrays for com_qpos and com_qvel respectively.
@@ -121,8 +127,12 @@ def get_trajectory(
traj_idx = self._random_state.choice(self._traj_indices)

traj_len = len(self._com_qpos[traj_idx])
start_step = 0 if start_step is None else start_step
end_step = traj_len if end_step is None else end_step
if self._randomize_start_step:
start_step = self._random_state.randint(traj_len - 50)
end_step = traj_len
else:
start_step = 0 if start_step is None else start_step
end_step = traj_len if end_step is None else end_step

com_qpos = self._com_qpos[traj_idx][start_step:end_step].copy()
com_qvel = self._com_qvel[traj_idx][start_step:end_step]
@@ -131,14 +141,55 @@ def get_trajectory(
return com_qpos, com_qvel


class InferenceFlightTrajectoryLoader():
"""Simple drop-in inference-time replacement for flight trajectory loader.
This trajectory loader can be used for bypassing loading actual flight
datasets and loading custom trajectories instead, e.g. at inference time.
A simple synthetic flight trajectory is automatically set upon this class
initialization.
To use this class with other custom trajectories, create qpos and qvel for
your custom trajectory and then set this trajectory for loading in the
flight task by calling:
env.task._traj_generator.set_next_trajectory(qpos, qvel)
"""

def __init__(self):
# Initially, set a simple synthetic trajectory, e.g. for quick testing.
qpos, qvel = constant_speed_trajectory(
n_steps=200, speed=20, init_pos=(0, 0, 1),
body_rot_angle_y=-47.5, control_timestep=_FLY_CONTROL_TIMESTEP)
self.set_next_trajectory(qpos, qvel)

def set_next_trajectory(self, com_qpos: np.ndarray, com_qvel: np.ndarray):
"""Set new trajectory to be returned by get_trajectory.
Args:
com_qpos: Center-of-mass trajectory, (time, 7).
com_qvel: Velocity of CoM trajectory, (time, 6).
"""
self._com_qpos = com_qpos.copy()
self._com_qpos[:, :2] -= self._com_qpos[0, :2]
self._com_qvel = com_qvel

def get_trajectory(self, traj_idx: int):
del traj_idx # Unused.
if not hasattr(self, '_com_qpos'):
raise AttributeError(
'Trajectory not set yet. Call set_next_trajectory first.')
return self._com_qpos, self._com_qvel


class HDF5WalkingTrajectoryLoader(HDF5TrajectoryLoader):
"""Loads and serves trajectories from hdf5 walking imitation dataset."""

def __init__(
self,
path: str,
traj_indices: Optional[Sequence[int]] = None,
random_state: Optional[np.random.RandomState] = None,
traj_indices: Sequence[int] | None = None,
random_state: np.random.RandomState | None = None,
):
"""Initializes the walking trajectory loader.
@@ -161,9 +212,9 @@ def trajectory_len(self, traj_idx: int) -> int:

def get_trajectory(
self,
traj_idx: Optional[int] = None,
start_step: Optional[int] = None,
end_step: Optional[int] = None) -> Dict[str, np.ndarray]:
traj_idx: int | None = None,
start_step: int | None = None,
end_step: int | None = None) -> dict[str, np.ndarray]:
"""Returns a walking trajectory from the dataset.
Args:

0 comments on commit d679422

Please sign in to comment.