diff --git a/terratorch/models/__init__.py b/terratorch/models/__init__.py index 3770465f..95bf674a 100644 --- a/terratorch/models/__init__.py +++ b/terratorch/models/__init__.py @@ -2,7 +2,6 @@ import logging - import terratorch.models.necks # register necks # noqa: F401 from terratorch.models.encoder_decoder_factory import EncoderDecoderFactory from terratorch.models.generic_unet_model_factory import GenericUnetModelFactory @@ -11,24 +10,26 @@ from terratorch.models.satmae_model_factory import SatMAEModelFactory from terratorch.models.smp_model_factory import SMPModelFactory from terratorch.models.timm_model_factory import TimmModelFactory +from terratorch.models.full_model_factory import FullModelFactory try: + granitewcx = True from terratorch.models.wxc_model_factory import WxCModelFactory except ImportError: - import logging logging.getLogger("terratorch").debug("granitewcx not installed") + granitewcx = False __all__ = ( "PrithviModelFactory", "ClayModelFactory", "SatMAEModelFactory", - "ScaleMAEModelFactory", "SMPModelFactory", "GenericUnetModelFactory", "TimmModelFactory", - "AuxiliaryHead", - "AuxiliaryHeadWithDecoderWithoutInstantiatedHead", - "UNet", "WxCModelFactory", - "EncoderDecoderFactory" + "EncoderDecoderFactory", + "FullModelFactory", ) + +if granitewcx: + __all__.__add__((WxCModelFactory,)) diff --git a/terratorch/models/full_model_factory.py b/terratorch/models/full_model_factory.py new file mode 100644 index 00000000..83839b2c --- /dev/null +++ b/terratorch/models/full_model_factory.py @@ -0,0 +1,82 @@ +# Copyright contributors to the Terratorch project + +import warnings +from torch import nn + +from terratorch.models.model import ModelFactory +from terratorch.models.peft_utils import get_peft_backbone +from terratorch.registry import FULL_MODEL_REGISTRY, MODEL_FACTORY_REGISTRY + + +def _get_model(model: str | nn.Module, **model_kwargs) -> nn.Module: + if isinstance(model, nn.Module): + return model + return FULL_MODEL_REGISTRY.build(model, **model_kwargs) + + +def _check_all_args_used(kwargs): + if kwargs: + msg = f"arguments {kwargs} were passed but not used." + raise ValueError(msg) + + +@MODEL_FACTORY_REGISTRY.register +class FullModelFactory(ModelFactory): + def build_model( + self, + model: str | nn.Module, + rescale: bool = True, # noqa: FBT002, FBT001 + padding: str = "reflect", + peft_config: dict | None = None, + **kwargs, + ) -> nn.Module: + """Generic model factory that wraps any model. + + All kwargs are passed to the model. + + Args: + task (str): Task to be performed. Currently supports "segmentation" and "regression". + model (str, nn.Module): Model to be used. If a string, will look for such models in the different + registries supported (internal terratorch registry, ...). If a torch nn.Module, will use it + directly. + rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size + is different from the ground truth. Only applicable to pixel wise models + (e.g. segmentation, pixel wise regression, reconstruction). Defaults to True. + padding (str): Padding method used if images are not divisible by the patch size. Defaults to "reflect". + peft_config (dict): Configuration options for using [PEFT](https://huggingface.co/docs/peft/index). + The dictionary should have the following keys: + - "method": Which PEFT method to use. Should be one implemented in PEFT, a list is available [here](https://huggingface.co/docs/peft/package_reference/peft_types#peft.PeftType). + - "replace_qkv": String containing a substring of the name of the submodules to replace with QKVSep. + This should be used when the qkv matrices are merged together in a single linear layer and the PEFT + method should be applied separately to query, key and value matrices (e.g. if LoRA is only desired in + Q and V matrices). e.g. If using Prithvi this should be "qkv" + - "peft_config_kwargs": Dictionary containing keyword arguments which will be passed to [PeftConfig](https://huggingface.co/docs/peft/package_reference/config#peft.PeftConfig) + + + Returns: + nn.Module: Full model. + """ + + model = _get_model(model, **kwargs) + + # If patch size is not provided in the config or by the model, it might lead to errors due to irregular images. + patch_size = kwargs.get("patch_size", None) + + if patch_size is None: + # Infer patch size from model by checking all backbone modules + for module in model.modules(): + if hasattr(module, "patch_size"): + patch_size = module.patch_size + break + + if peft_config is not None: + if not kwargs.get("pretrained", False): + msg = ( + "You are using PEFT without a pretrained backbone. If you are loading a checkpoint afterwards " + "this is probably fine, but if you are training a model check the backbone_pretrained parameter." + ) + warnings.warn(msg, stacklevel=1) + + model = get_peft_backbone(peft_config, model) + + return model