From 124cd0230e3aaf891d6bf17481107fe0f69c65c5 Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Wed, 18 Dec 2024 12:04:47 +0100 Subject: [PATCH 01/40] Implement multiple test dataloaders all tasks Signed-off-by: Francesc Marti Escofet --- terratorch/tasks/base_task.py | 5 ++- terratorch/tasks/classification_tasks.py | 46 ++++++++++++++++------- terratorch/tasks/regression_tasks.py | 48 ++++++++++++++++++------ terratorch/tasks/segmentation_tasks.py | 12 +++--- 4 files changed, 78 insertions(+), 33 deletions(-) diff --git a/terratorch/tasks/base_task.py b/terratorch/tasks/base_task.py index e59aaf39..69f744b7 100644 --- a/terratorch/tasks/base_task.py +++ b/terratorch/tasks/base_task.py @@ -71,8 +71,9 @@ def on_validation_epoch_end(self) -> None: self.val_metrics.reset() def on_test_epoch_end(self) -> None: - self.log_dict(self.test_metrics.compute(), sync_dist=True) - self.test_metrics.reset() + for metrics in self.test_metrics: + self.log_dict(metrics.compute(), sync_dist=True) + metrics.reset() def _do_plot_samples(self, batch_index): if not self.plot_on_val: # dont plot if self.plot_on_val is 0 diff --git a/terratorch/tasks/classification_tasks.py b/terratorch/tasks/classification_tasks.py index 89974004..bf952ab8 100644 --- a/terratorch/tasks/classification_tasks.py +++ b/terratorch/tasks/classification_tasks.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Any import logging import lightning @@ -16,7 +17,8 @@ from terratorch.tasks.optimizer_factory import optimizer_factory from terratorch.tasks.base_task import TerraTorchTask -logger = logging.getLogger('terratorch') +logger = logging.getLogger("terratorch") + def to_class_prediction(y: ModelOutput) -> Tensor: y_hat = y.output @@ -33,6 +35,7 @@ class ClassificationTask(TerraTorchTask): - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - It provides mIoU with both Micro and Macro averaging + - Allows to evaluate on multiple test dataloaders .. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect @@ -62,6 +65,7 @@ def __init__( freeze_backbone: bool = False, # noqa: FBT001, FBT002 freeze_decoder: bool = False, # noqa: FBT002, FBT001 class_names: list[str] | None = None, + test_dataloaders_names: list[str] | None = None, ) -> None: """Constructor @@ -97,6 +101,9 @@ def __init__( freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False. class_names (list[str] | None, optional): List of class names passed to metrics for better naming. Defaults to numeric ordering. + test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when + multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, + which assumes only one test dataloader is used. """ self.aux_loss = aux_loss self.aux_heads = aux_heads @@ -116,11 +123,12 @@ def __init__( self.model = model self.train_loss_handler = LossHandler(self.train_metrics.prefix) - self.test_loss_handler = LossHandler(self.test_metrics.prefix) + self.test_loss_handler: list[LossHandler] = [] + for metrics in self.test_metrics: + self.test_loss_handler.append(LossHandler(metrics.prefix)) self.val_loss_handler = LossHandler(self.val_metrics.prefix) self.monitor = f"{self.val_metrics.prefix}loss" - def configure_losses(self) -> None: """Initialize the loss criterion. @@ -131,8 +139,8 @@ def configure_losses(self) -> None: ignore_index = self.hparams["ignore_index"] class_weights = ( - torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None - ) + torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None + ) if loss == "ce": ignore_value = -100 if ignore_index is None else ignore_index self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights) @@ -187,7 +195,12 @@ def configure_metrics(self) -> None: ) self.train_metrics = metrics.clone(prefix="train/") self.val_metrics = metrics.clone(prefix="val/") - self.test_metrics = metrics.clone(prefix="test/") + if self.hparams["test_dataloaders_names"] is not None: + self.test_metrics = nn.ModuleList( + [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]] + ) + else: + self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")]) def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the train loss and additional metrics. @@ -200,7 +213,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> x = batch["image"] y = batch["label"] other_keys = batch.keys() - {"image", "label", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) @@ -221,7 +234,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - x = batch["image"] y = batch["label"] other_keys = batch.keys() - {"image", "label", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) @@ -239,12 +252,19 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None x = batch["image"] y = batch["label"] other_keys = batch.keys() - {"image", "label", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) - loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + if dataloader_idx >= len(self.test_loss_handler): + msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." + raise ValueError(msg) + loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss) + self.test_loss_handler[dataloader_idx].log_loss( + partial(self.log, add_dataloader_idx=False), # We don't need the dataloader idx as prefixes are different + loss_dict=loss, + batch_size=x.shape[0], + ) y_hat_hard = to_class_prediction(model_output) - self.test_metrics.update(y_hat_hard, y) + self.test_metrics[dataloader_idx].update(y_hat_hard, y) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities. @@ -260,7 +280,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T x = batch["image"] file_names = batch["filename"] other_keys = batch.keys() - {"image", "label", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) y_hat = self(x).output diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index 29bbc00f..f856e1b8 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -1,6 +1,7 @@ """This module contains the regression task and its auxiliary classes.""" from collections.abc import Sequence +from functools import partial from typing import Any import logging @@ -24,7 +25,8 @@ BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 -logger = logging.getLogger('terratorch') +logger = logging.getLogger("terratorch") + class RootLossWrapper(nn.Module): def __init__(self, loss_function: nn.Module, reduction: None | str = "mean") -> None: @@ -129,7 +131,8 @@ class PixelwiseRegressionTask(TerraTorchTask): - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - - Allows the setting of optimizers in the constructor""" + - Allows the setting of optimizers in the constructor + - Allows to evaluate on multiple test dataloaders""" def __init__( self, @@ -152,6 +155,7 @@ def __init__( freeze_decoder: bool = False, # noqa: FBT001, FBT002 plot_on_val: bool | int = 10, tiled_inference_parameters: TiledInferenceParameters | None = None, + test_dataloaders_names: list[str] | None = None, ) -> None: """Constructor @@ -186,6 +190,9 @@ def __init__( If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs. tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters used to determine if inference is done on the whole image or through tiling. + test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when + multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, + which assumes only one test dataloader is used. """ self.tiled_inference_parameters = tiled_inference_parameters self.aux_loss = aux_loss @@ -206,7 +213,9 @@ def __init__( self.model = model self.train_loss_handler = LossHandler(self.train_metrics.prefix) - self.test_loss_handler = LossHandler(self.test_metrics.prefix) + self.test_loss_handler: list[LossHandler] = [] + for metrics in self.test_metrics: + self.test_loss_handler.append(LossHandler(metrics.prefix)) self.val_loss_handler = LossHandler(self.val_metrics.prefix) self.monitor = f"{self.val_metrics.prefix}loss" self.plot_on_val = int(plot_on_val) @@ -253,7 +262,17 @@ def wrap_metrics_with_ignore_index(metrics): self.train_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="train/") self.val_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="val/") - self.test_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/") + if self.hparams["test_dataloaders_names"] is not None: + self.test_metrics = nn.ModuleList( + [ + MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix=f"test/{dl_name}/") + for dl_name in self.hparams["test_dataloaders_names"] + ] + ) + else: + self.test_metrics = nn.ModuleList( + [MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/")] + ) def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the train loss and additional metrics. @@ -266,7 +285,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) @@ -287,7 +306,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) @@ -329,12 +348,19 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) - loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + if dataloader_idx >= len(self.test_loss_handler): + msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." + raise ValueError(msg) + loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss) + self.test_loss_handler[dataloader_idx].log_loss( + partial(self.log, add_dataloader_idx=False), # We don't need the dataloader idx as prefixes are different + loss_dict=loss, + batch_size=x.shape[0], + ) y_hat = model_output.output - self.test_metrics.update(y_hat, y) + self.test_metrics[dataloader_idx].update(y_hat, y) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities. @@ -350,7 +376,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T x = batch["image"] file_names = batch["filename"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} def model_forward(x): return self(x).output diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 48e80221..0d999dd1 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -261,11 +261,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None y_hat_hard = to_segmentation_prediction(model_output) self.test_metrics[dataloader_idx].update(y_hat_hard, y) - def on_test_epoch_end(self) -> None: - for metrics in self.test_metrics: - self.log_dict(metrics.compute(), sync_dist=True) - metrics.reset() - def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the validation loss and additional metrics. Args: @@ -291,7 +286,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - batch["prediction"] = y_hat_hard if isinstance(batch["image"], dict): - if hasattr(datamodule, 'rgb_modality'): + if hasattr(datamodule, "rgb_modality"): # Generic multimodal dataset batch["image"] = batch["image"][datamodule.rgb_modality] else: @@ -337,7 +332,10 @@ def model_forward(x): if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference( # TODO: tiled inference does not work with additional input data (**rest) - model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters + model_forward, + x, + self.hparams["model_args"]["num_classes"], + self.tiled_inference_parameters, ) else: y_hat: Tensor = self(x, **rest).output From 8827ab4568ea6dbe310662fa977602eda2e04c9d Mon Sep 17 00:00:00 2001 From: Romeo Kienzler <5694071+romeokienzler@users.noreply.github.com> Date: Mon, 6 Jan 2025 12:24:08 +0000 Subject: [PATCH 02/40] Create smoke.py --- integrationtests/smoke.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 integrationtests/smoke.py diff --git a/integrationtests/smoke.py b/integrationtests/smoke.py new file mode 100644 index 00000000..1124e15a --- /dev/null +++ b/integrationtests/smoke.py @@ -0,0 +1,4 @@ +import pytest + +def test_smoke(): + None From cfff96ff9a0937b8ee42e40c1cd093d07aefc149 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler <5694071+romeokienzler@users.noreply.github.com> Date: Mon, 6 Jan 2025 13:07:45 +0000 Subject: [PATCH 03/40] add annotation --- integrationtests/smoke.py | 1 + 1 file changed, 1 insertion(+) diff --git a/integrationtests/smoke.py b/integrationtests/smoke.py index 1124e15a..c579cb31 100644 --- a/integrationtests/smoke.py +++ b/integrationtests/smoke.py @@ -1,4 +1,5 @@ import pytest +@pytest.mark.integration def test_smoke(): None From 6a6b1a0320d9ce2fd4f1109b37b92ca90c878c13 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler <5694071+romeokienzler@users.noreply.github.com> Date: Mon, 6 Jan 2025 13:09:01 +0000 Subject: [PATCH 04/40] add dummy code --- integrationtests/smoke.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrationtests/smoke.py b/integrationtests/smoke.py index c579cb31..ef2ba792 100644 --- a/integrationtests/smoke.py +++ b/integrationtests/smoke.py @@ -2,4 +2,4 @@ @pytest.mark.integration def test_smoke(): - None + assert True From 904c0cd14dd09bcc71a38015d855dc03c9608a54 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Mon, 6 Jan 2025 08:10:44 -0500 Subject: [PATCH 05/40] fix naming of test file --- integrationtests/{smoke.py => test_smoke.py} | 1 - 1 file changed, 1 deletion(-) rename integrationtests/{smoke.py => test_smoke.py} (65%) diff --git a/integrationtests/smoke.py b/integrationtests/test_smoke.py similarity index 65% rename from integrationtests/smoke.py rename to integrationtests/test_smoke.py index ef2ba792..afda48f2 100644 --- a/integrationtests/smoke.py +++ b/integrationtests/test_smoke.py @@ -1,5 +1,4 @@ import pytest -@pytest.mark.integration def test_smoke(): assert True From 3e8f73973ce85c53270a079509b9708a7e918f4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 13 Jan 2025 09:39:58 -0300 Subject: [PATCH 06/40] Testing the repository using the installation from pyproject MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .github/workflows/test.yaml | 3 ++- pyproject.toml | 19 ++++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 5b522c26..10ad4df9 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -28,7 +28,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements/required.txt -r requirements/test.txt -r requirements/optional.txt + #pip install -r requirements/required.txt -r requirements/test.txt -r requirements/optional.txt + pip install -e .[torchgeo] pip install git+https://github.com/NASA-IMPACT/Prithvi-WxC.git pip install git+https://github.com/IBM/granite-wxc.git - name: List pip dependencies diff --git a/pyproject.toml b/pyproject.toml index 6a477ac4..28a5130e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "torchgeo>=0.6.0", "rioxarray>=0.15.0", # see issue #64 - "albumentations>=1.3.1, <=1.4.10", + "albumentations>=1.3.1, <=1.4.21", "albucore<=0.0.16", "rasterio>=1.3.9", "torchmetrics<=1.3.1", @@ -49,6 +49,23 @@ dependencies = [ ] [project.optional-dependencies] +torchgeo = [ + "torch==2.4.1", + "torchvision==0.19.1", + "torchgeo @ git+https://github.com/microsoft/torchgeo.git@fedf99375535f801565856cd774bfa9e5a251d55", + "rioxarray>=0.15.0", + "albumentations>=1.3.1, <=1.4.21", + "albucore<=0.0.16", + "rasterio>=1.3.9", + "torchmetrics<=1.3.1", + "geopandas>=0.14.4", + "lightly>=1.4.25", + "h5py>=3.10.0", + "mlflow>=2.12.1", + "lightning[pytorch-extra]>=2,!=2.3.*", + "segmentation-models-pytorch>=0.3" +] + dev = [ "black", "mkdocs-material", From ea77654d3e685bcd58b678069ecee1c4d2c76c0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 13 Jan 2025 09:59:37 -0300 Subject: [PATCH 07/40] Additional dependencies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .github/workflows/test.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 10ad4df9..cc2fc6c4 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -29,6 +29,7 @@ jobs: run: | python -m pip install --upgrade pip #pip install -r requirements/required.txt -r requirements/test.txt -r requirements/optional.txt + pip install -r requirements/test.txt -r requirements/optional.txt pip install -e .[torchgeo] pip install git+https://github.com/NASA-IMPACT/Prithvi-WxC.git pip install git+https://github.com/IBM/granite-wxc.git From 9a286f20190016852bf6fdd864f156c4f9c48dc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 13 Jan 2025 10:53:11 -0300 Subject: [PATCH 08/40] pinning albumentations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 28a5130e..1d0c5645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ torchgeo = [ "torchvision==0.19.1", "torchgeo @ git+https://github.com/microsoft/torchgeo.git@fedf99375535f801565856cd774bfa9e5a251d55", "rioxarray>=0.15.0", - "albumentations>=1.3.1, <=1.4.21", + "albumentations==1.3.1", "albucore<=0.0.16", "rasterio>=1.3.9", "torchmetrics<=1.3.1", From 9ae2424fbbee1055e869165e1c777579c402083d Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Fri, 17 Jan 2025 21:35:09 +0100 Subject: [PATCH 09/40] add wxc conf --- examples/confs/wxc-gravity-wave.yaml | 135 +++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 examples/confs/wxc-gravity-wave.yaml diff --git a/examples/confs/wxc-gravity-wave.yaml b/examples/confs/wxc-gravity-wave.yaml new file mode 100644 index 00000000..bcef626e --- /dev/null +++ b/examples/confs/wxc-gravity-wave.yaml @@ -0,0 +1,135 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: + name: fire_scars + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 40 + + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: + +# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars +data: + class_path: GenericNonGeoSegmentationDataModule + init_args: + batch_size: 4 + num_workers: 8 + dataset_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_transform: + - class_path: albumentations.RandomCrop + init_args: + height: 224 + width: 224 + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: ToTensorV2 + no_data_replace: 0 + no_label_replace: -1 + train_data_root: /training + train_label_data_root: /training + val_data_root: /validation + val_label_data_root: /validation + test_data_root: /validation + test_label_data_root: /validation + img_grep: "*_merged.tif" + label_grep: "*.mask.tif" + means: + - 0.033349706741586264 + - 0.05701185520536176 + - 0.05889748132001316 + - 0.2323245113436119 + - 0.1972854853760658 + - 0.11944914225186566 + stds: + - 0.02269135568823774 + - 0.026807560223070237 + - 0.04004109844362779 + - 0.07791732423672691 + - 0.08708738838140137 + - 0.07241979477437814 + num_classes: 2 + +model: + class_path: terratorch.models.wxc_model_factory.WxCModelFactory + init_args: + model_args: + in_channels: 1280 + input_size_time: 1 + n_lats_px: 64 + n_lons_px: 128 + patch_size_px: [2 2] + mask_unit_size_px: [8 16] + mask_ratio_inputs: 0.5 + embed_dim: 2560 + n_blocks_encoder: 12 + n_blocks_decoder: 2 + mlp_multiplier: 4 + n_heads: 16 + dropout: 0.0 + drop_path: 0.05 + parameter_dropout: 0.0 + residual: none + masking_mode: both + decoder_shifting: False + positional_encoding: absolute + checkpoint_encoder: [3 6 9 12 15 18 21 24] + checkpoint_decoder: [1 3] + in_channels_static: 3 + input_scalers_mu: torch.tensor([0] * 1280) + input_scalers_sigma: torch.tensor([1] * 1280) + input_scalers_epsilon: 0 + static_input_scalers_mu: torch.tensor([0] * 3) + static_input_scalers_sigma: torch.tensor([1] * 3) + static_input_scalers_epsilon: 0 + output_scalers: torch.tensor([0] * 1280) + backbone_weights: magnet-flux-uvtp122-epoch-99-loss-0.1022.pt + backbone: prithviwxc + aux_decoders: unetpincer + model_factory: WxCModelFactory + mode: eval +optimizer: + class_path: torch.optim.Adam + init_args: + lr: 1.5e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss From 270abffd57d6b0ab2155ef2c340c096ea4938b61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 17 Jan 2025 18:37:58 -0300 Subject: [PATCH 10/40] registry and minor changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- examples/confs/wxc-gravity-wave.yaml | 3 ++- terratorch/models/wxc_model_factory.py | 1 + terratorch/tasks/__init__.py | 5 ++++- terratorch/tasks/wxc_task.py | 12 +++++++----- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/examples/confs/wxc-gravity-wave.yaml b/examples/confs/wxc-gravity-wave.yaml index bcef626e..f481e45c 100644 --- a/examples/confs/wxc-gravity-wave.yaml +++ b/examples/confs/wxc-gravity-wave.yaml @@ -87,7 +87,7 @@ data: num_classes: 2 model: - class_path: terratorch.models.wxc_model_factory.WxCModelFactory + class_path: WxCTask init_args: model_args: in_channels: 1280 @@ -122,6 +122,7 @@ model: backbone_weights: magnet-flux-uvtp122-epoch-99-loss-0.1022.pt backbone: prithviwxc aux_decoders: unetpincer + skip_connection: True model_factory: WxCModelFactory mode: eval optimizer: diff --git a/terratorch/models/wxc_model_factory.py b/terratorch/models/wxc_model_factory.py index f446509a..abcbfeaa 100644 --- a/terratorch/models/wxc_model_factory.py +++ b/terratorch/models/wxc_model_factory.py @@ -61,6 +61,7 @@ def build_model( raise #remove parameters not meant for the backbone but for other parts of the model + print(kwargs) skip_connection = kwargs.pop('skip_connection') backbone = prithviwxc.PrithviWxC(**kwargs) diff --git a/terratorch/tasks/__init__.py b/terratorch/tasks/__init__.py index 790c10ec..782b0f08 100644 --- a/terratorch/tasks/__init__.py +++ b/terratorch/tasks/__init__.py @@ -1,3 +1,4 @@ +import logging from terratorch.tasks.classification_tasks import ClassificationTask from terratorch.tasks.regression_tasks import PixelwiseRegressionTask from terratorch.tasks.segmentation_tasks import SemanticSegmentationTask @@ -6,6 +7,8 @@ try: wxc_present = True from terratorch.tasks.wxc_downscaling_task import WxCDownscalingTask + from terratorch.tasks.wxc_task import WxCTask + logging.getLogger('terratorch').debug('wxc_downscaling found.') except ImportError as e: import logging logging.getLogger('terratorch').debug('wxc_downscaling not installed') @@ -21,4 +24,4 @@ ) if wxc_present: - __all__.__add__(("WxCDownscalingTask", )) + __all__.__add__(("WxCDownscalingTask", "WxCTask",)) diff --git a/terratorch/tasks/wxc_task.py b/terratorch/tasks/wxc_task.py index 87312a9a..3e4b4421 100644 --- a/terratorch/tasks/wxc_task.py +++ b/terratorch/tasks/wxc_task.py @@ -1,17 +1,19 @@ - - from torchgeo.trainers import BaseTask import torch.nn as nn import torch import logging logger = logging.getLogger(__name__) +from terratorch.registry import MODEL_FACTORY_REGISTRY + class WxCTask(BaseTask): - def __init__(self, model_factory, model_args: dict, mode, learning_rate=0.1): + def __init__(self, model_factory, model_args: dict, mode:str='train', learning_rate=0.1): if mode not in ['train', 'eval']: raise ValueError(f'mode {mode} is not supported. (train, eval)') self.model_args = model_args - self.model_factory = model_factory + + self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) + self.learning_rate = learning_rate super().__init__() @@ -34,4 +36,4 @@ def training_step(self, batch, batch_idx): def train_dataloader(self): return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True) - \ No newline at end of file + From b53d07a9af54c57a196df8dca45d44fba7ece886 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Mon, 20 Jan 2025 14:22:57 +0100 Subject: [PATCH 11/40] fix data module --- examples/confs/wxc-gravity-wave.yaml | 65 +++------------------------- terratorch/datamodules/__init__.py | 1 + 2 files changed, 6 insertions(+), 60 deletions(-) diff --git a/examples/confs/wxc-gravity-wave.yaml b/examples/confs/wxc-gravity-wave.yaml index f481e45c..b504df0e 100644 --- a/examples/confs/wxc-gravity-wave.yaml +++ b/examples/confs/wxc-gravity-wave.yaml @@ -29,62 +29,7 @@ trainer: # dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars data: - class_path: GenericNonGeoSegmentationDataModule - init_args: - batch_size: 4 - num_workers: 8 - dataset_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - output_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - rgb_indices: - - 2 - - 1 - - 0 - train_transform: - - class_path: albumentations.RandomCrop - init_args: - height: 224 - width: 224 - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - - class_path: ToTensorV2 - no_data_replace: 0 - no_label_replace: -1 - train_data_root: /training - train_label_data_root: /training - val_data_root: /validation - val_label_data_root: /validation - test_data_root: /validation - test_label_data_root: /validation - img_grep: "*_merged.tif" - label_grep: "*.mask.tif" - means: - - 0.033349706741586264 - - 0.05701185520536176 - - 0.05889748132001316 - - 0.2323245113436119 - - 0.1972854853760658 - - 0.11944914225186566 - stds: - - 0.02269135568823774 - - 0.026807560223070237 - - 0.04004109844362779 - - 0.07791732423672691 - - 0.08708738838140137 - - 0.07241979477437814 - num_classes: 2 + class_path: terratorch.datamodules.era5.ERA5DataModule model: class_path: WxCTask @@ -94,8 +39,8 @@ model: input_size_time: 1 n_lats_px: 64 n_lons_px: 128 - patch_size_px: [2 2] - mask_unit_size_px: [8 16] + patch_size_px: [2, 2] + mask_unit_size_px: [8, 16] mask_ratio_inputs: 0.5 embed_dim: 2560 n_blocks_encoder: 12 @@ -109,8 +54,8 @@ model: masking_mode: both decoder_shifting: False positional_encoding: absolute - checkpoint_encoder: [3 6 9 12 15 18 21 24] - checkpoint_decoder: [1 3] + checkpoint_encoder: [3, 6, 9, 12, 15, 18, 21, 24] + checkpoint_decoder: [1, 3] in_channels_static: 3 input_scalers_mu: torch.tensor([0] * 1280) input_scalers_sigma: torch.tensor([1] * 1280) diff --git a/terratorch/datamodules/__init__.py b/terratorch/datamodules/__init__.py index f97c75fe..bef29be2 100644 --- a/terratorch/datamodules/__init__.py +++ b/terratorch/datamodules/__init__.py @@ -27,6 +27,7 @@ from terratorch.datamodules.multi_temporal_crop_classification import MultiTemporalCropClassificationDataModule from terratorch.datamodules.open_sentinel_map import OpenSentinelMapDataModule from terratorch.datamodules.pastis import PASTISDataModule +from terratorch.datamodules.era5 import ERA5DataModule try: wxc_present = True From 7f1105f87d3d19ad7c72fcd5488e4cc1cb45fbbd Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Mon, 20 Jan 2025 15:41:02 +0100 Subject: [PATCH 12/40] Trigger tests Signed-off-by: Francesc Marti Escofet From 9fd114bfa315c1006e8d8b6b4a789a40b07e4250 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Tue, 21 Jan 2025 02:30:06 -0500 Subject: [PATCH 13/40] add ccc config --- examples/confs/wxc-gravity-wave-ccc.yaml | 85 ++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 examples/confs/wxc-gravity-wave-ccc.yaml diff --git a/examples/confs/wxc-gravity-wave-ccc.yaml b/examples/confs/wxc-gravity-wave-ccc.yaml new file mode 100644 index 00000000..52a4af3a --- /dev/null +++ b/examples/confs/wxc-gravity-wave-ccc.yaml @@ -0,0 +1,85 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: + name: fire_scars + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 40 + + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: + +# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars +data: + class_path: terratorch.datamodules.era5.ERA5DataModule + init_args: + train_data_path: /dccstor/terratorch/users/rkie/gitco/terratorch + valid_data_path: /dccstor/terratorch/users/rkie/gitco/terratorch + file_glob_pattern: "wxc_input_u_v_t_p_output_theta_uw_vw_*.nc" + +model: + class_path: WxCTask + init_args: + model_args: + in_channels: 1280 + input_size_time: 1 + n_lats_px: 64 + n_lons_px: 128 + patch_size_px: [2, 2] + mask_unit_size_px: [8, 16] + mask_ratio_inputs: 0.5 + embed_dim: 2560 + n_blocks_encoder: 12 + n_blocks_decoder: 2 + mlp_multiplier: 4 + n_heads: 16 + dropout: 0.0 + drop_path: 0.05 + parameter_dropout: 0.0 + residual: none + masking_mode: both + decoder_shifting: False + positional_encoding: absolute + checkpoint_encoder: [3, 6, 9, 12, 15, 18, 21, 24] + checkpoint_decoder: [1, 3] + in_channels_static: 3 + input_scalers_mu: torch.tensor([0] * 1280) + input_scalers_sigma: torch.tensor([1] * 1280) + input_scalers_epsilon: 0 + static_input_scalers_mu: torch.tensor([0] * 3) + static_input_scalers_sigma: torch.tensor([1] * 3) + static_input_scalers_epsilon: 0 + output_scalers: torch.tensor([0] * 1280) + backbone_weights: magnet-flux-uvtp122-epoch-99-loss-0.1022.pt + backbone: prithviwxc + aux_decoders: unetpincer + skip_connection: True + model_factory: WxCModelFactory + mode: eval +optimizer: + class_path: torch.optim.Adam + init_args: + lr: 1.5e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss From 3222bf1893fdc91d92d576ddfd4c816bb1097708 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 21 Jan 2025 14:18:03 +0100 Subject: [PATCH 14/40] Replace timm registry with terratorch registry Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_mae.py | 41 +++---- terratorch/models/backbones/prithvi_vit.py | 119 ++++++++++----------- 2 files changed, 70 insertions(+), 90 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index c209b25d..df766932 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -248,16 +248,14 @@ def __init__(self, norm_layer: nn.Module = nn.LayerNorm, coords_encoding: List[str] | None = None, coords_scale_learn: bool = False, - encoder_only: bool = True, # needed for timm ** kwargs, ): super().__init__() - self.feature_info = [] - self.encoder_only = encoder_only self.in_chans = in_chans self.num_frames = num_frames self.embed_dim = embed_dim + self.out_channels = [embed_dim] * depth self.img_size = to_2tuple(img_size) if isinstance(patch_size, int): patch_size = (1, patch_size, patch_size) @@ -287,9 +285,6 @@ def __init__(self, self.blocks = [] for i in range(depth): self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)) - self.feature_info.append( - {"num_chs": embed_dim * self.patch_embed.grid_size[0], "reduction": 1, "module": f"blocks.{i}"} - ) self.blocks = nn.ModuleList(self.blocks) self.norm = norm_layer(embed_dim) @@ -607,7 +602,6 @@ def __init__(self, norm_pix_loss: bool = False, coords_encoding: List[str] | None = None, coords_scale_learn: bool = False, - encoder_only: bool = False, **kwargs, ): super().__init__() @@ -626,24 +620,19 @@ def __init__(self, coords_scale_learn=coords_scale_learn, ) - self.encoder_only = encoder_only - - if not encoder_only: - self.decoder = MAEDecoder( - patch_size=patch_size, - grid_size=self.encoder.patch_embed.grid_size, - in_chans=in_chans, - encoder_embed_dim=embed_dim, - decoder_embed_dim=decoder_embed_dim, - depth=decoder_depth, - num_heads=decoder_num_heads, - mlp_ratio=mlp_ratio, - norm_layer=norm_layer, - coords_encoding=coords_encoding, - coords_scale_learn=coords_scale_learn, - ) - else: - self.decoder = nn.Identity() + self.decoder = MAEDecoder( + patch_size=patch_size, + grid_size=self.encoder.patch_embed.grid_size, + in_chans=in_chans, + encoder_embed_dim=embed_dim, + decoder_embed_dim=decoder_embed_dim, + depth=decoder_depth, + num_heads=decoder_num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + coords_encoding=coords_encoding, + coords_scale_learn=coords_scale_learn, + ) self.norm_pix_loss = norm_pix_loss @@ -730,8 +719,10 @@ def forward( latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio) pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape) loss = self.forward_loss(pixel_values, pred, mask) + # TODO: return loss? return loss, pred, mask + # TODO: forward_features still needed? def forward_features( self, x: torch.Tensor, diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 9918d5f0..eb3ce032 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -3,13 +3,12 @@ import torch import logging from torch import nn, Tensor -from timm.models import (FeatureInfo, load_model_config_from_hf, build_model_with_cfg, generate_default_cfgs, - register_model) - from terratorch.datasets import HLSBands from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights from terratorch.datasets.utils import generate_bands_intervals from terratorch.models.backbones.prithvi_mae import PrithviViT, PrithviMAE +from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY +from timm.models import load_model_config_from_hf, load_state_dict_from_hf logger = logging.getLogger(__name__) @@ -66,8 +65,7 @@ def _cfg(**kwargs): } # Timm pretrained configs -default_cfgs = generate_default_cfgs( - { +pretrained_weights = { "prithvi_eo_v1_100": { "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-1.0-100M", "hf_hub_filename": "Prithvi_EO_V1_100M.pt", @@ -89,7 +87,6 @@ def _cfg(**kwargs): "hf_hub_filename": "Prithvi_EO_V2_600M_TL.pt", }, } -) def checkpoint_filter_fn_vit( @@ -172,6 +169,7 @@ def _create_prithvi( pretrained: bool = False, # noqa: FBT001, FBT002 pretrained_bands: list[HLSBands] | None = None, model_bands: list[HLSBands | int] | None = None, + ckpt_path: str | None = None, **kwargs, ) -> PrithviViT: if pretrained_bands is None: @@ -188,7 +186,18 @@ def _create_prithvi( model_bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in model_bands] # Little hack because VIT does not support timm's features_only - encoder_only = kwargs.pop("features_only", False) + encoder_only = kwargs.pop("features_only", True) + + # Backwards compatibility from timm (pretrained_cfg_overlay={"file": ""}) TODO: Remove before v1.0 + if "pretrained_cfg_overlay" in kwargs: + logger.warning(f"pretrained_cfg_overlay is deprecated and will be removed in a future version, " + f"use ckpt_path= instead.") + if ckpt_path is not None: + logger.warning(f"pretrained_cfg_overlay and ckpt_path are provided, ignoring pretrained_cfg_overlay.") + elif "file" not in kwargs["pretrained_cfg_overlay"]: + logger.warning("pretrained_cfg_overlay does not include 'file path', ignoring pretrained_cfg_overlay.") + else: + ckpt_path = kwargs.pop("pretrained_cfg_overlay")["file"] model_bands = generate_bands_intervals(model_bands) @@ -204,11 +213,12 @@ def checkpoint_filter_wrapper_fn(state_dict, model): return checkpoint_filter_fn_mae(state_dict, model, pretrained_bands, model_bands) if pretrained: - assert variant in default_cfgs, (f"No pre-trained model found for variant {variant} " - f"(pretrained models: {default_cfgs.keys()})") + assert variant in pretrained_weights, (f"No pre-trained model found for variant {variant} " + f"(pretrained models: {pretrained_weights.keys()})") # Load pre-trained config from hf try: - model_args = load_model_config_from_hf(default_cfgs[variant].default.hf_hub_id)[0] + # TODO: Rename model suffix to .ckpt and remove config.json. + model_args = load_model_config_from_hf(pretrained_weights[variant]["hf_hub_id"])[0] model_args.update(kwargs) except: logger.warning(f"No pretrained configuration was found on HuggingFace for the model {variant}." @@ -221,56 +231,42 @@ def checkpoint_filter_wrapper_fn(state_dict, model): model_args.update(kwargs) try: - model = build_model_with_cfg( - prithvi_model_class, - variant, - pretrained, - pretrained_filter_fn=checkpoint_filter_wrapper_fn, - pretrained_strict=True, - **model_args, - ) + model = prithvi_model_class(**model_args) + + if ckpt_path is not None: + # Load model from checkpoint + state_dict = torch.load(ckpt_path, map_location="cpu") + state_dict = checkpoint_filter_wrapper_fn(state_dict, model) + model.load_state_dict(state_dict, strict=False) + elif pretrained: + # Load model from Hugging Face + state_dict = load_state_dict_from_hf(model_id=pretrained_weights[variant]["hf_hub_id"], + filename=pretrained_weights[variant]["hf_hub_filename"]) + state_dict = checkpoint_filter_wrapper_fn(state_dict, model) + model.load_state_dict(state_dict, strict=True) + except RuntimeError as e: if pretrained: - logger.error(f"Failed to initialize the pre-trained model {variant} via timm, " - f"consider running the code with pretrained=False.") + logger.error(f"Failed to initialize the pre-trained model {variant}, " + f"consider testing the code with pretrained=False.") else: - logger.error(f"Failed to initialize the model {variant} via timm.") + logger.error(f"Failed to initialize the model {variant}.") raise e if encoder_only: default_out_indices = list(range(len(model.blocks))) out_indices = kwargs.pop("out_indices", default_out_indices) - model.feature_info = FeatureInfo(model.feature_info, out_indices) - model.encode_decode_forward = model.forward + # TODO: Needed? + # model.encode_decode_forward = model.forward + def forward_filter_indices(*args, **kwargs): features = model.forward_features(*args, **kwargs) return [features[i] for i in out_indices] + model.forward = forward_filter_indices model.model_bands = model_bands model.pretrained_bands = pretrained_bands - padding = kwargs.get("padding", "none") - patch_size = kwargs.get("patch_size", 16) - if isinstance(patch_size, list): - patch_size = patch_size[-1] - - if padding != "none": - original_forward = model.forward - original_forward_features = model.forward_features - - def pad_and_forward(forward_fn, patch_size, padding, *args, **kwargs): - inputs = pad_images(args[0], patch_size, padding) - return forward_fn(inputs, **kwargs) - - def forward_pad_images(*args, **kwargs): - return pad_and_forward(original_forward, patch_size, padding, *args, **kwargs) - - def forward_features_pad_images(*args, **kwargs): - return pad_and_forward(original_forward_features, patch_size, padding, *args, **kwargs) - - model.forward = forward_pad_images - model.forward_features = forward_features_pad_images - return model @@ -281,14 +277,7 @@ def create_prithvi_from_config( **kwargs, ) -> PrithviViT: pretrained_bands = PRETRAINED_BANDS - if bands is None: - bands = pretrained_bands - logger.info( - f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\ - Pretrained patch_embed layer may be misaligned with current bands" - ) - - kwargs['num_frames'] = kwargs.pop('num_frames', 1) # Set num frames to 1 if not present + kwargs["num_frames"] = kwargs.pop("num_frames", 1) # Set num frames to 1 if not present model = _create_prithvi( model_name, @@ -301,20 +290,20 @@ def create_prithvi_from_config( return model -@register_model +@ TERRATORCH_BACKBONE_REGISTRY.register def prithvi_vit_tiny( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - logger.warning(f'The model prithvi_vit_tiny was renamed to prithvi_eo_tiny. ' - f'prithvi_vit_tiny will be removed in a future version.') + logger.warning(f"The model prithvi_vit_tiny was renamed to prithvi_eo_tiny. " + f"prithvi_vit_tiny will be removed in a future version.") return prithvi_eo_tiny(pretrained=pretrained, bands=bands, **kwargs) -@register_model +@ TERRATORCH_BACKBONE_REGISTRY.register def prithvi_eo_tiny( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, @@ -324,20 +313,20 @@ def prithvi_eo_tiny( return create_prithvi_from_config("prithvi_eo_tiny", pretrained, bands, **kwargs) -@register_model +@ TERRATORCH_BACKBONE_REGISTRY.register def prithvi_vit_100( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - logger.warning(f'The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. ' - f'prithvi_vit_100 will be removed in a future version.') + logger.warning(f"The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. " + f"prithvi_vit_100 will be removed in a future version.") return prithvi_eo_v1_100(pretrained=pretrained, bands=bands, **kwargs) -@register_model +@ TERRATORCH_BACKBONE_REGISTRY.register def prithvi_eo_v1_100( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, @@ -347,7 +336,7 @@ def prithvi_eo_v1_100( return create_prithvi_from_config("prithvi_eo_v1_100", pretrained, bands, **kwargs) -@register_model +@ TERRATORCH_BACKBONE_REGISTRY.register def prithvi_eo_v2_300( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, @@ -357,7 +346,7 @@ def prithvi_eo_v2_300( return create_prithvi_from_config("prithvi_eo_v2_300", pretrained, bands, **kwargs) -@register_model +@ TERRATORCH_BACKBONE_REGISTRY.register def prithvi_eo_v2_600( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, @@ -368,7 +357,7 @@ def prithvi_eo_v2_600( return create_prithvi_from_config("prithvi_eo_v2_600", pretrained, bands, **kwargs) -@register_model +@ TERRATORCH_BACKBONE_REGISTRY.register def prithvi_eo_v2_300_tl( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, @@ -378,7 +367,7 @@ def prithvi_eo_v2_300_tl( return create_prithvi_from_config("prithvi_eo_v2_300_tl", pretrained, bands, **kwargs) -@register_model +@ TERRATORCH_BACKBONE_REGISTRY.register def prithvi_eo_v2_600_tl( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, From 498fbf61ac48a8c72b809958336e5212d8c3911b Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 21 Jan 2025 18:51:39 +0100 Subject: [PATCH 15/40] Simplified prithvi code Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_vit.py | 105 +++++++++++++-------- 1 file changed, 65 insertions(+), 40 deletions(-) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index eb3ce032..2cab150a 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -174,7 +174,7 @@ def _create_prithvi( ) -> PrithviViT: if pretrained_bands is None: pretrained_bands = PRETRAINED_BANDS - + kwargs["num_frames"] = kwargs.pop("num_frames", 1) # Set num frames to 1 if not present if model_bands is None: model_bands: list[HLSBands | int] = pretrained_bands @@ -253,43 +253,23 @@ def checkpoint_filter_wrapper_fn(state_dict, model): logger.error(f"Failed to initialize the model {variant}.") raise e + assert encoder_only or "out_indices" not in kwargs, "out_indices provided for a MAE model." if encoder_only: default_out_indices = list(range(len(model.blocks))) out_indices = kwargs.pop("out_indices", default_out_indices) - # TODO: Needed? - # model.encode_decode_forward = model.forward def forward_filter_indices(*args, **kwargs): features = model.forward_features(*args, **kwargs) return [features[i] for i in out_indices] model.forward = forward_filter_indices + model.out_indices = out_indices model.model_bands = model_bands model.pretrained_bands = pretrained_bands return model -def create_prithvi_from_config( - model_name: str, - pretrained: bool = False, # noqa: FBT001, FBT002 - bands: list[HLSBands] | None = None, - **kwargs, -) -> PrithviViT: - pretrained_bands = PRETRAINED_BANDS - kwargs["num_frames"] = kwargs.pop("num_frames", 1) # Set num frames to 1 if not present - - model = _create_prithvi( - model_name, - pretrained=pretrained, - model_bands=bands, - pretrained_bands=pretrained_bands, - **kwargs, - ) - - return model - - @ TERRATORCH_BACKBONE_REGISTRY.register def prithvi_vit_tiny( pretrained: bool = False, # noqa: FBT001, FBT002 @@ -310,68 +290,113 @@ def prithvi_eo_tiny( **kwargs, ) -> PrithviViT: - return create_prithvi_from_config("prithvi_eo_tiny", pretrained, bands, **kwargs) + return _create_prithvi("prithvi_eo_tiny", pretrained, bands, **kwargs) @ TERRATORCH_BACKBONE_REGISTRY.register -def prithvi_vit_100( +def prithvi_eo_v1_100( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - logger.warning(f"The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. " - f"prithvi_vit_100 will be removed in a future version.") + return _create_prithvi("prithvi_eo_v1_100", pretrained, bands, **kwargs) - return prithvi_eo_v1_100(pretrained=pretrained, bands=bands, **kwargs) + +@ TERRATORCH_BACKBONE_REGISTRY.register +def prithvi_eo_v2_300( + pretrained: bool = False, # noqa: FBT001, FBT002 + bands: list[HLSBands] | None = None, + **kwargs, +) -> PrithviViT: + + return _create_prithvi("prithvi_eo_v2_300", pretrained, bands, **kwargs) @ TERRATORCH_BACKBONE_REGISTRY.register -def prithvi_eo_v1_100( +def prithvi_eo_v2_600( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_from_config("prithvi_eo_v1_100", pretrained, bands, **kwargs) + return _create_prithvi("prithvi_eo_v2_600", pretrained, bands, **kwargs) @ TERRATORCH_BACKBONE_REGISTRY.register -def prithvi_eo_v2_300( +def prithvi_eo_v2_300_tl( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_from_config("prithvi_eo_v2_300", pretrained, bands, **kwargs) + return _create_prithvi("prithvi_eo_v2_300_tl", pretrained, bands, **kwargs) @ TERRATORCH_BACKBONE_REGISTRY.register -def prithvi_eo_v2_600( +def prithvi_eo_v2_600_tl( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: + return _create_prithvi("prithvi_eo_v2_600_tl", pretrained, bands, **kwargs) + - return create_prithvi_from_config("prithvi_eo_v2_600", pretrained, bands, **kwargs) +# TODO: Remove timm_ errors in before version v1.0. +@ TERRATORCH_BACKBONE_REGISTRY.register +def prithvi_vit_100( + pretrained: bool = False, # noqa: FBT001, FBT002 + bands: list[HLSBands] | None = None, + **kwargs, +) -> None: + raise ValueError("The model prithvi_vit_100 was renamed to prithvi_eo_v1_100.") @ TERRATORCH_BACKBONE_REGISTRY.register -def prithvi_eo_v2_300_tl( +def timm_prithvi_eo_v1_100( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, -) -> PrithviViT: +) -> None: + raise ValueError("The Prithvi models were moved to the terratorch registry. " + "Please remove the timm_ prefix from the model name.") + + +@ TERRATORCH_BACKBONE_REGISTRY.register +def timm_prithvi_eo_v2_300( + pretrained: bool = False, # noqa: FBT001, FBT002 + bands: list[HLSBands] | None = None, + **kwargs, +) -> None: + raise ValueError("The Prithvi models were moved to the terratorch registry. " + "Please remove the timm_ prefix from the model name.") - return create_prithvi_from_config("prithvi_eo_v2_300_tl", pretrained, bands, **kwargs) +@ TERRATORCH_BACKBONE_REGISTRY.register +def timm_prithvi_eo_v2_600( + pretrained: bool = False, # noqa: FBT001, FBT002 + bands: list[HLSBands] | None = None, + **kwargs, +) -> None: + raise ValueError("The Prithvi models were moved to the terratorch registry. " + "Please remove the timm_ prefix from the model name.") @ TERRATORCH_BACKBONE_REGISTRY.register -def prithvi_eo_v2_600_tl( +def timm_prithvi_eo_v2_300_tl( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, -) -> PrithviViT: +) -> None: + raise ValueError("The Prithvi models were moved to the terratorch registry. " + "Please remove the timm_ prefix from the model name.") + - return create_prithvi_from_config("prithvi_eo_v2_600_tl", pretrained, bands, **kwargs) +@ TERRATORCH_BACKBONE_REGISTRY.register +def timm_prithvi_eo_v2_600_tl( + pretrained: bool = False, # noqa: FBT001, FBT002 + bands: list[HLSBands] | None = None, + **kwargs, +) -> None: + raise ValueError("The Prithvi models were moved to the terratorch registry. " + "Please remove the timm_ prefix from the model name.") From 512634af01044e1c5370e72ec0bfd9821724b75e Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 21 Jan 2025 18:51:52 +0100 Subject: [PATCH 16/40] Updated prithvi tests Signed-off-by: Benedikt Blumenstiel --- ...ufactured-finetune_prithvi_eo_v1_100.yaml} | 2 +- ...-finetune_prithvi_swin_B_segmentation.yaml | 0 ...manufactured-finetune_prithvi_vit_300.yaml | 150 ------------------ tests/test_backbones.py | 42 ++--- tests/test_finetune.py | 5 +- tests/test_prithvi_vit.py | 19 +-- 6 files changed, 28 insertions(+), 190 deletions(-) rename tests/resources/configs/{manufactured-finetune_prithvi_vit_100.yaml => manufactured-finetune_prithvi_eo_v1_100.yaml} (99%) rename tests/{ => resources/configs}/manufactured-finetune_prithvi_swin_B_segmentation.yaml (100%) delete mode 100644 tests/resources/configs/manufactured-finetune_prithvi_vit_300.yaml diff --git a/tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml b/tests/resources/configs/manufactured-finetune_prithvi_eo_v1_100.yaml similarity index 99% rename from tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml rename to tests/resources/configs/manufactured-finetune_prithvi_eo_v1_100.yaml index bb652415..49f8186b 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_eo_v1_100.yaml @@ -96,7 +96,7 @@ model: model_args: decoder: UperNetDecoder pretrained: false - backbone: prithvi_vit_100 + backbone: prithvi_eo_v1_100 #backbone_pretrained_cfg_overlay: #file: tests/all_ecos_random/version_0/checkpoints/epoch=0_state_dict.ckpt #tests/prithvi_vit_100.pt backbone_drop_path_rate: 0.3 diff --git a/tests/manufactured-finetune_prithvi_swin_B_segmentation.yaml b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml similarity index 100% rename from tests/manufactured-finetune_prithvi_swin_B_segmentation.yaml rename to tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml diff --git a/tests/resources/configs/manufactured-finetune_prithvi_vit_300.yaml b/tests/resources/configs/manufactured-finetune_prithvi_vit_300.yaml deleted file mode 100644 index 3e44a1c5..00000000 --- a/tests/resources/configs/manufactured-finetune_prithvi_vit_300.yaml +++ /dev/null @@ -1,150 +0,0 @@ -# lightning.pytorch==2.1.1 -seed_everything: 42 -trainer: - accelerator: cpu - strategy: auto - devices: auto - num_nodes: 1 - # precision: 16-mixed - logger: - class_path: TensorBoardLogger - init_args: - save_dir: tests/ - name: all_ecos_random - callbacks: - - class_path: RichProgressBar - - class_path: LearningRateMonitor - init_args: - logging_interval: epoch - - class_path: EarlyStopping - init_args: - monitor: val/loss - patience: 100 - max_epochs: 2 - check_val_every_n_epoch: 1 - log_every_n_steps: 20 - enable_checkpointing: true - default_root_dir: tests/ -data: - class_path: GenericNonGeoPixelwiseRegressionDataModule - init_args: - batch_size: 2 - num_workers: 4 - train_transform: - #- class_path: albumentations.HorizontalFlip - # init_args: - # p: 0.5 - #- class_path: albumentations.Rotate - # init_args: - # limit: 30 - # border_mode: 0 # cv2.BORDER_CONSTANT - # value: 0 - # # mask_value: 1 - # p: 0.5 - - class_path: ToTensorV2 - dataset_bands: - - 0 - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - - 1 - - 2 - - 3 - - 4 - output_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - rgb_indices: - - 2 - - 1 - - 0 - train_data_root: tests/resources/inputs - train_label_data_root: tests/resources/inputs - val_data_root: tests/resources/inputs - val_label_data_root: tests/resources/inputs - test_data_root: tests/resources/inputs - test_label_data_root: tests/resources/inputs - img_grep: "regression*input*.tif" - label_grep: "regression*label*.tif" - means: - - 547.36707 - - 898.5121 - - 1020.9082 - - 2665.5352 - - 2340.584 - - 1610.1407 - stds: - - 411.4701 - - 558.54065 - - 815.94025 - - 812.4403 - - 1113.7145 - - 1067.641 - no_label_replace: -1 - no_data_replace: 0 - -model: - class_path: terratorch.tasks.PixelwiseRegressionTask - init_args: - model_args: - decoder: UperNetDecoder - pretrained: false - backbone: prithvi_eo_v2_300 - # backbone_pretrained_cfg_overlay: - # file: tests/prithvi_vit_300.pt - backbone_drop_path_rate: 0.3 - # backbone_window_size: 8 - decoder_channels: 64 - num_frames: 1 - in_channels: 6 - bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - head_dropout: 0.5708022831486758 - head_final_act: torch.nn.ReLU - head_learned_upscale_layers: 2 - loss: rmse - #aux_heads: - # - name: aux_head - # decoder: IdentityDecoder - # decoder_args: - # decoder_out_index: 2 - # head_dropout: 0,5 - # head_channel_list: - # - 64 - # head_final_act: torch.nn.ReLU - #aux_loss: - # aux_head: 0.4 - ignore_index: -1 - freeze_backbone: true - freeze_decoder: false - model_factory: PrithviModelFactory - - # uncomment this block for tiled inference - # tiled_inference_parameters: - # h_crop: 224 - # h_stride: 192 - # w_crop: 224 - # w_stride: 192 - # average_patches: true -optimizer: - class_path: torch.optim.AdamW - init_args: - lr: 0.00013524680528283027 - weight_decay: 0.047782217873995426 -lr_scheduler: - class_path: ReduceLROnPlateau - init_args: - monitor: val/loss - diff --git a/tests/test_backbones.py b/tests/test_backbones.py index fd562adc..546250d9 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -35,7 +35,7 @@ def input_386(): return torch.ones((1, NUM_CHANNELS, 386, 386)) -@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) +@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"]) @pytest.mark.parametrize("test_input", ["input_224", "input_512"]) def test_can_create_backbones_from_timm(model_name, test_input, request): backbone = timm.create_model(model_name, pretrained=False) @@ -43,7 +43,7 @@ def test_can_create_backbones_from_timm(model_name, test_input, request): backbone(input_tensor) gc.collect() -@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) +@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"]) @pytest.mark.parametrize("test_input", ["input_224", "input_512"]) def test_can_create_backbones_from_timm_features_only(model_name, test_input, request): backbone = timm.create_model(model_name, pretrained=False, features_only=True) @@ -51,36 +51,37 @@ def test_can_create_backbones_from_timm_features_only(model_name, test_input, re backbone(input_tensor) gc.collect() -@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) +@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_L", "prithvi_swin_B"]) @pytest.mark.parametrize("prefix", ["", "timm_"]) def test_can_create_timm_backbones_from_registry(model_name, input_224, prefix): backbone = BACKBONE_REGISTRY.build(prefix+model_name, pretrained=False) backbone(input_224) gc.collect() + @pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"]) -def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal): - backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES) - backbone(input_224_multitemporal) +def test_can_create_backbones_from_registry(model_name, input_224): + backbone = BACKBONE_REGISTRY.build(model_name, pretrained=False) + backbone(input_224) gc.collect() + @pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"]) -def test_vit_models_non_divisible_input(model_name, input_non_divisible): - #padding 'none','constant', 'reflect', 'replicate' or 'circular' default is 'none' - backbone = timm.create_model(model_name, pretrained=False, features_only=True, num_frames=NUM_FRAMES, padding='constant') - backbone(input_non_divisible) +def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal): + backbone = BACKBONE_REGISTRY.build(model_name, pretrained=False, num_frames=NUM_FRAMES) + backbone(input_224_multitemporal) gc.collect() + @pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"]) @pytest.mark.parametrize("patch_size", [8, 16]) @pytest.mark.parametrize("patch_size_time", [1, 2, 4]) def test_vit_models_different_patch_tubelet_sizes(model_name, patch_size, patch_size_time, input_224_multitemporal): - backbone = timm.create_model( + backbone = BACKBONE_REGISTRY.build( model_name, pretrained=False, num_frames=NUM_FRAMES, patch_size=[patch_size_time, patch_size, patch_size], - features_only=True, ) embedding = backbone(input_224_multitemporal) processed_embedding = backbone.prepare_features_for_image_model(embedding) @@ -105,10 +106,9 @@ def test_vit_models_different_patch_tubelet_sizes(model_name, patch_size, patch_ gc.collect() @pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"]) def test_out_indices(model_name, input_224): - # out_indices = [2, 4, 8, 10] out_indices = (2, 4, 8, 10) - backbone = timm.create_model(model_name, pretrained=False, features_only=True, out_indices=out_indices) - assert backbone.feature_info.out_indices == out_indices + backbone = BACKBONE_REGISTRY.build(model_name, pretrained=False, out_indices=out_indices) + assert backbone.out_indices == out_indices output = backbone(input_224) full_output = backbone.forward_features(input_224) @@ -116,18 +116,8 @@ def test_out_indices(model_name, input_224): for filtered_index, full_index in enumerate(out_indices): assert torch.allclose(full_output[full_index], output[filtered_index]) gc.collect() -@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"]) -def test_out_indices_non_divisible(model_name, input_non_divisible): - out_indices = [2, 4, 8, 10] - backbone = timm.create_model(model_name, pretrained=False, features_only=True, num_frames=NUM_FRAMES, out_indices=out_indices, padding='constant') - assert backbone.feature_info.out_indices == tuple(out_indices) - output = backbone(input_non_divisible) - full_output = backbone.forward_features(input_non_divisible) - for filtered_index, full_index in enumerate(out_indices): - assert torch.allclose(full_output[full_index], output[filtered_index]) - gc.collect() @pytest.mark.parametrize("model_name", ["vit_base_patch16", "vit_large_patch16"]) def test_scale_mae(model_name): # out_indices = [2, 4, 8, 10] @@ -139,6 +129,8 @@ def test_scale_mae(model_name): assert len(output) == len(out_indices) gc.collect() + + @pytest.mark.parametrize("model_name", ["vit_base_patch16", "vit_large_patch16"]) @pytest.mark.parametrize("bands", [2, 4, 6]) def test_scale_mae_new_channels(model_name, bands): diff --git a/tests/test_finetune.py b/tests/test_finetune.py index 9c06e8da..c1652173 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -6,10 +6,11 @@ import torch from terratorch.cli_tools import build_lightning_cli +from terratorch.registry import BACKBONE_REGISTRY @pytest.fixture(autouse=True) def setup_and_cleanup(model_name): - model_instance = timm.create_model(model_name) + model_instance = BACKBONE_REGISTRY.build(model_name) state_dict = model_instance.state_dict() @@ -22,7 +23,7 @@ def setup_and_cleanup(model_name): if os.path.isdir(os.path.join("tests", "all_ecos_random")): shutil.rmtree(os.path.join("tests", "all_ecos_random")) -@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_eo_v2_300", "prithvi_eo_v2_600"]) +@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v2_600"]) @pytest.mark.parametrize("case", ["fit", "test", "validate"]) def test_finetune_multiple_backbones(model_name, case): command_list = [case, "-c", f"tests/resources/configs/manufactured-finetune_{model_name}.yaml"] diff --git a/tests/test_prithvi_vit.py b/tests/test_prithvi_vit.py index 659812f5..86316599 100644 --- a/tests/test_prithvi_vit.py +++ b/tests/test_prithvi_vit.py @@ -3,26 +3,25 @@ from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights +from terratorch.registry import BACKBONE_REGISTRY import gc @pytest.mark.parametrize("patch_size", [4, 8, 16]) @pytest.mark.parametrize("patch_size_time,num_frames", [(1, 1), (1, 2), (1, 3), (2, 2), (3,3)]) def test_prithvi_vit_patch_embed_loading_compatible(patch_size, patch_size_time, num_frames): - model = timm.create_model( + model = BACKBONE_REGISTRY.build( "prithvi_eo_v1_100", pretrained=False, num_frames=num_frames, patch_size=[patch_size_time, 16, 16], - features_only=True, ) - weights = timm.create_model( + weights = BACKBONE_REGISTRY.build( "prithvi_eo_v1_100", pretrained=False, num_frames=num_frames, patch_size=[patch_size_time, 16, 16], - features_only=True, ).state_dict() select_patch_embed_weights(weights, model, PRETRAINED_BANDS, PRETRAINED_BANDS) @@ -31,20 +30,18 @@ def test_prithvi_vit_patch_embed_loading_compatible(patch_size, patch_size_time, @pytest.mark.parametrize("patch_size_time,patch_size_time_other", [(1, 2), (2, 4)]) def test_prithvi_vit_patch_embed_loading_time_patch_size_other(patch_size_time,patch_size_time_other): - model = timm.create_model( + model = BACKBONE_REGISTRY.build( "prithvi_eo_v1_100", pretrained=False, num_frames=4, patch_size=[patch_size_time, 16, 16], - features_only=True, ) - weights = timm.create_model( + weights = BACKBONE_REGISTRY.build( "prithvi_eo_v1_100", pretrained=False, num_frames=4, patch_size=[patch_size_time_other, 16, 16], - features_only=True, ).state_dict() # assert warning produced @@ -55,20 +52,18 @@ def test_prithvi_vit_patch_embed_loading_time_patch_size_other(patch_size_time,p @pytest.mark.parametrize("patch_size,patch_size_other", [(2, 4), (4, 8), (16, 4)]) def test_prithvi_vit_patch_embed_loading_not_compatible_patch(patch_size, patch_size_other): - model = timm.create_model( + model = BACKBONE_REGISTRY.build( "prithvi_eo_v1_100", pretrained=False, num_frames=1, patch_size=patch_size, - features_only=True, ) - weights = timm.create_model( + weights = BACKBONE_REGISTRY.build( "prithvi_eo_v1_100", pretrained=False, num_frames=1, patch_size=patch_size_other, - features_only=True, ).state_dict() with pytest.warns(UserWarning): From 10fcb83fbe89660e6b6c72292a94713579f315bb Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 21 Jan 2025 19:03:52 +0100 Subject: [PATCH 17/40] Fix DOFA model Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/dofa_vit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terratorch/models/backbones/dofa_vit.py b/terratorch/models/backbones/dofa_vit.py index 61913ced..a4d7ef5e 100644 --- a/terratorch/models/backbones/dofa_vit.py +++ b/terratorch/models/backbones/dofa_vit.py @@ -61,7 +61,7 @@ def __init__(self, dofa_model, wavelengths, weights=None, out_indices=None) -> N self.wavelengths = wavelengths self.out_indices = out_indices if out_indices else [-1] - self.out_channels = [self.dofa_model.embed_dim] * len(self.out_indices) + self.out_channels = [self.dofa_model.patch_embed.embed_dim] * len(self.out_indices) def forward(self, x: List[torch.Tensor], **kwargs) -> torch.Tensor: From eee546aa8728a3559653ee663163a7dbce71f4d8 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 21 Jan 2025 22:02:14 +0100 Subject: [PATCH 18/40] Update select patch embedding weights Signed-off-by: Benedikt Blumenstiel --- .../backbones/select_patch_embed_weights.py | 42 +++++++------------ 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/terratorch/models/backbones/select_patch_embed_weights.py b/terratorch/models/backbones/select_patch_embed_weights.py index 38dde626..9fdf9029 100644 --- a/terratorch/models/backbones/select_patch_embed_weights.py +++ b/terratorch/models/backbones/select_patch_embed_weights.py @@ -19,7 +19,7 @@ def patch_embed_weights_are_compatible(model_patch_embed: torch.Tensor, checkpoi return model_shape == checkpoint_shape def select_patch_embed_weights( - state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int | OpticalBands| SARBands], model_bands: list[HLSBands | int | OpticalBands| SARBands], custom_proj_key: str = None + state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int | OpticalBands| SARBands], model_bands: list[HLSBands | int | OpticalBands| SARBands], proj_key: str | None = None ) -> dict: """Filter out the patch embedding weights according to the bands being used. If a band exists in the pretrained_bands, but not in model_bands, drop it. @@ -31,39 +31,25 @@ def select_patch_embed_weights( model (nn.Module): Model to load the weights onto. pretrained_bands (list[HLSBands | int]): List of bands the model was pretrained on, in the correct order. model_bands (list[HLSBands | int]): List of bands the model is going to be finetuned on, in the correct order + proj_key (str, optional): Key to patch embedding projection weight in state_dict. Returns: dict: New state dict """ - if (type(pretrained_bands) == type(model_bands)) | (type(pretrained_bands) == int) | (type(model_bands) == int): + if (type(pretrained_bands) == type(model_bands)) | (type(pretrained_bands) == int) | (type(model_bands) == int): - if custom_proj_key is None: - _possible_keys_for_proj_weight = { - "patch_embed.proj.weight", - "module.patch_embed.proj.weight", - "patch_embed.projection.weight", - "module.patch_embed.projection.weight", - } - else: - _possible_keys_for_proj_weight = {custom_proj_key} - - patch_embed_proj_weight_key = state_dict.keys() & _possible_keys_for_proj_weight if (type(state_dict) in [collections.OrderedDict, dict]) else state_dict().keys() & _possible_keys_for_proj_weight - if len(patch_embed_proj_weight_key) == 0: - msg = "Could not find key for patch embed weight" - raise Exception(msg) - if len(patch_embed_proj_weight_key) > 1: - msg = "Too many matches for key for patch embed weight" - raise Exception(msg) - - # extract the single element from the set - if isinstance(patch_embed_proj_weight_key, tuple): - (patch_embed_proj_weight_key,) = patch_embed_proj_weight_key - elif isinstance(patch_embed_proj_weight_key, set): - patch_embed_proj_weight_key = list(patch_embed_proj_weight_key)[0] + if proj_key is None: + # Search for patch embedding weight in state dict + for key in state_dict.keys(): + if key.endswith('patch_embed.proj.weight') or key.endswith('patch_embed.projection.weight'): + proj_key = key + break + if proj_key is None or proj_key not in state_dict: + raise Exception("Could not find key for patch embed weight in state_dict.") - patch_embed_weight = state_dict[patch_embed_proj_weight_key] + patch_embed_weight = state_dict[proj_key] - temp_weight = model.state_dict()[patch_embed_proj_weight_key].clone() + temp_weight = model.state_dict()[proj_key].clone() # only do this if the patch size and tubelet size match. If not, start with random weights if patch_embed_weights_are_compatible(temp_weight, patch_embed_weight): @@ -80,6 +66,6 @@ def select_patch_embed_weights( stacklevel=1, ) - state_dict[patch_embed_proj_weight_key] = temp_weight + state_dict[proj_key] = temp_weight return state_dict From c376e6b27c9fb0f5528ad1ae08c0895dca73b324 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 21 Jan 2025 22:02:47 +0100 Subject: [PATCH 19/40] Add interpolated embeddings Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_mae.py | 165 ++++++++++++--------- 1 file changed, 91 insertions(+), 74 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index df766932..d2bc6262 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -17,9 +17,7 @@ # transformers: https://github.com/huggingface/transformers # -------------------------------------------------------- -from functools import partial -from typing import List, Tuple - +import warnings import logging import numpy as np import torch @@ -135,8 +133,8 @@ class PatchEmbed(nn.Module): """3D version of timm.models.vision_transformer.PatchEmbed""" def __init__( self, - input_size: Tuple[int, int, int] = (1, 224, 224), - patch_size: Tuple[int, int, int] = (1, 16, 16), + input_size: tuple[int, int, int] = (1, 224, 224), + patch_size: tuple[int, int, int] = (1, 16, 16), in_chans: int = 3, embed_dim: int = 768, norm_layer: nn.Module | None = None, @@ -153,16 +151,13 @@ def __init__( self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - self.log_warning = True def forward(self, x): B, C, T, H, W = x.shape - if (self.log_warning and - (T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1)): - logger.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." - f"The border will be ignored, add backbone_padding for pixel-wise tasks.") - self.log_warning = False + if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1: + warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." + f"The border will be ignored, add backbone_padding for pixel-wise tasks.") x = self.proj(x) if self.flatten: @@ -237,8 +232,8 @@ def forward(self, location_coords: torch.Tensor): class PrithviViT(nn.Module): """ Prithvi ViT Encoder""" def __init__(self, - img_size: int | Tuple[int, int] = 224, - patch_size: int | Tuple[int, int, int] = (1, 16, 16), + img_size: int | tuple[int, int] = 224, + patch_size: int | tuple[int, int, int] = (1, 16, 16), num_frames: int = 1, in_chans: int = 3, embed_dim: int = 1024, @@ -246,7 +241,7 @@ def __init__(self, num_heads: int = 16, mlp_ratio: float = 4., norm_layer: nn.Module = nn.LayerNorm, - coords_encoding: List[str] | None = None, + coords_encoding: list[str] | None = None, coords_scale_learn: bool = False, ** kwargs, ): @@ -339,21 +334,32 @@ def random_masking(self, sequence, mask_ratio, noise=None): return sequence_unmasked, mask, ids_restore - def _get_pos_embed(self, x): - t, h, w = x.shape[-3:] - - pos_embed = torch.from_numpy(get_3d_sincos_pos_embed( - self.embed_dim, - ( - t // self.patch_embed.patch_size[0], - h // self.patch_embed.patch_size[1], - w // self.patch_embed.patch_size[2], - ), - add_cls_token=True, - )).float().unsqueeze(0).to(x) - - return pos_embed - + def interpolate_pos_encoding(self, x, t, w, h): + """ + Adapted from: + - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding, + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194 + """ + if x.shape[1] == self.pos_embed.shape[1] and w == h: + # No interpolation needed + return self.pos_embed + + class_pos_embed = self.pos_embed[:, :1] + patch_pos_embed = self.pos_embed[:, 1:] + w_patches = w // self.patch_embed.patch_size[1] + h_patches = h // self.patch_embed.patch_size[2] + + n_sqrt = int((patch_pos_embed.shape[1] / t) ** 0.5) + patch_pos_embed = patch_pos_embed.reshape(t, n_sqrt, n_sqrt, self.embed_dim).permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(h_patches, w_patches), + mode='bicubic', + align_corners=True, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, self.embed_dim) + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward( self, x: torch.Tensor, @@ -361,15 +367,15 @@ def forward( location_coords: None | torch.Tensor = None, mask_ratio=0.75 ): - if x.shape[-3:] != self.patch_embed.input_size: - # changed input size - pos_embed = self._get_pos_embed(x) - else: - pos_embed = self.pos_embed + if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: + # add time dim + x = x.unsqueeze(2) + t, h, w = x.shape[-3:] # embed patches x = self.patch_embed(x) + pos_embed = self.interpolate_pos_encoding(x, t, h, w) # add pos embed w/o cls token x = x + pos_embed[:, 1:, :] @@ -405,15 +411,12 @@ def forward_features( if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: # add time dim x = x.unsqueeze(2) - - if x.shape[-3:] != self.patch_embed.input_size: - pos_embed = self._get_pos_embed(x) - else: - pos_embed = self.pos_embed + t, h, w = x.shape[-3:] # embed patches x = self.patch_embed(x) + pos_embed = self.interpolate_pos_encoding(x, t, h, w) # add pos embed w/o cls token x = x + pos_embed[:, 1:, :] @@ -462,8 +465,8 @@ def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list class MAEDecoder(nn.Module): """ Transformer Decoder used in the Prithvi MAE""" def __init__(self, - patch_size: int | Tuple[int, int, int] = (1, 16, 16), - grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14), + patch_size: int | tuple[int, int, int] = (1, 16, 16), + grid_size: list[int] | tuple[int, int, int] = (3, 14, 14), in_chans: int = 3, encoder_embed_dim: int = 1024, decoder_embed_dim: int = 512, @@ -471,7 +474,7 @@ def __init__(self, num_heads: int = 16, mlp_ratio: float = 4., norm_layer: nn.Module = nn.LayerNorm, - coords_encoding: List[str] | None = None, + coords_encoding: list[str] | None = None, coords_scale_learn: bool = False, ): super().__init__() @@ -520,6 +523,33 @@ def initialize_weights(self): torch.nn.init.normal_(self.mask_token, std=0.02) self.apply(_init_weights) + def interpolate_pos_encoding(self, x, t, w, h): + """ + Adapted from: + - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding, + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194 + """ + if x.shape[1] == self.decoder_pos_embed.shape[1] and w == h: + # No interpolation needed + return self.decoder_pos_embed + + class_pos_embed = self.decoder_pos_embed[:, :1] + patch_pos_embed = self.decoder_pos_embed[:, 1:] + w_patches = w // self.patch_size[1] + h_patches = h // self.patch_size[2] + + n_sqrt = int((patch_pos_embed.shape[1] / t) ** 0.5) + patch_pos_embed = patch_pos_embed.reshape(t, n_sqrt, n_sqrt, self.decoder_embed_dim).permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(h_patches, w_patches), + mode='bicubic', + align_corners=True, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, self.decoder_embed_dim) + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + def forward( self, hidden_states: torch.Tensor, @@ -530,44 +560,32 @@ def forward( ): # embed tokens x = self.decoder_embed(hidden_states) - - t, h, w = input_size[-3:] - decoder_pos_embed = torch.from_numpy( - get_3d_sincos_pos_embed( - self.decoder_embed_dim, - ( - t // self.patch_size[0], - h // self.patch_size[1], - w // self.patch_size[2], - ), - add_cls_token=True, - ) - ).to(x) + cls_token = x[:, :1, :] # append mask tokens to sequence mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) - x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + x = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token # unshuffle - x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device)) - x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token - # add pos embed - x = x + decoder_pos_embed + x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device)) - # remove cls token - x_ = x[:, 1:, :] + # add pos embed + t, h, w = input_size[-3:] + decoder_pos_embed = self.interpolate_pos_encoding(x, t, w, h) + cls_token = cls_token + decoder_pos_embed[:, :1, :] + x = x + decoder_pos_embed[:, 1:, :] if self.temporal_encoding and temporal_coords is not None: - num_tokens_per_frame = x_.shape[1] // self.num_frames + num_tokens_per_frame = x.shape[1] // self.num_frames temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame) # Add temporal encoding w/o cls token - x_ = x_ + temporal_encoding + x = x + temporal_encoding if self.location_encoding and location_coords is not None: location_encoding = self.location_embed_dec(location_coords) # Add location encoding w/o cls token - x_ = x_ + location_encoding + x = x + location_encoding # append cls token - x = torch.cat([x[:, :1, :], x_], dim=1) + x = torch.cat([cls_token, x], dim=1) # apply Transformer layers (blocks) for block in self.decoder_blocks: @@ -587,8 +605,8 @@ class PrithviMAE(nn.Module): """ Prithvi Masked Autoencoder""" def __init__(self, - img_size: int | Tuple[int, int] = 224, - patch_size: int | Tuple[int, int, int] = (1, 16, 16), + img_size: int | tuple[int, int] = 224, + patch_size: int | tuple[int, int, int] = (1, 16, 16), num_frames: int = 4, in_chans: int = 6, embed_dim: int = 768, @@ -600,7 +618,7 @@ def __init__(self, mlp_ratio: float = 4., norm_layer: nn.Module = nn.LayerNorm, norm_pix_loss: bool = False, - coords_encoding: List[str] | None = None, + coords_encoding: list[str] | None = None, coords_scale_learn: bool = False, **kwargs, ): @@ -635,6 +653,7 @@ def __init__(self, ) self.norm_pix_loss = norm_pix_loss + self.out_channels = self.encoder.out_channels def patchify(self, pixel_values): """ @@ -656,13 +675,13 @@ def patchify(self, pixel_values): return patchified_pixel_values - def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None): + def unpatchify(self, patchified_pixel_values, image_size: tuple[int, int] | None = None): """ Args: patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: Patchified pixel values. - image_size (`Tuple[int, int]`, *optional*): + image_size (`tuple[int, int]`, *optional*): Original image size. Returns: @@ -719,14 +738,12 @@ def forward( latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio) pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape) loss = self.forward_loss(pixel_values, pred, mask) - # TODO: return loss? return loss, pred, mask - # TODO: forward_features still needed? def forward_features( self, x: torch.Tensor, temporal_coords: None | torch.Tensor = None, location_coords: None | torch.Tensor = None, - ) -> List[torch.Tensor]: + ) -> list[torch.Tensor]: return self.encoder.forward_features(x, temporal_coords, location_coords) From c3d4e5be83a40d1be5018665f244258ffa2493d8 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 21 Jan 2025 22:03:06 +0100 Subject: [PATCH 20/40] Switched to warnings package Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_vit.py | 51 ++++++++++++---------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 2cab150a..8a0054e9 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -1,4 +1,5 @@ # Copyright contributors to the Terratorch project +import warnings import torch import logging @@ -190,12 +191,12 @@ def _create_prithvi( # Backwards compatibility from timm (pretrained_cfg_overlay={"file": ""}) TODO: Remove before v1.0 if "pretrained_cfg_overlay" in kwargs: - logger.warning(f"pretrained_cfg_overlay is deprecated and will be removed in a future version, " - f"use ckpt_path= instead.") + warnings.warn(f"pretrained_cfg_overlay is deprecated and will be removed in a future version, " + f"use ckpt_path= instead.", DeprecationWarning, stacklevel=2) if ckpt_path is not None: - logger.warning(f"pretrained_cfg_overlay and ckpt_path are provided, ignoring pretrained_cfg_overlay.") + warnings.warn(f"pretrained_cfg_overlay and ckpt_path are provided, ignoring pretrained_cfg_overlay.") elif "file" not in kwargs["pretrained_cfg_overlay"]: - logger.warning("pretrained_cfg_overlay does not include 'file path', ignoring pretrained_cfg_overlay.") + warnings.warn("pretrained_cfg_overlay does not include 'file path', ignoring pretrained_cfg_overlay.") else: ckpt_path = kwargs.pop("pretrained_cfg_overlay")["file"] @@ -217,12 +218,12 @@ def checkpoint_filter_wrapper_fn(state_dict, model): f"(pretrained models: {pretrained_weights.keys()})") # Load pre-trained config from hf try: - # TODO: Rename model suffix to .ckpt and remove config.json. + # TODO: Switch from timm to hf hub download and remove config.json download. model_args = load_model_config_from_hf(pretrained_weights[variant]["hf_hub_id"])[0] model_args.update(kwargs) except: - logger.warning(f"No pretrained configuration was found on HuggingFace for the model {variant}." - f"Using random initialization.") + warnings.warn(f"No pretrained configuration was found on HuggingFace for the model {variant}." + f"Using random initialization.", stacklevel=2) model_args = prithvi_cfgs[variant].copy() model_args.update(kwargs) else: @@ -270,19 +271,6 @@ def forward_filter_indices(*args, **kwargs): return model -@ TERRATORCH_BACKBONE_REGISTRY.register -def prithvi_vit_tiny( - pretrained: bool = False, # noqa: FBT001, FBT002 - bands: list[HLSBands] | None = None, - **kwargs, -) -> PrithviViT: - - logger.warning(f"The model prithvi_vit_tiny was renamed to prithvi_eo_tiny. " - f"prithvi_vit_tiny will be removed in a future version.") - - return prithvi_eo_tiny(pretrained=pretrained, bands=bands, **kwargs) - - @ TERRATORCH_BACKBONE_REGISTRY.register def prithvi_eo_tiny( pretrained: bool = False, # noqa: FBT001, FBT002 @@ -343,16 +331,33 @@ def prithvi_eo_v2_600_tl( return _create_prithvi("prithvi_eo_v2_600_tl", pretrained, bands, **kwargs) -# TODO: Remove timm_ errors in before version v1.0. +# TODO: Remove prithvi_vit_tiny and prithvi_vit_100 before version 1.0. +@ TERRATORCH_BACKBONE_REGISTRY.register +def prithvi_vit_tiny( + pretrained: bool = False, # noqa: FBT001, FBT002 + bands: list[HLSBands] | None = None, + **kwargs, +) -> PrithviViT: + + warnings.warn(f"The model prithvi_vit_tiny was renamed to prithvi_eo_tiny. " + f"prithvi_vit_tiny will be removed in a future version.", DeprecationWarning) + + return prithvi_eo_tiny(pretrained=pretrained, bands=bands, **kwargs) + + @ TERRATORCH_BACKBONE_REGISTRY.register def prithvi_vit_100( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, -) -> None: - raise ValueError("The model prithvi_vit_100 was renamed to prithvi_eo_v1_100.") +) -> PrithviViT: + warnings.warn(f"The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. " + f"prithvi_vit_100 will be removed in a future version.", DeprecationWarning) + + return prithvi_eo_v1_100(pretrained=pretrained, bands=bands, **kwargs) +# TODO: Remove timm_ errors before version v1.0. @ TERRATORCH_BACKBONE_REGISTRY.register def timm_prithvi_eo_v1_100( pretrained: bool = False, # noqa: FBT001, FBT002 From ab1bb417a2890944321cf3479efd6f78bf9822d7 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 21 Jan 2025 22:16:24 +0100 Subject: [PATCH 21/40] Update warnings Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_vit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 8a0054e9..8f767e8f 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -340,7 +340,7 @@ def prithvi_vit_tiny( ) -> PrithviViT: warnings.warn(f"The model prithvi_vit_tiny was renamed to prithvi_eo_tiny. " - f"prithvi_vit_tiny will be removed in a future version.", DeprecationWarning) + f"prithvi_vit_tiny will be removed in a future version.", FutureWarning) return prithvi_eo_tiny(pretrained=pretrained, bands=bands, **kwargs) @@ -351,8 +351,8 @@ def prithvi_vit_100( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - warnings.warn(f"The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. " - f"prithvi_vit_100 will be removed in a future version.", DeprecationWarning) + warnings.warn("The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. " + "prithvi_vit_100 will be removed in a future version.", FutureWarning) return prithvi_eo_v1_100(pretrained=pretrained, bands=bands, **kwargs) From 36817f659fa6e2191bbc72df5e69d26f6000d804 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 21 Jan 2025 22:19:53 +0100 Subject: [PATCH 22/40] Fixed jsonargparse version Signed-off-by: Benedikt Blumenstiel --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6a477ac4..8343f6b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,8 @@ dependencies = [ "mlflow>=2.12.1", # broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977 "lightning[pytorch-extra]>=2,!=2.3.*", - "segmentation-models-pytorch>=0.3" + "segmentation-models-pytorch>=0.3", + "jsonargparse<=4.35.0", # Dependencies not available on PyPI ] From 0419716a17a334fb01f8379f67206798105ef83f Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Wed, 22 Jan 2025 07:36:47 +0100 Subject: [PATCH 23/40] support arbitrary return values --- terratorch/cli_tools.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index a00f2ad3..768c5d3a 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -153,6 +153,8 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batc if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) + print(type(prediction)) + print(prediction) pred_batch, filename_batch = prediction for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False): From d2fbe8b729f69870d90688cdb59cd2378c59787f Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Wed, 22 Jan 2025 09:47:25 +0100 Subject: [PATCH 24/40] Pin jsonargparse Signed-off-by: Francesc Marti Escofet --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6a477ac4..010efcf6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ # broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977 "lightning[pytorch-extra]>=2,!=2.3.*", "segmentation-models-pytorch>=0.3" + "jsonargparse<=4.35.0", # Dependencies not available on PyPI ] From 0c82fd1f058e4a197409ad0266faebd5d790753f Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Wed, 22 Jan 2025 10:38:34 +0100 Subject: [PATCH 25/40] Fix Signed-off-by: Francesc Marti Escofet --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 010efcf6..8343f6b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "mlflow>=2.12.1", # broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977 "lightning[pytorch-extra]>=2,!=2.3.*", - "segmentation-models-pytorch>=0.3" + "segmentation-models-pytorch>=0.3", "jsonargparse<=4.35.0", # Dependencies not available on PyPI ] From 741cf8022bd76f1cfee590b040141d3ca40d9afc Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Wed, 22 Jan 2025 11:23:43 +0100 Subject: [PATCH 26/40] Fixed jsonargparse version Signed-off-by: Benedikt Blumenstiel --- requirements/required.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/required.txt b/requirements/required.txt index c4aec6ea..f298ec97 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -15,6 +15,7 @@ lightning==2.4.0 git+https://github.com/qubvel-org/segmentation_models.pytorch.git@3952e1f8e9684a385a81e30381b8fb5b1ac086cf timm==1.0.11 numpy==1.26.4 +jsonargparse==4.32.0 # These dependencies are optional # and must be installed just in case From 7073be645250a5a7c084f33179cc9c77827bc853 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Wed, 22 Jan 2025 11:43:42 +0100 Subject: [PATCH 27/40] Added drop_path and mask ratio to prithvi Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_mae.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index d2bc6262..6cb65ea3 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -243,6 +243,7 @@ def __init__(self, norm_layer: nn.Module = nn.LayerNorm, coords_encoding: list[str] | None = None, coords_scale_learn: bool = False, + drop_path: float = 0., ** kwargs, ): super().__init__() @@ -279,7 +280,8 @@ def __init__(self, # Transformer layers self.blocks = [] for i in range(depth): - self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)) + self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, + drop_path=drop_path,)) self.blocks = nn.ModuleList(self.blocks) self.norm = norm_layer(embed_dim) @@ -620,6 +622,8 @@ def __init__(self, norm_pix_loss: bool = False, coords_encoding: list[str] | None = None, coords_scale_learn: bool = False, + drop_path: float = 0., + mask_ratio: float = 0.75, **kwargs, ): super().__init__() @@ -636,6 +640,7 @@ def __init__(self, norm_layer=norm_layer, coords_encoding=coords_encoding, coords_scale_learn=coords_scale_learn, + drop_path=drop_path, ) self.decoder = MAEDecoder( @@ -652,6 +657,7 @@ def __init__(self, coords_scale_learn=coords_scale_learn, ) + self.mask_ratio = mask_ratio self.norm_pix_loss = norm_pix_loss self.out_channels = self.encoder.out_channels @@ -729,12 +735,13 @@ def forward( pixel_values: torch.Tensor, temporal_coords: None | torch.Tensor = None, location_coords: None | torch.Tensor = None, - mask_ratio: float = 0.75 + mask_ratio: float = None, ): if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1: # add time dim pixel_values = pixel_values.unsqueeze(2) + mask_ratio = mask_ratio or self.mask_ratio latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio) pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape) loss = self.forward_loss(pixel_values, pred, mask) From 5c5e1dd0db760a2f9a1729ab1932bb1bf422370b Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Wed, 22 Jan 2025 12:06:46 +0100 Subject: [PATCH 28/40] Updated examples Signed-off-by: Benedikt Blumenstiel --- examples/confs/burn_scars.yaml | 29 +++-- examples/confs/burnscars_smp.yaml | 2 +- examples/confs/eurosat.yaml | 2 +- examples/confs/forestnet_timm.yaml | 2 +- examples/confs/multi_temporal_crop.yaml | 2 +- examples/confs/multimae_sen1floods11.yaml | 2 +- .../multimodal_prithvi_sen1floods11.yaml | 16 +-- examples/confs/sen1floods11_vit.yaml | 41 +++--- examples/confs/sen1floods11_vit_dual_lr.yaml | 4 +- .../confs/sen1floods11_vit_local_ckpt.yaml | 46 +++---- examples/confs/sen1floods11_vit_mmseg.yaml | 122 ------------------ examples/confs/sen1floods11_vit_peft.yaml | 4 +- examples/confs/sen1floods11_vit_smp.yaml | 4 +- examples/confs/sen4agri.yaml | 59 --------- examples/confs/sen4map_ViT-L.yaml | 6 +- examples/confs/smp_model_factory.yaml | 93 ------------- examples/notebooks/Tutorial.ipynb | 115 ++++++++++------- 17 files changed, 140 insertions(+), 409 deletions(-) delete mode 100644 examples/confs/sen1floods11_vit_mmseg.yaml delete mode 100644 examples/confs/sen4agri.yaml delete mode 100644 examples/confs/smp_model_factory.yaml diff --git a/examples/confs/burn_scars.yaml b/examples/confs/burn_scars.yaml index 702357dc..2e46fd1f 100644 --- a/examples/confs/burn_scars.yaml +++ b/examples/confs/burn_scars.yaml @@ -29,7 +29,7 @@ trainer: # dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 4 num_workers: 8 @@ -56,9 +56,7 @@ data: init_args: height: 224 width: 224 - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 + - class_path: albumentations.D4 - class_path: ToTensorV2 no_data_replace: 0 no_label_replace: -1 @@ -89,11 +87,11 @@ data: model: class_path: terratorch.tasks.SemanticSegmentationTask init_args: + model_factory: EncoderDecoderFactory model_args: - decoder: FCNDecoder + backbone: prithvi_eo_v2_300 backbone_pretrained: true - backbone: prithvi_vit_100 - decoder_channels: 256 + backbone_drop_path: 0.1 backbone_bands: - BLUE - GREEN @@ -101,23 +99,30 @@ model: - NIR_NARROW - SWIR_1 - SWIR_2 - num_classes: 2 + necks: + - name: SelectIndices +# indices: [2, 5, 8, 11] # 100M models + indices: [5, 11, 17, 23] # 300M models +# indices: [7, 15, 23, 31] # 600M models + - name: ReshapeTokensToImage + - name: LearnedInterpolateToPyramidal + decoder: UNetDecoder + decoder_channels: [512, 256, 128, 64] + head_channel_list: [256] head_dropout: 0.1 - decoder_num_convs: 4 - head_channel_list: - - 256 + num_classes: 2 loss: dice plot_on_val: 10 ignore_index: -1 freeze_backbone: false freeze_decoder: false - model_factory: EncoderDecoderFactory tiled_inference_parameters: h_crop: 512 h_stride: 496 w_crop: 512 w_stride: 496 average_patches: true + optimizer: class_path: torch.optim.Adam init_args: diff --git a/examples/confs/burnscars_smp.yaml b/examples/confs/burnscars_smp.yaml index 5c78a7f1..07a9a888 100644 --- a/examples/confs/burnscars_smp.yaml +++ b/examples/confs/burnscars_smp.yaml @@ -32,7 +32,7 @@ trainer: default_root_dir: output/BurnScars data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 4 num_workers: 8 diff --git a/examples/confs/eurosat.yaml b/examples/confs/eurosat.yaml index fbcbadca..50a5bff3 100644 --- a/examples/confs/eurosat.yaml +++ b/examples/confs/eurosat.yaml @@ -57,7 +57,7 @@ model: model_args: decoder: IdentityDecoder backbone_pretrained: true - backbone: prithvi_vit_100 + backbone: prithvi_eo_v1_300 head_dim_list: - 384 - 128 diff --git a/examples/confs/forestnet_timm.yaml b/examples/confs/forestnet_timm.yaml index 6499c842..d86019dd 100644 --- a/examples/confs/forestnet_timm.yaml +++ b/examples/confs/forestnet_timm.yaml @@ -33,7 +33,7 @@ trainer: default_root_dir: output/ForestNet data: - class_path: GenericNonGeoClassificationDataModule + class_path: terratorch.datamodules.GenericNonGeoClassificationDataModule init_args: batch_size: 16 num_workers: 8 diff --git a/examples/confs/multi_temporal_crop.yaml b/examples/confs/multi_temporal_crop.yaml index 5ea76acf..dbcaf054 100644 --- a/examples/confs/multi_temporal_crop.yaml +++ b/examples/confs/multi_temporal_crop.yaml @@ -90,7 +90,7 @@ model: model_args: decoder: FCNDecoder backbone_pretrained: true - backbone: prithvi_vit_100 + backbone: prithvi_eo_v2_300 backbone_in_channels: 6 rescale: False backbone_bands: diff --git a/examples/confs/multimae_sen1floods11.yaml b/examples/confs/multimae_sen1floods11.yaml index 9d779ab3..5a1b515d 100644 --- a/examples/confs/multimae_sen1floods11.yaml +++ b/examples/confs/multimae_sen1floods11.yaml @@ -28,7 +28,7 @@ trainer: default_root_dir: output/multimae_sen1floods11/ data: - class_path: GenericMultiModalDataModule + class_path: terratorch.datamodules.GenericMultiModalDataModule init_args: task: 'segmentation' batch_size: 4 diff --git a/examples/confs/multimodal_prithvi_sen1floods11.yaml b/examples/confs/multimodal_prithvi_sen1floods11.yaml index 0a7633c8..f8f38723 100644 --- a/examples/confs/multimodal_prithvi_sen1floods11.yaml +++ b/examples/confs/multimodal_prithvi_sen1floods11.yaml @@ -29,7 +29,7 @@ trainer: default_root_dir: output/multimodal_prithvi_sen1floods11/ data: - class_path: GenericMultiModalDataModule + class_path: terratorch.datamodules.GenericMultiModalDataModule init_args: task: 'segmentation' batch_size: 16 @@ -124,7 +124,7 @@ model: init_args: model_factory: EncoderDecoderFactory model_args: - backbone: prithvi_vit_100 + backbone: prithvi_eo_v2_300 backbone_pretrained: false backbone_bands: - COASTAL_AEROSOL @@ -141,23 +141,15 @@ model: - SWIR_2 - VV - VH - decoder: FCNDecoder # FCNDecoder - decoder_num_convs: 4 # only for FCNDecoder - # decoder_scale_modules: True # only for UperNetDecoder + decoder: FCNDecoder + decoder_num_convs: 4 decoder_channels: 256 num_classes: 2 head_dropout: 0.1 head_channel_list: - 256 - loss: dice ignore_index: -1 - class_weights: - - 0.3 - - 0.7 - class_names: - - Others - - Flood freeze_backbone: false freeze_decoder: false diff --git a/examples/confs/sen1floods11_vit.yaml b/examples/confs/sen1floods11_vit.yaml index fc30711a..da890f55 100644 --- a/examples/confs/sen1floods11_vit.yaml +++ b/examples/confs/sen1floods11_vit.yaml @@ -19,7 +19,7 @@ trainer: enable_checkpointing: true default_root_dir: data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 16 num_workers: 8 @@ -82,11 +82,11 @@ data: model: class_path: terratorch.tasks.SemanticSegmentationTask init_args: + model_factory: EncoderDecoderFactory model_args: - decoder: FCNDecoder + backbone: prithvi_eo_v2_300 backbone_pretrained: true - backbone: prithvi_vit_100 - decoder_channels: 256 + backbone_drop_path: 0.1 backbone_bands: - BLUE - GREEN @@ -94,17 +94,19 @@ model: - NIR_NARROW - SWIR_1 - SWIR_2 - num_classes: 2 - head_dropout: 0.1 - decoder_num_convs: 4 - head_channel_list: - - 256 necks: - name: SelectIndices - indices: - - -1 +# indices: [2, 5, 8, 11] # 100M models + indices: [5, 11, 17, 23] # 300M models +# indices: [7, 15, 23, 31] # 600M models - name: ReshapeTokensToImage - loss: ce + - name: LearnedInterpolateToPyramidal + decoder: UNetDecoder + decoder_channels: [512, 256, 128, 64] + head_channel_list: [256] + head_dropout: 0.1 + num_classes: 2 + loss: dice aux_heads: - name: aux_head decoder: FCNDecoder @@ -113,25 +115,20 @@ model: decoder_in_index: -1 decoder_num_convs: 2 head_dropout: 0.1 - # head_channel_list: - # - 64 aux_loss: aux_head: 1.0 ignore_index: -1 - class_weights: - - 0.3 - - 0.7 freeze_backbone: false freeze_decoder: false - model_factory: EncoderDecoderFactory + optimizer: class_path: torch.optim.AdamW init_args: - lr: 6.e-5 - weight_decay: 0.05 + lr: 1.e-4 + weight_decay: 0.1 lr_scheduler: class_path: ReduceLROnPlateau init_args: monitor: val/loss - - + patience: 5 + factor: 0.5 diff --git a/examples/confs/sen1floods11_vit_dual_lr.yaml b/examples/confs/sen1floods11_vit_dual_lr.yaml index 47a6eaa7..b909630c 100644 --- a/examples/confs/sen1floods11_vit_dual_lr.yaml +++ b/examples/confs/sen1floods11_vit_dual_lr.yaml @@ -19,7 +19,7 @@ trainer: enable_checkpointing: true default_root_dir: data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 16 num_workers: 8 @@ -85,7 +85,7 @@ model: model_args: decoder: FCNDecoder backbone_pretrained: true - backbone: prithvi_vit_100 + backbone: prithvi_eo_v1_100 decoder_channels: 256 backbone_bands: - BLUE diff --git a/examples/confs/sen1floods11_vit_local_ckpt.yaml b/examples/confs/sen1floods11_vit_local_ckpt.yaml index 5eb1a2c2..52d4a4a8 100644 --- a/examples/confs/sen1floods11_vit_local_ckpt.yaml +++ b/examples/confs/sen1floods11_vit_local_ckpt.yaml @@ -19,7 +19,7 @@ trainer: enable_checkpointing: true default_root_dir: data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 16 num_workers: 8 @@ -82,13 +82,12 @@ data: model: class_path: terratorch.tasks.SemanticSegmentationTask init_args: + model_factory: EncoderDecoderFactory model_args: - decoder: FCNDecoder - backbone_pretrained: true - backbone_pretrained_cfg_overlay: - file: examples/Prithvi_100M.pt - backbone: prithvi_vit_100 - decoder_channels: 256 + backbone: prithvi_eo_v2_300 + backbone_pretrained: false + backbone_ckpt_path: examples/Prithvi_100M.pt + backbone_drop_path: 0.1 backbone_bands: - BLUE - GREEN @@ -96,17 +95,19 @@ model: - NIR_NARROW - SWIR_1 - SWIR_2 - num_classes: 2 - head_dropout: 0.1 - decoder_num_convs: 4 - head_channel_list: - - 256 necks: - name: SelectIndices - indices: - - -1 +# indices: [2, 5, 8, 11] # 100M models + indices: [5, 11, 17, 23] # 300M models +# indices: [7, 15, 23, 31] # 600M models - name: ReshapeTokensToImage - loss: ce + - name: LearnedInterpolateToPyramidal + decoder: UNetDecoder + decoder_channels: [512, 256, 128, 64] + head_channel_list: [256] + head_dropout: 0.1 + num_classes: 2 + loss: dice aux_heads: - name: aux_head decoder: FCNDecoder @@ -115,25 +116,20 @@ model: decoder_in_index: -1 decoder_num_convs: 2 head_dropout: 0.1 - # head_channel_list: - # - 64 aux_loss: aux_head: 1.0 ignore_index: -1 - class_weights: - - 0.3 - - 0.7 freeze_backbone: false freeze_decoder: false - model_factory: EncoderDecoderFactory + optimizer: class_path: torch.optim.AdamW init_args: - lr: 6.e-5 - weight_decay: 0.05 + lr: 1.e-4 + weight_decay: 0.1 lr_scheduler: class_path: ReduceLROnPlateau init_args: monitor: val/loss - - + patience: 5 + factor: 0.5 diff --git a/examples/confs/sen1floods11_vit_mmseg.yaml b/examples/confs/sen1floods11_vit_mmseg.yaml deleted file mode 100644 index 6a65edf5..00000000 --- a/examples/confs/sen1floods11_vit_mmseg.yaml +++ /dev/null @@ -1,122 +0,0 @@ -# lightning.pytorch==2.1.1 -seed_everything: 0 -trainer: - accelerator: auto - strategy: auto - devices: auto - num_nodes: 1 - precision: 16-mixed - logger: True # will use tensorboardlogger - callbacks: - - class_path: RichProgressBar - - class_path: LearningRateMonitor - init_args: - logging_interval: epoch - - max_epochs: 200 - check_val_every_n_epoch: 1 - log_every_n_steps: 50 - enable_checkpointing: true - default_root_dir: -data: - class_path: GenericNonGeoSegmentationDataModule - init_args: - batch_size: 16 - num_workers: 8 - constant_scale: 0.0001 - dataset_bands: - - COASTAL_AEROSOL - - BLUE - - GREEN - - RED - - RED_EDGE_1 - - RED_EDGE_2 - - RED_EDGE_3 - - NIR_BROAD - - NIR_NARROW - - WATER_VAPOR - - CIRRUS - - SWIR_1 - - SWIR_2 - output_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - rgb_indices: - - 2 - - 1 - - 0 - train_data_root: /v1.1/data/flood_events/HandLabeled/S2Hand/ - train_label_data_root: /v1.1/data/flood_events/HandLabeled/LabelHand - val_data_root: /v1.1/data/flood_events/HandLabeled/S2Hand/ - val_label_data_root: /v1.1/data/flood_events/HandLabeled/LabelHand - test_data_root: /v1.1/data/flood_events/HandLabeled/S2Hand/ - test_label_data_root: /v1.1/data/flood_events/HandLabeled/LabelHand - # these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files - train_split: /v1.1/splits/flood_handlabeled/flood_train_data.txt - test_split: /v1.1/splits/flood_handlabeled/flood_test_data.txt - val_split: /v1.1/splits/flood_handlabeled/flood_valid_data.txt - img_grep: "*_S2Hand.tif" - label_grep: "*_LabelHand.tif" - no_label_replace: -1 - no_data_replace: 0 - means: - - 0.1412956 - - 0.13795798 - - 0.12353792 - - 0.30902815 - - 0.2044958 - - 0.11912015 - stds: - - 0.07406382 - - 0.07370365 - - 0.08692279 - - 0.11798815 - - 0.09772074 - - 0.07659938 - num_classes: 2 -model: - class_path: terratorch.tasks.SemanticSegmentationTask - init_args: - model_args: - decoder: FCNHead - backbone_pretrained: True - backbone: prithvi_vit_100 - backbone_pretrain_img_size: 512 - decoder_num_convs: 4 - decoder_channels: 256 - decoder_dropout_ratio: 0.1 - backbone_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - num_classes: 2 - necks: - - name: ReshapeTokensToImage - - name: SelectIndices - indices: - - -1 - loss: ce - - ignore_index: -1 - class_weights: - - 0.3 - - 0.7 - freeze_backbone: true - freeze_decoder: false - model_factory: EncoderDecoderFactory -optimizer: - class_path: torch.optim.AdamW - init_args: - lr: 6.e-5 - weight_decay: 0.05 -lr_scheduler: - class_path: ReduceLROnPlateau - init_args: - monitor: val/loss diff --git a/examples/confs/sen1floods11_vit_peft.yaml b/examples/confs/sen1floods11_vit_peft.yaml index 3404e12b..68b4cc60 100644 --- a/examples/confs/sen1floods11_vit_peft.yaml +++ b/examples/confs/sen1floods11_vit_peft.yaml @@ -19,7 +19,7 @@ trainer: enable_checkpointing: true default_root_dir: data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 16 num_workers: 8 @@ -85,7 +85,7 @@ model: model_args: decoder: FCNDecoder backbone_pretrained: true - backbone: prithvi_vit_100 + backbone: prithvi_eo_v1_100 decoder_channels: 256 backbone_bands: - BLUE diff --git a/examples/confs/sen1floods11_vit_smp.yaml b/examples/confs/sen1floods11_vit_smp.yaml index f0c89f02..5ec07a6d 100644 --- a/examples/confs/sen1floods11_vit_smp.yaml +++ b/examples/confs/sen1floods11_vit_smp.yaml @@ -19,7 +19,7 @@ trainer: enable_checkpointing: true default_root_dir: data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 16 num_workers: 8 @@ -83,7 +83,7 @@ model: init_args: model_args: backbone_pretrained: True - backbone: prithvi_vit_100 + backbone: prithvi_eo_v1_100 backbone_pretrain_img_size: 512 backbone_bands: - BLUE diff --git a/examples/confs/sen4agri.yaml b/examples/confs/sen4agri.yaml deleted file mode 100644 index 4c041456..00000000 --- a/examples/confs/sen4agri.yaml +++ /dev/null @@ -1,59 +0,0 @@ -benchmark_suffix: benchmark -experiment_name: benchmark -precision: 16-mixed -backbone: your_model_here - -tasks: - - name: cashew - type: segmentation - loss: ce - bands: - - 12 - num_classes: 20 - max_epochs: 300 - direction: max - datamodule: - class_path: terratorch.datamodules.Sen4AgriNetDataModule - init_args: - data_root: "/dccstor/geofm-finetuning/datasets/Sen4AgriNet/S4A" - batch_size: 16 - num_workers: 6 - val_transform: - - class_path: FlattenTemporalIntoChannels - - class_path: ToTensorV2 - train_transform: - - class_path: FlattenTemporalIntoChannels - - class_path: ToTensorV2 - test_transform: - - class_path: FlattenTemporalIntoChannels - - class_path: ToTensorV2 - - decoder: UperNetDecoder - decoder_args: - channels: 128 - scale_modules: True - metric: val/Multiclass_Jaccard_Index - early_stop_patience: 50 - -n_trials: 16 -save_models: False -storage_uri: /path/to/storage -ray_storage_path: /path/to/ray/storage -optimization_space: - # decoder: - # - UperNetDecoder - # - UperNetDecoder - lr: - min: 1e-6 - max: 1e-3 - type: real - log: true - batch_size: - - 4 - - 8 - - 16 - - 32 - decoder_channels: - - 64 - - 128 - - 256 diff --git a/examples/confs/sen4map_ViT-L.yaml b/examples/confs/sen4map_ViT-L.yaml index e04c9a38..3ab959a8 100644 --- a/examples/confs/sen4map_ViT-L.yaml +++ b/examples/confs/sen4map_ViT-L.yaml @@ -25,7 +25,7 @@ trainer: default_root_dir: data: - class_path: Sen4MapLucasDataModule + class_path: terratorch.datamodules.Sen4MapLucasDataModule init_args: batch_size: 10 num_workers: 8 @@ -74,7 +74,7 @@ model: model_args: decoder: IdentityDecoder pretrained: true - backbone: prithvi_vit_300 + backbone: prithvi_eo_v2_300 backbone_pretrained_cfg_overlay: file: backbone_patch_size: 16 @@ -101,7 +101,7 @@ model: loss: ce freeze_backbone: false # freeze_decoder: false - model_factory: PrithviModelFactory + model_factory: EncoderDecoderFactory optimizer: class_path: torch.optim.AdamW diff --git a/examples/confs/smp_model_factory.yaml b/examples/confs/smp_model_factory.yaml deleted file mode 100644 index 7d9a1d53..00000000 --- a/examples/confs/smp_model_factory.yaml +++ /dev/null @@ -1,93 +0,0 @@ -benchmark_suffix: smp_test -experiment_name: smp_test -backbone: - backbone: resnet18 - backbone_args: - pretrained: False - output_stride: 2 - smp_decoder_channels: 512 - smp_encoder_depth: 5 - - # backbone: swin3d.swin3d_backbone.Swin3dBackbone - # backbone_args: - # pretrained: False - # output_stride: 2 - # out_channels: - # - 192 - # - 384 - # - 768 - # - 768 - # smp_decoder_channels: 768 - # smp_encoder_depth: 5 - - -tasks: - - name: cashew - type: segmentation - loss: ce - model_factory: SMPModelFactory - bands: - - RED - - GREEN - - BLUE - num_classes: 7 - max_epochs: 60 - direction: max - datamodule: - class_path: terratorch.datamodules.MBeninSmallHolderCashewsNonGeoDataModule - init_args: - batch_size: 16 - num_workers: 4 - train_transform: - - class_path: albumentations.Resize - init_args: - always_apply: True - height: 224 - width: 224 - - class_path: ToTensorV2 - test_transform: - - class_path: albumentations.Resize - init_args: - always_apply: True - height: 224 - width: 224 - - class_path: ToTensorV2 - val_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - data_root: "/dccstor/geofm-finetuning/geobench/segmentation_v1.0" - bands: - - "RED" - - "GREEN" - - "BLUE" - decoder: IdentityDecoder - decoder_args: - channels: 128 - metric: val/Multiclass Jaccard Index - -n_trials: 16 -save_models: False -storage_uri: /path/to/storage -optimization_space: - model: - - DeepLabV3 - lr: - min: 6e-5 - max: 1e-3 - type: real - log: true - batch_size: - - 8 - - 16 - - 32 - decoder_channels: - - 32 - - 64 - - 128 - head_dropout: - min: 0.2 - max: 0.8 - type: real \ No newline at end of file diff --git a/examples/notebooks/Tutorial.ipynb b/examples/notebooks/Tutorial.ipynb index 8a704fe2..699a6c71 100644 --- a/examples/notebooks/Tutorial.ipynb +++ b/examples/notebooks/Tutorial.ipynb @@ -2,16 +2,31 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, "id": "5d049232-f4b1-473d-aac3-0b3539905b03", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-22T10:44:22.839382Z", + "start_time": "2025-01-22T10:44:18.410638Z" + } + }, "source": [ "import os\n", "import torch\n", "\n", "from terratorch import BACKBONE_REGISTRY" - ] + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.0 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations\n", + "/opt/homebrew/lib/python3.11/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n", + " @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)\n" + ] + } + ], + "execution_count": 1 }, { "cell_type": "markdown", @@ -31,48 +46,35 @@ }, { "cell_type": "code", - "execution_count": 5, "id": "8dcdfa85-8e43-4db0-9ddf-cb11c5544942", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-22T10:44:59.384413Z", + "start_time": "2025-01-22T10:44:59.380583Z" + } + }, + "source": "print([model_name for model_name in BACKBONE_REGISTRY if \"terratorch_prithvi\" in model_name])", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "['timm_prithvi_swin_B', 'timm_prithvi_swin_L', 'timm_prithvi_vit_100', 'timm_prithvi_vit_300', 'timm_prithvi_vit_tiny']\n" + "['terratorch_prithvi_eo_tiny', 'terratorch_prithvi_eo_v1_100', 'terratorch_prithvi_eo_v2_300', 'terratorch_prithvi_eo_v2_600', 'terratorch_prithvi_eo_v2_300_tl', 'terratorch_prithvi_eo_v2_600_tl', 'terratorch_prithvi_vit_tiny', 'terratorch_prithvi_vit_100']\n" ] } ], - "source": [ - "print([model_name for model_name in BACKBONE_REGISTRY if \"prithvi\" in model_name])" - ] + "execution_count": 5 }, { "cell_type": "code", - "execution_count": 2, "id": "338c6071", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-22T10:45:05.471003Z", + "start_time": "2025-01-22T10:45:05.466191Z" } - ], - "source": [ - "\"timm_prithvi_vit_100\" in BACKBONE_REGISTRY" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "13e3ed35", - "metadata": {}, + }, + "source": "\"prithvi_vit_100\" in BACKBONE_REGISTRY", "outputs": [ { "data": { @@ -80,24 +82,37 @@ "True" ] }, - "execution_count": 3, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "\"prithvi_vit_100\" in BACKBONE_REGISTRY" - ] + "execution_count": 6 }, { "cell_type": "code", - "execution_count": 4, "id": "38db3f3c", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-22T10:44:34.736849Z", + "start_time": "2025-01-22T10:44:34.220986Z" + } + }, "source": [ "model = BACKBONE_REGISTRY.build(\"prithvi_vit_100\")" - ] + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/BLU/repos/terratorch_ibm/terratorch/models/backbones/prithvi_vit.py:354: FutureWarning: The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. prithvi_vit_100 will be removed in a future version.\n", + " warnings.warn(\"The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. \"\n", + "INFO:terratorch.models.backbones.prithvi_vit:Model bands not passed. Assuming bands are ordered in the same way as [, , , , , ]. Pretrained patch_embed layer may be misaligned with current bands\n" + ] + } + ], + "execution_count": 4 }, { "cell_type": "markdown", @@ -667,17 +682,17 @@ ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┃\u001B[1m \u001B[0m\u001B[1m Test metric \u001B[0m\u001B[1m \u001B[0m┃\u001B[1m \u001B[0m\u001B[1m DataLoader 0 \u001B[0m\u001B[1m \u001B[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test/Multiclass_Accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.807342529296875 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/Multiclass_F1_Score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.807342529296875 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/Multiclass_Jaccard_Index \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.4036712646484375 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36mtest/Multiclass_Jaccard_Index_Micro\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.676927387714386 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5365139245986938 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/multiclassaccuracy_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/multiclassaccuracy_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/multiclassjaccardindex_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.807342529296875 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/multiclassjaccardindex_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/Multiclass_Accuracy \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.807342529296875 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/Multiclass_F1_Score \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.807342529296875 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/Multiclass_Jaccard_Index \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.4036712646484375 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36mtest/Multiclass_Jaccard_Index_Micro\u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.676927387714386 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/loss \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.5365139245986938 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/multiclassaccuracy_0 \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 1.0 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/multiclassaccuracy_1 \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.0 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/multiclassjaccardindex_0 \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.807342529296875 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/multiclassjaccardindex_1 \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.0 \u001B[0m\u001B[35m \u001B[0m│\n", "└─────────────────────────────────────┴─────────────────────────────────────┘\n" ] }, From 151de08d24ae5b932bd79a371f79d687f5e232ea Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Wed, 22 Jan 2025 12:11:30 +0100 Subject: [PATCH 29/40] Fixed script for loading local model Signed-off-by: Benedikt Blumenstiel --- .../scripts/instantiate_satmae_backbone.py | 32 ------------------- examples/scripts/open_local_model.py | 4 +-- terratorch/io/file.py | 18 ++++------- 3 files changed, 8 insertions(+), 46 deletions(-) delete mode 100644 examples/scripts/instantiate_satmae_backbone.py diff --git a/examples/scripts/instantiate_satmae_backbone.py b/examples/scripts/instantiate_satmae_backbone.py deleted file mode 100644 index 79d4cc3d..00000000 --- a/examples/scripts/instantiate_satmae_backbone.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -import numpy as np - -from models_mae import MaskedAutoencoderViT - -kwargs = {"img_size":224, - "patch_size":16, - "in_chans":3, - "embed_dim":1024, - "depth":24, - "num_heads":16, - "decoder_embed_dim":512, - "decoder_depth":8, - "decoder_num_heads":16, - "mlp_ratio":4.} - -vit_mae = MaskedAutoencoderViT(**kwargs) - -mask_ratio = 0.75 -data = torch.from_numpy(np.random.rand(4, 3, 224, 224).astype("float32")) -latent, _, ids_restore = vit_mae.forward_encoder(data, mask_ratio) -reconstructed = vit_mae.forward_decoder(latent, ids_restore) - - -print(f"Output shape: {latent.shape}") -print("Done.") - -_, reconstructed, _ = vit_mae.forward(data, mask_ratio) - -print(f"Output shape: {reconstructed.shape}") -print("Done.") - diff --git a/examples/scripts/open_local_model.py b/examples/scripts/open_local_model.py index 8651c07c..dc970f16 100644 --- a/examples/scripts/open_local_model.py +++ b/examples/scripts/open_local_model.py @@ -1,10 +1,10 @@ from terratorch.io.file import open_generic_torch_model -from models_mae_temporal import MaskedAutoencoderViT +from terratorch.models.backbones.prithvi_mae import PrithviMAE from torch import nn # Path for a downloaded model model_weights_path = "./pretrain-vit-base-e199.pth" -model_template = MaskedAutoencoderViT +model_template = PrithviMAE model_kwargs = { 'img_size': 224, diff --git a/terratorch/io/file.py b/terratorch/io/file.py index 27efebcd..942a09e0 100644 --- a/terratorch/io/file.py +++ b/terratorch/io/file.py @@ -1,4 +1,5 @@ import os +import torch import importlib from torch import nn import numpy as np @@ -22,6 +23,7 @@ def open_generic_torch_model(model: type | str = None, return load_torch_weights(model=model, save_dir=dirname, name=filename) + def load_torch_weights(model:nn.Module=None, save_dir: str = None, name: str = None, device: str = None) -> None: print(f"Trying to load for {device}") @@ -30,29 +32,21 @@ def load_torch_weights(model:nn.Module=None, save_dir: str = None, name: str = N if device != None: model.load_state_dict( torch.load( - os.path.join(save_dir, name + ".pth"), + os.path.join(save_dir, name), map_location=torch.device(device), ) ) else: - #try: - # path = os.path.join(save_dir, name) - # checkpoint = torch.load(path, map_location='cpu') - # model = checkpoint['model'] - # state_dict = model.state_dict() - # msg = model.load_state_dict(model, strict=False) - - #except Exception: - - model.load_state_dict(torch.load(os.path.join(save_dir, name))) + model.load_state_dict(torch.load(os.path.join(save_dir, name), map_location='cpu')) except Exception: print( - f"It was not possible to load from {os.path.join(save_dir, name + '.pth')}" + f"It was not possible to load from {os.path.join(save_dir, name)}" ) return model + def load_from_file_or_attribute(value: list[float]|str): if isinstance(value, list): From 8eda7afc442b930173fc867c97607c18b08cd881 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Wed, 22 Jan 2025 12:25:24 +0100 Subject: [PATCH 30/40] Update prithvi related docs Signed-off-by: Benedikt Blumenstiel --- docs/quick_start.md | 109 +++++++++++++++++++------------------------- docs/registry.md | 18 ++++---- 2 files changed, 57 insertions(+), 70 deletions(-) diff --git a/docs/quick_start.md b/docs/quick_start.md index 6c4e8ee2..c95d6749 100644 --- a/docs/quick_start.md +++ b/docs/quick_start.md @@ -19,28 +19,27 @@ In the simplest case, we might only want access a backbone and code all the rest from terratorch import BACKBONE_REGISTRY # find available prithvi models -print([model_name for model_name in BACKBONE_REGISTRY if "prithvi" in model_name]) ->>> ['timm_prithvi_eo_tiny', 'timm_prithvi_eo_v1_100', 'timm_prithvi_eo_v2_300', 'timm_prithvi_eo_v2_300_tl', 'timm_prithvi_eo_v2_600', - 'timm_prithvi_eo_v2_600_tl', 'timm_prithvi_swin_B', 'timm_prithvi_swin_L', 'timm_prithvi_vit_100', 'timm_prithvi_vit_tiny'] +print([model_name for model_name in BACKBONE_REGISTRY if "terratorch_prithvi" in model_name]) +>>> ['terratorch_prithvi_eo_tiny', 'terratorch_prithvi_eo_v1_100', 'terratorch_prithvi_eo_v2_300', 'terratorch_prithvi_eo_v2_600', 'terratorch_prithvi_eo_v2_300_tl', 'terratorch_prithvi_eo_v2_600_tl'] # show all models with list(BACKBONE_REGISTRY) # check a model is in the registry -"timm_prithvi_swin_B" in BACKBONE_REGISTRY +"terratorch_prithvi_eo_v2_300" in BACKBONE_REGISTRY >>> True # without the prefix, all internal registries will be searched until the first match is found -"prithvi_swin_B" in BACKBONE_REGISTRY +"prithvi_eo_v1_100" in BACKBONE_REGISTRY >>> True # instantiate your desired model -# the backbone registry prefix (in this case 'timm') is optional -# in this case, the underlying registry is timm, so we can pass timm arguments to it -model = BACKBONE_REGISTRY.build("prithvi_eo_v1_100", num_frames=1, pretrained=True) +# the backbone registry prefix (e.g. `terratorch` or `timm`) is optional +# in this case, the underlying registry is terratorch. +model = BACKBONE_REGISTRY.build("prithvi_eo_v1_100", pretrained=True) -# instantiate your model with more options, for instance, passing weights of your own through timm +# instantiate your model with more options, for instance, passing weights from your own file model = BACKBONE_REGISTRY.build( - "prithvi_vit_100", num_frames=1, pretrained=True, pretrained_cfg_overlay={"file": ""} + "prithvi_eo_v2_300", num_frames=1, ckpt_path='path/to/model.pt' ) # Rest of your PyTorch / PyTorchLightning code @@ -68,25 +67,25 @@ model_factory = EncoderDecoderFactory() # Parameters prefixed with decoder_ get passed to the decoder # Parameters prefixed with head_ get passed to the head -model = model_factory.build_model(task="segmentation", - backbone="prithvi_vit_100", - decoder="FCNDecoder", - backbone_bands=[ - HLSBands.BLUE, - HLSBands.GREEN, - HLSBands.RED, - HLSBands.NIR_NARROW, - HLSBands.SWIR_1, - HLSBands.SWIR_2, - ], - necks=[{"name": "SelectIndices", "indices": [-1]}, - {"name": "ReshapeTokensToImage"}], - num_classes=4, - backbone_pretrained=True, - backbone_num_frames=1, - decoder_channels=128, - head_dropout=0.2 - ) +model = model_factory.build_model( + task="segmentation", + backbone="prithvi_eo_v2_300", + backbone_pretrained=True, + backbone_bands=[ + HLSBands.BLUE, + HLSBands.GREEN, + HLSBands.RED, + HLSBands.NIR_NARROW, + HLSBands.SWIR_1, + HLSBands.SWIR_2, + ], + necks=[{"name": "SelectIndices", "indices": [-1]}, + {"name": "ReshapeTokensToImage"}], + decoder="FCNDecoder", + decoder_channels=128, + head_dropout=0.1, + num_classes=4, +) # Rest of your PyTorch / PyTorchLightning code . @@ -102,8 +101,9 @@ At the highest level of abstraction, you can directly obtain a LightningModule r ```python title="Building a full Pixel-Wise Regression task" model_args = dict( - backbone="prithvi_vit_100", - decoder="FCNDecoder", + backbone="prithvi_eo_v2_300", + backbone_pretrained=True, + backbone_num_frames=1, backbone_bands=[ HLSBands.BLUE, HLSBands.GREEN, @@ -114,10 +114,9 @@ model_args = dict( ], necks=[{"name": "SelectIndices", "indices": [-1]}, {"name": "ReshapeTokensToImage"}], - backbone_pretrained=True, - backbone_num_frames=1, + decoder="FCNDecoder", decoder_channels=128, - head_dropout=0.2 + head_dropout=0.1 ) task = PixelwiseRegressionTask( @@ -175,14 +174,11 @@ data: model: class_path: terratorch.tasks.SemanticSegmentationTask init_args: + model_factory: EncoderDecoderFactory model_args: - decoder: UperNetDecoder + backbone: prithvi_eo_v2_300 + backbone_img_size: 512 backbone_pretrained: True - backbone: prithvi_vit_100 - backbone_pretrain_img_size: 512 - decoder_scale_modules: True - decoder_channels: 256 - backbone_in_channels: 6 backbone_bands: - BLUE - GREEN @@ -190,39 +186,30 @@ model: - NIR_NARROW - SWIR_1 - SWIR_2 - num_frames: 1 - num_classes: 2 - head_dropout: 0.1 - head_channel_list: - - 256 - post_backbone_ops: + necks: - name: SelectIndices - indices: - - 5 - - 11 - - 17 - - 23 + indices: [5, 11, 17, 23] - name: ReshapeTokensToImage - loss: ce - + - name: LearnedInterpolateToPyramidal + decoder: UperNetDecoder + decoder_channels: 256 + head_channel_list: [256] + head_dropout: 0.1 + num_classes: 2 + loss: dice ignore_index: -1 - class_weights: - - 0.3 - - 0.7 freeze_backbone: false - freeze_decoder: false - model_factory: EncoderDecoderFactory + freeze_decoder: false optimizer: class_path: torch.optim.AdamW init_args: - lr: 6.e-5 - weight_decay: 0.05 + lr: 1.e-4 + weight_decay: 0.1 lr_scheduler: class_path: ReduceLROnPlateau init_args: monitor: val/loss - ``` To run this training task using the YAML, simply execute: diff --git a/docs/registry.md b/docs/registry.md index 06ceb2bc..ba342eb7 100644 --- a/docs/registry.md +++ b/docs/registry.md @@ -13,27 +13,27 @@ To create the desired instance, registries expose a `build` method, which accept from terratorch import BACKBONE_REGISTRY # find available prithvi models -print([model_name for model_name in BACKBONE_REGISTRY if "prithvi" in model_name]) ->>> ['timm_prithvi_swin_B', 'timm_prithvi_swin_L', 'timm_prithvi_vit_100', 'timm_prithvi_vit_300', 'timm_prithvi_vit_tiny'] +print([model_name for model_name in BACKBONE_REGISTRY if "terratorch_prithvi" in model_name]) +>>> ['terratorch_prithvi_eo_tiny', 'terratorch_prithvi_eo_v1_100', 'terratorch_prithvi_eo_v2_300', 'terratorch_prithvi_eo_v2_600', 'terratorch_prithvi_eo_v2_300_tl', 'terratorch_prithvi_eo_v2_600_tl'] # show all models with list(BACKBONE_REGISTRY) # check a model is in the registry -"timm_prithvi_swin_B" in BACKBONE_REGISTRY +"terratorch_prithvi_eo_v2_300" in BACKBONE_REGISTRY >>> True # without the prefix, all internal registries will be searched until the first match is found -"prithvi_swin_B" in BACKBONE_REGISTRY +"prithvi_eo_v1_100" in BACKBONE_REGISTRY >>> True # instantiate your desired model -# the backbone registry prefix (in this case 'timm') is optional -# in this case, the underlying registry is timm, so we can pass timm arguments to it -model = BACKBONE_REGISTRY.build("prithvi_vit_100", num_frames=1, pretrained=True) +# the backbone registry prefix (e.g. `terratorch` or `timm`) is optional +# in this case, the underlying registry is terratorch. +model = BACKBONE_REGISTRY.build("prithvi_eo_v1_100", pretrained=True) -# instantiate your model with more options, for instance, passing weights of your own through timm +# instantiate your model with more options, for instance, passing weights from your own file model = BACKBONE_REGISTRY.build( - "prithvi_vit_100", num_frames=1, pretrained=True, pretrained_cfg_overlay={"file": ""} + "prithvi_eo_v2_300", num_frames=1, ckpt_path='path/to/model.pt' ) # Rest of your PyTorch / PyTorchLightning code From 025c2e47b70ad2bfa1fd002b7fd94ce5a74add0c Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Wed, 22 Jan 2025 12:37:01 +0100 Subject: [PATCH 31/40] Pin jsonargparse in required.txt Signed-off-by: Francesc Marti Escofet --- requirements/required.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/required.txt b/requirements/required.txt index c4aec6ea..986de635 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -15,6 +15,7 @@ lightning==2.4.0 git+https://github.com/qubvel-org/segmentation_models.pytorch.git@3952e1f8e9684a385a81e30381b8fb5b1ac086cf timm==1.0.11 numpy==1.26.4 +jsonargparse<=4.35.0 # These dependencies are optional # and must be installed just in case From 73b8dd541ffda0d0093953a3cf96730db945b402 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Wed, 22 Jan 2025 13:01:15 +0100 Subject: [PATCH 32/40] Fixed interpolated pos embeddings Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_mae.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index 6cb65ea3..bf0dde12 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -348,11 +348,12 @@ def interpolate_pos_encoding(self, x, t, w, h): class_pos_embed = self.pos_embed[:, :1] patch_pos_embed = self.pos_embed[:, 1:] + t_patches = t // self.patch_embed.patch_size[0] w_patches = w // self.patch_embed.patch_size[1] h_patches = h // self.patch_embed.patch_size[2] - n_sqrt = int((patch_pos_embed.shape[1] / t) ** 0.5) - patch_pos_embed = patch_pos_embed.reshape(t, n_sqrt, n_sqrt, self.embed_dim).permute(0, 3, 1, 2) + n_sqrt = int((patch_pos_embed.shape[1] / t_patches) ** 0.5) + patch_pos_embed = patch_pos_embed.reshape(t_patches, n_sqrt, n_sqrt, self.embed_dim).permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, @@ -537,11 +538,12 @@ def interpolate_pos_encoding(self, x, t, w, h): class_pos_embed = self.decoder_pos_embed[:, :1] patch_pos_embed = self.decoder_pos_embed[:, 1:] + t_patches = t // self.patch_size[0] w_patches = w // self.patch_size[1] h_patches = h // self.patch_size[2] - n_sqrt = int((patch_pos_embed.shape[1] / t) ** 0.5) - patch_pos_embed = patch_pos_embed.reshape(t, n_sqrt, n_sqrt, self.decoder_embed_dim).permute(0, 3, 1, 2) + n_sqrt = int((patch_pos_embed.shape[1] / t_patches) ** 0.5) + patch_pos_embed = patch_pos_embed.reshape(t_patches, n_sqrt, n_sqrt, self.decoder_embed_dim).permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, From e2240aa32680e7f4fb796bb20467982ec3dfe63d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Wed, 22 Jan 2025 14:03:17 -0300 Subject: [PATCH 33/40] Pinning 3.12.7 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 5b522c26..0f68793f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -15,7 +15,7 @@ jobs: timeout-minutes: 30 strategy: matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12.7"] steps: - name: Clone repo From f802ab1d3e028c50a15c082c316d3dc3c55c2e8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Wed, 22 Jan 2025 14:47:18 -0300 Subject: [PATCH 34/40] pinning jsonargparse MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- mkdocs.yml | 6 +++--- pyproject.toml | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mkdocs.yml b/mkdocs.yml index 3105bbf0..0ffd9b8b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -15,9 +15,9 @@ plugins: paths: [src] # search packages in the src folde options: show_root_heading: true -extra: - version: - provider: mike + #extra: + # version: + # provider: mike site_url: https://ibm.github.io/terratorch/ repo_url: https://github.com/IBM/terratorch diff --git a/pyproject.toml b/pyproject.toml index 1d0c5645..30d3cc0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,8 @@ dependencies = [ "mlflow>=2.12.1", # broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977 "lightning[pytorch-extra]>=2,!=2.3.*", - "segmentation-models-pytorch>=0.3" + "segmentation-models-pytorch>=0.3", + "jsonargparse<=4.35.0", # Dependencies not available on PyPI ] From 01f70cbde99de34a0f1b630d267d9f0e995b1795 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Wed, 22 Jan 2025 21:22:13 +0100 Subject: [PATCH 35/40] Restructure prithvi init Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_vit.py | 120 +++++++++------------ 1 file changed, 53 insertions(+), 67 deletions(-) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 8f767e8f..0d221291 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -9,7 +9,7 @@ from terratorch.datasets.utils import generate_bands_intervals from terratorch.models.backbones.prithvi_mae import PrithviViT, PrithviMAE from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY -from timm.models import load_model_config_from_hf, load_state_dict_from_hf +from huggingface_hub import hf_hub_download logger = logging.getLogger(__name__) @@ -60,8 +60,8 @@ def _cfg(**kwargs): "prithvi_eo_v2_300": _cfg(embed_dim=1024, depth=24, num_heads=16), "prithvi_eo_v2_300_tl": _cfg(embed_dim=1024, depth=24, num_heads=16, coords_encoding=["time", "location"], coords_scale_learn=True), - "prithvi_eo_v2_600": _cfg(embed_dim=1280, depth=32, num_heads=16), - "prithvi_eo_v2_600_tl": _cfg(embed_dim=1280, depth=32, num_heads=16, + "prithvi_eo_v2_600": _cfg(embed_dim=1280, depth=32, num_heads=16, patch_size=[1, 14, 14]), + "prithvi_eo_v2_600_tl": _cfg(embed_dim=1280, depth=32, num_heads=16, patch_size=[1, 14, 14], coords_encoding=["time", "location"], coords_scale_learn=True), } @@ -167,27 +167,21 @@ def pad_images(imgs: Tensor,patch_size: int, padding:str) -> Tensor: def _create_prithvi( variant: str, - pretrained: bool = False, # noqa: FBT001, FBT002 - pretrained_bands: list[HLSBands] | None = None, + pretrained: bool = False, # noqa: FBT001, FBT002 model_bands: list[HLSBands | int] | None = None, - ckpt_path: str | None = None, + ckpt_path: str = None, + pretrained_bands: list[HLSBands | str | int] | None = None, + num_frames: int = 1, + encoder_only: bool = True, **kwargs, -) -> PrithviViT: - if pretrained_bands is None: - pretrained_bands = PRETRAINED_BANDS - kwargs["num_frames"] = kwargs.pop("num_frames", 1) # Set num frames to 1 if not present +) -> PrithviViT | PrithviMAE: + """ + Build PrithviViT and PrithviMAE models. + By default, encoder_only is set to True and a ViT is returned. + """ - if model_bands is None: - model_bands: list[HLSBands | int] = pretrained_bands - logger.info( - f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\ - Pretrained patch_embed layer may be misaligned with current bands" - ) - else: - model_bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in model_bands] - - # Little hack because VIT does not support timm's features_only - encoder_only = kwargs.pop("features_only", True) + # Load default config + model_args = prithvi_cfgs[variant].copy() # Backwards compatibility from timm (pretrained_cfg_overlay={"file": ""}) TODO: Remove before v1.0 if "pretrained_cfg_overlay" in kwargs: @@ -200,59 +194,51 @@ def _create_prithvi( else: ckpt_path = kwargs.pop("pretrained_cfg_overlay")["file"] - model_bands = generate_bands_intervals(model_bands) + pretrained_bands = pretrained_bands or model_args.get("bands", PRETRAINED_BANDS) + + if model_bands is None: + model_bands: list[HLSBands | int] = pretrained_bands + logger.info(f"Model bands not passed. Assuming bands are ordered in the same way as {pretrained_bands}." + f"Pretrained patch_embed layer may be misaligned with current bands") + else: + model_bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in model_bands] + model_bands = generate_bands_intervals(model_bands) kwargs["in_chans"] = len(model_bands) + kwargs["num_frames"] = num_frames + model_args.update(kwargs) if encoder_only: prithvi_model_class = PrithviViT - def checkpoint_filter_wrapper_fn(state_dict, model): - return checkpoint_filter_fn_vit(state_dict, model, pretrained_bands, model_bands) + checkpoint_filter_wrapper_fn = checkpoint_filter_fn_vit else: prithvi_model_class = PrithviMAE - def checkpoint_filter_wrapper_fn(state_dict, model): - return checkpoint_filter_fn_mae(state_dict, model, pretrained_bands, model_bands) + checkpoint_filter_wrapper_fn = checkpoint_filter_fn_mae if pretrained: assert variant in pretrained_weights, (f"No pre-trained model found for variant {variant} " f"(pretrained models: {pretrained_weights.keys()})") - # Load pre-trained config from hf - try: - # TODO: Switch from timm to hf hub download and remove config.json download. - model_args = load_model_config_from_hf(pretrained_weights[variant]["hf_hub_id"])[0] - model_args.update(kwargs) - except: - warnings.warn(f"No pretrained configuration was found on HuggingFace for the model {variant}." - f"Using random initialization.", stacklevel=2) - model_args = prithvi_cfgs[variant].copy() - model_args.update(kwargs) - else: - # Load default config - model_args = prithvi_cfgs[variant].copy() - model_args.update(kwargs) - try: - model = prithvi_model_class(**model_args) + model = prithvi_model_class(**model_args) - if ckpt_path is not None: - # Load model from checkpoint - state_dict = torch.load(ckpt_path, map_location="cpu") - state_dict = checkpoint_filter_wrapper_fn(state_dict, model) - model.load_state_dict(state_dict, strict=False) - elif pretrained: + if ckpt_path is not None: + # Load model from checkpoint + state_dict = torch.load(ckpt_path, map_location="cpu") + state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands) + model.load_state_dict(state_dict, strict=False) + elif pretrained: + try: + # Download config.json to count model downloads + _ = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"], filename="config.json") # Load model from Hugging Face - state_dict = load_state_dict_from_hf(model_id=pretrained_weights[variant]["hf_hub_id"], - filename=pretrained_weights[variant]["hf_hub_filename"]) - state_dict = checkpoint_filter_wrapper_fn(state_dict, model) + pretrained_path = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"], + filename=pretrained_weights[variant]["hf_hub_filename"]) + state_dict = torch.load(pretrained_path, map_location="cpu") + state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands) model.load_state_dict(state_dict, strict=True) - - except RuntimeError as e: - if pretrained: - logger.error(f"Failed to initialize the pre-trained model {variant}, " - f"consider testing the code with pretrained=False.") - else: - logger.error(f"Failed to initialize the model {variant}.") - raise e + except RuntimeError as e: + logger.error(f"Failed to load the pre-trained weights for {variant}.") + raise e assert encoder_only or "out_indices" not in kwargs, "out_indices provided for a MAE model." if encoder_only: @@ -278,7 +264,7 @@ def prithvi_eo_tiny( **kwargs, ) -> PrithviViT: - return _create_prithvi("prithvi_eo_tiny", pretrained, bands, **kwargs) + return _create_prithvi("prithvi_eo_tiny", pretrained=pretrained, model_bands=bands, **kwargs) @ TERRATORCH_BACKBONE_REGISTRY.register @@ -288,7 +274,7 @@ def prithvi_eo_v1_100( **kwargs, ) -> PrithviViT: - return _create_prithvi("prithvi_eo_v1_100", pretrained, bands, **kwargs) + return _create_prithvi("prithvi_eo_v1_100", pretrained=pretrained, model_bands=bands, **kwargs) @ TERRATORCH_BACKBONE_REGISTRY.register @@ -298,7 +284,7 @@ def prithvi_eo_v2_300( **kwargs, ) -> PrithviViT: - return _create_prithvi("prithvi_eo_v2_300", pretrained, bands, **kwargs) + return _create_prithvi("prithvi_eo_v2_300", pretrained=pretrained, model_bands=bands, **kwargs) @ TERRATORCH_BACKBONE_REGISTRY.register @@ -308,7 +294,7 @@ def prithvi_eo_v2_600( **kwargs, ) -> PrithviViT: - return _create_prithvi("prithvi_eo_v2_600", pretrained, bands, **kwargs) + return _create_prithvi("prithvi_eo_v2_600", pretrained=pretrained, model_bands=bands, **kwargs) @ TERRATORCH_BACKBONE_REGISTRY.register @@ -318,7 +304,7 @@ def prithvi_eo_v2_300_tl( **kwargs, ) -> PrithviViT: - return _create_prithvi("prithvi_eo_v2_300_tl", pretrained, bands, **kwargs) + return _create_prithvi("prithvi_eo_v2_300_tl", pretrained=pretrained, model_bands=bands, **kwargs) @ TERRATORCH_BACKBONE_REGISTRY.register @@ -328,7 +314,7 @@ def prithvi_eo_v2_600_tl( **kwargs, ) -> PrithviViT: - return _create_prithvi("prithvi_eo_v2_600_tl", pretrained, bands, **kwargs) + return _create_prithvi("prithvi_eo_v2_600_tl", pretrained=pretrained, model_bands=bands, **kwargs) # TODO: Remove prithvi_vit_tiny and prithvi_vit_100 before version 1.0. @@ -342,7 +328,7 @@ def prithvi_vit_tiny( warnings.warn(f"The model prithvi_vit_tiny was renamed to prithvi_eo_tiny. " f"prithvi_vit_tiny will be removed in a future version.", FutureWarning) - return prithvi_eo_tiny(pretrained=pretrained, bands=bands, **kwargs) + return prithvi_eo_tiny(pretrained=pretrained, model_bands=bands, **kwargs) @ TERRATORCH_BACKBONE_REGISTRY.register @@ -354,7 +340,7 @@ def prithvi_vit_100( warnings.warn("The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. " "prithvi_vit_100 will be removed in a future version.", FutureWarning) - return prithvi_eo_v1_100(pretrained=pretrained, bands=bands, **kwargs) + return prithvi_eo_v1_100(pretrained=pretrained, model_bands=bands, **kwargs) # TODO: Remove timm_ errors before version v1.0. From b2b90cefe86dad68a96151933dad44953982ecea Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Thu, 23 Jan 2025 03:35:09 -0500 Subject: [PATCH 36/40] support for torch.Tensor as prediction output --- terratorch/cli_tools.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 768c5d3a..3861a663 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -20,6 +20,9 @@ import rasterio import torch +import random +import string + # Allows classes to be referenced using only the class name import torchgeo.datamodules import yaml @@ -153,12 +156,16 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batc if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) - print(type(prediction)) - print(prediction) - pred_batch, filename_batch = prediction - - for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False): - save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype) + if isinstance(prediction, torch.Tensor): + filename_batch = ''.join(random.choices(string.ascii_letters + string.digits, k=8)) + torch.save(prediction, os.path.join(output_dir, f"{filename_batch}.pt")) + elif isinstance(prediction, tuple): + pred_batch, filename_batch = prediction + for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False): + print(prediction, file_name, output_dir, trainer.out_dtype) + save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype) + else: + raise TypeError(f"Unknown type for prediction{type(prediction)}") def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # noqa: ARG002 # this will create N (num processes) files in `output_dir` each containing From 93d8828e9f88ff5528c73eec47f7e6791528be21 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Thu, 23 Jan 2025 09:43:49 +0100 Subject: [PATCH 37/40] cleanup --- terratorch/cli_tools.py | 19 +++++++++++++------ terratorch/models/wxc_model_factory.py | 2 +- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 768c5d3a..3861a663 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -20,6 +20,9 @@ import rasterio import torch +import random +import string + # Allows classes to be referenced using only the class name import torchgeo.datamodules import yaml @@ -153,12 +156,16 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batc if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) - print(type(prediction)) - print(prediction) - pred_batch, filename_batch = prediction - - for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False): - save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype) + if isinstance(prediction, torch.Tensor): + filename_batch = ''.join(random.choices(string.ascii_letters + string.digits, k=8)) + torch.save(prediction, os.path.join(output_dir, f"{filename_batch}.pt")) + elif isinstance(prediction, tuple): + pred_batch, filename_batch = prediction + for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False): + print(prediction, file_name, output_dir, trainer.out_dtype) + save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype) + else: + raise TypeError(f"Unknown type for prediction{type(prediction)}") def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # noqa: ARG002 # this will create N (num processes) files in `output_dir` each containing diff --git a/terratorch/models/wxc_model_factory.py b/terratorch/models/wxc_model_factory.py index abcbfeaa..f8d6fd19 100644 --- a/terratorch/models/wxc_model_factory.py +++ b/terratorch/models/wxc_model_factory.py @@ -61,7 +61,7 @@ def build_model( raise #remove parameters not meant for the backbone but for other parts of the model - print(kwargs) + logger.trace(kwargs) skip_connection = kwargs.pop('skip_connection') backbone = prithviwxc.PrithviWxC(**kwargs) From e641e7cf63cc6fa059d76163187f931771d6cd1a Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Thu, 23 Jan 2025 10:49:31 +0100 Subject: [PATCH 38/40] Remove pin Signed-off-by: Francesc Marti Escofet --- requirements/required.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/required.txt b/requirements/required.txt index 986de635..c4aec6ea 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -15,7 +15,6 @@ lightning==2.4.0 git+https://github.com/qubvel-org/segmentation_models.pytorch.git@3952e1f8e9684a385a81e30381b8fb5b1ac086cf timm==1.0.11 numpy==1.26.4 -jsonargparse<=4.35.0 # These dependencies are optional # and must be installed just in case From 4d6f0ed1a45a7a4a78aa8d87fc0a81a4a6a5d832 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Thu, 23 Jan 2025 14:46:24 +0100 Subject: [PATCH 39/40] Remove _timm_module from prithvi weights Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_vit.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 0d221291..1ce0a872 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -97,6 +97,9 @@ def checkpoint_filter_fn_vit( clean_dict = {} for k, v in state_dict.items(): + if "_timm_module." in k: # Backwards compatibility for old model checkpoints + k = k.replace("_timm_module.", "") + if "pos_embed" in k: v = model.pos_embed # pos_embed depends on num_frames and is fixed. if "decoder" in k or "_dec" in k or k == "mask_token": @@ -126,6 +129,9 @@ def checkpoint_filter_fn_mae( clean_dict = {} for k, v in state_dict.items(): + if "_timm_module." in k: # Backwards compatibility for old model checkpoints + k = k.replace("_timm_module.", "") + # pos_embed depends on num_frames and is fixed. if "decoder_pos_embed" in k: v = model.decoder.decoder_pos_embed From 6bc62936cd7d6d64b7cc9b0d5e17bcebf8b7f919 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler <5694071+romeokienzler@users.noreply.github.com> Date: Mon, 27 Jan 2025 08:38:02 +0000 Subject: [PATCH 40/40] Update cli_tools.py --- terratorch/cli_tools.py | 1 - 1 file changed, 1 deletion(-) diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 3861a663..21f22c40 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -162,7 +162,6 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batc elif isinstance(prediction, tuple): pred_batch, filename_batch = prediction for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False): - print(prediction, file_name, output_dir, trainer.out_dtype) save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype) else: raise TypeError(f"Unknown type for prediction{type(prediction)}")