From 95a31de4142f0fa93cf58e22e69bf54846f25d6b Mon Sep 17 00:00:00 2001 From: Nico Kemnitz Date: Fri, 10 Nov 2023 13:00:15 +0100 Subject: [PATCH] fix/tests(convblock): fix skip connections with same skip_dest, stricter tests --- .../convnet/architecture/test_convblock.py | 44 +++++++++++++---- tests/unit/convnet/architecture/test_unet.py | 47 +++++++++++++++---- zetta_utils/convnet/architecture/convblock.py | 2 +- 3 files changed, 73 insertions(+), 20 deletions(-) diff --git a/tests/unit/convnet/architecture/test_convblock.py b/tests/unit/convnet/architecture/test_convblock.py index 91a37d30a..61a989b9e 100644 --- a/tests/unit/convnet/architecture/test_convblock.py +++ b/tests/unit/convnet/architecture/test_convblock.py @@ -110,12 +110,38 @@ def not_test_forward_naive(mocker): ) -def test_forward_skips(mocker): - mocker.patch("torch.nn.Conv2d.forward", lambda _, x: x) - block = convnet.architecture.ConvBlock( - kernel_sizes=[2, 2], num_channels=[1, 2, 3, 4, 5], skips={"0": 2, "1": 2, "2": 3} - ) - result = block(torch.ones([1, 1, 1, 1])) - assert_array_equal( - result.cpu().detach().numpy(), 6 * torch.ones([1, 1, 1, 1]).cpu().detach().numpy() - ) +@pytest.mark.parametrize( + "skips, expected", + [ + # fmt: off + [ + None, + 1 * 2 * 2 * 2 * 2 # 4 convolutions + ], + [ + {"0": 2, "1": 2, "2": 3}, + (((1 * 2 * 2) # first 2 convolutions + + 1 # Skip content "0": 2 + + 1 * 2 # Skip content "1": 2 + ) * 2 # third convolution + + (1 * 2 * 2) # Skip content "2": 3 + + 1 # also includes the skip content "0": 2 + + 1 * 2 # also includes the skip content "1": 2 + ) * 2 # fourth convolution + ], + [ + {"0": 3, "1": 3, "2": 4}, + ((1 * 2 * 2 * 2) # first 3 convolutions + + 1 # Skip content "0": 3 + + 1 * 2 # Skip content "1": 3 + ) * 2 # fourth convolution + + (1 * 2 * 2) # Skip content "2": 4 + ], + # fmt: on + ], +) +def test_forward_skips(mocker, skips, expected): + mocker.patch("torch.nn.Conv2d.forward", lambda _, x: 2 * x) + block = convnet.architecture.ConvBlock(kernel_sizes=[1], num_channels=[1] * 5, skips=skips) + result = block(torch.ones((1, 1, 1, 1))) + assert_array_equal(result.detach().numpy(), torch.full((1, 1, 1, 1), expected).numpy()) diff --git a/tests/unit/convnet/architecture/test_unet.py b/tests/unit/convnet/architecture/test_unet.py index 84512ae15..82729edaf 100644 --- a/tests/unit/convnet/architecture/test_unet.py +++ b/tests/unit/convnet/architecture/test_unet.py @@ -185,19 +185,46 @@ def not_test_forward_naive(mocker): ) -def test_forward_skips(mocker): - mocker.patch("torch.nn.Conv2d.forward", lambda _, x: x) +@pytest.mark.parametrize( + "skips, expected", + [ + # fmt: off + [ + None, + (1 * 2 * 2 # L0 convblock, left + * 2 * 2 * 2 # L1 convblock + + 1 * 2 * 2 # UNet skip connection + ) * 2 * 2 # L0 convblock, right + ], + [ + {"0": 1}, + ((((1 * 2 + 1) * 2 # L0 convblock, left --> 6 + * 2 + 6) * 2 * 2 # L1 convblock --> 72 + + 6 # UNet skip connection --> 78 + ) * 2 + 78) * 2 # L0 convblock, right --> 468 + ], + [ + {"0": 2}, + ((((1 * 2 * 2 + 1) # L0 convblock, left --> 5 + * 2 * 2 + 5) * 2 # L1 convblock --> 50 + + 5 # UNet skip connection --> 55 + ) * 2 * 2 + 55) # L0 convblock, right --> 275 + ], + # fmt: on + ], +) +def test_forward_skips(mocker, skips, expected): + mocker.patch("torch.nn.Conv2d.forward", lambda _, x: 2 * x) unet = convnet.architecture.UNet( - kernel_sizes=[3, 3], + kernel_sizes=[1], list_num_channels=[[1, 1, 1], [1, 1, 1, 1], [1, 1, 1]], - downsample=partial(torch.nn.AvgPool2d, kernel_size=2), - upsample=partial(torch.nn.Upsample, scale_factor=2), - skips={"0": 1}, - ) - result = unet.forward(torch.ones([1, 1, 2, 2])) - assert_array_equal( - result.cpu().detach().numpy(), 12 * torch.ones([1, 1, 2, 2]).cpu().detach().numpy() + downsample=partial(torch.nn.AvgPool2d, kernel_size=1), + upsample=partial(torch.nn.Upsample, scale_factor=1), + skips=skips, + unet_skip_mode="sum", ) + result = unet.forward(torch.ones((1, 1, 1, 1))) + assert_array_equal(result.detach().numpy(), torch.full((1, 1, 1, 1), expected).numpy()) @pytest.mark.parametrize( diff --git a/zetta_utils/convnet/architecture/convblock.py b/zetta_utils/convnet/architecture/convblock.py index 768290ab6..10a524196 100644 --- a/zetta_utils/convnet/architecture/convblock.py +++ b/zetta_utils/convnet/architecture/convblock.py @@ -168,7 +168,7 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: skip_dest = self.skips[str(conv_count)] if skip_dest in skip_data_for: size = _get_size(result) - skip_data_for[skip_dest] += crop_center(skip_data_for[skip_dest], size) + skip_data_for[skip_dest] += crop_center(result, size) else: skip_data_for[skip_dest] = result