Skip to content

Commit

Permalink
Minor updates to UNETR model function (#843)
Browse files Browse the repository at this point in the history
Updates to unetr model checkpoint loading
  • Loading branch information
anwai98 authored Jan 23, 2025
1 parent 552dc55 commit d930618
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import os
import warnings
from abc import ABC
from copy import deepcopy
from collections import OrderedDict
Expand Down Expand Up @@ -747,6 +748,7 @@ def get_unetr(
decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
device: Optional[Union[str, torch.device]] = None,
out_channels: int = 3,
flexible_load_checkpoint: bool = False,
) -> torch.nn.Module:
"""Get UNETR model for automatic instance segmentation.
Expand All @@ -756,6 +758,8 @@ def get_unetr(
decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
device: The device.
out_channels: The number of output channels.
flexible_load_checkpoint: Whether to allow reinitialization of parameters
which could not be found in the provided decoder state.
Returns:
The UNETR model.
Expand All @@ -775,7 +779,18 @@ def get_unetr(
unetr_state_dict = unetr.state_dict()
for k, v in unetr_state_dict.items():
if not k.startswith("encoder"):
unetr_state_dict[k] = decoder_state[k]
if flexible_load_checkpoint: # Whether allow reinitalization of params, if not found.
if k in decoder_state: # First check whether the key is available in the provided decoder state.
unetr_state_dict[k] = decoder_state[k]
else: # Otherwise, allow it to initialize it.
warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.")
unetr_state_dict[k] = v

else: # Whether be strict on finding the parameter in the decoder state.
if k not in decoder_state:
raise RuntimeError(f"The parameters for '{k}' could not be found.")
unetr_state_dict[k] = decoder_state[k]

unetr.load_state_dict(unetr_state_dict)

unetr.to(device)
Expand Down

0 comments on commit d930618

Please sign in to comment.