forked from talmolab/Brax-Rodent-Run
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_ppo.py
569 lines (503 loc) · 20.9 KB
/
custom_ppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Proximal policy optimization training.
See: https://arxiv.org/pdf/1707.06347.pdf
"""
import functools
import time
from typing import Callable, Optional, Tuple, Union
from absl import logging
from brax import base
from brax import envs
from brax.training import acting
from brax.training import gradients
from brax.training import pmap
from brax.training import types
from brax.training.acme import running_statistics
from brax.training.acme import specs
import orbax.checkpoint
# from brax.training.agents.ppo import losses as ppo_losses
import custom_losses as ppo_losses
# from brax.training.agents.ppo import networks as ppo_networks
import custom_ppo_networks
from brax.training.types import Params
from brax.training.types import PRNGKey
from brax.v1 import envs as envs_v1
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import orbax
import custom_wrappers
from etils import epath
InferenceParams = Tuple[running_statistics.NestedMeanStd, Params]
Metrics = types.Metrics
_PMAP_AXIS_NAME = "i"
@flax.struct.dataclass
class TrainingState:
"""Contains training state for the learner."""
optimizer_state: optax.OptState
params: ppo_losses.PPONetworkParams
normalizer_params: running_statistics.RunningStatisticsState
env_steps: jnp.ndarray
def _unpmap(v):
return jax.tree_util.tree_map(lambda x: x[0], v)
def _strip_weak_type(tree):
# brax user code is sometimes ambiguous about weak_type. in order to
# avoid extra jit recompilations we strip all weak types from user input
def f(leaf):
leaf = jnp.asarray(leaf)
return leaf.astype(leaf.dtype)
return jax.tree_util.tree_map(f, tree)
def train(
environment: Union[envs_v1.Env, envs.Env],
num_timesteps: int,
episode_length: int,
checkpoint_manager: orbax.CheckpointManager,
action_repeat: int = 1,
num_envs: int = 1,
max_devices_per_host: Optional[int] = None,
num_eval_envs: int = 128,
learning_rate: float = 1e-4,
entropy_cost: float = 1e-4,
kl_weight: float = 1e-3,
discounting: float = 0.9,
seed: int = 0,
unroll_length: int = 10,
batch_size: int = 32,
num_minibatches: int = 16,
num_updates_per_batch: int = 2,
num_evals: int = 1,
num_resets_per_eval: int = 0,
normalize_observations: bool = False,
reward_scaling: float = 1.0,
clipping_epsilon: float = 0.3,
gae_lambda: float = 0.95,
deterministic_eval: bool = False,
network_factory: types.NetworkFactory[
custom_ppo_networks.PPOImitationNetworks
] = custom_ppo_networks.make_intention_ppo_networks,
progress_fn: Callable[[int, Metrics], None] = lambda *args: None,
normalize_advantage: bool = True,
eval_env: Optional[envs.Env] = None,
policy_params_fn: Callable[..., None] = lambda *args: None,
randomization_fn: Optional[
Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]]
] = None,
restore_checkpoint_path: Optional[str] = None,
freeze_mask=None,
):
"""PPO training.
Args:
environment: the environment to train
num_timesteps: the total number of environment steps to use during training
episode_length: the length of an environment episode
action_repeat: the number of timesteps to repeat an action
num_envs: the number of parallel environments to use for rollouts
NOTE: `num_envs` must be divisible by the total number of chips since each
chip gets `num_envs // total_number_of_chips` environments to roll out
NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since
data generated by `num_envs` parallel envs gets used for gradient
updates over `num_minibatches` of data, where each minibatch has a
leading dimension of `batch_size`
max_devices_per_host: maximum number of chips to use per host process
num_eval_envs: the number of envs to use for evluation. Each env will run 1
episode, and all envs run in parallel during eval.
learning_rate: learning rate for ppo loss
entropy_cost: entropy reward for ppo loss, higher values increase entropy
of the policy
discounting: discounting rate
seed: random seed
unroll_length: the number of timesteps to unroll in each environment. The
PPO loss is computed over `unroll_length` timesteps
batch_size: the batch size for each minibatch SGD step
num_minibatches: the number of times to run the SGD step, each with a
different minibatch with leading dimension of `batch_size`
num_updates_per_batch: the number of times to run the gradient update over
all minibatches before doing a new environment rollout
num_evals: the number of evals to run during the entire training run.
Increasing the number of evals increases total training time
num_resets_per_eval: the number of environment resets to run between each
eval. The environment resets occur on the host
normalize_observations: whether to normalize observations
reward_scaling: float scaling for reward
clipping_epsilon: clipping epsilon for PPO loss
gae_lambda: General advantage estimation lambda
deterministic_eval: whether to run the eval with a deterministic policy
network_factory: function that generates networks for policy and value
functions
progress_fn: a user-defined callback function for reporting/plotting metrics
normalize_advantage: whether to normalize advantage estimate
eval_env: an optional environment for eval only, defaults to `environment`
policy_params_fn: a user-defined callback function that can be used for
saving policy checkpoints
randomization_fn: a user-defined callback function that generates randomized
environments
Returns:
Tuple of (make_policy function, network params, metrics)
"""
assert batch_size * num_minibatches % num_envs == 0
xt = time.time()
process_count = jax.process_count()
process_id = jax.process_index()
local_device_count = jax.local_device_count()
local_devices_to_use = local_device_count
if max_devices_per_host:
local_devices_to_use = min(local_devices_to_use, max_devices_per_host)
logging.info(
"Device count: %d, process count: %d (id %d), local device count: %d, "
"devices to be used count: %d",
jax.device_count(),
process_count,
process_id,
local_device_count,
local_devices_to_use,
)
device_count = local_devices_to_use * process_count
# The number of environment steps executed for every training step.
env_step_per_training_step = (
batch_size * unroll_length * num_minibatches * action_repeat
)
num_evals_after_init = max(num_evals - 1, 1)
# The number of training_step calls per training_epoch call.
# equals to ceil(num_timesteps / (num_evals * env_step_per_training_step *
# num_resets_per_eval))
num_training_steps_per_epoch = np.ceil(
num_timesteps
/ (
num_evals_after_init
* env_step_per_training_step
* max(num_resets_per_eval, 1)
)
).astype(int)
key = jax.random.PRNGKey(seed)
global_key, local_key = jax.random.split(key)
del key
local_key = jax.random.fold_in(local_key, process_id)
local_key, key_env, eval_key = jax.random.split(local_key, 3)
# key_networks should be global, so that networks are initialized the same
# way for different processes.
key_policy, key_value, policy_params_fn_key = jax.random.split(global_key, 3)
del global_key
assert num_envs % device_count == 0
v_randomization_fn = None
if randomization_fn is not None:
randomization_batch_size = num_envs // local_device_count
# all devices gets the same randomization rng
randomization_rng = jax.random.split(key_env, randomization_batch_size)
v_randomization_fn = functools.partial(randomization_fn, rng=randomization_rng)
if isinstance(environment, envs.Env):
wrap_for_training = custom_wrappers.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training
env = wrap_for_training(
environment,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
reset_fn = jax.jit(jax.vmap(env.reset))
key_envs = jax.random.split(key_env, num_envs // process_count)
key_envs = jnp.reshape(key_envs, (local_devices_to_use, -1) + key_envs.shape[1:])
env_state = reset_fn(key_envs)
normalize = lambda x, y: x
if normalize_observations:
normalize = running_statistics.normalize
ppo_network = network_factory(
env_state.obs.shape[-1],
int(_unpmap(env_state.info["reference_obs_size"])[0]),
env.action_size,
preprocess_observations_fn=normalize,
)
make_policy = custom_ppo_networks.make_inference_fn(ppo_network)
if freeze_mask is not None:
optimizer = optax.multi_transform(
{
"encoder": optax.adam(learning_rate=learning_rate),
"decoder": optax.set_to_zero(),
},
freeze_mask,
)
logging.info("Freezing layers")
else:
optimizer = optax.adam(learning_rate=learning_rate)
loss_fn = functools.partial(
ppo_losses.compute_ppo_loss,
ppo_network=ppo_network,
entropy_cost=entropy_cost,
kl_weight=kl_weight,
discounting=discounting,
reward_scaling=reward_scaling,
gae_lambda=gae_lambda,
clipping_epsilon=clipping_epsilon,
normalize_advantage=normalize_advantage,
)
gradient_update_fn = gradients.gradient_update_fn(
loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True
)
def minibatch_step(
carry,
data: types.Transition,
normalizer_params: running_statistics.RunningStatisticsState,
):
optimizer_state, params, key = carry
key, key_loss = jax.random.split(key)
(_, metrics), params, optimizer_state = gradient_update_fn(
params, normalizer_params, data, key_loss, optimizer_state=optimizer_state
)
return (optimizer_state, params, key), metrics
def sgd_step(
carry,
unused_t,
data: types.Transition,
normalizer_params: running_statistics.RunningStatisticsState,
):
optimizer_state, params, key = carry
key, key_perm, key_grad = jax.random.split(key, 3)
def convert_data(x: jnp.ndarray):
x = jax.random.permutation(key_perm, x)
x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:])
return x
shuffled_data = jax.tree_util.tree_map(convert_data, data)
(optimizer_state, params, _), metrics = jax.lax.scan(
functools.partial(minibatch_step, normalizer_params=normalizer_params),
(optimizer_state, params, key_grad),
shuffled_data,
length=num_minibatches,
)
return (optimizer_state, params, key), metrics
def training_step(
carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t
) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]:
training_state, state, key = carry
key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)
policy = make_policy(
(training_state.normalizer_params, training_state.params.policy)
)
def f(carry, unused_t):
current_state, current_key = carry
current_key, next_key = jax.random.split(current_key)
next_state, data = acting.generate_unroll(
env,
current_state,
policy,
current_key,
unroll_length,
extra_fields=("truncation",),
)
return (next_state, next_key), data
(state, _), data = jax.lax.scan(
f,
(state, key_generate_unroll),
(),
length=batch_size * num_minibatches // num_envs,
)
# Have leading dimensions (batch_size * num_minibatches, unroll_length)
data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data)
data = jax.tree_util.tree_map(
lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), data
)
assert data.discount.shape[1:] == (unroll_length,)
# Update normalization params and normalize observations.
normalizer_params = running_statistics.update(
training_state.normalizer_params,
data.observation,
pmap_axis_name=_PMAP_AXIS_NAME,
)
(optimizer_state, params, _), metrics = jax.lax.scan(
functools.partial(sgd_step, data=data, normalizer_params=normalizer_params),
(training_state.optimizer_state, training_state.params, key_sgd),
(),
length=num_updates_per_batch,
)
new_training_state = TrainingState(
optimizer_state=optimizer_state,
params=params,
normalizer_params=normalizer_params,
env_steps=training_state.env_steps + env_step_per_training_step,
)
return (new_training_state, state, new_key), metrics
def training_epoch(
training_state: TrainingState, state: envs.State, key: PRNGKey
) -> Tuple[TrainingState, envs.State, Metrics]:
(training_state, state, _), loss_metrics = jax.lax.scan(
training_step,
(training_state, state, key),
(),
length=num_training_steps_per_epoch,
)
loss_metrics = jax.tree_util.tree_map(jnp.mean, loss_metrics)
return training_state, state, loss_metrics
training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME)
# Note that this is NOT a pure jittable method.
def training_epoch_with_timing(
training_state: TrainingState, env_state: envs.State, key: PRNGKey
) -> Tuple[TrainingState, envs.State, Metrics]:
nonlocal training_walltime
t = time.time()
training_state, env_state = _strip_weak_type((training_state, env_state))
result = training_epoch(training_state, env_state, key)
training_state, env_state, metrics = _strip_weak_type(result)
metrics = jax.tree_util.tree_map(jnp.mean, metrics)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics)
epoch_training_time = time.time() - t
training_walltime += epoch_training_time
sps = (
num_training_steps_per_epoch
* env_step_per_training_step
* max(num_resets_per_eval, 1)
) / epoch_training_time
metrics = {
"training/sps": sps,
"training/walltime": training_walltime,
**{f"training/{name}": value for name, value in metrics.items()},
}
return (
training_state,
env_state,
metrics,
) # pytype: disable=bad-return-type # py311-upgrade
init_params = ppo_losses.PPONetworkParams(
policy=ppo_network.policy_network.init(key_policy),
value=ppo_network.value_network.init(key_value),
)
training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray
optimizer_state=optimizer.init(
init_params
), # pytype: disable=wrong-arg-types # numpy-scalars
params=init_params,
normalizer_params=running_statistics.init_state(
specs.Array(env_state.obs.shape[-1:], jnp.dtype("float32"))
),
env_steps=0,
)
# Load from checkpoint
if (
restore_checkpoint_path is not None
and epath.Path(restore_checkpoint_path).exists()
):
logging.info("restoring from checkpoint %s", restore_checkpoint_path)
# env_steps = int(epath.Path(restore_checkpoint_path).stem)
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
target = training_state.normalizer_params, init_params, training_state.env_steps
(normalizer_params, load_params, env_steps) = orbax_checkpointer.restore(
restore_checkpoint_path,
item=target,
restore_args=flax.training.orbax_utils.restore_args_from_target(
target, mesh=None
),
)
if freeze_mask is not None:
load_params.policy["params"]["encoder"] = init_params.policy["params"][
"encoder"
]
init_params = init_params.replace(
policy=load_params.policy, value=load_params.value
)
else:
init_params = init_params.replace(
policy=load_params.policy, value=load_params.value
)
training_state = (
TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray
optimizer_state=optimizer.init(
init_params
), # pytype: disable=wrong-arg-types # numpy-scalars
params=init_params,
normalizer_params=normalizer_params,
env_steps=env_steps,
)
)
training_state = jax.device_put_replicated(
training_state, jax.local_devices()[:local_devices_to_use]
)
if not eval_env:
eval_env = environment
if randomization_fn is not None:
v_randomization_fn = functools.partial(
randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)
)
eval_env = wrap_for_training(
eval_env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
evaluator = acting.Evaluator(
eval_env,
functools.partial(make_policy, deterministic=deterministic_eval),
num_eval_envs=num_eval_envs,
episode_length=episode_length,
action_repeat=action_repeat,
key=eval_key,
)
# Run initial eval
metrics = {}
if process_id == 0 and num_evals > 1:
metrics = evaluator.run_evaluation(
_unpmap((training_state.normalizer_params, training_state.params.policy)),
training_metrics={},
)
logging.info(metrics)
progress_fn(0, metrics)
training_metrics = {}
training_walltime = 0
current_step = 0
for it in range(num_evals_after_init):
logging.info("starting iteration %s %s", it, time.time() - xt)
for _ in range(max(num_resets_per_eval, 1)):
# optimization
epoch_key, local_key = jax.random.split(local_key)
epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
(training_state, env_state, training_metrics) = training_epoch_with_timing(
training_state, env_state, epoch_keys
)
current_step = int(_unpmap(training_state.env_steps))
key_envs = jax.vmap(
lambda x, s: jax.random.split(x[0], s), in_axes=(0, None)
)(key_envs, key_envs.shape[1])
# TODO: move extra reset logic to the AutoResetWrapper.
env_state = reset_fn(key_envs) if num_resets_per_eval > 0 else env_state
if process_id == 0:
# Run evals.
metrics = evaluator.run_evaluation(
_unpmap(
(training_state.normalizer_params, training_state.params.policy)
),
training_metrics,
)
logging.info(metrics)
progress_fn(current_step, metrics)
params = _unpmap(
(
training_state.normalizer_params,
training_state.params,
training_state.env_steps,
)
)
# Save checkpoint
save_args = flax.training.orbax_utils.save_args_from_target(params)
checkpoint_manager.save(it, params, save_kwargs={"save_args": save_args})
_, policy_params_fn_key = jax.random.split(policy_params_fn_key)
policy_params_fn(current_step, make_policy, params, policy_params_fn_key)
total_steps = current_step
assert total_steps >= num_timesteps
# If there was no mistakes the training_state should still be identical on all
# devices.
pmap.assert_is_replicated(training_state)
params = _unpmap((training_state.normalizer_params, training_state.params.policy))
logging.info("total steps: %s", total_steps)
pmap.synchronize_hosts()
return (make_policy, params, metrics)