forked from talmolab/Brax-Rodent-Run
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_networks.py
167 lines (130 loc) · 5.24 KB
/
custom_networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import dataclasses
from typing import Any, Callable, Sequence, Tuple
import warnings
from brax.training import networks
from brax.training import types
from brax.training import distribution
from brax.training.networks import MLP
import brax.training.agents.ppo.networks as ppo_networks
from brax.training.types import PRNGKey
import jax
import jax.numpy as jnp
from jax import random
import flax
from flax import linen as nn
class VariationalLayer(nn.Module):
latent_size: int
@nn.compact
def __call__(self, x):
mean_x = nn.Dense(self.latent_size, name="mean")(x)
logvar_x = nn.Dense(self.latent_size, name="logvar")(x)
return mean_x, logvar_x
class MLP(nn.Module):
"""MLP with Layer Norm"""
layer_sizes: Sequence[int]
activation: networks.ActivationFn = nn.relu
kernel_init: networks.Initializer = jax.nn.initializers.lecun_uniform()
activate_final: bool = False
bias: bool = True
@nn.compact
def __call__(self, x: jnp.ndarray):
for i, hidden_size in enumerate(self.layer_sizes):
x = nn.Dense(
hidden_size,
name=f"hidden_{i}",
kernel_init=self.kernel_init,
use_bias=self.bias,
)(x)
if i != len(self.layer_sizes) - 1 or self.activate_final:
x = self.activation(x)
x = nn.LayerNorm()(x)
return x
def reparameterize(rng, mean, logvar):
std = jnp.exp(0.5 * logvar)
eps = random.normal(rng, logvar.shape)
return mean + eps * std
class IntentionNetwork(nn.Module):
"""Full VAE model, encode -> decode with sampled actions"""
encoder_layers: Sequence[int]
decoder_layers: Sequence[int]
reference_obs_size: int
latents: int = 60
def setup(self):
self.encoder = MLP(layer_sizes=self.encoder_layers, activate_final=True)
self.latent = VariationalLayer(latent_size=self.latents)
self.decoder = MLP(layer_sizes=self.decoder_layers)
def __call__(self, obs, key):
_, encoder_rng = jax.random.split(key)
traj = obs[..., : self.reference_obs_size]
latent_mean, latent_logvar = self.latent(self.encoder(traj))
z = reparameterize(encoder_rng, latent_mean, latent_logvar)
action = self.decoder(
jnp.concatenate([z, obs[..., self.reference_obs_size :]], axis=-1)
)
return action, (latent_mean, latent_logvar)
class EncoderDecoderNetwork(nn.Module):
"""encoder and decoder, no sampling in bottleneck."""
encoder_layers: Sequence[int]
decoder_layers: Sequence[int]
reference_obs_size: int
latents: int = 60
def setup(self):
self.encoder = MLP(layer_sizes=self.encoder_layers, activate_final=True)
self.decoder = MLP(layer_sizes=self.decoder_layers)
def __call__(self, obs, key):
traj = obs[..., : self.reference_obs_size]
z = nn.Dense(self.latents, name="bottleneck")(self.encoder(traj))
action = self.decoder(
jnp.concatenate([z, obs[..., self.reference_obs_size :]], axis=-1)
)
return action, z
def make_intention_policy(
param_size: int,
latent_size: int,
total_obs_size: int,
reference_obs_size: int,
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
encoder_hidden_layer_sizes: Sequence[int] = (1024, 1024),
decoder_hidden_layer_sizes: Sequence[int] = (1024, 1024),
) -> IntentionNetwork:
"""Creates an intention policy network."""
policy_module = IntentionNetwork(
encoder_layers=list(encoder_hidden_layer_sizes),
decoder_layers=list(decoder_hidden_layer_sizes) + [param_size],
reference_obs_size=reference_obs_size,
latents=latent_size,
)
def apply(processor_params, policy_params, obs, key):
obs = preprocess_observations_fn(obs, processor_params)
return policy_module.apply(policy_params, obs=obs, key=key)
dummy_total_obs = jnp.zeros((1, total_obs_size))
dummy_key = jax.random.PRNGKey(0)
return networks.FeedForwardNetwork(
init=lambda key: policy_module.init(key, dummy_total_obs, dummy_key),
apply=apply,
)
def make_encoderdecoder_policy(
param_size: int,
latent_size: int,
total_obs_size: int,
reference_obs_size: int,
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
encoder_hidden_layer_sizes: Sequence[int] = (1024, 1024),
decoder_hidden_layer_sizes: Sequence[int] = (1024, 1024),
) -> IntentionNetwork:
"""Creates an intention policy network."""
policy_module = EncoderDecoderNetwork(
encoder_layers=list(encoder_hidden_layer_sizes),
decoder_layers=list(decoder_hidden_layer_sizes) + [param_size],
reference_obs_size=reference_obs_size,
latents=latent_size,
)
def apply(processor_params, policy_params, obs, key):
obs = preprocess_observations_fn(obs, processor_params)
return policy_module.apply(policy_params, obs=obs, key=key)
dummy_total_obs = jnp.zeros((1, total_obs_size))
dummy_key = jax.random.PRNGKey(0)
return networks.FeedForwardNetwork(
init=lambda key: policy_module.init(key, dummy_total_obs, dummy_key),
apply=apply,
)