-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
executable file
·123 lines (93 loc) · 3.48 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as nninit
from torchvision import models
from distributions import Categorical, DiagGaussian
# A temporary solution from the master branch.
# https://github.com/pytorch/pytorch/blob/7752fe5d4e50052b3b0bbc9109e599f8157febc0/torch/nn/init.py#L312
# Remove after the next version of PyTorch gets release.
def orthogonal(tensor, gain=1):
if tensor.ndimension() < 2:
raise ValueError("Only tensors with 2 or more dimensions are supported")
rows = tensor.size(0)
cols = tensor[0].numel()
flattened = torch.Tensor(rows, cols).normal_(0, 1)
if rows < cols:
flattened.t_()
# Compute the qr factorization
q, r = torch.qr(flattened)
# Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
d = torch.diag(r, 0)
ph = d.sign()
q *= ph.expand_as(q)
if rows < cols:
q.t_()
tensor.view_as(q).copy_(q)
tensor.mul_(gain)
return tensor
def weights_init(m):
classname = m.__class__.__name__
if classname == "CNNPolicy" or classname == "MLPPolicy":
return
if classname.find('Conv') != -1 or classname.find('Linear') != -1:
orthogonal(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0)
class FFPolicy(nn.Module):
def __init__(self):
super(FFPolicy, self).__init__()
def forward(self, x):
raise NotImplementedError
def act(self, inputs, deterministic=False):
value, x = self(inputs)
action = self.dist.sample(x, deterministic=deterministic)
return value, action
def evaluate_actions(self, inputs, actions):
value, x = self(inputs)
action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, actions)
return value, action_log_probs, dist_entropy
class CNNPolicy(FFPolicy):
def __init__(self, num_inputs, action_space_shape):
super(CNNPolicy, self).__init__()
self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4)
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
self.conv3 = nn.Conv2d(64, 32, 3, stride=1)
self.linear1 = nn.Linear(32 * 7 * 7, 512)
self.critic_linear = nn.Linear(512, 1)
num_outputs = action_space_shape
self.dist = Categorical(512, num_outputs)
self.train()
self.reset_parameters()
def reset_parameters(self):
self.apply(weights_init)
relu_gain = nn.init.calculate_gain('relu')
self.conv1.weight.data.mul_(relu_gain)
self.conv2.weight.data.mul_(relu_gain)
self.conv3.weight.data.mul_(relu_gain)
self.linear1.weight.data.mul_(relu_gain)
if self.dist.__class__.__name__ == "DiagGaussian":
self.dist.fc_mean.weight.data.mul_(0.01)
def forward(self, inputs):
x = self.conv1(inputs)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.conv3(x)
x = F.relu(x)
x = x.view(-1, 32 * 7 * 7)
x = self.linear1(x)
x = F.relu(x)
return self.critic_linear(x), x
def get_probs(self, inputs):
value, x = self(inputs)
x = self.dist(x)
probs = F.softmax(x)
return probs
def weights_init_mlp(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
m.weight.data.normal_(0, 1)
m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
if m.bias is not None:
m.bias.data.fill_(0)