forked from TuragaLab/flybody
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfly_envs.py
executable file
·259 lines (229 loc) · 8.31 KB
/
fly_envs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""Create examples of flight and walking task environments for fruitfly."""
from typing import Callable
import numpy as np
from dm_control import mujoco
from dm_control import composer
from dm_control.locomotion.arenas import floors
from vnl_ray.fruitfly import fruitfly
from vnl_ray.tasks.flight_imitation import FlightImitationWBPG
from vnl_ray.tasks.walk_imitation import WalkImitation
from vnl_ray.tasks.walk_on_ball import WalkOnBall
from vnl_ray.tasks.vision_flight import VisionFlightImitationWBPG
from vnl_ray.tasks.template_task import TemplateTask
from vnl_ray.tasks.arenas.ball import BallFloor
from vnl_ray.tasks.arenas.hills import SineBumps, SineTrench
from vnl_ray.tasks.pattern_generators import WingBeatPatternGenerator
from vnl_ray.tasks.trajectory_loaders import (
HDF5FlightTrajectoryLoader,
HDF5WalkingTrajectoryLoader,
InferenceWalkingTrajectoryLoader,
)
def flight_imitation(
wpg_pattern_path: str,
ref_path: str,
random_state: np.random.RandomState | None = None,
terminal_com_dist: float = 2.0,
):
"""Requires a fruitfly to track a flying reference.
Args:
wpg_pattern_path: Path to baseline wing beat pattern for WPG.
ref_path: Path to reference trajectory dataset.
random_state: Random state for reproducibility.
terminal_com_dist: Episode will be terminated when distance from model
CoM to ghost CoM exceeds terminal_com_dist. Can be float('inf').
Returns:
Environment for flight tracking task.
"""
# Build a fruitfly walker and arena.
walker = fruitfly.FruitFly
arena = floors.Floor()
# Initialize wing pattern generator and flight trajectory loader.
wbpg = WingBeatPatternGenerator(base_pattern_path=wpg_pattern_path)
traj_generator = HDF5FlightTrajectoryLoader(path=ref_path, random_state=random_state)
# Build the task.
time_limit = 0.6
task = FlightImitationWBPG(
walker=walker,
arena=arena,
wbpg=wbpg,
traj_generator=traj_generator,
terminal_com_dist=terminal_com_dist,
initialize_qvel=True,
time_limit=time_limit,
joint_filter=0.0,
future_steps=5,
)
return composer.Environment(
time_limit=time_limit,
task=task,
random_state=random_state,
strip_singleton_obs_buffer_dim=True,
)
def walk_imitation(
ref_path: str | None = None,
random_state: np.random.RandomState | None = None,
terminal_com_dist: float = 0.3,
):
"""Requires a fruitfly to track a reference walking fly.
Args:
ref_path: Path to reference trajectory dataset. If not provided, task
will run in inference mode with InferenceWalkingTrajectoryLoader,
without loading actual walking dataset.
random_state: Random state for reproducibility.
terminal_com_dist: Episode will be terminated when distance from model
CoM to ghost CoM exceeds terminal_com_dist. Can be float('inf').
Returns:
Environment for walking tracking task.
"""
# Build a fruitfly walker and arena.
walker = fruitfly.FruitFly
arena = floors.Floor()
# Initialize a walking trajectory loader.
if ref_path is not None:
inference_mode = False
traj_generator = HDF5WalkingTrajectoryLoader(path=ref_path, random_state=random_state)
else:
inference_mode = True
traj_generator = InferenceWalkingTrajectoryLoader()
# Build a task that rewards the agent for tracking a walking ghost.
time_limit = 10.0
task = WalkImitation(
walker=walker,
arena=arena,
traj_generator=traj_generator,
terminal_com_dist=terminal_com_dist,
mocap_joint_names=traj_generator.get_joint_names(),
mocap_site_names=traj_generator.get_site_names(),
inference_mode=inference_mode,
joint_filter=0.01,
future_steps=64,
time_limit=time_limit,
)
return composer.Environment(
time_limit=time_limit,
task=task,
random_state=random_state,
strip_singleton_obs_buffer_dim=True,
)
def walk_on_ball(random_state: np.random.RandomState | None = None):
"""Requires a tethered fruitfly to walk on a floating ball.
Args:
random_state: Random state for reproducibility.
Returns:
Environment for fly walking on ball.
"""
# Build a fruitfly walker and arena.
walker = fruitfly.FruitFly
arena = BallFloor(
ball_pos=(-0.05, 0, -0.419),
ball_radius=0.454,
ball_density=0.0025,
skybox=False,
)
# Build a task that rewards the agent for tracking a walking ghost.
time_limit = 2.0
task = WalkOnBall(
walker=walker,
arena=arena,
joint_filter=0.01,
adhesion_filter=0.007,
time_limit=time_limit,
)
return composer.Environment(
time_limit=time_limit,
task=task,
random_state=random_state,
strip_singleton_obs_buffer_dim=True,
)
def vision_guided_flight(
wpg_pattern_path: str,
bumps_or_trench: str = "bumps",
random_state: np.random.RandomState | None = None,
**kwargs_arena
):
"""Vision-guided flight tasks: 'bumps' and 'trench'.
Args:
wpg_pattern_path: Path to baseline wing beat pattern for WPG.
bumps_or_trench: Whether to create 'bumps' or 'trench' vision task.
random_state: Random state for reproducibility.
kwargs_arena: kwargs to be passed on to arena.
Returns:
Environment for vision-guided flight task.
"""
if bumps_or_trench == "bumps":
arena = SineBumps
elif bumps_or_trench == "trench":
arena = SineTrench
else:
raise ValueError("Only 'bumps' and 'trench' terrains are supported.")
# Build fruitfly walker and arena.
walker = fruitfly.FruitFly
arena = arena(**kwargs_arena)
# Initialize a wing beat pattern generator.
wbpg = WingBeatPatternGenerator(base_pattern_path=wpg_pattern_path)
# Build task.
time_limit = 0.4
task = VisionFlightImitationWBPG(
walker=walker,
arena=arena,
wbpg=wbpg,
time_limit=time_limit,
joint_filter=0.0,
floor_contacts=True,
floor_contacts_fatal=True,
)
return composer.Environment(
time_limit=time_limit,
task=task,
random_state=random_state,
strip_singleton_obs_buffer_dim=True,
)
def template_task(
random_state: np.random.RandomState | None = None,
joint_filter: float = 0.01,
adhesion_filter: float = 0.007,
time_limit: float = 1.0,
mjcb_control: Callable | None = None,
observables_options: dict | None = None,
action_corruptor: Callable | None = None,
):
"""An empty no-op walking task for testing.
Args:
random_state: Random state for reproducibility.
joint_filter: Timescale of filter for joint actuators. 0: disabled.
adhesion_filter: Timescale of filter for adhesion actuators. 0: disabled.
time_limit: Episode time limit.
mjcb_control: Optional MuJoCo control callback, a callable with
arguments (model, data). For more information, see
https://mujoco.readthedocs.io/en/stable/APIreference/APIglobals.html#mjcb-control
observables_options (optional): A dict of dicts of configuration options
keyed on observable names, or a dict of configuration options, which
will propagate those options to all observables.
action_corruptor (optional): A callable which takes an action as an
argument, modifies it, and returns it. An example use case for
this is to add random noise to the action.
Returns:
Template walking environment.
"""
# Build a fruitfly walker and arena.
walker = fruitfly.FruitFly
arena = floors.Floor()
# Build a no-op task.
task = TemplateTask(
walker=walker,
arena=arena,
joint_filter=joint_filter,
adhesion_filter=adhesion_filter,
observables_options=observables_options,
mjcb_control=mjcb_control,
action_corruptor=action_corruptor,
time_limit=time_limit,
)
# Reset control callback, if any.
mujoco.set_mjcb_control(None)
return composer.Environment(
time_limit=time_limit,
task=task,
random_state=random_state,
strip_singleton_obs_buffer_dim=True,
)