diff --git a/README.md b/README.md index ff222be2..a669877b 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,27 @@ TerraTorch is a library based on [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) and the [TorchGeo](https://github.com/microsoft/torchgeo) domain library for geospatial data. -TerraTorch’s main purpose is to provide a flexible fine-tuning framework for Geospatial Foundation Models, which can be interacted with at different abstraction levels. +TerraTorch’s main purpose is to provide a flexible fine-tuning framework for Geospatial Foundation Models, which can be interacted with at different abstraction levels. The library provides: -The library provides: - -- Easy access to open source pre-trained Geospatial Foundation Model backbones (e.g., [Prithvi](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M), [SatMAE](https://sustainlab-group.github.io/SatMAE/) and [ScaleMAE](https://github.com/bair-climate-initiative/scale-mae), other backbones available in the [timm](https://github.com/huggingface/pytorch-image-models) (Pytorch image models) or [SMP](https://github.com/qubvel/segmentation_models.pytorch) (Pytorch Segmentation models with pre-training backbones) packages, as well as fine-tuned models such as [granite-geospatial-biomass](https://huggingface.co/ibm-granite/granite-geospatial-biomass) -- Flexible trainers for Image Segmentation, Classification and Pixel Wise Regression fine-tuning tasks -- Launching of fine-tuning tasks through flexible configuration files +- Convenient modelling tools: + - Flexible trainers for Image Segmentation, Classification and Pixel Wise Regression fine-tuning tasks + - Model factories that allow to easily combine backbones and decoders for different tasks + - Ready-to-go datasets and datamodules that require only to point to your data with no need of creating new custom classes + - Launching of fine-tuning tasks through CLI and flexible configuration files, or via jupyter notebooks +- Easy access to: + - Open source pre-trained Geospatial Foundation Model backbones: + * [Prithvi](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M) + * [SatMAE](https://sustainlab-group.github.io/SatMAE/) + * [ScaleMAE](https://github.com/bair-climate-initiative/scale-mae) + * Satlas (as implemented in [TorchGeo](https://github.com/microsoft/torchgeo)) + * DOFA (as implemented in [TorchGeo](https://github.com/microsoft/torchgeo)) + * SSL4EO-L and SSL4EO-S12 models (as implemented in [TorchGeo](https://github.com/microsoft/torchgeo)) + * Clay + - Backbones available in the [timm](https://github.com/huggingface/pytorch-image-models) (Pytorch image models) + - Decoders available in [SMP](https://github.com/qubvel/segmentation_models.pytorch) (Pytorch Segmentation models with pre-training backbones) and [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) packages + - Fine-tuned models such as [granite-geospatial-biomass](https://huggingface.co/ibm-granite/granite-geospatial-biomass) + - All GEO-Bench datasets and datamodules + - All [TorchGeo](https://github.com/microsoft/torchgeo) datasets and datamodules ## Install ### Pip @@ -26,7 +40,15 @@ To get the most recent version of the main branch, install the library with `pip TerraTorch requires gdal to be installed, which can be quite a complex process. If you don't have GDAL set up on your system, we reccomend using a conda environment and installing it with `conda install -c conda-forge gdal`. -To install as a developer (e.g. to extend the library) clone this repo, install dependencies using `pip install -r requirements/required.txt -r requirements/dev.txt` and run `pip install -e .` +To install as a developer (e.g. to extend the library): +``` +git clone https://github.com/IBM/terratorch.git +cd terratorch +pip install -r requirements/required.txt -r requirements/dev.txt +conda install -c conda-forge gdal +pip install -e . +``` + To install terratorch with partial (work in development) support for Weather Foundation Models, `pip install -e .[wxc]`, which currently works just for `Python >= 3.11`. ## Quick start diff --git a/examples/confs/dofa_sen1floods11_fcn.yaml b/examples/confs/dofa_sen1floods11_fcn.yaml new file mode 100644 index 00000000..a0c1ee2d --- /dev/null +++ b/examples/confs/dofa_sen1floods11_fcn.yaml @@ -0,0 +1,194 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + num_sanity_val_steps: 0 + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: bf16 + logger: + class_path: TensorBoardLogger + init_args: + save_dir: /dccstor/geofm-finetuning/carlosgomes/torchgeo_floods + name: dofa + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + + + max_epochs: 50 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: /dccstor/geofm-finetuning/carlosgomes/torchgeo_floods +data: + class_path: GenericNonGeoSegmentationDataModule + init_args: + batch_size: 16 + num_workers: 8 + constant_scale: 0.0001 + dataset_bands: + - COASTAL_AEROSOL + - BLUE + - GREEN + - RED + - RED_EDGE_1 + - RED_EDGE_2 + - RED_EDGE_3 + - NIR_BROAD + - NIR_NARROW + - WATER_VAPOR + - CIRRUS + - SWIR_1 + - SWIR_2 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/ + train_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + val_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/ + val_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + test_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/ + test_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + # these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files + train_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_train_data_S2.txt + test_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_test_data_S2.txt + val_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_valid_data_S2.txt + img_grep: "*_S2Hand.tif" + label_grep: "*_LabelHand.tif" + no_label_replace: -1 + no_data_replace: 0 + means: + - 0.1412956 + - 0.13795798 + - 0.12353792 + - 0.30902815 + - 0.2044958 + - 0.11912015 + stds: + - 0.07406382 + - 0.07370365 + - 0.08692279 + - 0.11798815 + - 0.09772074 + - 0.07659938 + num_classes: 2 + # train_transform: + # - class_path: albumentations.RandomCrop + # init_args: + # height: 224 + # width: 224 + # - class_path: albumentations.HorizontalFlip + # init_args: + # p: 0.5 + # - class_path: ToTensorV2 + # val_transform: + # - class_path: albumentations.RandomCrop + # init_args: + # height: 224 + # width: 224 + # - class_path: ToTensorV2 + # test_transform: + # - class_path: albumentations.CenterCrop + # init_args: + # height: 224 + # width: 224 + # - class_path: ToTensorV2 + + + + # class_path: terratorch.datamodules.sen1floods11.Sen1Floods11NonGeoDataModule + # init_args: + # batch_size: 8 + # num_workers: 8 + # train_aug: + # - class_path: albumentations.RandomCrop + # init_args: + # height: 224 + # width: 224 + # - class_path: albumentations.HorizontalFlip + # init_args: + # p: 0.5 + # - class_path: ToTensorV2 + # val_aug: + # - class_path: albumentations.RandomCrop + # init_args: + # height: 224 + # width: 224 + # - class_path: ToTensorV2 + + # dict_kwargs: + # data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/ + # bands: + # - 1 + # - 2 + # - 3 + # - 8 + # - 11 + # - 12 + +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_args: + decoder: FCNDecoder + backbone_pretrained: True + backbone_img_size: 512 + backbone: dofa_large_patch16_224 + # backbone_pretrain_img_size: 512 + # decoder_scale_modules: True + # decoder_in_channels: 1024 + decoder_channels: 256 + # backbone_in_channels: 6 + backbone_model_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 +# num_frames: 1 + num_classes: 2 + head_dropout: 0.1 + head_channel_list: + - 256 + necks: + - name: SelectIndices + indices: + - -1 + - name: ReshapeTokensToImage + loss: ce + + ignore_index: -1 + class_weights: + - 0.3 + - 0.7 + freeze_backbone: false + freeze_decoder: false + model_factory: EncoderDecoderFactory + tiled_inference_parameters: + h_crop: 224 + h_stride: 196 + w_crop: 224 + w_stride: 196 + average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 6.e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss diff --git a/examples/confs/dofa_sen1floods11_uper.yaml b/examples/confs/dofa_sen1floods11_uper.yaml new file mode 100644 index 00000000..8e881885 --- /dev/null +++ b/examples/confs/dofa_sen1floods11_uper.yaml @@ -0,0 +1,161 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: bf16 + logger: + class_path: TensorBoardLogger + init_args: + save_dir: /dccstor/geofm-finetuning/carlosgomes/torchgeo_floods + name: dofa + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + + max_epochs: 50 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: /dccstor/geofm-finetuning/carlosgomes/torchgeo_floods +data: + class_path: GenericNonGeoSegmentationDataModule + init_args: + batch_size: 16 + num_workers: 8 + constant_scale: 0.0001 + dataset_bands: + - COASTAL_AEROSOL + - BLUE + - GREEN + - RED + - RED_EDGE_1 + - RED_EDGE_2 + - RED_EDGE_3 + - NIR_BROAD + - NIR_NARROW + - WATER_VAPOR + - CIRRUS + - SWIR_1 + - SWIR_2 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/ + train_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + val_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/ + val_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + test_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/ + test_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + # these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files + train_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_train_data_S2.txt + test_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_test_data_S2.txt + val_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_valid_data_S2.txt + img_grep: "*_S2Hand.tif" + label_grep: "*_LabelHand.tif" + no_label_replace: -1 + no_data_replace: 0 + means: + - 0.1412956 + - 0.13795798 + - 0.12353792 + - 0.30902815 + - 0.2044958 + - 0.11912015 + stds: + - 0.07406382 + - 0.07370365 + - 0.08692279 + - 0.11798815 + - 0.09772074 + - 0.07659938 + num_classes: 2 + train_transform: + - class_path: albumentations.RandomCrop + init_args: + height: 224 + width: 224 + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: ToTensorV2 + val_transform: + - class_path: albumentations.RandomCrop + init_args: + height: 224 + width: 224 + - class_path: ToTensorV2 + test_transform: + - class_path: albumentations.CenterCrop + init_args: + height: 224 + width: 224 + - class_path: ToTensorV2 + +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_args: + decoder: UperNetDecoder + backbone_pretrained: True + backbone: dofa_large_patch16_224 + # backbone_pretrain_img_size: 512 + # decoder_scale_modules: True + # decoder_in_channels: 1024 + decoder_channels: 256 + # backbone_in_channels: 6 + backbone_model_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + backbone_out_indices: + - 5 + - 11 + - 17 + - 23 + # num_frames: 1 + num_classes: 2 + head_dropout: 0.1 + head_channel_list: + - 256 + necks: + # - name: SelectIndices + # indices: + # - 5 + # - 11 + # - 17 + # - 23 + - name: ReshapeTokensToImage + loss: ce + + ignore_index: -1 + class_weights: + - 0.3 + - 0.7 + freeze_backbone: true + freeze_decoder: false + model_factory: EncoderDecoderFactory +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 6.e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss diff --git a/examples/confs/satlas_sen1floods11_uper.yaml b/examples/confs/satlas_sen1floods11_uper.yaml new file mode 100644 index 00000000..c9ede1a9 --- /dev/null +++ b/examples/confs/satlas_sen1floods11_uper.yaml @@ -0,0 +1,162 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: bf16 + logger: + class_path: TensorBoardLogger + init_args: + save_dir: /dccstor/geofm-finetuning/carlosgomes/torchgeo_floods + name: satlas + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + + max_epochs: 50 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: /dccstor/geofm-finetuning/carlosgomes/torchgeo_floods +data: + class_path: GenericNonGeoSegmentationDataModule + init_args: + batch_size: 16 + num_workers: 8 + constant_scale: 0.0001 + dataset_bands: + - COASTAL_AEROSOL + - BLUE + - GREEN + - RED + - RED_EDGE_1 + - RED_EDGE_2 + - RED_EDGE_3 + - NIR_BROAD + - NIR_NARROW + - WATER_VAPOR + - CIRRUS + - SWIR_1 + - SWIR_2 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/ + train_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + val_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/ + val_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + test_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/ + test_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + # these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files + train_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_train_data_S2.txt + test_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_test_data_S2.txt + val_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_valid_data_S2.txt + img_grep: "*_S2Hand.tif" + label_grep: "*_LabelHand.tif" + no_label_replace: -1 + no_data_replace: 0 + means: + - 0.1412956 + - 0.13795798 + - 0.12353792 + - 0.30902815 + - 0.2044958 + - 0.11912015 + stds: + - 0.07406382 + - 0.07370365 + - 0.08692279 + - 0.11798815 + - 0.09772074 + - 0.07659938 + num_classes: 2 + # train_transform: + # - class_path: albumentations.RandomCrop + # init_args: + # height: 224 + # width: 224 + # - class_path: albumentations.HorizontalFlip + # init_args: + # p: 0.5 + # - class_path: ToTensorV2 + # val_transform: + # - class_path: albumentations.RandomCrop + # init_args: + # height: 224 + # width: 224 + # - class_path: ToTensorV2 + # test_transform: + # - class_path: albumentations.CenterCrop + # init_args: + # height: 224 + # width: 224 + # - class_path: ToTensorV2 + +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_args: + decoder: UperNetDecoder + backbone_pretrained: True + # backbone: satlas_swin_b_sentinel2_si_ms + backbone: ssl4eol_resnet18_landsat_oli_tirs_toa_moco + # backbone_pretrain_img_size: 512 + # decoder_scale_modules: True + # decoder_in_channels: 1024 + decoder_channels: 256 + # backbone_in_channels: 6 + backbone_model_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + backbone_out_indices: + - 1 + - 3 + - 5 + - 7 + # num_frames: 1 + num_classes: 2 + head_dropout: 0.1 + head_channel_list: + - 256 + necks: + # - name: SelectIndices + # indices: + # - 5 + # - 11 + # - 17 + # - 23 + - name: ReshapeTokensToImage + loss: ce + + ignore_index: -1 + class_weights: + - 0.3 + - 0.7 + freeze_backbone: true + freeze_decoder: false + model_factory: EncoderDecoderFactory +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 6.e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss diff --git a/requirements/required.txt b/requirements/required.txt index 17e62964..c4aec6ea 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -1,17 +1,20 @@ -torchgeo==0.6.0 +#torchgeo==0.7.0.dev0 +git+https://github.com/microsoft/torchgeo.git@fedf99375535f801565856cd774bfa9e5a251d55 rioxarray==0.15.0 albumentations==1.3.1 albucore<=0.0.16 -rasterio==1.3.10 -torch==2.3.1 -torchvision==0.18.1 -torchmetrics==1.4.0 +rasterio==1.3.11 +torch==2.4.1 +torchvision==0.19.1 +torchmetrics==1.3.1 geopandas==0.14.4 -lightly==1.4.25 -h5py==3.12.1 +lightly==1.5.13 +h5py==3.10.0 mlflow==2.14.3 lightning==2.4.0 -segmentation-models-pytorch==0.3.4 +git+https://github.com/qubvel-org/segmentation_models.pytorch.git@3952e1f8e9684a385a81e30381b8fb5b1ac086cf +timm==1.0.11 +numpy==1.26.4 # These dependencies are optional # and must be installed just in case diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index 47c08aa0..2413672f 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -17,6 +17,7 @@ from torchgeo.datamodules import NonGeoDataModule from torchgeo.transforms import AugmentationSequential +from terratorch.datamodules.utils import wrap_in_compose_is_list from terratorch.datasets import GenericNonGeoPixelwiseRegressionDataset, GenericNonGeoSegmentationDataset, HLSBands from terratorch.io.file import load_from_file_or_attribute diff --git a/terratorch/datamodules/sen1floods11.py b/terratorch/datamodules/sen1floods11.py index d2699076..b9e2ff68 100644 --- a/terratorch/datamodules/sen1floods11.py +++ b/terratorch/datamodules/sen1floods11.py @@ -45,7 +45,6 @@ "SWIR_2": 0.07659938, } - class Sen1Floods11NonGeoDataModule(NonGeoDataModule): """NonGeo Fire Scars data module implementation""" diff --git a/terratorch/datasets/__init__.py b/terratorch/datasets/__init__.py index faceed41..7006f5e0 100644 --- a/terratorch/datasets/__init__.py +++ b/terratorch/datasets/__init__.py @@ -36,9 +36,12 @@ from terratorch.datasets.pastis import PASTIS # GenericNonGeoRegressionDataset, -from terratorch.datasets.sen1floods11 import Sen1Floods11NonGeo + +from terratorch.datasets.sen1floods11_lat_lon import Sen1Floods11NonGeo +from terratorch.datasets.utils import HLSBands, OpticalBands, SARBands + +#from terratorch.datasets.sen1floods11 import Sen1Floods11NonGeo from terratorch.datasets.sen4agrinet import Sen4AgriNet -from terratorch.datasets.utils import HLSBands from terratorch.datasets.burn_intensity import BurnIntensityNonGeo from terratorch.datasets.carbonflux import CarbonFluxNonGeo @@ -98,5 +101,7 @@ "WSFEvolution", "HLSL30", "HLSS30", + "OpticalBands", + "SARBands", "OpenEarthMapNonGeo" ) diff --git a/terratorch/datasets/sen1floods11.py b/terratorch/datasets/sen1floods11.py index b36965c7..e156924d 100644 --- a/terratorch/datasets/sen1floods11.py +++ b/terratorch/datasets/sen1floods11.py @@ -23,21 +23,21 @@ class Sen1Floods11NonGeo(NonGeoDataset): """NonGeo dataset implementation for sen1floods11.""" - all_band_names = ( - "COASTAL_AEROSOL", - "BLUE", - "GREEN", - "RED", - "RED_EDGE_1", - "RED_EDGE_2", - "RED_EDGE_3", - "NIR_BROAD", - "NIR_NARROW", - "WATER_VAPOR", - "CIRRUS", - "SWIR_1", - "SWIR_2", - ) + all_band_names = ( + "COASTAL_AEROSOL", + "BLUE", + "GREEN", + "RED", + "RED_EDGE_1", + "RED_EDGE_2", + "RED_EDGE_3", + "NIR_BROAD", + "NIR_NARROW", + "WATER_VAPOR", + "CIRRUS", + "SWIR_1", + "SWIR_2", + ) rgb_bands = ("RED", "GREEN", "BLUE") diff --git a/terratorch/datasets/sen1floods11_lat_lon.py b/terratorch/datasets/sen1floods11_lat_lon.py new file mode 100644 index 00000000..a3287fef --- /dev/null +++ b/terratorch/datasets/sen1floods11_lat_lon.py @@ -0,0 +1,208 @@ +import glob +import os +from pathlib import Path +from typing import Any + +import albumentations as A +import geopandas +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import rioxarray +import torch +from matplotlib import cm +from matplotlib.figure import Figure +from matplotlib.patches import Rectangle +from torch import Tensor +from torchgeo.datasets import NonGeoDataset + +from terratorch.datasets.utils import default_transform, filter_valid_files + + +class Sen1Floods11NonGeo(NonGeoDataset): + """NonGeo dataset implementation for sen1floods11.""" + all_band_names = ( + "COASTAL_AEROSOL", + "BLUE", + "GREEN", + "RED", + "RED_EDGE_1", + "RED_EDGE_2", + "RED_EDGE_3", + "NIR_BROAD", + "NIR_NARROW", + "WATER_VAPOR", + "CIRRUS", + "SWIR_1", + "SWIR_2", + ) + + + def __init__( + self, + data_root: str, + split="train", + bands: None | list[int] = None, + transform: A.Compose | None = None, + ) -> None: + super().__init__() + if split not in ["train", "test", "val"]: + msg = "Split must be one of train, test, val." + raise Exception(msg) + if split == "val": + split = "valid" + + self.bands = bands + data_root = Path(data_root) + data_directory = data_root / "data/data/flood_events/HandLabeled/" + input_directory = data_directory / "S2Hand" + label_directory = data_directory / "LabelHand" + + # split_file = data_root / f"splits/splits/flood_handlabeled/flood_bolivia_data_S2.txt" + split_file = data_root / f"splits/splits/flood_handlabeled/flood_{split}_data_S2_geodn.txt" + # split_file = data_root / f"splits/splits/flood_handlabeled/flood_bolivia_data_S2.txt" + metadata_file = data_root / "Sen1Floods11_Metadata.geojson" + self.metadata = geopandas.read_file(metadata_file) + + self.image_files = sorted(glob.glob(os.path.join(input_directory, "*.tif"))) + self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_directory, "*.tif"))) + + with open(split_file) as f: + split = f.readlines() + valid_files = {rf"{substring.strip()}" for substring in split} + self.image_files = filter_valid_files( + self.image_files, + valid_files=valid_files, + ignore_extensions=True, + allow_substring=True, + ) + self.segmentation_mask_files = filter_valid_files( + self.segmentation_mask_files, + valid_files=valid_files, + ignore_extensions=True, + allow_substring=True, + ) + + self.rgb_indices = [2, 1, 0] + self.transform = transform if transform else default_transform + + def __len__(self) -> int: + return len(self.image_files) + + def _get_date(self, index) -> np.ndarray: + # move this logic to the model? + file_name = self.image_files[index] + location = os.path.basename(file_name).split("_")[0] + if self.metadata[self.metadata["location"] == location].shape[0] != 1: + # msg = f"No date found for sample {file_name}" + # raise Exception(msg) + date = pd.to_datetime("13-10-1998", dayfirst=True) + else: + date = pd.to_datetime(self.metadata[self.metadata["location"] == location]["s2_date"].item()) + date_np = np.zeros((1, 3)) + date_np[0, 0] = date.year + date_np[0, 1] = date.dayofyear - 1 # base 0 + # date_np[0, 2] = date.hour + return date_np + + def _get_coords(self, index) -> np.ndarray: + file_name = self.image_files[index] + image = rioxarray.open_rasterio(file_name) + # lons_lats = np.meshgrid(image.x / 180, image.y / 90) + # coords = np.stack([np.stack(lon_lat) for lon_lat in lons_lats]) + # coords shape: batch_size, 2 (lon, lat), height, width + + lat_lon = np.array([image.y[image.shape[0]//2], image.x[image.shape[1]//2]]) + return lat_lon + + def __getitem__(self, index: int) -> dict[str, Any]: + image = self._load_file(self.image_files[index]).astype(np.float32) * 0.0001 + if self.bands: + image = image[self.bands, ...] + image = np.moveaxis(image, 0, -1) + output = { + "image": image, + "mask": self._load_file(self.segmentation_mask_files[index]).astype(np.int64)[0], + } + if self.transform: + output = self.transform(**output) + + output["location_coords"] = np.moveaxis(self._get_coords(index).astype(np.float32), 0, -1) + output["temporal_coords"] = self._get_date(index).astype(np.float32) + return output + + def _load_file(self, path: Path): + data = rioxarray.open_rasterio(path) + return data.to_numpy() + + def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + image = sample["image"] + if torch.is_tensor(image): + image = image.numpy() + image = image.take(self.rgb_indices, axis=0) + image = np.transpose(image, (1, 2, 0)) + image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1))) + image = np.clip(image, 0, 1) + + label_mask = sample["mask"] + label_mask = np.transpose(label_mask, (1, 2, 0)) + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction_mask = sample["prediction"] + + return self._plot_sample( + image, + label_mask, + prediction=prediction_mask if showing_predictions else None, + suptitle=suptitle, + ) + + @staticmethod + def _plot_sample(image, label, num_classes, prediction=None, suptitle=None, class_names=None): + num_images = 5 if prediction else 4 + fig, ax = plt.subplots(1, num_images, figsize=(8, 6)) + + # for legend + ax[0].axis("off") + + norm = mpl.colors.Normalize(vmin=0, vmax=num_classes - 1) + ax[1].axis("off") + ax[1].title.set_text("Image") + ax[1].imshow(image) + + ax[2].axis("off") + ax[2].title.set_text("Ground Truth Mask") + ax[2].imshow(label, cmap="jet", norm=norm) + + ax[3].axis("off") + ax[3].title.set_text("GT Mask on Image") + ax[3].imshow(image) + ax[3].imshow(label, cmap="jet", alpha=0.3, norm=norm) + + if prediction: + ax[4].title.set_text("Predicted Mask") + ax[4].imshow(prediction, cmap="jet", norm=norm) + + cmap = cm.get_cmap("jet") + legend_data = [] + for i, _ in enumerate(range(num_classes)): + class_name = class_names[i] if class_names else str(i) + data = [i, cmap(norm(i)), class_name] + legend_data.append(data) + handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data] + labels = [n for k, c, n in legend_data] + ax[0].legend(handles, labels, loc="center") + if suptitle is not None: + plt.suptitle(suptitle) + return fig diff --git a/terratorch/datasets/utils.py b/terratorch/datasets/utils.py index 764864db..8f18ee74 100644 --- a/terratorch/datasets/utils.py +++ b/terratorch/datasets/utils.py @@ -34,6 +34,45 @@ def try_convert_to_hls_bands_enum(cls, x: Any): except ValueError: return x +class OpticalBands(Enum): + COASTAL_AEROSOL = "COASTAL_AEROSOL" + BLUE = "BLUE" + GREEN = "GREEN" + RED = "RED" + RED_EDGE_1 = "RED_EDGE_1" + RED_EDGE_2 = "RED_EDGE_2" + RED_EDGE_3 = "RED_EDGE_3" + NIR_BROAD = "NIR_BROAD" + NIR_NARROW = "NIR_NARROW" + SWIR_1 = "SWIR_1" + SWIR_2 = "SWIR_2" + WATER_VAPOR = "WATER_VAPOR" + CIRRUS = "CIRRUS" + THEMRAL_INFRARED_1 = "THEMRAL_INFRARED_1" + THEMRAL_INFRARED_2 = "THEMRAL_INFRARED_2" + + @classmethod + def try_convert_to_optical_bands_enum(cls, x: Any): + try: + return cls(x) + except ValueError: + return x + +class SARBands(Enum): + VV = "VV" + VH = "VH" + ASC_VV = "ASC_VV" + ASC_VH = "ASC_VH" + DSC_VV = "DSC_VV" + DSC_VH = "DSC_VH" + VV_VH = "VV_VH" + + @classmethod + def try_convert_to_optical_bands_enum(cls, x: Any): + try: + return cls(x) + except ValueError: + return x class S1Bands(Enum): VV = 'VV' diff --git a/terratorch/models/backbones/__init__.py b/terratorch/models/backbones/__init__.py index 7dc33dc0..2b8d4691 100644 --- a/terratorch/models/backbones/__init__.py +++ b/terratorch/models/backbones/__init__.py @@ -5,6 +5,8 @@ import terratorch.models.backbones.prithvi_vit import terratorch.models.backbones.clay_v1 import terratorch.models.backbones.scalemae +import terratorch.models.backbones.dofa_vit +import terratorch.models.backbones.torchgeo_swin_satlas +import terratorch.models.backbones.torchgeo_resnet import terratorch.models.backbones.multimae_register - from terratorch.models.backbones.unet import UNet diff --git a/terratorch/models/backbones/dofa_vit.py b/terratorch/models/backbones/dofa_vit.py new file mode 100644 index 00000000..61913ced --- /dev/null +++ b/terratorch/models/backbones/dofa_vit.py @@ -0,0 +1,187 @@ +# reference torchgeo https://torchgeo.readthedocs.io/en/latest/_modules/torchgeo/models/dofa.html#DOFA +import torch +import torchgeo.models.dofa as dofa +import logging +from collections.abc import Callable +from functools import partial +import huggingface_hub +import torch.nn as nn +from typing import List +import huggingface_hub +from torchvision.models._api import Weights, WeightsEnum +from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY +import pdb + +waves_list= { + "COASTAL_AEROSOL": 0.44, + "BLUE": 0.49, + "GREEN": 0.56, + "RED": 0.665, + "RED_EDGE_1": 0.705, + "RED_EDGE_2": 0.74, + "RED_EDGE_3": 0.783, + "NIR_BROAD": 0.832, + "NIR_NARROW": 0.864, + "WATER_VAPOR": 0.945, + "CIRRUS": 1.373, + "SWIR_1": 1.61, + "SWIR_2": 2.20, + "THEMRAL_INFRARED_1": 10.90, + "THEMRAL_INFRARED_12": 12.00, + "VV": 3.75, + "VH": 3.75, + "ASC_VV": 3.75, + "ASC_VH": 3.75, + "DSC_VV": 3.75, + "DSC_VH": 3.75, + "VV-VH": 3.75 +} + + +class DOFAEncoderWrapper(nn.Module): + + """ + A wrapper for DOFA models from torchgeo to return only the forward pass of the encoder + Attributes: + dofa_model (DOFA): The instantiated dofa model + Methods: + forward(x: List[torch.Tensor], wavelengths: list[float]) -> torch.Tensor: + Forward pass for embeddings with specified indices. + """ + + def __init__(self, dofa_model, wavelengths, weights=None, out_indices=None) -> None: + """ + Args: + dofa_model (DOFA): The decoder module to be wrapped. + weights () + """ + super().__init__() + self.dofa_model = dofa_model + self.weights = weights + self.wavelengths = wavelengths + + self.out_indices = out_indices if out_indices else [-1] + self.out_channels = [self.dofa_model.embed_dim] * len(self.out_indices) + + def forward(self, x: List[torch.Tensor], **kwargs) -> torch.Tensor: + + wavelist = torch.tensor(self.wavelengths, device=x.device).float() + + x, _ = self.dofa_model.patch_embed(x, wavelist) + + x = x + self.dofa_model.pos_embed[:, 1:, :] + # append cls token + cls_token = self.dofa_model.cls_token + self.dofa_model.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + outs = [] + # apply Transformer blocks + for i, block in enumerate(self.dofa_model.blocks): + x = block(x) + if i in self.out_indices: + outs.append(x) + elif (i == (len(self.dofa_model.blocks)-1)) & (-1 in self.out_indices): + outs.append(x) + + return tuple(outs) + +def get_wavelenghts(model_bands): + + wavelengths = [waves_list[x.split('.')[-1]] for x in model_bands] + + return wavelengths + + +@TERRATORCH_BACKBONE_REGISTRY.register +def dofa_small_patch16_224(model_bands, input_size = 224, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = None, out_indices: list | None = None, **kwargs): + model = dofa.dofa_small_patch16_224(**kwargs) + input_size = kwargs["img_size"] if "img_size" in kwargs else 224 + if pretrained: + model = load_dofa_weights(model, ckpt_data, weights, input_size) + wavelengths = get_wavelenghts(model_bands) + + return DOFAEncoderWrapper(model, wavelengths, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def dofa_base_patch16_224(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = dofa.DOFABase16_Weights.DOFA_MAE, out_indices: list | None = None, **kwargs): + model = dofa.dofa_base_patch16_224(**kwargs) + input_size = kwargs["img_size"] if "img_size" in kwargs else 224 + if pretrained: + model = load_dofa_weights(model, ckpt_data, weights, input_size) + wavelengths = get_wavelenghts(model_bands) + + return DOFAEncoderWrapper(model, wavelengths, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def dofa_large_patch16_224(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = dofa.DOFALarge16_Weights.DOFA_MAE, out_indices: list | None = None, **kwargs): + model = dofa.dofa_large_patch16_224(**kwargs) + input_size = kwargs["img_size"] if "img_size" in kwargs else 224 + if pretrained: + model = load_dofa_weights(model, ckpt_data, weights, input_size) + wavelengths = get_wavelenghts(model_bands) + + return DOFAEncoderWrapper(model, wavelengths, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def dofa_huge_patch16_224(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = None, out_indices: list | None = None, **kwargs): + model = dofa.dofa_huge_patch16_224(**kwargs) + input_size = kwargs["img_size"] if "img_size" in kwargs else 224 + if pretrained: + model = load_dofa_weights(model, ckpt_data, weights, input_size) + wavelengths = get_wavelenghts(model_bands) + + return DOFAEncoderWrapper(model, wavelengths, weights, out_indices) + +def load_dofa_weights(model: nn.Module, ckpt_data: str | None = None, weights: Weights | None = None, input_size = 224) -> nn.Module: + state_dict = model.state_dict() + print("Loading weights") + if ckpt_data is not None: + if ckpt_data.find("https://hf.co/") > -1: + repo_id = ckpt_data.split("/resolve/")[0].replace("https://hf.co/", '') + filename = ckpt_data.split("/")[-1] + ckpt_data = huggingface_hub.hf_hub_download(repo_id=repo_id, filename=filename) + checkpoint_model = torch.load(ckpt_data, map_location="cpu") + + for k in ["head.weight", "head.bias"]: + if ( + k in checkpoint_model + and checkpoint_model[k].shape != state_dict[k].shape + ): + logging.info(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + if input_size != 224: + if ( + "pos_embed" in checkpoint_model + and checkpoint_model["pos_embed"].shape != state_dict["pos_embed"].shape + ): + logging.info("Removing key pos_embed from pretrained checkpoint") + del checkpoint_model["pos_embed"] + + + msg = model.load_state_dict(checkpoint_model, strict=False) + + logging.info(msg) + else: + if weights is not None: + + checkpoint_model = weights.get_state_dict(progress=True) + allowed_missing_keys = {'fc_norm.weight', 'fc_norm.bias', 'head.weight', 'head.bias'} + if input_size != 224: + if ( + "pos_embed" in checkpoint_model + and checkpoint_model["pos_embed"].shape != state_dict["pos_embed"].shape + ): + logging.info("Removing key pos_embed from pretrained checkpoint") + del checkpoint_model["pos_embed"] + allowed_missing_keys.add('pos_embed') + missing_keys, unexpected_keys = model.load_state_dict(checkpoint_model, strict=False) + logging.info(f"Weights loaded.") + # Both fc_norm and head are generated dynamically + assert set(missing_keys) <= allowed_missing_keys + assert not unexpected_keys + else: + print("No weights to load.") + + return model + diff --git a/terratorch/models/backbones/mmearth_convnextv2.py b/terratorch/models/backbones/mmearth_convnextv2.py new file mode 100644 index 00000000..a3c57210 --- /dev/null +++ b/terratorch/models/backbones/mmearth_convnextv2.py @@ -0,0 +1,256 @@ +# code from https://github.com/vishalned/MMEarth-train/blob/main/models/convnextv2.py +# https://github.com/vishalned/MMEarth-train?tab=readme-ov-file +# Copyright (c) Meta Platforms, Inc. and affiliates. +from argparse import Namespace + +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_, DropPath +from torch import Tensor + +from .norm_layers import LayerNorm, GRN + + +# All rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +class Block(nn.Module): + """ConvNeXtV2 Block. + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + """ + + def __init__(self, dim, drop_path=0.0): + super().__init__() + self.dwconv: nn.Module = nn.Conv2d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depth-wise conv + self.norm: nn.Module = LayerNorm(dim, eps=1e-6) + self.pwconv1: nn.Module = nn.Linear( + dim, 4 * dim + ) # point-wise/1x1 convs, implemented with linear layers + self.act: nn.Module = nn.GELU() + self.grn: nn.Module = GRN(4 * dim) + self.pwconv2: nn.Module = nn.Linear(4 * dim, dim) + self.drop_path: nn.Module = ( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + ) + + def forward(self, x: Tensor) -> Tensor: + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ConvNeXtV2(nn.Module): + """ConvNeXt V2 + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__( + self, + patch_size: int = 32, + img_size: int = 128, + in_chans: int = 3, + num_classes: int = 1000, + depths: list[int] = None, + dims: list[int] = None, + drop_path_rate: float = 0.0, + head_init_scale: float = 1.0, + use_orig_stem: bool = False, + args: Namespace = None, + ): + super().__init__() + self.depths = depths + if self.depths is None: # set default value + self.depths = [3, 3, 9, 3] + self.img_size = img_size + self.use_orig_stem = use_orig_stem + self.num_stage = len(depths) + self.downsample_layers = ( + nn.ModuleList() + ) # stem and 3 intermediate downsampling conv layer + self.patch_size = patch_size + if dims is None: + dims = [96, 192, 384, 768] + + if self.use_orig_stem: + self.stem_orig = nn.Sequential( + nn.Conv2d( + in_chans, + dims[0], + kernel_size=patch_size // (2 ** (self.num_stage - 1)), + stride=patch_size // (2 ** (self.num_stage - 1)), + ), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + else: + self.initial_conv = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=3, stride=1), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + nn.GELU(), + ) + # depthwise conv for stem + self.stem = nn.Sequential( + nn.Conv2d( + dims[0], + dims[0], + kernel_size=patch_size // (2 ** (self.num_stage - 1)), + stride=patch_size // (2 ** (self.num_stage - 1)), + padding=(patch_size // (2 ** (self.num_stage - 1))) // 2, + groups=dims[0], + ), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = ( + nn.ModuleList() + ) # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(self.num_stage): + stage = nn.Sequential( + *[ + Block(dim=dims[i], drop_path=dp_rates[cur + j]) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer + self.head = nn.Linear(dims[-1], num_classes) + + self.apply(self._init_weights) + self.head.weight.data.mul_(head_init_scale) + self.head.bias.data.mul_(head_init_scale) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + if self.use_orig_stem: + x = self.stem_orig(x) + else: + x = self.initial_conv(x) + x = self.stem(x) + + x = self.stages[0](x) + for i in range(3): + x = self.downsample_layers[i](x) + x = self.stages[i + 1](x) + + return self.norm( + x.mean([-2, -1]) + ) # global average pooling, (N, C, H, W) -> (N, C) + + def upsample_mask(self, mask, scale): + assert len(mask.shape) == 2 + p = int(mask.shape[1] ** 0.5) + return ( + mask.reshape(-1, p, p) + .repeat_interleave(scale, axis=1) + .repeat_interleave(scale, axis=2) + ) + + def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: + if mask is not None: # for the pretraining case + num_patches = mask.shape[1] + scale = int(self.img_size // (num_patches**0.5)) + mask = self.upsample_mask(mask, scale) + + mask = mask.unsqueeze(1).type_as(x) + x *= 1.0 - mask + if self.use_orig_stem: + x = self.stem_orig(x) + else: + x = self.initial_conv(x) + x = self.stem(x) + + x = self.stages[0](x) + for i in range(3): + x = self.downsample_layers[i](x) + x = self.stages[i + 1](x) + return x + + x = self.forward_features(x) + x = self.head(x) + return x + + +def convnextv2_atto(**kwargs): + model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) + return model + + +def convnextv2_femto(**kwargs): + model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) + return model + + +def convnext_pico(**kwargs): + model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) + return model + + +def convnextv2_nano(**kwargs): + model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) + return model + + +def convnextv2_tiny(**kwargs): + model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) + return model + + +def convnextv2_base(**kwargs): + model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + return model + + +def convnextv2_large(**kwargs): + model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) + return model + + +def convnextv2_huge(**kwargs): + model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) + return model + + +checkpoints = { + "pt-S2_atto_1M_64_uncertainty_56-8": "https://sid.erda.dk/share_redirect/g23YOnaaTp/pt-S2_atto_1M_64_uncertainty_56-8/checkpoint-199.pth", + "pt-all_mod_atto_100k_128_uncertainty_112-16": "https://sid.erda.dk/share_redirect/g23YOnaaTp/pt-all_mod_atto_100k_128_uncertainty_112-16/checkpoint-199.pth", + "pt-all_mod_atto_1M_128_uncertainty_112-16": "https://sid.erda.dk/share_redirect/g23YOnaaTp/pt-all_mod_atto_1M_128_uncertainty_112-16/checkpoint-199.pth", + "pt-all_mod_atto_1M_64_uncertainty_56-8": "https://sid.erda.dk/share_redirect/g23YOnaaTp/pt-all_mod_atto_1M_64_uncertainty_56-8/checkpoint-199.pth", +} diff --git a/terratorch/models/backbones/select_patch_embed_weights.py b/terratorch/models/backbones/select_patch_embed_weights.py index 142814d1..b175140e 100644 --- a/terratorch/models/backbones/select_patch_embed_weights.py +++ b/terratorch/models/backbones/select_patch_embed_weights.py @@ -7,8 +7,8 @@ import torch from torch import nn -from terratorch.datasets import HLSBands - +from terratorch.datasets import HLSBands, OpticalBands, SARBands +import collections def patch_embed_weights_are_compatible(model_patch_embed: torch.Tensor, checkpoint_patch_embed: torch.Tensor) -> bool: # check all dimensions are the same except for channel dimension @@ -19,7 +19,7 @@ def patch_embed_weights_are_compatible(model_patch_embed: torch.Tensor, checkpoi return model_shape == checkpoint_shape def select_patch_embed_weights( - state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int], model_bands: list[HLSBands | int] + state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int | OpticalBands| SARBands], model_bands: list[HLSBands | int | OpticalBands| SARBands], custom_proj_key: str = None ) -> dict: """Filter out the patch embedding weights according to the bands being used. If a band exists in the pretrained_bands, but not in model_bands, drop it. @@ -35,21 +35,49 @@ def select_patch_embed_weights( Returns: dict: New state dict """ - _possible_keys_for_proj_weight = { - "encoder.patch_embed.proj.weight", - "patch_embed.proj.weight", - "module.patch_embed.proj.weight", - "patch_embed.projection.weight", - "module.patch_embed.projection.weight", - } - patch_embed_proj_weight_key = state_dict.keys() & _possible_keys_for_proj_weight - if len(patch_embed_proj_weight_key) == 0: - msg = "Could not find key for patch embed weight" - raise Exception(msg) - if len(patch_embed_proj_weight_key) > 1: - msg = "Too many matches for key for patch embed weight" - raise Exception(msg) + if (type(pretrained_bands) == type(model_bands)) | (type(pretrained_bands) == int) | (type(model_bands) == int): + + if custom_proj_key is None: + _possible_keys_for_proj_weight = { + "patch_embed.proj.weight", + "module.patch_embed.proj.weight", + "patch_embed.projection.weight", + "module.patch_embed.projection.weight", + } + else: + _possible_keys_for_proj_weight = {custom_proj_key} + patch_embed_proj_weight_key = state_dict.keys() & _possible_keys_for_proj_weight if (type(state_dict) in [collections.OrderedDict, dict]) else state_dict().keys() & _possible_keys_for_proj_weight + if len(patch_embed_proj_weight_key) == 0: + msg = "Could not find key for patch embed weight" + raise Exception(msg) + if len(patch_embed_proj_weight_key) > 1: + msg = "Too many matches for key for patch embed weight" + raise Exception(msg) + + # extract the single element from the set + (patch_embed_proj_weight_key,) = patch_embed_proj_weight_key + patch_embed_weight = state_dict[patch_embed_proj_weight_key] + + temp_weight = model.state_dict()[patch_embed_proj_weight_key].clone() + + # only do this if the patch size and tubelet size match. If not, start with random weights + if patch_embed_weights_are_compatible(temp_weight, patch_embed_weight): + torch.nn.init.xavier_uniform_(temp_weight.view([temp_weight.shape[0], -1])) + for index, band in enumerate(model_bands): + if band in pretrained_bands: + logging.debug(f"Loaded weights for {band} in position {index} of patch embed") + temp_weight[:, index] = patch_embed_weight[:, pretrained_bands.index(band)] + else: + warnings.warn( + f"Incompatible shapes between patch embedding of model {temp_weight.shape} and\ + of checkpoint {patch_embed_weight.shape}", + category=UserWarning, + stacklevel=1, + ) + + state_dict[patch_embed_proj_weight_key] = temp_weight + # extract the single element from the set (patch_embed_proj_weight_key,) = patch_embed_proj_weight_key patch_embed_weight = state_dict[patch_embed_proj_weight_key] @@ -72,4 +100,5 @@ def select_patch_embed_weights( ) state_dict[patch_embed_proj_weight_key] = temp_weight +>>>>>>> main return state_dict diff --git a/terratorch/models/backbones/torchgeo_resnet.py b/terratorch/models/backbones/torchgeo_resnet.py new file mode 100644 index 00000000..b2ab79ed --- /dev/null +++ b/terratorch/models/backbones/torchgeo_resnet.py @@ -0,0 +1,476 @@ +# reference torchgeo https://torchgeo.readthedocs.io/en/stable/_modules/torchgeo/models/resnet.html + +import torchgeo.models.resnet as resnet +from torchgeo.models.resnet import ResNet, ResNet18_Weights, ResNet50_Weights, ResNet152_Weights, resnet18, resnet50, resnet152 +import logging +from collections.abc import Callable +from functools import partial +import huggingface_hub +import torch.nn as nn +from typing import List +import huggingface_hub +from torchvision.models._api import Weights, WeightsEnum +from terratorch.datasets.utils import OpticalBands, SARBands +from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights + +from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY +import torch +import pdb + +class ResNetEncoderWrapper(nn.Module): + + """ + A wrapper for ViT models from torchgeo to return only the forward pass of the encoder + Attributes: + satlas_model (VisionTransformer): The instantiated dofa model + weights + Methods: + forward(x: List[torch.Tensor], wavelengths: list[float]) -> torch.Tensor: + Forward pass for embeddings with specified indices. + """ + + def __init__(self, resnet_model, resnet_meta, weights=None, out_indices=None) -> None: + """ + Args: + dofa_model (DOFA): The decoder module to be wrapped. + weights () + """ + super().__init__() + self.resnet_model = resnet_model + self.resnet_meta = resnet_meta + self.weights = weights + self.out_indices = out_indices if out_indices else [-1] + self.out_channels = [x['num_chs'] for x in self.resnet_model.feature_info] + self.resnet_meta['original_out_channels'] = self.out_channels + self.out_channels = [x for i, x in enumerate(self.out_channels) if (i in self.out_indices) | (i == (len(self.out_channels)-1)) & (-1 in self.out_indices)] + + + def forward(self, x: List[torch.Tensor], **kwargs) -> torch.Tensor: + + features = self.resnet_model.forward_intermediates(x, intermediates_only=True) + + outs = [] + for i, feature in enumerate(features): + if i in self.out_indices: + outs.append(feature) + elif (i == (len(self.resnet_meta["original_out_channels"])-1)) & (-1 in self.out_indices): + outs.append(feature) + + return outs + + +look_up_table = { + "B01": "COASTAL_AEROSOL", + "B02": "BLUE", + "B03": "GREEN", + "B04": "RED", + "B05": "RED_EDGE_1", + "B06": "RED_EDGE_2", + "B07": "RED_EDGE_3", + "B08": "NIR_BROAD", + "B8A": "NIR_NARROW", + "B09": "WATER_VAPOR", + "B10": "CIRRUS", + "B11": "SWIR_1", + "B12": "SWIR_2", + "VV": "VV", + "VH": "VH", + "R": "RED", + "G": "GREEN", + "B": "BLUE" +} + +resnet18_meta = { + "layers": (2, 2, 2, 2) + } + +resnet50_meta = { + "layers": (3, 4, 6, 3) +} + +resnet152_meta = { + "layers": (3, 8, 36, 3) +} + + +def get_pretrained_bands(model_bands): + + model_bands = [look_up_table[x.split('.')[-1]] for x in model_bands] + + return model_bands + + +#### resnet 18 +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet18_landsat_tm_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet18_Weights.LANDSAT_TM_TOA_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet18(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet18_landsat_tm_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet18_Weights.LANDSAT_TM_TOA_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet18(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices) + + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet18_landsat_etm_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet18_Weights.LANDSAT_ETM_TOA_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet18(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices) + + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet18_landsat_etm_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet18_Weights.LANDSAT_ETM_TOA_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet18(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices) + + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet18_landsat_etm_sr_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet18_Weights.LANDSAT_ETM_SR_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet18(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices) + + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet18_landsat_etm_sr_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet18_Weights.LANDSAT_ETM_SR_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet18(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices) + + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet18_landsat_oli_tirs_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet18_Weights.LANDSAT_OLI_TIRS_TOA_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet18(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices) + + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet18_landsat_oli_tirs_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet18_Weights.LANDSAT_OLI_TIRS_TOA_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet18(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices) + + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet18_landsat_oli_sr_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet18_Weights.LANDSAT_OLI_SR_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet18(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices) + + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet18_landsat_oli_sr_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet18_Weights.LANDSAT_OLI_SR_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet18(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices) + + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eos12_resnet18_sentinel2_all_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet18_Weights.SENTINEL2_ALL_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet18(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices) + + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eos12_resnet18_sentinel2_rgb_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet18_Weights.SENTINEL2_RGB_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet18(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices) + + +@TERRATORCH_BACKBONE_REGISTRY.register +def seco_resnet18_sentinel2_rgb_seco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet18_Weights.SENTINEL2_RGB_SECO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet18(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices) + + +#### resnet 50 +@TERRATORCH_BACKBONE_REGISTRY.register +def fmow_resnet50_fmow_rgb_gassl(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.FMOW_RGB_GASSL, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet50_landsat_tm_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.LANDSAT_TM_TOA_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet50_landsat_tm_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.LANDSAT_TM_TOA_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet50_landsat_etm_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.LANDSAT_ETM_TOA_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet50_landsat_etm_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.LANDSAT_ETM_TOA_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet50_landsat_etm_sr_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.LANDSAT_ETM_SR_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet50_landsat_etm_sr_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.LANDSAT_ETM_SR_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet50_landsat_oli_tirs_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.LANDSAT_OLI_TIRS_TOA_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet50_landsat_oli_tirs_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.LANDSAT_OLI_TIRS_TOA_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet50_landsat_oli_sr_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.LANDSAT_OLI_SR_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_resnet50_landsat_oli_sr_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.LANDSAT_OLI_SR_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eos12_resnet50_sentinel1_all_decur(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.SENTINEL1_ALL_DECUR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + if weights is not None: + weights.meta['bands'] = ['VV', 'VH'] + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eos12_resnet50_sentinel1_all_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.SENTINEL1_ALL_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + if weights is not None: + weights.meta['bands'] = ['VV', 'VH'] + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eos12_resnet50_sentinel2_all_decur(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.SENTINEL2_ALL_DECUR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + if weights is not None: + weights.meta['bands'] = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eos12_resnet50_sentinel2_all_dino(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.SENTINEL2_ALL_DINO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + if weights is not None: + weights.meta['bands'] = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eos12_resnet50_sentinel2_all_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.SENTINEL2_ALL_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + if weights is not None: + weights.meta['bands'] = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eos12_resnet50_sentinel2_rgb_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.SENTINEL2_RGB_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def seco_resnet50_sentinel2_rgb_seco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.SENTINEL2_RGB_SECO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_resnet50_sentinel2_mi_ms_satlas(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.SENTINEL2_MI_MS_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_resnet50_sentinel2_mi_rgb_satlas(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.SENTINEL2_MI_RGB_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_resnet50_sentinel2_si_ms_satlas(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.SENTINEL2_SI_MS_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_resnet50_sentinel2_si_rgb_satlas(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet50_Weights.SENTINEL2_SI_RGB_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet50(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices) + +#### resnet152 +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_resnet152_sentinel2_mi_ms(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet152_Weights.SENTINEL2_MI_MS_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet152(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet152_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_resnet152_sentinel2_mi_rgb(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet152_Weights.SENTINEL2_MI_RGB_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet152(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet152_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_resnet152_sentinel2_si_ms_satlas(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet152_Weights.SENTINEL2_SI_MS_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet152(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet152_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_resnet152_sentinel2_si_rgb_satlas(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ResNet152_Weights.SENTINEL2_SI_RGB_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = resnet152(**kwargs) + if pretrained: + model = load_resnet_weights(model, model_bands, ckpt_data, weights) + return ResNetEncoderWrapper(model, resnet152_meta, weights, out_indices) + + +#### to add build model and load weights +def load_resnet_weights(model: nn.Module, model_bands, ckpt_data: str, weights: Weights, input_size: int = 224, custom_weight_proj: str = "conv1.weight") -> nn.Module: + + pretrained_bands = get_pretrained_bands(weights.meta["bands"]) if "bands" in weights.meta else [] + if ckpt_data is not None: + if ckpt_data.find("https://hf.co/") > -1: + repo_id = ckpt_data.split("/resolve/")[0].replace("https://hf.co/", '') + filename = ckpt_data.split("/")[-1] + ckpt_data = huggingface_hub.hf_hub_download(repo_id=repo_id, filename=filename) + # checkpoint_model = torch.load(ckpt_data, map_location="cpu")["model"] + checkpoint_model = torch.load(ckpt_data, map_location="cpu") + state_dict = model.state_dict() + + for k in ["fc.weight", "fc.bias"]: + if ( + k in checkpoint_model + and checkpoint_model[k].shape != state_dict[k].shape + ): + logging.info(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + + checkpoint_model = select_patch_embed_weights(checkpoint_model, model, pretrained_bands, model_bands, custom_weight_proj) + # load pre-trained model + msg = model.load_state_dict(checkpoint_model, strict=False) + + logging.info(msg) + else: + if weights is not None: + checkpoint_model = weights.get_state_dict(progress=True) + checkpoint_model = select_patch_embed_weights(checkpoint_model, model, pretrained_bands, model_bands, custom_weight_proj) + missing_keys, unexpected_keys = model.load_state_dict(checkpoint_model, strict=False) + assert set(missing_keys) <= {'fc.weight', 'fc.bias'} + assert not unexpected_keys + + return model diff --git a/terratorch/models/backbones/torchgeo_swin_satlas.py b/terratorch/models/backbones/torchgeo_swin_satlas.py new file mode 100644 index 00000000..1de43453 --- /dev/null +++ b/terratorch/models/backbones/torchgeo_swin_satlas.py @@ -0,0 +1,265 @@ +# reference torchgeo https://torchgeo.readthedocs.io/en/stable/_modules/torchgeo/models/swin.html + +import torchgeo.models.swin as swin +from torchgeo.models.swin import Swin_V2_T_Weights, Swin_V2_B_Weights +import logging +from collections.abc import Callable +from functools import partial +import huggingface_hub +import torch.nn as nn +from typing import List +import huggingface_hub +from torchvision.models._api import Weights, WeightsEnum +from torchvision.models import swin_v2_t, swin_v2_b +import torch +from terratorch.datasets.utils import OpticalBands, SARBands +from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights + +from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY + +class SwinEncoderWrapper(nn.Module): + + """ + A wrapper for Satlas models from torchgeo to return only the forward pass of the encoder + Attributes: + swin_model (SwinTransformer): The instantiated dofa model + weights + Methods: + forward(x: List[torch.Tensor], wavelengths: list[float]) -> torch.Tensor: + Forward pass for embeddings with specified indices. + """ + + def __init__(self, swin_model, swin_meta, weights=None, out_indices=None) -> None: + """ + Args: + swin_model (SwinTransformer): The backbone module to be wrapped. + swin_meta (dict): dict containing the metadata for swin. + weights (Weights): Weights class for the swin model to be wrapped. + out_indices (list): List containing the feature indices to be returned. + """ + super().__init__() + self.swin_model = swin_model + self.weights = weights + self.out_indices = out_indices if out_indices else [-1] + + self.out_channels = [] + for i in range(len(swin_meta["depths"])): + self.out_channels.append(swin_meta["embed_dim"] * 2**i) + self.out_channels = [elem for elem in self.out_channels for _ in range(2)] + self.out_channels = [x for i, x in enumerate(self.out_channels) if (i in self.out_indices) | (i == (len(self.out_channels)-1)) & (-1 in self.out_indices)] + + def forward(self, x: List[torch.Tensor], **kwargs) -> torch.Tensor: + + outs = [] + for i, layer in enumerate(self.swin_model.features): + x = layer(x) + if i in self.out_indices: + outs.append(x) + elif (i == (len(self.swin_model.features)-1)) & (-1 in self.out_indices): + outs.append(x) + + return tuple(outs) + +look_up_table = { + "B01": "COASTAL_AEROSOL", + "B02": "BLUE", + "B03": "GREEN", + "B04": "RED", + "B05": "RED_EDGE_1", + "B06": "RED_EDGE_2", + "B07": "RED_EDGE_3", + "B08": "NIR_BROAD", + "B8A": "NIR_NARROW", + "B09": "WATER_VAPOR", + "B10": "CIRRUS", + "B11": "SWIR_1", + "B12": "SWIR_2", + "VV": "VV", + "VH": "VH", + "R": "RED", + "G": "GREEN", + "B": "BLUE" +} + +swin_v2_t_meta = { + "patch_size":[4, 4], + "embed_dim": 96, + "depths": [2, 2, 6, 2], + "num_heads": [3, 6, 12, 24], + "window_size": [8, 8], + "stochastic_depth_prob": 0.2 +} + +swin_v2_b_meta = { + "patch_size":[4, 4], + "embed_dim": 128, + "depths": [2, 2, 18, 2], + "num_heads": [4, 8, 16, 32], + "window_size": [8, 8], + "stochastic_depth_prob": 0.5 +} + +def get_pretrained_bands(model_bands): + + model_bands = [look_up_table[x.split('.')[-1]] for x in model_bands] + + return model_bands + +def load_model(load_function, swin_meta, **kwargs): + + in_chans = kwargs['in_chans'] + del kwargs['in_chans'] + model = load_function(**kwargs) + model.features[0][0] = torch.nn.Conv2d(in_chans, swin_meta["embed_dim"], kernel_size=swin_meta["patch_size"], stride=swin_meta["patch_size"]) + return model + + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_t_sentinel2_mi_ms(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_T_Weights.SENTINEL2_MI_MS_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_t, swin_v2_t_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_t_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_t_sentinel2_mi_rgb(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_T_Weights.SENTINEL2_MI_RGB_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_t, swin_v2_t_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_t_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_t_sentinel2_si_ms(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_T_Weights.SENTINEL2_SI_MS_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_t, swin_v2_t_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_t_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_t_sentinel2_si_rgb(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_T_Weights.SENTINEL2_SI_RGB_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_t, swin_v2_t_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_t_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_b_sentinel2_mi_ms(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_B_Weights.SENTINEL2_MI_MS_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_b_sentinel2_mi_rgb(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_B_Weights.SENTINEL2_MI_RGB_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_b_sentinel2_si_ms(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_B_Weights.SENTINEL2_SI_MS_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_b_sentinel2_si_rgb(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_B_Weights.SENTINEL2_SI_RGB_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_b_naip_mi_rgb(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_B_Weights.NAIP_RGB_MI_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_b_naip_si_rgb(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_b_landsat_mi_ms(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_B_Weights.LANDSAT_MI_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_b_landsat_mi_rgb(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_B_Weights.LANDSAT_SI_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_b_sentinel1_mi(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_B_Weights.SENTINEL1_MI_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def satlas_swin_b_sentinel1_si(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = Swin_V2_B_Weights.SENTINEL1_SI_SATLAS, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs) + if pretrained: + model = load_swin_weights(model, model_bands, ckpt_data, weights) + return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices) + + +def load_swin_weights(model: nn.Module, model_bands, ckpt_data: str, weights: Weights, input_size: int = 224, custom_weight_proj: str = "features.0.0.weight") -> nn.Module: + + pretrained_bands = get_pretrained_bands(weights.meta["bands"]) + + if ckpt_data is not None: + if ckpt_data.find("https://hf.co/") > -1: + repo_id = ckpt_data.split("/resolve/")[0].replace("https://hf.co/", '') + filename = ckpt_data.split("/")[-1] + ckpt_data = huggingface_hub.hf_hub_download(repo_id=repo_id, filename=filename) + # checkpoint_model = torch.load(ckpt_data, map_location="cpu")["model"] + checkpoint_model = torch.load(ckpt_data, map_location="cpu") + state_dict = model.state_dict() + + for k in ["head.weight", "head.bias"]: + if ( + k in checkpoint_model + and checkpoint_model[k].shape != state_dict[k].shape + ): + logging.info(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + + checkpoint_model = select_patch_embed_weights(checkpoint_model, model, pretrained_bands, model_bands, custom_weight_proj) + # load pre-trained model + msg = model.load_state_dict(checkpoint_model, strict=False) + + logging.info(msg) + else: + if weights is not None: + checkpoint_model = weights.get_state_dict(progress=True) + checkpoint_model = select_patch_embed_weights(checkpoint_model, model, pretrained_bands, model_bands, custom_weight_proj) + missing_keys, unexpected_keys = model.load_state_dict(checkpoint_model, strict=False) + assert set(missing_keys) <= set() + assert not unexpected_keys + + + return model diff --git a/terratorch/models/backbones/torchgeo_vit.py b/terratorch/models/backbones/torchgeo_vit.py new file mode 100644 index 00000000..dbe9d97b --- /dev/null +++ b/terratorch/models/backbones/torchgeo_vit.py @@ -0,0 +1,210 @@ +# reference torchgeo https://torchgeo.readthedocs.io/en/stable/_modules/torchgeo/models/vit.html + +from torchgeo.models.vit import vit_small_patch16_224 +from torchgeo.models.vit import ViTSmall16_Weights +import logging +from collections.abc import Callable +from functools import partial +import huggingface_hub +import torch.nn as nn +from typing import List +import huggingface_hub +from torchvision.models._api import Weights, WeightsEnum +import torch +from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights + +from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY + +class ViTEncoderWrapper(nn.Module): + + """ + A wrapper for ViT models from torchgeo to return only the forward pass of the encoder + Attributes: + satlas_model (VisionTransformer): The instantiated dofa model + weights + Methods: + forward(x: List[torch.Tensor], wavelengths: list[float]) -> torch.Tensor: + Forward pass for embeddings with specified indices. + """ + + def __init__(self, vit_model, vit_meta, weights=None, out_indices=None) -> None: + """ + Args: + dofa_model (DOFA): The decoder module to be wrapped. + weights () + """ + super().__init__() + self.vit_model = vit_model + self.weights = weights + self.out_channels = [x['num_chs'] for x in self.vit_model.feature_info] + self.vit_meta = vit_meta + + + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + return self.vit_model.forward_intermediates(x, intermediates_only=True) + +look_up_table = { + "B01": "COASTAL_AEROSOL", + "B02": "BLUE", + "B03": "GREEN", + "B04": "RED", + "B05": "RED_EDGE_1", + "B06": "RED_EDGE_2", + "B07": "RED_EDGE_3", + "B08": "NIR_BROAD", + "B8A": "NIR_NARROW", + "B09": "WATER_VAPOR", + "B10": "CIRRUS", + "B11": "SWIR_1", + "B12": "SWIR_2", + "VV": "VV", + "VH": "VH", + "R": "RED", + "G": "GREEN", + "B": "BLUE" +} + +def get_pretrained_bands(model_bands): + + model_bands = [look_up_table[x.split('.')[-1]] for x in model_bands] + + return model_bands + +vit_s_meta = { + 'patch_size': 16, + 'embed_dim': 384, + 'depth': 12, + 'num_heads': 6 +} + + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_vit_small_patch16_224_landsat_tm_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ViTSmall16_Weights.LANDSAT_TM_TOA_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = vit_small_patch16_224(**kwargs) + if pretrained: + model = load_vit_weights(model, model_bands, ckpt_data, weights) + return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices) + + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_vit_small_patch16_224_landsat_tm_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ViTSmall16_Weights.LANDSAT_TM_TOA_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = vit_small_patch16_224(**kwargs) + if pretrained: + model = load_vit_weights(model, model_bands, ckpt_data, weights) + return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_vit_small_patch16_224_landsat_etm_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ViTSmall16_Weights.LANDSAT_ETM_TOA_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = vit_small_patch16_224(**kwargs) + if pretrained: + model = load_vit_weights(model, model_bands, ckpt_data, weights) + return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_vit_small_patch16_224_landsat_etm_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ViTSmall16_Weights.LANDSAT_ETM_TOA_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = vit_small_patch16_224(**kwargs) + if pretrained: + model = load_vit_weights(model, model_bands, ckpt_data, weights) + return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_vit_small_patch16_224_landsat_etm_sr_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ViTSmall16_Weights.LANDSAT_ETM_SR_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = vit_small_patch16_224(**kwargs) + if pretrained: + model = load_vit_weights(model, model_bands, ckpt_data, weights) + return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_vit_small_patch16_224_landsat_etm_sr_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ViTSmall16_Weights.LANDSAT_ETM_SR_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = vit_small_patch16_224(**kwargs) + if pretrained: + model = load_vit_weights(model, model_bands, ckpt_data, weights) + return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_vit_small_patch16_224_landsat_oli_tirs_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ViTSmall16_Weights.LANDSAT_OLI_TIRS_TOA_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = vit_small_patch16_224(**kwargs) + if pretrained: + model = load_vit_weights(model, model_bands, ckpt_data, weights) + return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_vit_small_patch16_224_landsat_oli_sr_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = vit_small_patch16_224(**kwargs) + if pretrained: + model = load_vit_weights(model, model_bands, ckpt_data, weights) + return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eol_vit_small_patch16_224_landsat_oli_sr_simclr(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = vit_small_patch16_224(**kwargs) + if pretrained: + model = load_vit_weights(model, model_bands, ckpt_data, weights) + return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eos12_vit_small_patch16_224_sentinel2_all_dino(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ViTSmall16_Weights.SENTINEL2_ALL_DINO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = vit_small_patch16_224(**kwargs) + if pretrained: + model = load_vit_weights(model, model_bands, ckpt_data, weights) + return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices) + +@TERRATORCH_BACKBONE_REGISTRY.register +def ssl4eos12_vit_small_patch16_224_sentinel2_all_moco(model_bands, pretrained = False, ckpt_data: str | None = None, weights: Weights | None = ViTSmall16_Weights.SENTINEL2_ALL_MOCO, out_indices: list | None = None, **kwargs): + if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands) + model = vit_small_patch16_224(**kwargs) + if pretrained: + model = load_vit_weights(model, model_bands, ckpt_data, weights) + return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices) + + +#### to add build model and load weights +def load_vit_weights(model: nn.Module, model_bands, ckpt_data: str, weights: Weights, input_size: int = 224, custom_weight_proj: str = "patch_embed.proj.weight") -> nn.Module: + + pretrained_bands = get_pretrained_bands(weights.meta["bands"]) if "bands" in weights.meta else ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B09", "B10", "B11", "B12"] + + print("Loading weights") + if ckpt_data is not None: + if ckpt_data.find("https://hf.co/") > -1: + repo_id = ckpt_data.split("/resolve/")[0].replace("https://hf.co/", '') + filename = ckpt_data.split("/")[-1] + ckpt_data = huggingface_hub.hf_hub_download(repo_id=repo_id, filename=filename) + # checkpoint_model = torch.load(ckpt_data, map_location="cpu")["model"] + checkpoint_model = torch.load(ckpt_data, map_location="cpu") + state_dict = model.state_dict() + + for k in ["head.weight", "head.bias"]: + if ( + k in checkpoint_model + and checkpoint_model[k].shape != state_dict[k].shape + ): + logging.info(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + + checkpoint_model = select_patch_embed_weights(checkpoint_model, model, pretrained_bands, model_bands, custom_weight_proj) + # load pre-trained model + msg = model.load_state_dict(checkpoint_model, strict=False) + + logging.info(msg) + else: + if weights is not None: + checkpoint_model = weights.get_state_dict(progress=True) + checkpoint_model = remove_keys(checkpoint_model, model.state_dict()) + checkpoint_model = select_patch_embed_weights(checkpoint_model, model, pretrained_bands, model_bands, custom_weight_proj) + missing_keys, unexpected_keys = model.load_state_dict(checkpoint_model, strict=False) + assert set(missing_keys) <= {'head.weight', 'head.bias'} + assert not unexpected_keys + + + return model + diff --git a/terratorch/models/scalar_output_model.py b/terratorch/models/scalar_output_model.py index 92866e09..76cf653e 100644 --- a/terratorch/models/scalar_output_model.py +++ b/terratorch/models/scalar_output_model.py @@ -6,7 +6,7 @@ from terratorch.models.heads import ClassificationHead from terratorch.models.model import AuxiliaryHeadWithDecoderWithoutInstantiatedHead, Model, ModelOutput - +import pdb def freeze_module(module: nn.Module): for param in module.parameters(): @@ -77,7 +77,6 @@ def check_input_shape(self, x: torch.Tensor) -> bool: # noqa: ARG002 def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: """Sequentially pass `x` through model`s encoder, decoder and heads""" - self.check_input_shape(x) features = self.encoder(x, **kwargs) diff --git a/terratorch/registry/registry.py b/terratorch/registry/registry.py index 855476e3..91235923 100644 --- a/terratorch/registry/registry.py +++ b/terratorch/registry/registry.py @@ -4,7 +4,6 @@ from contextlib import suppress from reprlib import recursive_repr as _recursive_repr - class BuildableRegistry(typing.Protocol): def __iter__(self): ... def __len__(self) -> int: ... @@ -78,23 +77,6 @@ def __iter__(self): def __len__(self): return sum(len(source) for source in self._sources.values()) - # def __getitem__(self, name): - # parsed_prefix = self._parse_prefix(name) - # if parsed_prefix: - # prefix, name_without_prefix = parsed_prefix - # registry = self._sources[prefix] - # return registry[name_without_prefix] - - # # if no prefix is given, go through all sources in order - # for source in self._sources.values(): - # try: - # return source[name] - # except Exception as e: - # logging.debug(e) - - # msg = f"Could not find Model {name} not from any source." - # raise KeyError(msg) - def __getitem__(self, name): return self._sources[name] diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 48e80221..a6a540bc 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -226,6 +226,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) @@ -247,7 +248,9 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} + rest = {k: batch[k] for k in other_keys} + model_output: ModelOutput = self(x, **rest) if dataloader_idx >= len(self.test_loss_handler): msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." @@ -315,6 +318,44 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - finally: plt.close() + def on_validation_epoch_end(self) -> None: + self.log_dict(self.val_metrics.compute(), sync_dist=True) + self.val_metrics.reset() + return super().on_validation_epoch_end() + + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + """Compute the test loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. + """ + x = batch["image"] + y = batch["mask"] + + other_keys = batch.keys() - {"image", "mask", "filename"} + rest = {k: batch[k] for k in other_keys} + model_output: ModelOutput = self(x, **rest) + if dataloader_idx >= len(self.test_loss_handler): + msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." + raise ValueError(msg) + loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss) + self.test_loss_handler[dataloader_idx].log_loss( + partial(self.log, add_dataloader_idx=False), # We don't need the dataloader idx as prefixes are different + loss_dict=loss, + batch_size=y.shape[0], + ) + + y_hat_hard = to_segmentation_prediction(model_output) + self.test_metrics[dataloader_idx].update(y_hat_hard, y) + + def on_test_epoch_end(self) -> None: + for metrics in self.test_metrics: + self.log_dict(metrics.compute(), sync_dist=True) + metrics.reset() + return super().on_test_epoch_end() + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities. @@ -329,8 +370,11 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T x = batch["image"] file_names = batch["filename"] other_keys = batch.keys() - {"image", "mask", "filename"} + rest = {k: batch[k] for k in other_keys} + model_output: ModelOutput = self(x, **rest) + def model_forward(x): return self(x).output diff --git a/tests/test_backbones.py b/tests/test_backbones.py index df60fad4..0d4abdb6 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -105,7 +105,8 @@ def test_vit_models_different_patch_tubelet_sizes(model_name, patch_size, patch_ gc.collect() @pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"]) def test_out_indices(model_name, input_224): - out_indices = [2, 4, 8, 10] + # out_indices = [2, 4, 8, 10] + out_indices = (2, 4, 8, 10) backbone = timm.create_model(model_name, pretrained=False, features_only=True, out_indices=out_indices) assert backbone.feature_info.out_indices == out_indices @@ -129,8 +130,8 @@ def test_out_indices_non_divisible(model_name, input_non_divisible): gc.collect() @pytest.mark.parametrize("model_name", ["vit_base_patch16", "vit_large_patch16"]) def test_scale_mae(model_name): - out_indices = [2, 4, 8, 10] - + # out_indices = [2, 4, 8, 10] + out_indices = (2, 4, 8, 10) # default should have 3 channels backbone = scalemae.create_model(model_name, out_indices=out_indices) input_tensor = torch.ones((1, 3, 224, 224)) diff --git a/tests/test_encoder_decoder_torchgeo_models.py b/tests/test_encoder_decoder_torchgeo_models.py new file mode 100644 index 00000000..cc8a4e2c --- /dev/null +++ b/tests/test_encoder_decoder_torchgeo_models.py @@ -0,0 +1,176 @@ +# Copyright contributors to the Terratorch project + +import importlib + +import pytest +import torch + +from terratorch.models import EncoderDecoderFactory +from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS +from terratorch.models.model import AuxiliaryHead + +NUM_CHANNELS = 6 +NUM_CLASSES = 10 +EXPECTED_SEGMENTATION_OUTPUT_SHAPE = (1, NUM_CLASSES, 224, 224) +EXPECTED_REGRESSION_OUTPUT_SHAPE = (1, 224, 224) +EXPECTED_CLASSIFICATION_OUTPUT_SHAPE = (1, NUM_CLASSES) + +PIXELWISE_TASK_EXPECTED_OUTPUT = [ + ("regression", EXPECTED_REGRESSION_OUTPUT_SHAPE), + ("segmentation", EXPECTED_SEGMENTATION_OUTPUT_SHAPE), +] + +VIT_UPERNET_NECK = [ + {"name": "SelectIndices", "indices": [0, 1, 2, 3]}, + {"name": "ReshapeTokensToImage"}, + {"name": "LearnedInterpolateToPyramidal"}, +] + +PRETRAINED_BANDS = ["RED", "GREEN", "BLUE", "NIR_NARROW", "SWIR_1", "SWIR_2"] + +@pytest.fixture(scope="session") +def model_factory() -> EncoderDecoderFactory: + return EncoderDecoderFactory() + + +@pytest.fixture(scope="session") +def model_input() -> torch.Tensor: + return torch.ones((1, NUM_CHANNELS, 224, 224)) + + +backbones = ["ssl4eos12_resnet50_sentinel2_all_decur"] +pretrained = [False, True] +@pytest.mark.parametrize("backbone", backbones) +@pytest.mark.parametrize("backbone_pretrained", pretrained) +def test_create_classification_model_resnet(backbone, model_factory: EncoderDecoderFactory, model_input, backbone_pretrained): + model = model_factory.build_model( + "classification", + backbone=backbone, + decoder="IdentityDecoder", + backbone_model_bands=PRETRAINED_BANDS, + backbone_pretrained=backbone_pretrained, + num_classes=NUM_CLASSES, + ) + model.eval() + + with torch.no_grad(): + assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE + + +backbones = ["dofa_large_patch16_224"] +@pytest.mark.parametrize("backbone", backbones) +@pytest.mark.parametrize("backbone_pretrained", pretrained) +def test_create_classification_model_dofa(backbone, model_factory: EncoderDecoderFactory, model_input, backbone_pretrained): + model = model_factory.build_model( + "classification", + backbone=backbone, + decoder="IdentityDecoder", + backbone_model_bands=PRETRAINED_BANDS, + backbone_pretrained=backbone_pretrained, + num_classes=NUM_CLASSES, + necks = [{"name": "PermuteDims", "new_order": [0, 2, 1]}] + ) + model.eval() + + with torch.no_grad(): + assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE + +backbones = ["satlas_swin_b_sentinel2_si_ms"] +@pytest.mark.parametrize("backbone", backbones) +@pytest.mark.parametrize("backbone_pretrained", pretrained) +def test_create_classification_model_swin(backbone, model_factory: EncoderDecoderFactory, model_input, backbone_pretrained): + model = model_factory.build_model( + "classification", + backbone=backbone, + decoder="IdentityDecoder", + backbone_model_bands=PRETRAINED_BANDS, + backbone_pretrained=backbone_pretrained, + num_classes=NUM_CLASSES, + necks = [{"name": "PermuteDims", "new_order": [0, 3, 1, 2]}] + ) + model.eval() + + with torch.no_grad(): + assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE + +@pytest.mark.parametrize("backbone", ["ssl4eos12_resnet50_sentinel2_all_decur"]) +@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) +@pytest.mark.parametrize("decoder", ["FCNDecoder", "IdentityDecoder", "smp_Unet"]) +@pytest.mark.parametrize("backbone_pretrained", pretrained) +def test_create_pixelwise_model_resnet(backbone, task, expected, decoder, model_factory: EncoderDecoderFactory, model_input, backbone_pretrained): + model_args = { + "task": task, + "backbone": backbone, + "decoder": decoder, + "backbone_model_bands": PRETRAINED_BANDS, + "backbone_pretrained": backbone_pretrained, + "backbone_out_indices": [0, 1, 2, 3, 4], + + } + + if decoder == "smp_Unet": + model_args["decoder_decoder_channels"] = [512, 256, 128, 64] + + if task == "segmentation": + model_args["num_classes"] = NUM_CLASSES + + model = model_factory.build_model(**model_args) + model.eval() + + with torch.no_grad(): + assert model(model_input).output.shape == expected + + + +@pytest.mark.parametrize("backbone", ["dofa_large_patch16_224"]) +@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) +@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) +@pytest.mark.parametrize("backbone_pretrained", pretrained) +def test_create_pixelwise_model_dofa(backbone, task, expected, decoder, model_factory: EncoderDecoderFactory, model_input, backbone_pretrained): + model_args = { + "task": task, + "backbone": backbone, + "decoder": decoder, + "backbone_model_bands": PRETRAINED_BANDS, + "backbone_pretrained": backbone_pretrained, + "backbone_out_indices": [5, 11, 17, 23] + } + + if task == "segmentation": + model_args["num_classes"] = NUM_CLASSES + + model_args["necks"] = [{"name": "ReshapeTokensToImage"}] + + model = model_factory.build_model(**model_args) + model.eval() + + with torch.no_grad(): + assert model(model_input).output.shape == expected + + +@pytest.mark.parametrize("backbone", ["satlas_swin_b_sentinel2_si_ms"]) +@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) +@pytest.mark.parametrize("decoder", ["UperNetDecoder", "IdentityDecoder"]) +@pytest.mark.parametrize("backbone_pretrained", pretrained) +def test_create_pixelwise_model_swin(backbone, task, expected, decoder, model_factory: EncoderDecoderFactory, model_input, backbone_pretrained): + model_args = { + "task": task, + "backbone": backbone, + "decoder": decoder, + "backbone_model_bands": PRETRAINED_BANDS, + "backbone_pretrained": backbone_pretrained, + "backbone_out_indices": [1, 3, 5, 7] + } + + if task == "segmentation": + model_args["num_classes"] = NUM_CLASSES + + model_args["necks"] = [{"name": "PermuteDims", "new_order": [0, 3, 1, 2]}] + + model = model_factory.build_model(**model_args) + model.eval() + + with torch.no_grad(): + assert model(model_input).output.shape == expected + +