forked from thu-ml/tianshou
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathatari_network.py
322 lines (281 loc) · 10.5 KB
/
atari_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
from collections.abc import Callable, Sequence
from typing import Any
import numpy as np
import torch
from torch import nn
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import (
TDevice,
)
from tianshou.highlevel.module.intermediate import (
IntermediateModule,
IntermediateModuleFactory,
)
from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.common import NetBase
from tianshou.utils.net.discrete import Actor, NoisyLinear
def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module:
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
class ScaledObsInputModule(torch.nn.Module):
def __init__(self, module: NetBase, denom: float = 255.0) -> None:
super().__init__()
self.module = module
self.denom = denom
# This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim)
self.output_dim = module.output_dim
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any | None = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
if info is None:
info = {}
return self.module.forward(obs / self.denom, state, info)
def scale_obs(module: NetBase, denom: float = 255.0) -> ScaledObsInputModule:
return ScaledObsInputModule(module, denom=denom)
class DQN(NetBase[Any]):
"""Reference: Human-level control through deep reinforcement learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int] | int,
device: str | int | torch.device = "cpu",
features_only: bool = False,
output_dim_added_layer: int | None = None,
layer_init: Callable[[nn.Module], nn.Module] = lambda x: x,
) -> None:
# TODO: Add docstring
if not features_only and output_dim_added_layer is not None:
raise ValueError(
"Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.",
)
super().__init__()
self.device = device
self.net = nn.Sequential(
layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)),
nn.ReLU(inplace=True),
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
nn.ReLU(inplace=True),
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
nn.ReLU(inplace=True),
nn.Flatten(),
)
with torch.no_grad():
base_cnn_output_dim = int(np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:]))
if not features_only:
action_dim = int(np.prod(action_shape))
self.net = nn.Sequential(
self.net,
layer_init(nn.Linear(base_cnn_output_dim, 512)),
nn.ReLU(inplace=True),
layer_init(nn.Linear(512, action_dim)),
)
self.output_dim = action_dim
elif output_dim_added_layer is not None:
self.net = nn.Sequential(
self.net,
layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)),
nn.ReLU(inplace=True),
)
self.output_dim = output_dim_added_layer
else:
self.output_dim = base_cnn_output_dim
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any | None = None,
info: dict[str, Any] | None = None,
**kwargs: Any,
) -> tuple[torch.Tensor, Any]:
r"""Mapping: s -> Q(s, \*)."""
obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
return self.net(obs), state
class C51(DQN):
"""Reference: A distributional perspective on reinforcement learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
num_atoms: int = 51,
device: str | int | torch.device = "cpu",
) -> None:
self.action_num = int(np.prod(action_shape))
super().__init__(c, h, w, [self.action_num * num_atoms], device)
self.num_atoms = num_atoms
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any | None = None,
info: dict[str, Any] | None = None,
**kwargs: Any,
) -> tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
obs, state = super().forward(obs)
obs = obs.view(-1, self.num_atoms).softmax(dim=-1)
obs = obs.view(-1, self.action_num, self.num_atoms)
return obs, state
class Rainbow(DQN):
"""Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
num_atoms: int = 51,
noisy_std: float = 0.5,
device: str | int | torch.device = "cpu",
is_dueling: bool = True,
is_noisy: bool = True,
) -> None:
super().__init__(c, h, w, action_shape, device, features_only=True)
self.action_num = int(np.prod(action_shape))
self.num_atoms = num_atoms
def linear(x: int, y: int) -> NoisyLinear | nn.Linear:
if is_noisy:
return NoisyLinear(x, y, noisy_std)
return nn.Linear(x, y)
self.Q = nn.Sequential(
linear(self.output_dim, 512),
nn.ReLU(inplace=True),
linear(512, self.action_num * self.num_atoms),
)
self._is_dueling = is_dueling
if self._is_dueling:
self.V = nn.Sequential(
linear(self.output_dim, 512),
nn.ReLU(inplace=True),
linear(512, self.num_atoms),
)
self.output_dim = self.action_num * self.num_atoms
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any | None = None,
info: dict[str, Any] | None = None,
**kwargs: Any,
) -> tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
obs, state = super().forward(obs)
q = self.Q(obs)
q = q.view(-1, self.action_num, self.num_atoms)
if self._is_dueling:
v = self.V(obs)
v = v.view(-1, 1, self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v
else:
logits = q
probs = logits.softmax(dim=2)
return probs, state
class QRDQN(DQN):
"""Reference: Distributional Reinforcement Learning with Quantile Regression.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
*,
c: int,
h: int,
w: int,
action_shape: Sequence[int] | int,
num_quantiles: int = 200,
device: str | int | torch.device = "cpu",
) -> None:
self.action_num = int(np.prod(action_shape))
super().__init__(c, h, w, [self.action_num * num_quantiles], device)
self.num_quantiles = num_quantiles
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any | None = None,
info: dict[str, Any] | None = None,
**kwargs: Any,
) -> tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
obs, state = super().forward(obs)
obs = obs.view(-1, self.action_num, self.num_quantiles)
return obs, state
class ActorFactoryAtariDQN(ActorFactory):
USE_SOFTMAX_OUTPUT = False
def __init__(
self,
scale_obs: bool = True,
features_only: bool = False,
output_dim_added_layer: int | None = None,
) -> None:
self.output_dim_added_layer = output_dim_added_layer
self.scale_obs = scale_obs
self.features_only = features_only
def create_module(self, envs: Environments, device: TDevice) -> Actor:
c, h, w = envs.get_observation_shape() # type: ignore # only right shape is a sequence of length 3
action_shape = envs.get_action_shape()
if isinstance(action_shape, np.int64):
action_shape = int(action_shape)
net: DQN | ScaledObsInputModule
net = DQN(
c=c,
h=h,
w=w,
action_shape=action_shape,
device=device,
features_only=self.features_only,
output_dim_added_layer=self.output_dim_added_layer,
layer_init=layer_init,
)
if self.scale_obs:
net = scale_obs(net)
return Actor(
net,
envs.get_action_shape(),
device=device,
softmax_output=self.USE_SOFTMAX_OUTPUT,
).to(device)
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return DistributionFunctionFactoryCategorical(
is_probs_input=self.USE_SOFTMAX_OUTPUT,
).create_dist_fn(envs)
class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
def __init__(self, features_only: bool = False, net_only: bool = False) -> None:
self.features_only = features_only
self.net_only = net_only
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
obs_shape = envs.get_observation_shape()
if isinstance(obs_shape, int):
obs_shape = [obs_shape]
assert len(obs_shape) == 3
c, h, w = obs_shape
action_shape = envs.get_action_shape()
if isinstance(action_shape, np.int64):
action_shape = int(action_shape)
dqn = DQN(
c=c,
h=h,
w=w,
action_shape=action_shape,
device=device,
features_only=self.features_only,
).to(device)
module = dqn.net if self.net_only else dqn
return IntermediateModule(module, dqn.output_dim)
class IntermediateModuleFactoryAtariDQNFeatures(IntermediateModuleFactoryAtariDQN):
def __init__(self) -> None:
super().__init__(features_only=True, net_only=True)