Skip to content

Commit

Permalink
Fix load pretrain weight issue in ResNet (#7924)
Browse files Browse the repository at this point in the history
Fixes #7923


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
  • Loading branch information
KumoLiu and ericspod authored Jul 18, 2024
1 parent 7e4f141 commit 46e2b0e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
7 changes: 3 additions & 4 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,16 +510,15 @@ def _resnet(
# Check model bias_downsample and shortcut_type
bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth)
if shortcut_type == kwargs.get("shortcut_type", "B") and (
bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True
bias_downsample == kwargs.get("bias_downsample", True)
):
# Download the MedicalNet pretrained model
model_state_dict = get_pretrained_resnet_medicalnet(
resnet_depth, device=device, datasets23=True
)
else:
raise NotImplementedError(
f"Please set shortcut_type to {shortcut_type} and bias_downsample to"
f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}"
f"Please set shortcut_type to {shortcut_type} and bias_downsample to {bias_downsample} "
f"when using pretrained MedicalNet resnet{resnet_depth}"
)
else:
Expand Down Expand Up @@ -681,7 +680,7 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int):
# After testing
# False: 10, 50, 101, 152, 200
# Any: 18, 34
bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34
bias_downsample = resnet_depth in (18, 34)
shortcut_type = "A" if resnet_depth in [18, 34] else "B"
return bias_downsample, shortcut_type

Expand Down
8 changes: 3 additions & 5 deletions tests/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape):
@parameterized.expand(PRETRAINED_TEST_CASES)
@skip_if_quick
@skip_if_no_cuda
def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape):
def test_resnet_pretrained(self, model, input_param, _input_shape, _expected_shape):
net = model(**input_param).to(device)
# Save ckpt
torch.save(net.state_dict(), self.tmp_ckpt_filename)
Expand All @@ -290,9 +290,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape
and input_param.get("n_input_channels", 3) == 1
and input_param.get("feed_forward", True) is False
and input_param.get("shortcut_type", "B") == shortcut_type
and (
input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True
)
and (input_param.get("bias_downsample", True) == bias_downsample)
):
model(**cp_input_param)
else:
Expand All @@ -303,7 +301,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape
cp_input_param["n_input_channels"] = 1
cp_input_param["feed_forward"] = False
cp_input_param["shortcut_type"] = shortcut_type
cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample != -1 else True
cp_input_param["bias_downsample"] = bias_downsample
if cp_input_param.get("spatial_dims", 3) == 3:
with skip_if_downloading_fails():
pretrained_net = model(**cp_input_param).to(device)
Expand Down

0 comments on commit 46e2b0e

Please sign in to comment.