Skip to content

Commit

Permalink
Merge pull request #320 from IBM/removed-pretrained-fallback
Browse files Browse the repository at this point in the history
Remove fallback by error with pretrained weights
  • Loading branch information
romeokienzler authored Dec 11, 2024
2 parents 1cdc2d4 + d191e29 commit 16e5af9
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def checkpoint_filter_wrapper_fn(state_dict, model):
model_args = prithvi_cfgs[variant].copy()
model_args.update(kwargs)

# When the pretrained configuration is not available in HF, we shift to pretrained=False
try:
model = build_model_with_cfg(
prithvi_model_class,
Expand All @@ -230,16 +229,13 @@ def checkpoint_filter_wrapper_fn(state_dict, model):
pretrained_strict=True,
**model_args,
)
except RuntimeError:
logger.warning(f"No pretrained configuration was found for the model {variant}. Using random initialization.")
model = build_model_with_cfg(
prithvi_model_class,
variant,
False,
pretrained_filter_fn=checkpoint_filter_wrapper_fn,
pretrained_strict=True,
**model_args,
)
except RuntimeError as e:
if pretrained:
logger.error(f"Failed to initialize the pre-trained model {variant} via timm, "
f"consider running the code with pretrained=False.")
else:
logger.error(f"Failed to initialize the model {variant} via timm.")
raise e

if encoder_only:
default_out_indices = list(range(len(model.blocks)))
Expand Down

0 comments on commit 16e5af9

Please sign in to comment.