diff --git a/docs/models.md b/docs/models.md index 52071076..33b27008 100644 --- a/docs/models.md +++ b/docs/models.md @@ -34,6 +34,13 @@ We also provide a model factory that can build a task specific model for a downs By passing a list of bands being used to the constructor, we automatically filter out unused bands, and randomly initialize weights for new bands that were not pretrained on. +!!! info + + To pass your own path from where to load the weights with the PrithviModelFactory, you can make use of timm's `pretrained_cfg_overlay`. + E.g. to pass a local path, you can pass the parameter `backbone_pretrained_cfg_overlay = {"file": ""}` to the model factory. + + Besides `file`, you can also pass `url`, `hf_hub_id`, amongst others. Check timm's documentation for full details. + :::terratorch.models.backbones.prithvi_select_patch_embed_weights ## Decoders diff --git a/docs/quick_start.md b/docs/quick_start.md index 7f0373a3..5f18c80c 100644 --- a/docs/quick_start.md +++ b/docs/quick_start.md @@ -105,6 +105,17 @@ task = PixelwiseRegressionTask( At this level of abstraction, you can also provide a configuration file (see [LightningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html#lightning-cli)) with all the details of the training. See an example for semantic segmentation below: +!!! info + + To pass your own path from where to load the weights with the PrithviModelFactory, you can make use of timm's `pretrained_cfg_overlay`. + E.g. to pass a local path, you can add, under model_args: + + ```yaml + backbone_pretrained_cfg_overlay: + file: + ``` + Besides `file`, you can also pass `url`, `hf_hub_id`, amongst others. Check timm's documentation for full details. + ```yaml title="Configuration file for a Semantic Segmentation Task" # lightning.pytorch==2.1.1 seed_everything: 0 diff --git a/examples/confs/sen1floods11_vit_local_ckpt.yaml b/examples/confs/sen1floods11_vit_local_ckpt.yaml index 6e1ac1ec..6cdcc3d9 100644 --- a/examples/confs/sen1floods11_vit_local_ckpt.yaml +++ b/examples/confs/sen1floods11_vit_local_ckpt.yaml @@ -87,8 +87,8 @@ model: backbone_patch_size: 16 backbone_in_chans: 6 backbone_num_frames: 1 - backbone_checkpoint_path: examples/Prithvi_100M.pt - # backbone_window_size: 8 + backbone_pretrained_cfg_overlay: + file: examples/Prithvi_100M.pt decoder_channels: 256 in_channels: 6 bands: diff --git a/src/terratorch/models/prithvi_model_factory.py b/src/terratorch/models/prithvi_model_factory.py index e1e9dcf3..ad424084 100644 --- a/src/terratorch/models/prithvi_model_factory.py +++ b/src/terratorch/models/prithvi_model_factory.py @@ -1,10 +1,8 @@ -import importlib from collections.abc import Callable -from typing import Any import timm +import torch from torch import nn -import torch import terratorch.models.decoders as decoder_registry from terratorch.datasets import HLSBands @@ -15,11 +13,8 @@ ModelFactory, register_factory, ) - from terratorch.models.pixel_wise_model import PixelWiseModel from terratorch.models.scalar_output_model import ScalarOutputModel -from terratorch.models.backbones.vit_encoder_decoder import TemporalViTEncoder -from terratorch.models.backbones.prithvi_vit import checkpoint_filter_fn PIXEL_WISE_TASKS = ["segmentation", "regression"] SCALAR_TASKS = ["classification"] @@ -29,40 +24,6 @@ class DecoderNotFoundError(Exception): pass -class AttrInfo: - - def __init__(self, channels:callable): - - self.channels = channels - -class ModelWrapper(nn.Module): - - def __init__(self, **kwargs) -> None: - - super(ModelWrapper, self).__init__() - - # Little hack because VIT does not support timm's features_only - # so we do it ourselves - encoder_only = kwargs.get("features_only", False) - if "features_only" in kwargs: - kwargs = {k: v for k, v in kwargs.items() if k != "features_only"} - - self.config = kwargs - self.model = TemporalViTEncoder(**kwargs, encoder_only=True) - - # Sharing methods between model and wrapper - self.forward = self.model.forward_features - self.encode_decode_forward = self.model.forward - self.prepare_features_for_image_model = self.model.prepare_features_for_image_model - - # Adaptation to allow the wrapper to acess information about - # the model - self.feature_info = AttrInfo(channels=self.channels) - - def channels(self): - return self.config["num_heads"]*[self.config["embed_dim"]] - - @register_factory class PrithviModelFactory(ModelFactory): def build_model( @@ -106,14 +67,12 @@ def build_model( aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well. rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size - is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True. + is different from the ground truth. Only applicable to pixel wise models + (e.g. segmentation, pixel wise regression). Defaults to True. - Raises: - NotImplementedError: _description_ - DecoderNotFoundException: _description_ Returns: - nn.Module: _description_ + nn.Module: Full model with encoder, decoder and head. """ if not torch.cuda.is_available(): self.CPU_ONLY = True @@ -136,39 +95,15 @@ def build_model( backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_") - if "checkpoint_path" in backbone_kwargs: - checkpoint_path = backbone_kwargs.pop("checkpoint_path") - else: - checkpoint_path = None - - # Currently the Prithvi restoring needs access to some - # registries which are not always available on all the - # system. The alternative is to load from a local - # checkpoint. - try: - backbone: nn.Module = timm.create_model( - backbone, - pretrained=pretrained, - in_chans=in_channels, # this can be removed, can be derived from bands. But is a breaking change. - num_frames=num_frames, - bands=bands, - features_only=True, - **backbone_kwargs, - ) - except Exception: - - print(f"Trying to load from the path defined in the config file, {checkpoint_path}.") - backbone = ModelWrapper(**backbone_kwargs, features_only=True) - - if self.CPU_ONLY: - model_dict = torch.load(checkpoint_path, map_location="cpu") - else: - model_dict = torch.load(checkpoint_path) - - model_dict = checkpoint_filter_fn(model_dict, model=backbone.model, pretrained_bands=bands, model_bands=bands) - - backbone.model.load_state_dict(model_dict, strict=False) - + backbone: nn.Module = timm.create_model( + backbone, + pretrained=pretrained, + in_chans=in_channels, # this can be removed, can be derived from bands. But is a breaking change. + num_frames=num_frames, + bands=bands, + features_only=True, + **backbone_kwargs, + ) # allow decoder to be a module passed directly decoder_cls = _get_decoder(decoder) @@ -192,13 +127,10 @@ def build_model( aux_decoder_cls: nn.Module = _get_decoder(aux_decoder.decoder) aux_decoder_kwargs, kwargs = _extract_prefix_keys(args, "decoder_") aux_decoder_instance = aux_decoder_cls(backbone.feature_info.channels(), **aux_decoder_kwargs) - # aux_decoder_instance = aux_decoder_cls([128, 256, 512, 1024], **decoder_kwargs) aux_head_kwargs, kwargs = _extract_prefix_keys(args, "head_") if num_classes: aux_head_kwargs["num_classes"] = num_classes - # aux_head: nn.Module = _get_head(task, aux_decoder_instance, num_classes=num_classes, **head_kwargs) - # aux_decoder.decoder = nn.Sequential(aux_decoder_instance, aux_head) to_be_aux_decoders.append( AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs) )