diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index c1da9237..f9ba8908 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -4,6 +4,7 @@ """ import os +import warnings from abc import ABC from copy import deepcopy from collections import OrderedDict @@ -747,6 +748,7 @@ def get_unetr( decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None, device: Optional[Union[str, torch.device]] = None, out_channels: int = 3, + flexible_load_checkpoint: bool = False, ) -> torch.nn.Module: """Get UNETR model for automatic instance segmentation. @@ -756,6 +758,8 @@ def get_unetr( decoder_state: Optional decoder state to initialize the weights of the UNETR decoder. device: The device. out_channels: The number of output channels. + flexible_load_checkpoint: Whether to allow reinitialization of parameters + which could not be found in the provided decoder state. Returns: The UNETR model. @@ -775,7 +779,18 @@ def get_unetr( unetr_state_dict = unetr.state_dict() for k, v in unetr_state_dict.items(): if not k.startswith("encoder"): - unetr_state_dict[k] = decoder_state[k] + if flexible_load_checkpoint: # Whether allow reinitalization of params, if not found. + if k in decoder_state: # First check whether the key is available in the provided decoder state. + unetr_state_dict[k] = decoder_state[k] + else: # Otherwise, allow it to initialize it. + warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.") + unetr_state_dict[k] = v + + else: # Whether be strict on finding the parameter in the decoder state. + if k not in decoder_state: + raise RuntimeError(f"The parameters for '{k}' could not be found.") + unetr_state_dict[k] = decoder_state[k] + unetr.load_state_dict(unetr_state_dict) unetr.to(device)