diff --git a/model2_cpx.py b/model2_cpx.py index 328a6d8..c6eb8f9 100644 --- a/model2_cpx.py +++ b/model2_cpx.py @@ -105,7 +105,7 @@ def __init__( super(ResBlock, self).__init__() m = [] - for i in range(2): + for i in range(3): m.append(ComplexConv2d(n_feats, n_feats, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feats)) @@ -187,7 +187,7 @@ def forward(self, x): ################################################# y = x - for i in range(2): + for i in range(3): x = y new_k = torch.complex(x[:, 0, :, :], x[:, 1, :, :]) @@ -668,7 +668,7 @@ def forward(self, x): ################################################# y = x - for i in range(2): + for i in range(3): x = y new_k = torch.complex(x[:, 0, :, :], x[:, 1, :, :])