Skip to content

Commit

Permalink
simplify model loading from local ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
Carlos Gomes committed May 17, 2024
1 parent 0ef8d9a commit c0cf729
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 83 deletions.
4 changes: 2 additions & 2 deletions examples/confs/sen1floods11_vit_local_ckpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ model:
backbone_patch_size: 16
backbone_in_chans: 6
backbone_num_frames: 1
backbone_checkpoint_path: examples/Prithvi_100M.pt
# backbone_window_size: 8
backbone_pretrained_cfg_overlay:
file: examples/Prithvi_100M.pt
decoder_channels: 256
in_channels: 6
bands:
Expand Down
94 changes: 13 additions & 81 deletions src/terratorch/models/prithvi_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import importlib
from collections.abc import Callable
from typing import Any

import timm
import torch
from torch import nn
import torch

import terratorch.models.decoders as decoder_registry
from terratorch.datasets import HLSBands
Expand All @@ -15,11 +13,8 @@
ModelFactory,
register_factory,
)

from terratorch.models.pixel_wise_model import PixelWiseModel
from terratorch.models.scalar_output_model import ScalarOutputModel
from terratorch.models.backbones.vit_encoder_decoder import TemporalViTEncoder
from terratorch.models.backbones.prithvi_vit import checkpoint_filter_fn

PIXEL_WISE_TASKS = ["segmentation", "regression"]
SCALAR_TASKS = ["classification"]
Expand All @@ -29,40 +24,6 @@
class DecoderNotFoundError(Exception):
pass

class AttrInfo:

def __init__(self, channels:callable):

self.channels = channels

class ModelWrapper(nn.Module):

def __init__(self, **kwargs) -> None:

super(ModelWrapper, self).__init__()

# Little hack because VIT does not support timm's features_only
# so we do it ourselves
encoder_only = kwargs.get("features_only", False)
if "features_only" in kwargs:
kwargs = {k: v for k, v in kwargs.items() if k != "features_only"}

self.config = kwargs
self.model = TemporalViTEncoder(**kwargs, encoder_only=True)

# Sharing methods between model and wrapper
self.forward = self.model.forward_features
self.encode_decode_forward = self.model.forward
self.prepare_features_for_image_model = self.model.prepare_features_for_image_model

# Adaptation to allow the wrapper to acess information about
# the model
self.feature_info = AttrInfo(channels=self.channels)

def channels(self):
return self.config["num_heads"]*[self.config["embed_dim"]]


@register_factory
class PrithviModelFactory(ModelFactory):
def build_model(
Expand Down Expand Up @@ -106,14 +67,12 @@ def build_model(
aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model.
These decoders take the input from the encoder as well.
rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.
is different from the ground truth. Only applicable to pixel wise models
(e.g. segmentation, pixel wise regression). Defaults to True.
Raises:
NotImplementedError: _description_
DecoderNotFoundException: _description_
Returns:
nn.Module: _description_
nn.Module: Full model with encoder, decoder and head.
"""
if not torch.cuda.is_available():
self.CPU_ONLY = True
Expand All @@ -136,39 +95,15 @@ def build_model(

backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_")

if "checkpoint_path" in backbone_kwargs:
checkpoint_path = backbone_kwargs.pop("checkpoint_path")
else:
checkpoint_path = None

# Currently the Prithvi restoring needs access to some
# registries which are not always available on all the
# system. The alternative is to load from a local
# checkpoint.
try:
backbone: nn.Module = timm.create_model(
backbone,
pretrained=pretrained,
in_chans=in_channels, # this can be removed, can be derived from bands. But is a breaking change.
num_frames=num_frames,
bands=bands,
features_only=True,
**backbone_kwargs,
)
except Exception:

print(f"Trying to load from the path defined in the config file, {checkpoint_path}.")
backbone = ModelWrapper(**backbone_kwargs, features_only=True)

if self.CPU_ONLY:
model_dict = torch.load(checkpoint_path, map_location="cpu")
else:
model_dict = torch.load(checkpoint_path)

model_dict = checkpoint_filter_fn(model_dict, model=backbone.model, pretrained_bands=bands, model_bands=bands)

backbone.model.load_state_dict(model_dict, strict=False)

backbone: nn.Module = timm.create_model(
backbone,
pretrained=pretrained,
in_chans=in_channels, # this can be removed, can be derived from bands. But is a breaking change.
num_frames=num_frames,
bands=bands,
features_only=True,
**backbone_kwargs,
)
# allow decoder to be a module passed directly
decoder_cls = _get_decoder(decoder)

Expand All @@ -192,13 +127,10 @@ def build_model(
aux_decoder_cls: nn.Module = _get_decoder(aux_decoder.decoder)
aux_decoder_kwargs, kwargs = _extract_prefix_keys(args, "decoder_")
aux_decoder_instance = aux_decoder_cls(backbone.feature_info.channels(), **aux_decoder_kwargs)
# aux_decoder_instance = aux_decoder_cls([128, 256, 512, 1024], **decoder_kwargs)

aux_head_kwargs, kwargs = _extract_prefix_keys(args, "head_")
if num_classes:
aux_head_kwargs["num_classes"] = num_classes
# aux_head: nn.Module = _get_head(task, aux_decoder_instance, num_classes=num_classes, **head_kwargs)
# aux_decoder.decoder = nn.Sequential(aux_decoder_instance, aux_head)
to_be_aux_decoders.append(
AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
)
Expand Down

0 comments on commit c0cf729

Please sign in to comment.