diff --git a/bpartis/models.py b/bpartis/models.py index 113b6b7..0382806 100644 --- a/bpartis/models.py +++ b/bpartis/models.py @@ -191,10 +191,10 @@ def init_output(self, n_sigma=1): print('initialize last layer with size: ', output_conv.weight.size()) - output_conv.weight[:, 0:2, :, :].fill_(0) + output_conv.weight[0:2, :, :, :].fill_(0) output_conv.bias[0:2].fill_(0) - output_conv.weight[:, 2:2+n_sigma, :, :].fill_(0) + output_conv.weight[2:2+n_sigma, :, :, :].fill_(0) output_conv.bias[2:2+n_sigma].fill_(1) def forward(self, input, only_encode=False):