Skip to content

Commit

Permalink
Merge pull request #3 from IBM/feature/load_from_local_ckpt
Browse files Browse the repository at this point in the history
simplify model loading from local ckpt
  • Loading branch information
CarlosGomes98 authored May 17, 2024
2 parents 0ef8d9a + 80b3b6c commit 5de7e68
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 83 deletions.
7 changes: 7 additions & 0 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ We also provide a model factory that can build a task specific model for a downs

By passing a list of bands being used to the constructor, we automatically filter out unused bands, and randomly initialize weights for new bands that were not pretrained on.

!!! info

To pass your own path from where to load the weights with the PrithviModelFactory, you can make use of timm's `pretrained_cfg_overlay`.
E.g. to pass a local path, you can pass the parameter `backbone_pretrained_cfg_overlay = {"file": "<local_path>"}` to the model factory.

Besides `file`, you can also pass `url`, `hf_hub_id`, amongst others. Check timm's documentation for full details.

:::terratorch.models.backbones.prithvi_select_patch_embed_weights

## Decoders
Expand Down
11 changes: 11 additions & 0 deletions docs/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ task = PixelwiseRegressionTask(

At this level of abstraction, you can also provide a configuration file (see [LightningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html#lightning-cli)) with all the details of the training. See an example for semantic segmentation below:

!!! info

To pass your own path from where to load the weights with the PrithviModelFactory, you can make use of timm's `pretrained_cfg_overlay`.
E.g. to pass a local path, you can add, under model_args:

```yaml
backbone_pretrained_cfg_overlay:
file: <local_path>
```
Besides `file`, you can also pass `url`, `hf_hub_id`, amongst others. Check timm's documentation for full details.

```yaml title="Configuration file for a Semantic Segmentation Task"
# lightning.pytorch==2.1.1
seed_everything: 0
Expand Down
4 changes: 2 additions & 2 deletions examples/confs/sen1floods11_vit_local_ckpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ model:
backbone_patch_size: 16
backbone_in_chans: 6
backbone_num_frames: 1
backbone_checkpoint_path: examples/Prithvi_100M.pt
# backbone_window_size: 8
backbone_pretrained_cfg_overlay:
file: examples/Prithvi_100M.pt
decoder_channels: 256
in_channels: 6
bands:
Expand Down
94 changes: 13 additions & 81 deletions src/terratorch/models/prithvi_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import importlib
from collections.abc import Callable
from typing import Any

import timm
import torch
from torch import nn
import torch

import terratorch.models.decoders as decoder_registry
from terratorch.datasets import HLSBands
Expand All @@ -15,11 +13,8 @@
ModelFactory,
register_factory,
)

from terratorch.models.pixel_wise_model import PixelWiseModel
from terratorch.models.scalar_output_model import ScalarOutputModel
from terratorch.models.backbones.vit_encoder_decoder import TemporalViTEncoder
from terratorch.models.backbones.prithvi_vit import checkpoint_filter_fn

PIXEL_WISE_TASKS = ["segmentation", "regression"]
SCALAR_TASKS = ["classification"]
Expand All @@ -29,40 +24,6 @@
class DecoderNotFoundError(Exception):
pass

class AttrInfo:

def __init__(self, channels:callable):

self.channels = channels

class ModelWrapper(nn.Module):

def __init__(self, **kwargs) -> None:

super(ModelWrapper, self).__init__()

# Little hack because VIT does not support timm's features_only
# so we do it ourselves
encoder_only = kwargs.get("features_only", False)
if "features_only" in kwargs:
kwargs = {k: v for k, v in kwargs.items() if k != "features_only"}

self.config = kwargs
self.model = TemporalViTEncoder(**kwargs, encoder_only=True)

# Sharing methods between model and wrapper
self.forward = self.model.forward_features
self.encode_decode_forward = self.model.forward
self.prepare_features_for_image_model = self.model.prepare_features_for_image_model

# Adaptation to allow the wrapper to acess information about
# the model
self.feature_info = AttrInfo(channels=self.channels)

def channels(self):
return self.config["num_heads"]*[self.config["embed_dim"]]


@register_factory
class PrithviModelFactory(ModelFactory):
def build_model(
Expand Down Expand Up @@ -106,14 +67,12 @@ def build_model(
aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model.
These decoders take the input from the encoder as well.
rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.
is different from the ground truth. Only applicable to pixel wise models
(e.g. segmentation, pixel wise regression). Defaults to True.
Raises:
NotImplementedError: _description_
DecoderNotFoundException: _description_
Returns:
nn.Module: _description_
nn.Module: Full model with encoder, decoder and head.
"""
if not torch.cuda.is_available():
self.CPU_ONLY = True
Expand All @@ -136,39 +95,15 @@ def build_model(

backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_")

if "checkpoint_path" in backbone_kwargs:
checkpoint_path = backbone_kwargs.pop("checkpoint_path")
else:
checkpoint_path = None

# Currently the Prithvi restoring needs access to some
# registries which are not always available on all the
# system. The alternative is to load from a local
# checkpoint.
try:
backbone: nn.Module = timm.create_model(
backbone,
pretrained=pretrained,
in_chans=in_channels, # this can be removed, can be derived from bands. But is a breaking change.
num_frames=num_frames,
bands=bands,
features_only=True,
**backbone_kwargs,
)
except Exception:

print(f"Trying to load from the path defined in the config file, {checkpoint_path}.")
backbone = ModelWrapper(**backbone_kwargs, features_only=True)

if self.CPU_ONLY:
model_dict = torch.load(checkpoint_path, map_location="cpu")
else:
model_dict = torch.load(checkpoint_path)

model_dict = checkpoint_filter_fn(model_dict, model=backbone.model, pretrained_bands=bands, model_bands=bands)

backbone.model.load_state_dict(model_dict, strict=False)

backbone: nn.Module = timm.create_model(
backbone,
pretrained=pretrained,
in_chans=in_channels, # this can be removed, can be derived from bands. But is a breaking change.
num_frames=num_frames,
bands=bands,
features_only=True,
**backbone_kwargs,
)
# allow decoder to be a module passed directly
decoder_cls = _get_decoder(decoder)

Expand All @@ -192,13 +127,10 @@ def build_model(
aux_decoder_cls: nn.Module = _get_decoder(aux_decoder.decoder)
aux_decoder_kwargs, kwargs = _extract_prefix_keys(args, "decoder_")
aux_decoder_instance = aux_decoder_cls(backbone.feature_info.channels(), **aux_decoder_kwargs)
# aux_decoder_instance = aux_decoder_cls([128, 256, 512, 1024], **decoder_kwargs)

aux_head_kwargs, kwargs = _extract_prefix_keys(args, "head_")
if num_classes:
aux_head_kwargs["num_classes"] = num_classes
# aux_head: nn.Module = _get_head(task, aux_decoder_instance, num_classes=num_classes, **head_kwargs)
# aux_decoder.decoder = nn.Sequential(aux_decoder_instance, aux_head)
to_be_aux_decoders.append(
AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
)
Expand Down

0 comments on commit 5de7e68

Please sign in to comment.