From 46e2b0eb9a7c8249177df775047fa098af26f6a6 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 18 Jul 2024 19:49:20 +0800 Subject: [PATCH] Fix load pretrain weight issue in ResNet (#7924) Fixes #7923 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/networks/nets/resnet.py | 7 +++---- tests/test_resnet.py | 8 +++----- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 6e61db07ca..d62722478e 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -510,7 +510,7 @@ def _resnet( # Check model bias_downsample and shortcut_type bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) if shortcut_type == kwargs.get("shortcut_type", "B") and ( - bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True + bias_downsample == kwargs.get("bias_downsample", True) ): # Download the MedicalNet pretrained model model_state_dict = get_pretrained_resnet_medicalnet( @@ -518,8 +518,7 @@ def _resnet( ) else: raise NotImplementedError( - f"Please set shortcut_type to {shortcut_type} and bias_downsample to" - f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}" + f"Please set shortcut_type to {shortcut_type} and bias_downsample to {bias_downsample} " f"when using pretrained MedicalNet resnet{resnet_depth}" ) else: @@ -681,7 +680,7 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int): # After testing # False: 10, 50, 101, 152, 200 # Any: 18, 34 - bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 + bias_downsample = resnet_depth in (18, 34) shortcut_type = "A" if resnet_depth in [18, 34] else "B" return bias_downsample, shortcut_type diff --git a/tests/test_resnet.py b/tests/test_resnet.py index e873f1238a..a55d18f5de 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -266,7 +266,7 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape): @parameterized.expand(PRETRAINED_TEST_CASES) @skip_if_quick @skip_if_no_cuda - def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape): + def test_resnet_pretrained(self, model, input_param, _input_shape, _expected_shape): net = model(**input_param).to(device) # Save ckpt torch.save(net.state_dict(), self.tmp_ckpt_filename) @@ -290,9 +290,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape and input_param.get("n_input_channels", 3) == 1 and input_param.get("feed_forward", True) is False and input_param.get("shortcut_type", "B") == shortcut_type - and ( - input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True - ) + and (input_param.get("bias_downsample", True) == bias_downsample) ): model(**cp_input_param) else: @@ -303,7 +301,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape cp_input_param["n_input_channels"] = 1 cp_input_param["feed_forward"] = False cp_input_param["shortcut_type"] = shortcut_type - cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample != -1 else True + cp_input_param["bias_downsample"] = bias_downsample if cp_input_param.get("spatial_dims", 3) == 3: with skip_if_downloading_fails(): pretrained_net = model(**cp_input_param).to(device)