forked from talmolab/Brax-Rodent-Run
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_ppo_networks.py
144 lines (120 loc) · 4.98 KB
/
custom_ppo_networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Custom network definitions.
This is needed because we need to route the observations
to proper places in the network in the case of the VAE (CoMic, Hasenclever 2020)
"""
import dataclasses
from typing import Any, Callable, Sequence, Tuple
import warnings
from brax.training import networks
from brax.training import types
from brax.training import distribution
from brax.training.types import PRNGKey
import jax
import jax.numpy as jnp
from jax import random
import flax
from flax import linen as nn
import custom_networks
@flax.struct.dataclass
class PPOImitationNetworks:
policy_network: custom_networks.IntentionNetwork
value_network: networks.FeedForwardNetwork
parametric_action_distribution: distribution.ParametricDistribution
def make_inference_fn(ppo_networks: PPOImitationNetworks):
"""Creates params and inference function for the PPO agent."""
def make_policy(
params: types.PolicyParams, deterministic: bool = False
) -> types.Policy:
policy_network = ppo_networks.policy_network
parametric_action_distribution = ppo_networks.parametric_action_distribution
def policy(
observations: types.Observation,
key_sample: PRNGKey,
) -> Tuple[types.Action, types.Extra]:
key_sample, key_network = jax.random.split(key_sample)
logits, _ = policy_network.apply(*params, observations, key_network)
if deterministic:
return ppo_networks.parametric_action_distribution.mode(logits), {}
# Sample action based on logits (mean and logvar)
raw_actions = parametric_action_distribution.sample_no_postprocessing(
logits, key_sample
)
log_prob = parametric_action_distribution.log_prob(logits, raw_actions)
postprocessed_actions = parametric_action_distribution.postprocess(
raw_actions
)
return postprocessed_actions, {
"log_prob": log_prob,
"raw_action": raw_actions,
"logits": logits,
}
return policy
return make_policy
# intention policy
def make_intention_ppo_networks(
observation_size: int,
reference_obs_size: int,
action_size: int,
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
intention_latent_size: int = 60,
encoder_hidden_layer_sizes: Sequence[int] = (1024,) * 2,
decoder_hidden_layer_sizes: Sequence[int] = (1024,) * 2,
value_hidden_layer_sizes: Sequence[int] = (1024,) * 2,
) -> PPOImitationNetworks:
"""Make Imitation PPO networks with preprocessor."""
parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size
)
policy_network = custom_networks.make_intention_policy(
parametric_action_distribution.param_size,
latent_size=intention_latent_size,
total_obs_size=observation_size,
reference_obs_size=reference_obs_size,
preprocess_observations_fn=preprocess_observations_fn,
encoder_hidden_layer_sizes=encoder_hidden_layer_sizes,
decoder_hidden_layer_sizes=decoder_hidden_layer_sizes,
)
value_network = networks.make_value_network(
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=value_hidden_layer_sizes,
)
return PPOImitationNetworks(
policy_network=policy_network,
value_network=value_network,
parametric_action_distribution=parametric_action_distribution,
)
def make_encoderdecoder_ppo_networks(
observation_size: int,
reference_obs_size: int,
action_size: int,
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
intention_latent_size: int = 60,
encoder_hidden_layer_sizes: Sequence[int] = (1024,) * 2,
decoder_hidden_layer_sizes: Sequence[int] = (1024,) * 2,
value_hidden_layer_sizes: Sequence[int] = (1024,) * 2,
) -> PPOImitationNetworks:
"""Make Imitation PPO networks with preprocessor."""
parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size
)
policy_network = custom_networks.make_encoderdecoder_policy(
parametric_action_distribution.param_size,
latent_size=intention_latent_size,
total_obs_size=observation_size,
reference_obs_size=reference_obs_size,
preprocess_observations_fn=preprocess_observations_fn,
encoder_hidden_layer_sizes=encoder_hidden_layer_sizes,
decoder_hidden_layer_sizes=decoder_hidden_layer_sizes,
)
value_network = networks.make_value_network(
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=value_hidden_layer_sizes,
)
return PPOImitationNetworks(
policy_network=policy_network,
value_network=value_network,
parametric_action_distribution=parametric_action_distribution,
)