Skip to content

Commit

Permalink
Simplified prithvi code
Browse files Browse the repository at this point in the history
Signed-off-by: Benedikt Blumenstiel <[email protected]>
  • Loading branch information
blumenstiel committed Jan 21, 2025
1 parent 3222bf1 commit 498fbf6
Showing 1 changed file with 65 additions and 40 deletions.
105 changes: 65 additions & 40 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _create_prithvi(
) -> PrithviViT:
if pretrained_bands is None:
pretrained_bands = PRETRAINED_BANDS

kwargs["num_frames"] = kwargs.pop("num_frames", 1) # Set num frames to 1 if not present

if model_bands is None:
model_bands: list[HLSBands | int] = pretrained_bands
Expand Down Expand Up @@ -253,43 +253,23 @@ def checkpoint_filter_wrapper_fn(state_dict, model):
logger.error(f"Failed to initialize the model {variant}.")
raise e

assert encoder_only or "out_indices" not in kwargs, "out_indices provided for a MAE model."
if encoder_only:
default_out_indices = list(range(len(model.blocks)))
out_indices = kwargs.pop("out_indices", default_out_indices)
# TODO: Needed?
# model.encode_decode_forward = model.forward

def forward_filter_indices(*args, **kwargs):
features = model.forward_features(*args, **kwargs)
return [features[i] for i in out_indices]

model.forward = forward_filter_indices
model.out_indices = out_indices
model.model_bands = model_bands
model.pretrained_bands = pretrained_bands

return model


def create_prithvi_from_config(
model_name: str,
pretrained: bool = False, # noqa: FBT001, FBT002
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:
pretrained_bands = PRETRAINED_BANDS
kwargs["num_frames"] = kwargs.pop("num_frames", 1) # Set num frames to 1 if not present

model = _create_prithvi(
model_name,
pretrained=pretrained,
model_bands=bands,
pretrained_bands=pretrained_bands,
**kwargs,
)

return model


@ TERRATORCH_BACKBONE_REGISTRY.register
def prithvi_vit_tiny(
pretrained: bool = False, # noqa: FBT001, FBT002
Expand All @@ -310,68 +290,113 @@ def prithvi_eo_tiny(
**kwargs,
) -> PrithviViT:

return create_prithvi_from_config("prithvi_eo_tiny", pretrained, bands, **kwargs)
return _create_prithvi("prithvi_eo_tiny", pretrained, bands, **kwargs)


@ TERRATORCH_BACKBONE_REGISTRY.register
def prithvi_vit_100(
def prithvi_eo_v1_100(
pretrained: bool = False, # noqa: FBT001, FBT002
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:

logger.warning(f"The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. "
f"prithvi_vit_100 will be removed in a future version.")
return _create_prithvi("prithvi_eo_v1_100", pretrained, bands, **kwargs)

return prithvi_eo_v1_100(pretrained=pretrained, bands=bands, **kwargs)

@ TERRATORCH_BACKBONE_REGISTRY.register
def prithvi_eo_v2_300(
pretrained: bool = False, # noqa: FBT001, FBT002
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:

return _create_prithvi("prithvi_eo_v2_300", pretrained, bands, **kwargs)


@ TERRATORCH_BACKBONE_REGISTRY.register
def prithvi_eo_v1_100(
def prithvi_eo_v2_600(
pretrained: bool = False, # noqa: FBT001, FBT002
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:

return create_prithvi_from_config("prithvi_eo_v1_100", pretrained, bands, **kwargs)
return _create_prithvi("prithvi_eo_v2_600", pretrained, bands, **kwargs)


@ TERRATORCH_BACKBONE_REGISTRY.register
def prithvi_eo_v2_300(
def prithvi_eo_v2_300_tl(
pretrained: bool = False, # noqa: FBT001, FBT002
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:

return create_prithvi_from_config("prithvi_eo_v2_300", pretrained, bands, **kwargs)
return _create_prithvi("prithvi_eo_v2_300_tl", pretrained, bands, **kwargs)


@ TERRATORCH_BACKBONE_REGISTRY.register
def prithvi_eo_v2_600(
def prithvi_eo_v2_600_tl(
pretrained: bool = False, # noqa: FBT001, FBT002
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:

return _create_prithvi("prithvi_eo_v2_600_tl", pretrained, bands, **kwargs)


return create_prithvi_from_config("prithvi_eo_v2_600", pretrained, bands, **kwargs)
# TODO: Remove timm_ errors in before version v1.0.
@ TERRATORCH_BACKBONE_REGISTRY.register
def prithvi_vit_100(
pretrained: bool = False, # noqa: FBT001, FBT002
bands: list[HLSBands] | None = None,
**kwargs,
) -> None:
raise ValueError("The model prithvi_vit_100 was renamed to prithvi_eo_v1_100.")


@ TERRATORCH_BACKBONE_REGISTRY.register
def prithvi_eo_v2_300_tl(
def timm_prithvi_eo_v1_100(
pretrained: bool = False, # noqa: FBT001, FBT002
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:
) -> None:
raise ValueError("The Prithvi models were moved to the terratorch registry. "
"Please remove the timm_ prefix from the model name.")


@ TERRATORCH_BACKBONE_REGISTRY.register
def timm_prithvi_eo_v2_300(
pretrained: bool = False, # noqa: FBT001, FBT002
bands: list[HLSBands] | None = None,
**kwargs,
) -> None:
raise ValueError("The Prithvi models were moved to the terratorch registry. "
"Please remove the timm_ prefix from the model name.")

return create_prithvi_from_config("prithvi_eo_v2_300_tl", pretrained, bands, **kwargs)

@ TERRATORCH_BACKBONE_REGISTRY.register
def timm_prithvi_eo_v2_600(
pretrained: bool = False, # noqa: FBT001, FBT002
bands: list[HLSBands] | None = None,
**kwargs,
) -> None:
raise ValueError("The Prithvi models were moved to the terratorch registry. "
"Please remove the timm_ prefix from the model name.")

@ TERRATORCH_BACKBONE_REGISTRY.register
def prithvi_eo_v2_600_tl(
def timm_prithvi_eo_v2_300_tl(
pretrained: bool = False, # noqa: FBT001, FBT002
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:
) -> None:
raise ValueError("The Prithvi models were moved to the terratorch registry. "
"Please remove the timm_ prefix from the model name.")


return create_prithvi_from_config("prithvi_eo_v2_600_tl", pretrained, bands, **kwargs)
@ TERRATORCH_BACKBONE_REGISTRY.register
def timm_prithvi_eo_v2_600_tl(
pretrained: bool = False, # noqa: FBT001, FBT002
bands: list[HLSBands] | None = None,
**kwargs,
) -> None:
raise ValueError("The Prithvi models were moved to the terratorch registry. "
"Please remove the timm_ prefix from the model name.")

0 comments on commit 498fbf6

Please sign in to comment.