Skip to content

Commit

Permalink
Fix timm pretrained error
Browse files Browse the repository at this point in the history
Signed-off-by: Benedikt Blumenstiel <[email protected]>
  • Loading branch information
blumenstiel committed Dec 11, 2024
1 parent 97a2f61 commit 48c7d7e
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def _cfg(**kwargs):
**kwargs
}

prithvi_default_cfgs = {

prithvi_cfgs = {
"prithvi_eo_tiny": _cfg(num_frames=1, embed_dim=256, depth=4, num_heads=4,
decoder_embed_dim=128, decoder_depth=4, decoder_num_heads=4),
"prithvi_eo_v1_100": _cfg(num_frames=3, mean=PRITHVI_V1_MEAN, std=PRITHVI_V1_STD),
Expand All @@ -64,8 +65,8 @@ def _cfg(**kwargs):
coords_encoding=["time", "location"], coords_scale_learn=True),
}


pretrained_cfgs = generate_default_cfgs(
# Timm pretrained configs
default_cfgs = generate_default_cfgs(
{
"prithvi_eo_v1_100": {
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-1.0-100M",
Expand Down Expand Up @@ -203,19 +204,20 @@ def checkpoint_filter_wrapper_fn(state_dict, model):
return checkpoint_filter_fn_mae(state_dict, model, pretrained_bands, model_bands)

if pretrained:
assert variant in pretrained_cfgs, (f"No pre-trained model found for variant {variant} "
f"(pretrained models: {pretrained_cfgs.keys()})")
assert variant in default_cfgs, (f"No pre-trained model found for variant {variant} "
f"(pretrained models: {default_cfgs.keys()})")
# Load pre-trained config from hf
try:
model_args, _ = load_model_config_from_hf(pretrained_cfgs[variant].default.hf_hub_id)
model_args, _ = load_model_config_from_hf(default_cfgs[variant].default.hf_hub_id)
model_args.update(kwargs)
except:
logger.warning(f"No pretrained configuration was found on HuggingFace for the model {variant}.")
model_args = prithvi_default_cfgs[variant].copy()
logger.warning(f"No pretrained configuration was found on HuggingFace for the model {variant}."
f"Using random initialization.")
model_args = prithvi_cfgs[variant].copy()
model_args.update(kwargs)
else:
# Load default config
model_args = prithvi_default_cfgs[variant].copy()
model_args = prithvi_cfgs[variant].copy()
model_args.update(kwargs)

# When the pretrained configuration is not available in HF, we shift to pretrained=False
Expand All @@ -229,7 +231,7 @@ def checkpoint_filter_wrapper_fn(state_dict, model):
**model_args,
)
except RuntimeError:
logger.warning(f"No pretrained configuration was found for the model {variant}.")
logger.warning(f"No pretrained configuration was found for the model {variant}. Using random initialization.")
model = build_model_with_cfg(
prithvi_model_class,
variant,
Expand Down

0 comments on commit 48c7d7e

Please sign in to comment.