Skip to content

Commit

Permalink
fix/tests(convblock): fix skip connections with same skip_dest, stric…
Browse files Browse the repository at this point in the history
…ter tests
  • Loading branch information
nkemnitz committed Nov 10, 2023
1 parent f5324da commit 6285b2c
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 20 deletions.
44 changes: 35 additions & 9 deletions tests/unit/convnet/architecture/test_convblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,38 @@ def not_test_forward_naive(mocker):
)


def test_forward_skips(mocker):
mocker.patch("torch.nn.Conv2d.forward", lambda _, x: x)
block = convnet.architecture.ConvBlock(
kernel_sizes=[2, 2], num_channels=[1, 2, 3, 4, 5], skips={"0": 3, "1": 3, "2": 4}
)
result = block(torch.ones([1, 1, 1, 1]))
assert_array_equal(
result.cpu().detach().numpy(), 6 * torch.ones([1, 1, 1, 1]).cpu().detach().numpy()
)
@pytest.mark.parametrize(
"skips, expected",
[
# fmt: off
[
None,
1 * 2 * 2 * 2 * 2 # 4 convolutions
],
[
{"0": 2, "1": 2, "2": 3},
(((1 * 2 * 2) # first 2 convolutions
+ 1 # Skip content "0": 2
+ 1 * 2 # Skip content "1": 2
) * 2 # third convolution
+ (1 * 2 * 2) # Skip content "2": 3
+ 1 # also includes the skip content "0": 2
+ 1 * 2 # also includes the skip content "1": 2
) * 2 # fourth convolution
],
[
{"0": 3, "1": 3, "2": 4},
((1 * 2 * 2 * 2) # first 3 convolutions
+ 1 # Skip content "0": 3
+ 1 * 2 # Skip content "1": 3
) * 2 # fourth convolution
+ (1 * 2 * 2) # Skip content "2": 4
],
# fmt: on
],
)
def test_forward_skips(mocker, skips, expected):
mocker.patch("torch.nn.Conv2d.forward", lambda _, x: 2 * x)
block = convnet.architecture.ConvBlock(kernel_sizes=[1], num_channels=[1] * 5, skips=skips)
result = block(torch.ones((1, 1, 1, 1)))
assert_array_equal(result.detach().numpy(), torch.full((1, 1, 1, 1), expected).numpy())
47 changes: 37 additions & 10 deletions tests/unit/convnet/architecture/test_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,19 +185,46 @@ def not_test_forward_naive(mocker):
)


def test_forward_skips(mocker):
mocker.patch("torch.nn.Conv2d.forward", lambda _, x: x)
@pytest.mark.parametrize(
"skips, expected",
[
# fmt: off
[
None,
(1 * 2 * 2 # L0 convblock, left
* 2 * 2 * 2 # L1 convblock
+ 1 * 2 * 2 # UNet skip connection
) * 2 * 2 # L0 convblock, right
],
[
{"0": 1},
((((1 * 2 + 1) * 2 # L0 convblock, left --> 6
* 2 + 6) * 2 * 2 # L1 convblock --> 72
+ 6 # UNet skip connection --> 78
) * 2 + 78) * 2 # L0 convblock, right --> 468
],
[
{"0": 2},
((((1 * 2 * 2 + 1) # L0 convblock, left --> 5
* 2 * 2 + 5) * 2 # L1 convblock --> 50
+ 5 # UNet skip connection --> 55
) * 2 * 2 + 55) # L0 convblock, right --> 275
],
# fmt: on
],
)
def test_forward_skips(mocker, skips, expected):
mocker.patch("torch.nn.Conv2d.forward", lambda _, x: 2 * x)
unet = convnet.architecture.UNet(
kernel_sizes=[3, 3],
kernel_sizes=[1],
list_num_channels=[[1, 1, 1], [1, 1, 1, 1], [1, 1, 1]],
downsample=partial(torch.nn.AvgPool2d, kernel_size=2),
upsample=partial(torch.nn.Upsample, scale_factor=2),
skips={"0": 2},
)
result = unet.forward(torch.ones([1, 1, 2, 2]))
assert_array_equal(
result.cpu().detach().numpy(), 12 * torch.ones([1, 1, 2, 2]).cpu().detach().numpy()
downsample=partial(torch.nn.AvgPool2d, kernel_size=1),
upsample=partial(torch.nn.Upsample, scale_factor=1),
skips=skips,
unet_skip_mode="sum",
)
result = unet.forward(torch.ones((1, 1, 1, 1)))
assert_array_equal(result.detach().numpy(), torch.full((1, 1, 1, 1), expected).numpy())


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion zetta_utils/convnet/architecture/convblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:
skip_dest = self.skips[str(conv_count)]
if skip_dest in skip_data_for:
size = _get_size(result)
skip_data_for[skip_dest] += crop_center(skip_data_for[skip_dest], size)
skip_data_for[skip_dest] += crop_center(result, size)
else:
skip_data_for[skip_dest] = result

Expand Down

0 comments on commit 6285b2c

Please sign in to comment.