diff --git a/terratorch/models/backbones/select_patch_embed_weights.py b/terratorch/models/backbones/select_patch_embed_weights.py index aa373ecc..fda7a060 100644 --- a/terratorch/models/backbones/select_patch_embed_weights.py +++ b/terratorch/models/backbones/select_patch_embed_weights.py @@ -12,7 +12,7 @@ def patch_embed_weights_are_compatible(model_patch_embed: torch.Tensor, checkpoint_patch_embed: torch.Tensor) -> bool: # check all dimensions are the same except for channel dimension if len(model_patch_embed.shape) != len(checkpoint_patch_embed.shape): - return False + return False model_shape = [model_patch_embed.shape[i] for i in range(len(model_patch_embed.shape)) if i != 1] checkpoint_shape = [checkpoint_patch_embed.shape[i] for i in range(len(checkpoint_patch_embed.shape)) if i != 1]