diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 5b522c26..a92f1dee 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -15,7 +15,7 @@ jobs: timeout-minutes: 30 strategy: matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12.7"] steps: - name: Clone repo @@ -28,7 +28,9 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements/required.txt -r requirements/test.txt -r requirements/optional.txt + #pip install -r requirements/required.txt -r requirements/test.txt -r requirements/optional.txt + pip install -r requirements/test.txt -r requirements/optional.txt + pip install -e .[torchgeo] pip install git+https://github.com/NASA-IMPACT/Prithvi-WxC.git pip install git+https://github.com/IBM/granite-wxc.git - name: List pip dependencies diff --git a/docs/quick_start.md b/docs/quick_start.md index 6c4e8ee2..c95d6749 100644 --- a/docs/quick_start.md +++ b/docs/quick_start.md @@ -19,28 +19,27 @@ In the simplest case, we might only want access a backbone and code all the rest from terratorch import BACKBONE_REGISTRY # find available prithvi models -print([model_name for model_name in BACKBONE_REGISTRY if "prithvi" in model_name]) ->>> ['timm_prithvi_eo_tiny', 'timm_prithvi_eo_v1_100', 'timm_prithvi_eo_v2_300', 'timm_prithvi_eo_v2_300_tl', 'timm_prithvi_eo_v2_600', - 'timm_prithvi_eo_v2_600_tl', 'timm_prithvi_swin_B', 'timm_prithvi_swin_L', 'timm_prithvi_vit_100', 'timm_prithvi_vit_tiny'] +print([model_name for model_name in BACKBONE_REGISTRY if "terratorch_prithvi" in model_name]) +>>> ['terratorch_prithvi_eo_tiny', 'terratorch_prithvi_eo_v1_100', 'terratorch_prithvi_eo_v2_300', 'terratorch_prithvi_eo_v2_600', 'terratorch_prithvi_eo_v2_300_tl', 'terratorch_prithvi_eo_v2_600_tl'] # show all models with list(BACKBONE_REGISTRY) # check a model is in the registry -"timm_prithvi_swin_B" in BACKBONE_REGISTRY +"terratorch_prithvi_eo_v2_300" in BACKBONE_REGISTRY >>> True # without the prefix, all internal registries will be searched until the first match is found -"prithvi_swin_B" in BACKBONE_REGISTRY +"prithvi_eo_v1_100" in BACKBONE_REGISTRY >>> True # instantiate your desired model -# the backbone registry prefix (in this case 'timm') is optional -# in this case, the underlying registry is timm, so we can pass timm arguments to it -model = BACKBONE_REGISTRY.build("prithvi_eo_v1_100", num_frames=1, pretrained=True) +# the backbone registry prefix (e.g. `terratorch` or `timm`) is optional +# in this case, the underlying registry is terratorch. +model = BACKBONE_REGISTRY.build("prithvi_eo_v1_100", pretrained=True) -# instantiate your model with more options, for instance, passing weights of your own through timm +# instantiate your model with more options, for instance, passing weights from your own file model = BACKBONE_REGISTRY.build( - "prithvi_vit_100", num_frames=1, pretrained=True, pretrained_cfg_overlay={"file": ""} + "prithvi_eo_v2_300", num_frames=1, ckpt_path='path/to/model.pt' ) # Rest of your PyTorch / PyTorchLightning code @@ -68,25 +67,25 @@ model_factory = EncoderDecoderFactory() # Parameters prefixed with decoder_ get passed to the decoder # Parameters prefixed with head_ get passed to the head -model = model_factory.build_model(task="segmentation", - backbone="prithvi_vit_100", - decoder="FCNDecoder", - backbone_bands=[ - HLSBands.BLUE, - HLSBands.GREEN, - HLSBands.RED, - HLSBands.NIR_NARROW, - HLSBands.SWIR_1, - HLSBands.SWIR_2, - ], - necks=[{"name": "SelectIndices", "indices": [-1]}, - {"name": "ReshapeTokensToImage"}], - num_classes=4, - backbone_pretrained=True, - backbone_num_frames=1, - decoder_channels=128, - head_dropout=0.2 - ) +model = model_factory.build_model( + task="segmentation", + backbone="prithvi_eo_v2_300", + backbone_pretrained=True, + backbone_bands=[ + HLSBands.BLUE, + HLSBands.GREEN, + HLSBands.RED, + HLSBands.NIR_NARROW, + HLSBands.SWIR_1, + HLSBands.SWIR_2, + ], + necks=[{"name": "SelectIndices", "indices": [-1]}, + {"name": "ReshapeTokensToImage"}], + decoder="FCNDecoder", + decoder_channels=128, + head_dropout=0.1, + num_classes=4, +) # Rest of your PyTorch / PyTorchLightning code . @@ -102,8 +101,9 @@ At the highest level of abstraction, you can directly obtain a LightningModule r ```python title="Building a full Pixel-Wise Regression task" model_args = dict( - backbone="prithvi_vit_100", - decoder="FCNDecoder", + backbone="prithvi_eo_v2_300", + backbone_pretrained=True, + backbone_num_frames=1, backbone_bands=[ HLSBands.BLUE, HLSBands.GREEN, @@ -114,10 +114,9 @@ model_args = dict( ], necks=[{"name": "SelectIndices", "indices": [-1]}, {"name": "ReshapeTokensToImage"}], - backbone_pretrained=True, - backbone_num_frames=1, + decoder="FCNDecoder", decoder_channels=128, - head_dropout=0.2 + head_dropout=0.1 ) task = PixelwiseRegressionTask( @@ -175,14 +174,11 @@ data: model: class_path: terratorch.tasks.SemanticSegmentationTask init_args: + model_factory: EncoderDecoderFactory model_args: - decoder: UperNetDecoder + backbone: prithvi_eo_v2_300 + backbone_img_size: 512 backbone_pretrained: True - backbone: prithvi_vit_100 - backbone_pretrain_img_size: 512 - decoder_scale_modules: True - decoder_channels: 256 - backbone_in_channels: 6 backbone_bands: - BLUE - GREEN @@ -190,39 +186,30 @@ model: - NIR_NARROW - SWIR_1 - SWIR_2 - num_frames: 1 - num_classes: 2 - head_dropout: 0.1 - head_channel_list: - - 256 - post_backbone_ops: + necks: - name: SelectIndices - indices: - - 5 - - 11 - - 17 - - 23 + indices: [5, 11, 17, 23] - name: ReshapeTokensToImage - loss: ce - + - name: LearnedInterpolateToPyramidal + decoder: UperNetDecoder + decoder_channels: 256 + head_channel_list: [256] + head_dropout: 0.1 + num_classes: 2 + loss: dice ignore_index: -1 - class_weights: - - 0.3 - - 0.7 freeze_backbone: false - freeze_decoder: false - model_factory: EncoderDecoderFactory + freeze_decoder: false optimizer: class_path: torch.optim.AdamW init_args: - lr: 6.e-5 - weight_decay: 0.05 + lr: 1.e-4 + weight_decay: 0.1 lr_scheduler: class_path: ReduceLROnPlateau init_args: monitor: val/loss - ``` To run this training task using the YAML, simply execute: diff --git a/docs/registry.md b/docs/registry.md index 06ceb2bc..ba342eb7 100644 --- a/docs/registry.md +++ b/docs/registry.md @@ -13,27 +13,27 @@ To create the desired instance, registries expose a `build` method, which accept from terratorch import BACKBONE_REGISTRY # find available prithvi models -print([model_name for model_name in BACKBONE_REGISTRY if "prithvi" in model_name]) ->>> ['timm_prithvi_swin_B', 'timm_prithvi_swin_L', 'timm_prithvi_vit_100', 'timm_prithvi_vit_300', 'timm_prithvi_vit_tiny'] +print([model_name for model_name in BACKBONE_REGISTRY if "terratorch_prithvi" in model_name]) +>>> ['terratorch_prithvi_eo_tiny', 'terratorch_prithvi_eo_v1_100', 'terratorch_prithvi_eo_v2_300', 'terratorch_prithvi_eo_v2_600', 'terratorch_prithvi_eo_v2_300_tl', 'terratorch_prithvi_eo_v2_600_tl'] # show all models with list(BACKBONE_REGISTRY) # check a model is in the registry -"timm_prithvi_swin_B" in BACKBONE_REGISTRY +"terratorch_prithvi_eo_v2_300" in BACKBONE_REGISTRY >>> True # without the prefix, all internal registries will be searched until the first match is found -"prithvi_swin_B" in BACKBONE_REGISTRY +"prithvi_eo_v1_100" in BACKBONE_REGISTRY >>> True # instantiate your desired model -# the backbone registry prefix (in this case 'timm') is optional -# in this case, the underlying registry is timm, so we can pass timm arguments to it -model = BACKBONE_REGISTRY.build("prithvi_vit_100", num_frames=1, pretrained=True) +# the backbone registry prefix (e.g. `terratorch` or `timm`) is optional +# in this case, the underlying registry is terratorch. +model = BACKBONE_REGISTRY.build("prithvi_eo_v1_100", pretrained=True) -# instantiate your model with more options, for instance, passing weights of your own through timm +# instantiate your model with more options, for instance, passing weights from your own file model = BACKBONE_REGISTRY.build( - "prithvi_vit_100", num_frames=1, pretrained=True, pretrained_cfg_overlay={"file": ""} + "prithvi_eo_v2_300", num_frames=1, ckpt_path='path/to/model.pt' ) # Rest of your PyTorch / PyTorchLightning code diff --git a/examples/confs/burn_scars.yaml b/examples/confs/burn_scars.yaml index 702357dc..2e46fd1f 100644 --- a/examples/confs/burn_scars.yaml +++ b/examples/confs/burn_scars.yaml @@ -29,7 +29,7 @@ trainer: # dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 4 num_workers: 8 @@ -56,9 +56,7 @@ data: init_args: height: 224 width: 224 - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 + - class_path: albumentations.D4 - class_path: ToTensorV2 no_data_replace: 0 no_label_replace: -1 @@ -89,11 +87,11 @@ data: model: class_path: terratorch.tasks.SemanticSegmentationTask init_args: + model_factory: EncoderDecoderFactory model_args: - decoder: FCNDecoder + backbone: prithvi_eo_v2_300 backbone_pretrained: true - backbone: prithvi_vit_100 - decoder_channels: 256 + backbone_drop_path: 0.1 backbone_bands: - BLUE - GREEN @@ -101,23 +99,30 @@ model: - NIR_NARROW - SWIR_1 - SWIR_2 - num_classes: 2 + necks: + - name: SelectIndices +# indices: [2, 5, 8, 11] # 100M models + indices: [5, 11, 17, 23] # 300M models +# indices: [7, 15, 23, 31] # 600M models + - name: ReshapeTokensToImage + - name: LearnedInterpolateToPyramidal + decoder: UNetDecoder + decoder_channels: [512, 256, 128, 64] + head_channel_list: [256] head_dropout: 0.1 - decoder_num_convs: 4 - head_channel_list: - - 256 + num_classes: 2 loss: dice plot_on_val: 10 ignore_index: -1 freeze_backbone: false freeze_decoder: false - model_factory: EncoderDecoderFactory tiled_inference_parameters: h_crop: 512 h_stride: 496 w_crop: 512 w_stride: 496 average_patches: true + optimizer: class_path: torch.optim.Adam init_args: diff --git a/examples/confs/burnscars_smp.yaml b/examples/confs/burnscars_smp.yaml index 5c78a7f1..07a9a888 100644 --- a/examples/confs/burnscars_smp.yaml +++ b/examples/confs/burnscars_smp.yaml @@ -32,7 +32,7 @@ trainer: default_root_dir: output/BurnScars data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 4 num_workers: 8 diff --git a/examples/confs/eurosat.yaml b/examples/confs/eurosat.yaml index fbcbadca..50a5bff3 100644 --- a/examples/confs/eurosat.yaml +++ b/examples/confs/eurosat.yaml @@ -57,7 +57,7 @@ model: model_args: decoder: IdentityDecoder backbone_pretrained: true - backbone: prithvi_vit_100 + backbone: prithvi_eo_v1_300 head_dim_list: - 384 - 128 diff --git a/examples/confs/forestnet_timm.yaml b/examples/confs/forestnet_timm.yaml index 6499c842..d86019dd 100644 --- a/examples/confs/forestnet_timm.yaml +++ b/examples/confs/forestnet_timm.yaml @@ -33,7 +33,7 @@ trainer: default_root_dir: output/ForestNet data: - class_path: GenericNonGeoClassificationDataModule + class_path: terratorch.datamodules.GenericNonGeoClassificationDataModule init_args: batch_size: 16 num_workers: 8 diff --git a/examples/confs/multi_temporal_crop.yaml b/examples/confs/multi_temporal_crop.yaml index 5ea76acf..dbcaf054 100644 --- a/examples/confs/multi_temporal_crop.yaml +++ b/examples/confs/multi_temporal_crop.yaml @@ -90,7 +90,7 @@ model: model_args: decoder: FCNDecoder backbone_pretrained: true - backbone: prithvi_vit_100 + backbone: prithvi_eo_v2_300 backbone_in_channels: 6 rescale: False backbone_bands: diff --git a/examples/confs/multimae_sen1floods11.yaml b/examples/confs/multimae_sen1floods11.yaml index 9d779ab3..5a1b515d 100644 --- a/examples/confs/multimae_sen1floods11.yaml +++ b/examples/confs/multimae_sen1floods11.yaml @@ -28,7 +28,7 @@ trainer: default_root_dir: output/multimae_sen1floods11/ data: - class_path: GenericMultiModalDataModule + class_path: terratorch.datamodules.GenericMultiModalDataModule init_args: task: 'segmentation' batch_size: 4 diff --git a/examples/confs/multimodal_prithvi_sen1floods11.yaml b/examples/confs/multimodal_prithvi_sen1floods11.yaml index 0a7633c8..f8f38723 100644 --- a/examples/confs/multimodal_prithvi_sen1floods11.yaml +++ b/examples/confs/multimodal_prithvi_sen1floods11.yaml @@ -29,7 +29,7 @@ trainer: default_root_dir: output/multimodal_prithvi_sen1floods11/ data: - class_path: GenericMultiModalDataModule + class_path: terratorch.datamodules.GenericMultiModalDataModule init_args: task: 'segmentation' batch_size: 16 @@ -124,7 +124,7 @@ model: init_args: model_factory: EncoderDecoderFactory model_args: - backbone: prithvi_vit_100 + backbone: prithvi_eo_v2_300 backbone_pretrained: false backbone_bands: - COASTAL_AEROSOL @@ -141,23 +141,15 @@ model: - SWIR_2 - VV - VH - decoder: FCNDecoder # FCNDecoder - decoder_num_convs: 4 # only for FCNDecoder - # decoder_scale_modules: True # only for UperNetDecoder + decoder: FCNDecoder + decoder_num_convs: 4 decoder_channels: 256 num_classes: 2 head_dropout: 0.1 head_channel_list: - 256 - loss: dice ignore_index: -1 - class_weights: - - 0.3 - - 0.7 - class_names: - - Others - - Flood freeze_backbone: false freeze_decoder: false diff --git a/examples/confs/sen1floods11_vit.yaml b/examples/confs/sen1floods11_vit.yaml index fc30711a..da890f55 100644 --- a/examples/confs/sen1floods11_vit.yaml +++ b/examples/confs/sen1floods11_vit.yaml @@ -19,7 +19,7 @@ trainer: enable_checkpointing: true default_root_dir: data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 16 num_workers: 8 @@ -82,11 +82,11 @@ data: model: class_path: terratorch.tasks.SemanticSegmentationTask init_args: + model_factory: EncoderDecoderFactory model_args: - decoder: FCNDecoder + backbone: prithvi_eo_v2_300 backbone_pretrained: true - backbone: prithvi_vit_100 - decoder_channels: 256 + backbone_drop_path: 0.1 backbone_bands: - BLUE - GREEN @@ -94,17 +94,19 @@ model: - NIR_NARROW - SWIR_1 - SWIR_2 - num_classes: 2 - head_dropout: 0.1 - decoder_num_convs: 4 - head_channel_list: - - 256 necks: - name: SelectIndices - indices: - - -1 +# indices: [2, 5, 8, 11] # 100M models + indices: [5, 11, 17, 23] # 300M models +# indices: [7, 15, 23, 31] # 600M models - name: ReshapeTokensToImage - loss: ce + - name: LearnedInterpolateToPyramidal + decoder: UNetDecoder + decoder_channels: [512, 256, 128, 64] + head_channel_list: [256] + head_dropout: 0.1 + num_classes: 2 + loss: dice aux_heads: - name: aux_head decoder: FCNDecoder @@ -113,25 +115,20 @@ model: decoder_in_index: -1 decoder_num_convs: 2 head_dropout: 0.1 - # head_channel_list: - # - 64 aux_loss: aux_head: 1.0 ignore_index: -1 - class_weights: - - 0.3 - - 0.7 freeze_backbone: false freeze_decoder: false - model_factory: EncoderDecoderFactory + optimizer: class_path: torch.optim.AdamW init_args: - lr: 6.e-5 - weight_decay: 0.05 + lr: 1.e-4 + weight_decay: 0.1 lr_scheduler: class_path: ReduceLROnPlateau init_args: monitor: val/loss - - + patience: 5 + factor: 0.5 diff --git a/examples/confs/sen1floods11_vit_dual_lr.yaml b/examples/confs/sen1floods11_vit_dual_lr.yaml index 47a6eaa7..b909630c 100644 --- a/examples/confs/sen1floods11_vit_dual_lr.yaml +++ b/examples/confs/sen1floods11_vit_dual_lr.yaml @@ -19,7 +19,7 @@ trainer: enable_checkpointing: true default_root_dir: data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 16 num_workers: 8 @@ -85,7 +85,7 @@ model: model_args: decoder: FCNDecoder backbone_pretrained: true - backbone: prithvi_vit_100 + backbone: prithvi_eo_v1_100 decoder_channels: 256 backbone_bands: - BLUE diff --git a/examples/confs/sen1floods11_vit_local_ckpt.yaml b/examples/confs/sen1floods11_vit_local_ckpt.yaml index 5eb1a2c2..52d4a4a8 100644 --- a/examples/confs/sen1floods11_vit_local_ckpt.yaml +++ b/examples/confs/sen1floods11_vit_local_ckpt.yaml @@ -19,7 +19,7 @@ trainer: enable_checkpointing: true default_root_dir: data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 16 num_workers: 8 @@ -82,13 +82,12 @@ data: model: class_path: terratorch.tasks.SemanticSegmentationTask init_args: + model_factory: EncoderDecoderFactory model_args: - decoder: FCNDecoder - backbone_pretrained: true - backbone_pretrained_cfg_overlay: - file: examples/Prithvi_100M.pt - backbone: prithvi_vit_100 - decoder_channels: 256 + backbone: prithvi_eo_v2_300 + backbone_pretrained: false + backbone_ckpt_path: examples/Prithvi_100M.pt + backbone_drop_path: 0.1 backbone_bands: - BLUE - GREEN @@ -96,17 +95,19 @@ model: - NIR_NARROW - SWIR_1 - SWIR_2 - num_classes: 2 - head_dropout: 0.1 - decoder_num_convs: 4 - head_channel_list: - - 256 necks: - name: SelectIndices - indices: - - -1 +# indices: [2, 5, 8, 11] # 100M models + indices: [5, 11, 17, 23] # 300M models +# indices: [7, 15, 23, 31] # 600M models - name: ReshapeTokensToImage - loss: ce + - name: LearnedInterpolateToPyramidal + decoder: UNetDecoder + decoder_channels: [512, 256, 128, 64] + head_channel_list: [256] + head_dropout: 0.1 + num_classes: 2 + loss: dice aux_heads: - name: aux_head decoder: FCNDecoder @@ -115,25 +116,20 @@ model: decoder_in_index: -1 decoder_num_convs: 2 head_dropout: 0.1 - # head_channel_list: - # - 64 aux_loss: aux_head: 1.0 ignore_index: -1 - class_weights: - - 0.3 - - 0.7 freeze_backbone: false freeze_decoder: false - model_factory: EncoderDecoderFactory + optimizer: class_path: torch.optim.AdamW init_args: - lr: 6.e-5 - weight_decay: 0.05 + lr: 1.e-4 + weight_decay: 0.1 lr_scheduler: class_path: ReduceLROnPlateau init_args: monitor: val/loss - - + patience: 5 + factor: 0.5 diff --git a/examples/confs/sen1floods11_vit_mmseg.yaml b/examples/confs/sen1floods11_vit_mmseg.yaml deleted file mode 100644 index 6a65edf5..00000000 --- a/examples/confs/sen1floods11_vit_mmseg.yaml +++ /dev/null @@ -1,122 +0,0 @@ -# lightning.pytorch==2.1.1 -seed_everything: 0 -trainer: - accelerator: auto - strategy: auto - devices: auto - num_nodes: 1 - precision: 16-mixed - logger: True # will use tensorboardlogger - callbacks: - - class_path: RichProgressBar - - class_path: LearningRateMonitor - init_args: - logging_interval: epoch - - max_epochs: 200 - check_val_every_n_epoch: 1 - log_every_n_steps: 50 - enable_checkpointing: true - default_root_dir: -data: - class_path: GenericNonGeoSegmentationDataModule - init_args: - batch_size: 16 - num_workers: 8 - constant_scale: 0.0001 - dataset_bands: - - COASTAL_AEROSOL - - BLUE - - GREEN - - RED - - RED_EDGE_1 - - RED_EDGE_2 - - RED_EDGE_3 - - NIR_BROAD - - NIR_NARROW - - WATER_VAPOR - - CIRRUS - - SWIR_1 - - SWIR_2 - output_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - rgb_indices: - - 2 - - 1 - - 0 - train_data_root: /v1.1/data/flood_events/HandLabeled/S2Hand/ - train_label_data_root: /v1.1/data/flood_events/HandLabeled/LabelHand - val_data_root: /v1.1/data/flood_events/HandLabeled/S2Hand/ - val_label_data_root: /v1.1/data/flood_events/HandLabeled/LabelHand - test_data_root: /v1.1/data/flood_events/HandLabeled/S2Hand/ - test_label_data_root: /v1.1/data/flood_events/HandLabeled/LabelHand - # these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files - train_split: /v1.1/splits/flood_handlabeled/flood_train_data.txt - test_split: /v1.1/splits/flood_handlabeled/flood_test_data.txt - val_split: /v1.1/splits/flood_handlabeled/flood_valid_data.txt - img_grep: "*_S2Hand.tif" - label_grep: "*_LabelHand.tif" - no_label_replace: -1 - no_data_replace: 0 - means: - - 0.1412956 - - 0.13795798 - - 0.12353792 - - 0.30902815 - - 0.2044958 - - 0.11912015 - stds: - - 0.07406382 - - 0.07370365 - - 0.08692279 - - 0.11798815 - - 0.09772074 - - 0.07659938 - num_classes: 2 -model: - class_path: terratorch.tasks.SemanticSegmentationTask - init_args: - model_args: - decoder: FCNHead - backbone_pretrained: True - backbone: prithvi_vit_100 - backbone_pretrain_img_size: 512 - decoder_num_convs: 4 - decoder_channels: 256 - decoder_dropout_ratio: 0.1 - backbone_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - num_classes: 2 - necks: - - name: ReshapeTokensToImage - - name: SelectIndices - indices: - - -1 - loss: ce - - ignore_index: -1 - class_weights: - - 0.3 - - 0.7 - freeze_backbone: true - freeze_decoder: false - model_factory: EncoderDecoderFactory -optimizer: - class_path: torch.optim.AdamW - init_args: - lr: 6.e-5 - weight_decay: 0.05 -lr_scheduler: - class_path: ReduceLROnPlateau - init_args: - monitor: val/loss diff --git a/examples/confs/sen1floods11_vit_peft.yaml b/examples/confs/sen1floods11_vit_peft.yaml index 3404e12b..68b4cc60 100644 --- a/examples/confs/sen1floods11_vit_peft.yaml +++ b/examples/confs/sen1floods11_vit_peft.yaml @@ -19,7 +19,7 @@ trainer: enable_checkpointing: true default_root_dir: data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 16 num_workers: 8 @@ -85,7 +85,7 @@ model: model_args: decoder: FCNDecoder backbone_pretrained: true - backbone: prithvi_vit_100 + backbone: prithvi_eo_v1_100 decoder_channels: 256 backbone_bands: - BLUE diff --git a/examples/confs/sen1floods11_vit_smp.yaml b/examples/confs/sen1floods11_vit_smp.yaml index f0c89f02..5ec07a6d 100644 --- a/examples/confs/sen1floods11_vit_smp.yaml +++ b/examples/confs/sen1floods11_vit_smp.yaml @@ -19,7 +19,7 @@ trainer: enable_checkpointing: true default_root_dir: data: - class_path: GenericNonGeoSegmentationDataModule + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule init_args: batch_size: 16 num_workers: 8 @@ -83,7 +83,7 @@ model: init_args: model_args: backbone_pretrained: True - backbone: prithvi_vit_100 + backbone: prithvi_eo_v1_100 backbone_pretrain_img_size: 512 backbone_bands: - BLUE diff --git a/examples/confs/sen4agri.yaml b/examples/confs/sen4agri.yaml deleted file mode 100644 index 4c041456..00000000 --- a/examples/confs/sen4agri.yaml +++ /dev/null @@ -1,59 +0,0 @@ -benchmark_suffix: benchmark -experiment_name: benchmark -precision: 16-mixed -backbone: your_model_here - -tasks: - - name: cashew - type: segmentation - loss: ce - bands: - - 12 - num_classes: 20 - max_epochs: 300 - direction: max - datamodule: - class_path: terratorch.datamodules.Sen4AgriNetDataModule - init_args: - data_root: "/dccstor/geofm-finetuning/datasets/Sen4AgriNet/S4A" - batch_size: 16 - num_workers: 6 - val_transform: - - class_path: FlattenTemporalIntoChannels - - class_path: ToTensorV2 - train_transform: - - class_path: FlattenTemporalIntoChannels - - class_path: ToTensorV2 - test_transform: - - class_path: FlattenTemporalIntoChannels - - class_path: ToTensorV2 - - decoder: UperNetDecoder - decoder_args: - channels: 128 - scale_modules: True - metric: val/Multiclass_Jaccard_Index - early_stop_patience: 50 - -n_trials: 16 -save_models: False -storage_uri: /path/to/storage -ray_storage_path: /path/to/ray/storage -optimization_space: - # decoder: - # - UperNetDecoder - # - UperNetDecoder - lr: - min: 1e-6 - max: 1e-3 - type: real - log: true - batch_size: - - 4 - - 8 - - 16 - - 32 - decoder_channels: - - 64 - - 128 - - 256 diff --git a/examples/confs/sen4map_ViT-L.yaml b/examples/confs/sen4map_ViT-L.yaml index e04c9a38..3ab959a8 100644 --- a/examples/confs/sen4map_ViT-L.yaml +++ b/examples/confs/sen4map_ViT-L.yaml @@ -25,7 +25,7 @@ trainer: default_root_dir: data: - class_path: Sen4MapLucasDataModule + class_path: terratorch.datamodules.Sen4MapLucasDataModule init_args: batch_size: 10 num_workers: 8 @@ -74,7 +74,7 @@ model: model_args: decoder: IdentityDecoder pretrained: true - backbone: prithvi_vit_300 + backbone: prithvi_eo_v2_300 backbone_pretrained_cfg_overlay: file: backbone_patch_size: 16 @@ -101,7 +101,7 @@ model: loss: ce freeze_backbone: false # freeze_decoder: false - model_factory: PrithviModelFactory + model_factory: EncoderDecoderFactory optimizer: class_path: torch.optim.AdamW diff --git a/examples/confs/smp_model_factory.yaml b/examples/confs/smp_model_factory.yaml deleted file mode 100644 index 7d9a1d53..00000000 --- a/examples/confs/smp_model_factory.yaml +++ /dev/null @@ -1,93 +0,0 @@ -benchmark_suffix: smp_test -experiment_name: smp_test -backbone: - backbone: resnet18 - backbone_args: - pretrained: False - output_stride: 2 - smp_decoder_channels: 512 - smp_encoder_depth: 5 - - # backbone: swin3d.swin3d_backbone.Swin3dBackbone - # backbone_args: - # pretrained: False - # output_stride: 2 - # out_channels: - # - 192 - # - 384 - # - 768 - # - 768 - # smp_decoder_channels: 768 - # smp_encoder_depth: 5 - - -tasks: - - name: cashew - type: segmentation - loss: ce - model_factory: SMPModelFactory - bands: - - RED - - GREEN - - BLUE - num_classes: 7 - max_epochs: 60 - direction: max - datamodule: - class_path: terratorch.datamodules.MBeninSmallHolderCashewsNonGeoDataModule - init_args: - batch_size: 16 - num_workers: 4 - train_transform: - - class_path: albumentations.Resize - init_args: - always_apply: True - height: 224 - width: 224 - - class_path: ToTensorV2 - test_transform: - - class_path: albumentations.Resize - init_args: - always_apply: True - height: 224 - width: 224 - - class_path: ToTensorV2 - val_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - data_root: "/dccstor/geofm-finetuning/geobench/segmentation_v1.0" - bands: - - "RED" - - "GREEN" - - "BLUE" - decoder: IdentityDecoder - decoder_args: - channels: 128 - metric: val/Multiclass Jaccard Index - -n_trials: 16 -save_models: False -storage_uri: /path/to/storage -optimization_space: - model: - - DeepLabV3 - lr: - min: 6e-5 - max: 1e-3 - type: real - log: true - batch_size: - - 8 - - 16 - - 32 - decoder_channels: - - 32 - - 64 - - 128 - head_dropout: - min: 0.2 - max: 0.8 - type: real \ No newline at end of file diff --git a/examples/confs/wxc-gravity-wave-ccc.yaml b/examples/confs/wxc-gravity-wave-ccc.yaml new file mode 100644 index 00000000..52a4af3a --- /dev/null +++ b/examples/confs/wxc-gravity-wave-ccc.yaml @@ -0,0 +1,85 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: + name: fire_scars + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 40 + + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: + +# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars +data: + class_path: terratorch.datamodules.era5.ERA5DataModule + init_args: + train_data_path: /dccstor/terratorch/users/rkie/gitco/terratorch + valid_data_path: /dccstor/terratorch/users/rkie/gitco/terratorch + file_glob_pattern: "wxc_input_u_v_t_p_output_theta_uw_vw_*.nc" + +model: + class_path: WxCTask + init_args: + model_args: + in_channels: 1280 + input_size_time: 1 + n_lats_px: 64 + n_lons_px: 128 + patch_size_px: [2, 2] + mask_unit_size_px: [8, 16] + mask_ratio_inputs: 0.5 + embed_dim: 2560 + n_blocks_encoder: 12 + n_blocks_decoder: 2 + mlp_multiplier: 4 + n_heads: 16 + dropout: 0.0 + drop_path: 0.05 + parameter_dropout: 0.0 + residual: none + masking_mode: both + decoder_shifting: False + positional_encoding: absolute + checkpoint_encoder: [3, 6, 9, 12, 15, 18, 21, 24] + checkpoint_decoder: [1, 3] + in_channels_static: 3 + input_scalers_mu: torch.tensor([0] * 1280) + input_scalers_sigma: torch.tensor([1] * 1280) + input_scalers_epsilon: 0 + static_input_scalers_mu: torch.tensor([0] * 3) + static_input_scalers_sigma: torch.tensor([1] * 3) + static_input_scalers_epsilon: 0 + output_scalers: torch.tensor([0] * 1280) + backbone_weights: magnet-flux-uvtp122-epoch-99-loss-0.1022.pt + backbone: prithviwxc + aux_decoders: unetpincer + skip_connection: True + model_factory: WxCModelFactory + mode: eval +optimizer: + class_path: torch.optim.Adam + init_args: + lr: 1.5e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss diff --git a/examples/confs/wxc-gravity-wave.yaml b/examples/confs/wxc-gravity-wave.yaml new file mode 100644 index 00000000..b504df0e --- /dev/null +++ b/examples/confs/wxc-gravity-wave.yaml @@ -0,0 +1,81 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: + name: fire_scars + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 40 + + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: + +# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars +data: + class_path: terratorch.datamodules.era5.ERA5DataModule + +model: + class_path: WxCTask + init_args: + model_args: + in_channels: 1280 + input_size_time: 1 + n_lats_px: 64 + n_lons_px: 128 + patch_size_px: [2, 2] + mask_unit_size_px: [8, 16] + mask_ratio_inputs: 0.5 + embed_dim: 2560 + n_blocks_encoder: 12 + n_blocks_decoder: 2 + mlp_multiplier: 4 + n_heads: 16 + dropout: 0.0 + drop_path: 0.05 + parameter_dropout: 0.0 + residual: none + masking_mode: both + decoder_shifting: False + positional_encoding: absolute + checkpoint_encoder: [3, 6, 9, 12, 15, 18, 21, 24] + checkpoint_decoder: [1, 3] + in_channels_static: 3 + input_scalers_mu: torch.tensor([0] * 1280) + input_scalers_sigma: torch.tensor([1] * 1280) + input_scalers_epsilon: 0 + static_input_scalers_mu: torch.tensor([0] * 3) + static_input_scalers_sigma: torch.tensor([1] * 3) + static_input_scalers_epsilon: 0 + output_scalers: torch.tensor([0] * 1280) + backbone_weights: magnet-flux-uvtp122-epoch-99-loss-0.1022.pt + backbone: prithviwxc + aux_decoders: unetpincer + skip_connection: True + model_factory: WxCModelFactory + mode: eval +optimizer: + class_path: torch.optim.Adam + init_args: + lr: 1.5e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss diff --git a/examples/notebooks/Tutorial.ipynb b/examples/notebooks/Tutorial.ipynb index 8a704fe2..699a6c71 100644 --- a/examples/notebooks/Tutorial.ipynb +++ b/examples/notebooks/Tutorial.ipynb @@ -2,16 +2,31 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, "id": "5d049232-f4b1-473d-aac3-0b3539905b03", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-22T10:44:22.839382Z", + "start_time": "2025-01-22T10:44:18.410638Z" + } + }, "source": [ "import os\n", "import torch\n", "\n", "from terratorch import BACKBONE_REGISTRY" - ] + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.0 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations\n", + "/opt/homebrew/lib/python3.11/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n", + " @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)\n" + ] + } + ], + "execution_count": 1 }, { "cell_type": "markdown", @@ -31,48 +46,35 @@ }, { "cell_type": "code", - "execution_count": 5, "id": "8dcdfa85-8e43-4db0-9ddf-cb11c5544942", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-22T10:44:59.384413Z", + "start_time": "2025-01-22T10:44:59.380583Z" + } + }, + "source": "print([model_name for model_name in BACKBONE_REGISTRY if \"terratorch_prithvi\" in model_name])", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "['timm_prithvi_swin_B', 'timm_prithvi_swin_L', 'timm_prithvi_vit_100', 'timm_prithvi_vit_300', 'timm_prithvi_vit_tiny']\n" + "['terratorch_prithvi_eo_tiny', 'terratorch_prithvi_eo_v1_100', 'terratorch_prithvi_eo_v2_300', 'terratorch_prithvi_eo_v2_600', 'terratorch_prithvi_eo_v2_300_tl', 'terratorch_prithvi_eo_v2_600_tl', 'terratorch_prithvi_vit_tiny', 'terratorch_prithvi_vit_100']\n" ] } ], - "source": [ - "print([model_name for model_name in BACKBONE_REGISTRY if \"prithvi\" in model_name])" - ] + "execution_count": 5 }, { "cell_type": "code", - "execution_count": 2, "id": "338c6071", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-22T10:45:05.471003Z", + "start_time": "2025-01-22T10:45:05.466191Z" } - ], - "source": [ - "\"timm_prithvi_vit_100\" in BACKBONE_REGISTRY" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "13e3ed35", - "metadata": {}, + }, + "source": "\"prithvi_vit_100\" in BACKBONE_REGISTRY", "outputs": [ { "data": { @@ -80,24 +82,37 @@ "True" ] }, - "execution_count": 3, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "\"prithvi_vit_100\" in BACKBONE_REGISTRY" - ] + "execution_count": 6 }, { "cell_type": "code", - "execution_count": 4, "id": "38db3f3c", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-22T10:44:34.736849Z", + "start_time": "2025-01-22T10:44:34.220986Z" + } + }, "source": [ "model = BACKBONE_REGISTRY.build(\"prithvi_vit_100\")" - ] + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/BLU/repos/terratorch_ibm/terratorch/models/backbones/prithvi_vit.py:354: FutureWarning: The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. prithvi_vit_100 will be removed in a future version.\n", + " warnings.warn(\"The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. \"\n", + "INFO:terratorch.models.backbones.prithvi_vit:Model bands not passed. Assuming bands are ordered in the same way as [, , , , , ]. Pretrained patch_embed layer may be misaligned with current bands\n" + ] + } + ], + "execution_count": 4 }, { "cell_type": "markdown", @@ -667,17 +682,17 @@ ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┃\u001B[1m \u001B[0m\u001B[1m Test metric \u001B[0m\u001B[1m \u001B[0m┃\u001B[1m \u001B[0m\u001B[1m DataLoader 0 \u001B[0m\u001B[1m \u001B[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test/Multiclass_Accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.807342529296875 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/Multiclass_F1_Score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.807342529296875 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/Multiclass_Jaccard_Index \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.4036712646484375 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36mtest/Multiclass_Jaccard_Index_Micro\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.676927387714386 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5365139245986938 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/multiclassaccuracy_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/multiclassaccuracy_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/multiclassjaccardindex_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.807342529296875 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/multiclassjaccardindex_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/Multiclass_Accuracy \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.807342529296875 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/Multiclass_F1_Score \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.807342529296875 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/Multiclass_Jaccard_Index \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.4036712646484375 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36mtest/Multiclass_Jaccard_Index_Micro\u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.676927387714386 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/loss \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.5365139245986938 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/multiclassaccuracy_0 \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 1.0 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/multiclassaccuracy_1 \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.0 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/multiclassjaccardindex_0 \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.807342529296875 \u001B[0m\u001B[35m \u001B[0m│\n", + "│\u001B[36m \u001B[0m\u001B[36m test/multiclassjaccardindex_1 \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.0 \u001B[0m\u001B[35m \u001B[0m│\n", "└─────────────────────────────────────┴─────────────────────────────────────┘\n" ] }, diff --git a/examples/scripts/instantiate_satmae_backbone.py b/examples/scripts/instantiate_satmae_backbone.py deleted file mode 100644 index 79d4cc3d..00000000 --- a/examples/scripts/instantiate_satmae_backbone.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -import numpy as np - -from models_mae import MaskedAutoencoderViT - -kwargs = {"img_size":224, - "patch_size":16, - "in_chans":3, - "embed_dim":1024, - "depth":24, - "num_heads":16, - "decoder_embed_dim":512, - "decoder_depth":8, - "decoder_num_heads":16, - "mlp_ratio":4.} - -vit_mae = MaskedAutoencoderViT(**kwargs) - -mask_ratio = 0.75 -data = torch.from_numpy(np.random.rand(4, 3, 224, 224).astype("float32")) -latent, _, ids_restore = vit_mae.forward_encoder(data, mask_ratio) -reconstructed = vit_mae.forward_decoder(latent, ids_restore) - - -print(f"Output shape: {latent.shape}") -print("Done.") - -_, reconstructed, _ = vit_mae.forward(data, mask_ratio) - -print(f"Output shape: {reconstructed.shape}") -print("Done.") - diff --git a/examples/scripts/open_local_model.py b/examples/scripts/open_local_model.py index 8651c07c..dc970f16 100644 --- a/examples/scripts/open_local_model.py +++ b/examples/scripts/open_local_model.py @@ -1,10 +1,10 @@ from terratorch.io.file import open_generic_torch_model -from models_mae_temporal import MaskedAutoencoderViT +from terratorch.models.backbones.prithvi_mae import PrithviMAE from torch import nn # Path for a downloaded model model_weights_path = "./pretrain-vit-base-e199.pth" -model_template = MaskedAutoencoderViT +model_template = PrithviMAE model_kwargs = { 'img_size': 224, diff --git a/integrationtests/test_smoke.py b/integrationtests/test_smoke.py new file mode 100644 index 00000000..afda48f2 --- /dev/null +++ b/integrationtests/test_smoke.py @@ -0,0 +1,4 @@ +import pytest + +def test_smoke(): + assert True diff --git a/mkdocs.yml b/mkdocs.yml index 3105bbf0..0ffd9b8b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -15,9 +15,9 @@ plugins: paths: [src] # search packages in the src folde options: show_root_heading: true -extra: - version: - provider: mike + #extra: + # version: + # provider: mike site_url: https://ibm.github.io/terratorch/ repo_url: https://github.com/IBM/terratorch diff --git a/pyproject.toml b/pyproject.toml index 6a477ac4..30d3cc0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "torchgeo>=0.6.0", "rioxarray>=0.15.0", # see issue #64 - "albumentations>=1.3.1, <=1.4.10", + "albumentations>=1.3.1, <=1.4.21", "albucore<=0.0.16", "rasterio>=1.3.9", "torchmetrics<=1.3.1", @@ -44,11 +44,29 @@ dependencies = [ "mlflow>=2.12.1", # broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977 "lightning[pytorch-extra]>=2,!=2.3.*", - "segmentation-models-pytorch>=0.3" + "segmentation-models-pytorch>=0.3", + "jsonargparse<=4.35.0", # Dependencies not available on PyPI ] [project.optional-dependencies] +torchgeo = [ + "torch==2.4.1", + "torchvision==0.19.1", + "torchgeo @ git+https://github.com/microsoft/torchgeo.git@fedf99375535f801565856cd774bfa9e5a251d55", + "rioxarray>=0.15.0", + "albumentations==1.3.1", + "albucore<=0.0.16", + "rasterio>=1.3.9", + "torchmetrics<=1.3.1", + "geopandas>=0.14.4", + "lightly>=1.4.25", + "h5py>=3.10.0", + "mlflow>=2.12.1", + "lightning[pytorch-extra]>=2,!=2.3.*", + "segmentation-models-pytorch>=0.3" +] + dev = [ "black", "mkdocs-material", diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index a00f2ad3..21f22c40 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -20,6 +20,9 @@ import rasterio import torch +import random +import string + # Allows classes to be referenced using only the class name import torchgeo.datamodules import yaml @@ -153,10 +156,15 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batc if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) - pred_batch, filename_batch = prediction - - for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False): - save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype) + if isinstance(prediction, torch.Tensor): + filename_batch = ''.join(random.choices(string.ascii_letters + string.digits, k=8)) + torch.save(prediction, os.path.join(output_dir, f"{filename_batch}.pt")) + elif isinstance(prediction, tuple): + pred_batch, filename_batch = prediction + for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False): + save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype) + else: + raise TypeError(f"Unknown type for prediction{type(prediction)}") def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # noqa: ARG002 # this will create N (num processes) files in `output_dir` each containing diff --git a/terratorch/datamodules/__init__.py b/terratorch/datamodules/__init__.py index f97c75fe..bef29be2 100644 --- a/terratorch/datamodules/__init__.py +++ b/terratorch/datamodules/__init__.py @@ -27,6 +27,7 @@ from terratorch.datamodules.multi_temporal_crop_classification import MultiTemporalCropClassificationDataModule from terratorch.datamodules.open_sentinel_map import OpenSentinelMapDataModule from terratorch.datamodules.pastis import PASTISDataModule +from terratorch.datamodules.era5 import ERA5DataModule try: wxc_present = True diff --git a/terratorch/io/file.py b/terratorch/io/file.py index 27efebcd..942a09e0 100644 --- a/terratorch/io/file.py +++ b/terratorch/io/file.py @@ -1,4 +1,5 @@ import os +import torch import importlib from torch import nn import numpy as np @@ -22,6 +23,7 @@ def open_generic_torch_model(model: type | str = None, return load_torch_weights(model=model, save_dir=dirname, name=filename) + def load_torch_weights(model:nn.Module=None, save_dir: str = None, name: str = None, device: str = None) -> None: print(f"Trying to load for {device}") @@ -30,29 +32,21 @@ def load_torch_weights(model:nn.Module=None, save_dir: str = None, name: str = N if device != None: model.load_state_dict( torch.load( - os.path.join(save_dir, name + ".pth"), + os.path.join(save_dir, name), map_location=torch.device(device), ) ) else: - #try: - # path = os.path.join(save_dir, name) - # checkpoint = torch.load(path, map_location='cpu') - # model = checkpoint['model'] - # state_dict = model.state_dict() - # msg = model.load_state_dict(model, strict=False) - - #except Exception: - - model.load_state_dict(torch.load(os.path.join(save_dir, name))) + model.load_state_dict(torch.load(os.path.join(save_dir, name), map_location='cpu')) except Exception: print( - f"It was not possible to load from {os.path.join(save_dir, name + '.pth')}" + f"It was not possible to load from {os.path.join(save_dir, name)}" ) return model + def load_from_file_or_attribute(value: list[float]|str): if isinstance(value, list): diff --git a/terratorch/models/backbones/dofa_vit.py b/terratorch/models/backbones/dofa_vit.py index 61913ced..a4d7ef5e 100644 --- a/terratorch/models/backbones/dofa_vit.py +++ b/terratorch/models/backbones/dofa_vit.py @@ -61,7 +61,7 @@ def __init__(self, dofa_model, wavelengths, weights=None, out_indices=None) -> N self.wavelengths = wavelengths self.out_indices = out_indices if out_indices else [-1] - self.out_channels = [self.dofa_model.embed_dim] * len(self.out_indices) + self.out_channels = [self.dofa_model.patch_embed.embed_dim] * len(self.out_indices) def forward(self, x: List[torch.Tensor], **kwargs) -> torch.Tensor: diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index c209b25d..bf0dde12 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -17,9 +17,7 @@ # transformers: https://github.com/huggingface/transformers # -------------------------------------------------------- -from functools import partial -from typing import List, Tuple - +import warnings import logging import numpy as np import torch @@ -135,8 +133,8 @@ class PatchEmbed(nn.Module): """3D version of timm.models.vision_transformer.PatchEmbed""" def __init__( self, - input_size: Tuple[int, int, int] = (1, 224, 224), - patch_size: Tuple[int, int, int] = (1, 16, 16), + input_size: tuple[int, int, int] = (1, 224, 224), + patch_size: tuple[int, int, int] = (1, 16, 16), in_chans: int = 3, embed_dim: int = 768, norm_layer: nn.Module | None = None, @@ -153,16 +151,13 @@ def __init__( self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - self.log_warning = True def forward(self, x): B, C, T, H, W = x.shape - if (self.log_warning and - (T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1)): - logger.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." - f"The border will be ignored, add backbone_padding for pixel-wise tasks.") - self.log_warning = False + if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1: + warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." + f"The border will be ignored, add backbone_padding for pixel-wise tasks.") x = self.proj(x) if self.flatten: @@ -237,8 +232,8 @@ def forward(self, location_coords: torch.Tensor): class PrithviViT(nn.Module): """ Prithvi ViT Encoder""" def __init__(self, - img_size: int | Tuple[int, int] = 224, - patch_size: int | Tuple[int, int, int] = (1, 16, 16), + img_size: int | tuple[int, int] = 224, + patch_size: int | tuple[int, int, int] = (1, 16, 16), num_frames: int = 1, in_chans: int = 3, embed_dim: int = 1024, @@ -246,18 +241,17 @@ def __init__(self, num_heads: int = 16, mlp_ratio: float = 4., norm_layer: nn.Module = nn.LayerNorm, - coords_encoding: List[str] | None = None, + coords_encoding: list[str] | None = None, coords_scale_learn: bool = False, - encoder_only: bool = True, # needed for timm + drop_path: float = 0., ** kwargs, ): super().__init__() - self.feature_info = [] - self.encoder_only = encoder_only self.in_chans = in_chans self.num_frames = num_frames self.embed_dim = embed_dim + self.out_channels = [embed_dim] * depth self.img_size = to_2tuple(img_size) if isinstance(patch_size, int): patch_size = (1, patch_size, patch_size) @@ -286,10 +280,8 @@ def __init__(self, # Transformer layers self.blocks = [] for i in range(depth): - self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)) - self.feature_info.append( - {"num_chs": embed_dim * self.patch_embed.grid_size[0], "reduction": 1, "module": f"blocks.{i}"} - ) + self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, + drop_path=drop_path,)) self.blocks = nn.ModuleList(self.blocks) self.norm = norm_layer(embed_dim) @@ -344,21 +336,33 @@ def random_masking(self, sequence, mask_ratio, noise=None): return sequence_unmasked, mask, ids_restore - def _get_pos_embed(self, x): - t, h, w = x.shape[-3:] - - pos_embed = torch.from_numpy(get_3d_sincos_pos_embed( - self.embed_dim, - ( - t // self.patch_embed.patch_size[0], - h // self.patch_embed.patch_size[1], - w // self.patch_embed.patch_size[2], - ), - add_cls_token=True, - )).float().unsqueeze(0).to(x) - - return pos_embed - + def interpolate_pos_encoding(self, x, t, w, h): + """ + Adapted from: + - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding, + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194 + """ + if x.shape[1] == self.pos_embed.shape[1] and w == h: + # No interpolation needed + return self.pos_embed + + class_pos_embed = self.pos_embed[:, :1] + patch_pos_embed = self.pos_embed[:, 1:] + t_patches = t // self.patch_embed.patch_size[0] + w_patches = w // self.patch_embed.patch_size[1] + h_patches = h // self.patch_embed.patch_size[2] + + n_sqrt = int((patch_pos_embed.shape[1] / t_patches) ** 0.5) + patch_pos_embed = patch_pos_embed.reshape(t_patches, n_sqrt, n_sqrt, self.embed_dim).permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(h_patches, w_patches), + mode='bicubic', + align_corners=True, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, self.embed_dim) + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward( self, x: torch.Tensor, @@ -366,15 +370,15 @@ def forward( location_coords: None | torch.Tensor = None, mask_ratio=0.75 ): - if x.shape[-3:] != self.patch_embed.input_size: - # changed input size - pos_embed = self._get_pos_embed(x) - else: - pos_embed = self.pos_embed + if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: + # add time dim + x = x.unsqueeze(2) + t, h, w = x.shape[-3:] # embed patches x = self.patch_embed(x) + pos_embed = self.interpolate_pos_encoding(x, t, h, w) # add pos embed w/o cls token x = x + pos_embed[:, 1:, :] @@ -410,15 +414,12 @@ def forward_features( if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: # add time dim x = x.unsqueeze(2) - - if x.shape[-3:] != self.patch_embed.input_size: - pos_embed = self._get_pos_embed(x) - else: - pos_embed = self.pos_embed + t, h, w = x.shape[-3:] # embed patches x = self.patch_embed(x) + pos_embed = self.interpolate_pos_encoding(x, t, h, w) # add pos embed w/o cls token x = x + pos_embed[:, 1:, :] @@ -467,8 +468,8 @@ def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list class MAEDecoder(nn.Module): """ Transformer Decoder used in the Prithvi MAE""" def __init__(self, - patch_size: int | Tuple[int, int, int] = (1, 16, 16), - grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14), + patch_size: int | tuple[int, int, int] = (1, 16, 16), + grid_size: list[int] | tuple[int, int, int] = (3, 14, 14), in_chans: int = 3, encoder_embed_dim: int = 1024, decoder_embed_dim: int = 512, @@ -476,7 +477,7 @@ def __init__(self, num_heads: int = 16, mlp_ratio: float = 4., norm_layer: nn.Module = nn.LayerNorm, - coords_encoding: List[str] | None = None, + coords_encoding: list[str] | None = None, coords_scale_learn: bool = False, ): super().__init__() @@ -525,6 +526,34 @@ def initialize_weights(self): torch.nn.init.normal_(self.mask_token, std=0.02) self.apply(_init_weights) + def interpolate_pos_encoding(self, x, t, w, h): + """ + Adapted from: + - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding, + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194 + """ + if x.shape[1] == self.decoder_pos_embed.shape[1] and w == h: + # No interpolation needed + return self.decoder_pos_embed + + class_pos_embed = self.decoder_pos_embed[:, :1] + patch_pos_embed = self.decoder_pos_embed[:, 1:] + t_patches = t // self.patch_size[0] + w_patches = w // self.patch_size[1] + h_patches = h // self.patch_size[2] + + n_sqrt = int((patch_pos_embed.shape[1] / t_patches) ** 0.5) + patch_pos_embed = patch_pos_embed.reshape(t_patches, n_sqrt, n_sqrt, self.decoder_embed_dim).permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(h_patches, w_patches), + mode='bicubic', + align_corners=True, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, self.decoder_embed_dim) + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + def forward( self, hidden_states: torch.Tensor, @@ -535,44 +564,32 @@ def forward( ): # embed tokens x = self.decoder_embed(hidden_states) - - t, h, w = input_size[-3:] - decoder_pos_embed = torch.from_numpy( - get_3d_sincos_pos_embed( - self.decoder_embed_dim, - ( - t // self.patch_size[0], - h // self.patch_size[1], - w // self.patch_size[2], - ), - add_cls_token=True, - ) - ).to(x) + cls_token = x[:, :1, :] # append mask tokens to sequence mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) - x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + x = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token # unshuffle - x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device)) - x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token - # add pos embed - x = x + decoder_pos_embed + x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device)) - # remove cls token - x_ = x[:, 1:, :] + # add pos embed + t, h, w = input_size[-3:] + decoder_pos_embed = self.interpolate_pos_encoding(x, t, w, h) + cls_token = cls_token + decoder_pos_embed[:, :1, :] + x = x + decoder_pos_embed[:, 1:, :] if self.temporal_encoding and temporal_coords is not None: - num_tokens_per_frame = x_.shape[1] // self.num_frames + num_tokens_per_frame = x.shape[1] // self.num_frames temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame) # Add temporal encoding w/o cls token - x_ = x_ + temporal_encoding + x = x + temporal_encoding if self.location_encoding and location_coords is not None: location_encoding = self.location_embed_dec(location_coords) # Add location encoding w/o cls token - x_ = x_ + location_encoding + x = x + location_encoding # append cls token - x = torch.cat([x[:, :1, :], x_], dim=1) + x = torch.cat([cls_token, x], dim=1) # apply Transformer layers (blocks) for block in self.decoder_blocks: @@ -592,8 +609,8 @@ class PrithviMAE(nn.Module): """ Prithvi Masked Autoencoder""" def __init__(self, - img_size: int | Tuple[int, int] = 224, - patch_size: int | Tuple[int, int, int] = (1, 16, 16), + img_size: int | tuple[int, int] = 224, + patch_size: int | tuple[int, int, int] = (1, 16, 16), num_frames: int = 4, in_chans: int = 6, embed_dim: int = 768, @@ -605,9 +622,10 @@ def __init__(self, mlp_ratio: float = 4., norm_layer: nn.Module = nn.LayerNorm, norm_pix_loss: bool = False, - coords_encoding: List[str] | None = None, + coords_encoding: list[str] | None = None, coords_scale_learn: bool = False, - encoder_only: bool = False, + drop_path: float = 0., + mask_ratio: float = 0.75, **kwargs, ): super().__init__() @@ -624,28 +642,26 @@ def __init__(self, norm_layer=norm_layer, coords_encoding=coords_encoding, coords_scale_learn=coords_scale_learn, + drop_path=drop_path, ) - self.encoder_only = encoder_only - - if not encoder_only: - self.decoder = MAEDecoder( - patch_size=patch_size, - grid_size=self.encoder.patch_embed.grid_size, - in_chans=in_chans, - encoder_embed_dim=embed_dim, - decoder_embed_dim=decoder_embed_dim, - depth=decoder_depth, - num_heads=decoder_num_heads, - mlp_ratio=mlp_ratio, - norm_layer=norm_layer, - coords_encoding=coords_encoding, - coords_scale_learn=coords_scale_learn, - ) - else: - self.decoder = nn.Identity() + self.decoder = MAEDecoder( + patch_size=patch_size, + grid_size=self.encoder.patch_embed.grid_size, + in_chans=in_chans, + encoder_embed_dim=embed_dim, + decoder_embed_dim=decoder_embed_dim, + depth=decoder_depth, + num_heads=decoder_num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + coords_encoding=coords_encoding, + coords_scale_learn=coords_scale_learn, + ) + self.mask_ratio = mask_ratio self.norm_pix_loss = norm_pix_loss + self.out_channels = self.encoder.out_channels def patchify(self, pixel_values): """ @@ -667,13 +683,13 @@ def patchify(self, pixel_values): return patchified_pixel_values - def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None): + def unpatchify(self, patchified_pixel_values, image_size: tuple[int, int] | None = None): """ Args: patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: Patchified pixel values. - image_size (`Tuple[int, int]`, *optional*): + image_size (`tuple[int, int]`, *optional*): Original image size. Returns: @@ -721,12 +737,13 @@ def forward( pixel_values: torch.Tensor, temporal_coords: None | torch.Tensor = None, location_coords: None | torch.Tensor = None, - mask_ratio: float = 0.75 + mask_ratio: float = None, ): if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1: # add time dim pixel_values = pixel_values.unsqueeze(2) + mask_ratio = mask_ratio or self.mask_ratio latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio) pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape) loss = self.forward_loss(pixel_values, pred, mask) @@ -737,5 +754,5 @@ def forward_features( x: torch.Tensor, temporal_coords: None | torch.Tensor = None, location_coords: None | torch.Tensor = None, - ) -> List[torch.Tensor]: + ) -> list[torch.Tensor]: return self.encoder.forward_features(x, temporal_coords, location_coords) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 9918d5f0..1ce0a872 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -1,15 +1,15 @@ # Copyright contributors to the Terratorch project +import warnings import torch import logging from torch import nn, Tensor -from timm.models import (FeatureInfo, load_model_config_from_hf, build_model_with_cfg, generate_default_cfgs, - register_model) - from terratorch.datasets import HLSBands from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights from terratorch.datasets.utils import generate_bands_intervals from terratorch.models.backbones.prithvi_mae import PrithviViT, PrithviMAE +from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY +from huggingface_hub import hf_hub_download logger = logging.getLogger(__name__) @@ -60,14 +60,13 @@ def _cfg(**kwargs): "prithvi_eo_v2_300": _cfg(embed_dim=1024, depth=24, num_heads=16), "prithvi_eo_v2_300_tl": _cfg(embed_dim=1024, depth=24, num_heads=16, coords_encoding=["time", "location"], coords_scale_learn=True), - "prithvi_eo_v2_600": _cfg(embed_dim=1280, depth=32, num_heads=16), - "prithvi_eo_v2_600_tl": _cfg(embed_dim=1280, depth=32, num_heads=16, + "prithvi_eo_v2_600": _cfg(embed_dim=1280, depth=32, num_heads=16, patch_size=[1, 14, 14]), + "prithvi_eo_v2_600_tl": _cfg(embed_dim=1280, depth=32, num_heads=16, patch_size=[1, 14, 14], coords_encoding=["time", "location"], coords_scale_learn=True), } # Timm pretrained configs -default_cfgs = generate_default_cfgs( - { +pretrained_weights = { "prithvi_eo_v1_100": { "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-1.0-100M", "hf_hub_filename": "Prithvi_EO_V1_100M.pt", @@ -89,7 +88,6 @@ def _cfg(**kwargs): "hf_hub_filename": "Prithvi_EO_V2_600M_TL.pt", }, } -) def checkpoint_filter_fn_vit( @@ -99,6 +97,9 @@ def checkpoint_filter_fn_vit( clean_dict = {} for k, v in state_dict.items(): + if "_timm_module." in k: # Backwards compatibility for old model checkpoints + k = k.replace("_timm_module.", "") + if "pos_embed" in k: v = model.pos_embed # pos_embed depends on num_frames and is fixed. if "decoder" in k or "_dec" in k or k == "mask_token": @@ -128,6 +129,9 @@ def checkpoint_filter_fn_mae( clean_dict = {} for k, v in state_dict.items(): + if "_timm_module." in k: # Backwards compatibility for old model checkpoints + k = k.replace("_timm_module.", "") + # pos_embed depends on num_frames and is fixed. if "decoder_pos_embed" in k: v = model.decoder.decoder_pos_embed @@ -169,220 +173,227 @@ def pad_images(imgs: Tensor,patch_size: int, padding:str) -> Tensor: def _create_prithvi( variant: str, - pretrained: bool = False, # noqa: FBT001, FBT002 - pretrained_bands: list[HLSBands] | None = None, + pretrained: bool = False, # noqa: FBT001, FBT002 model_bands: list[HLSBands | int] | None = None, + ckpt_path: str = None, + pretrained_bands: list[HLSBands | str | int] | None = None, + num_frames: int = 1, + encoder_only: bool = True, **kwargs, -) -> PrithviViT: - if pretrained_bands is None: - pretrained_bands = PRETRAINED_BANDS +) -> PrithviViT | PrithviMAE: + """ + Build PrithviViT and PrithviMAE models. + By default, encoder_only is set to True and a ViT is returned. + """ + + # Load default config + model_args = prithvi_cfgs[variant].copy() + + # Backwards compatibility from timm (pretrained_cfg_overlay={"file": ""}) TODO: Remove before v1.0 + if "pretrained_cfg_overlay" in kwargs: + warnings.warn(f"pretrained_cfg_overlay is deprecated and will be removed in a future version, " + f"use ckpt_path= instead.", DeprecationWarning, stacklevel=2) + if ckpt_path is not None: + warnings.warn(f"pretrained_cfg_overlay and ckpt_path are provided, ignoring pretrained_cfg_overlay.") + elif "file" not in kwargs["pretrained_cfg_overlay"]: + warnings.warn("pretrained_cfg_overlay does not include 'file path', ignoring pretrained_cfg_overlay.") + else: + ckpt_path = kwargs.pop("pretrained_cfg_overlay")["file"] + pretrained_bands = pretrained_bands or model_args.get("bands", PRETRAINED_BANDS) if model_bands is None: model_bands: list[HLSBands | int] = pretrained_bands - logger.info( - f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\ - Pretrained patch_embed layer may be misaligned with current bands" - ) + logger.info(f"Model bands not passed. Assuming bands are ordered in the same way as {pretrained_bands}." + f"Pretrained patch_embed layer may be misaligned with current bands") else: model_bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in model_bands] - - # Little hack because VIT does not support timm's features_only - encoder_only = kwargs.pop("features_only", False) - - model_bands = generate_bands_intervals(model_bands) + model_bands = generate_bands_intervals(model_bands) kwargs["in_chans"] = len(model_bands) + kwargs["num_frames"] = num_frames + model_args.update(kwargs) if encoder_only: prithvi_model_class = PrithviViT - def checkpoint_filter_wrapper_fn(state_dict, model): - return checkpoint_filter_fn_vit(state_dict, model, pretrained_bands, model_bands) + checkpoint_filter_wrapper_fn = checkpoint_filter_fn_vit else: prithvi_model_class = PrithviMAE - def checkpoint_filter_wrapper_fn(state_dict, model): - return checkpoint_filter_fn_mae(state_dict, model, pretrained_bands, model_bands) + checkpoint_filter_wrapper_fn = checkpoint_filter_fn_mae if pretrained: - assert variant in default_cfgs, (f"No pre-trained model found for variant {variant} " - f"(pretrained models: {default_cfgs.keys()})") - # Load pre-trained config from hf - try: - model_args = load_model_config_from_hf(default_cfgs[variant].default.hf_hub_id)[0] - model_args.update(kwargs) - except: - logger.warning(f"No pretrained configuration was found on HuggingFace for the model {variant}." - f"Using random initialization.") - model_args = prithvi_cfgs[variant].copy() - model_args.update(kwargs) - else: - # Load default config - model_args = prithvi_cfgs[variant].copy() - model_args.update(kwargs) - - try: - model = build_model_with_cfg( - prithvi_model_class, - variant, - pretrained, - pretrained_filter_fn=checkpoint_filter_wrapper_fn, - pretrained_strict=True, - **model_args, - ) - except RuntimeError as e: - if pretrained: - logger.error(f"Failed to initialize the pre-trained model {variant} via timm, " - f"consider running the code with pretrained=False.") - else: - logger.error(f"Failed to initialize the model {variant} via timm.") - raise e + assert variant in pretrained_weights, (f"No pre-trained model found for variant {variant} " + f"(pretrained models: {pretrained_weights.keys()})") + model = prithvi_model_class(**model_args) + + if ckpt_path is not None: + # Load model from checkpoint + state_dict = torch.load(ckpt_path, map_location="cpu") + state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands) + model.load_state_dict(state_dict, strict=False) + elif pretrained: + try: + # Download config.json to count model downloads + _ = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"], filename="config.json") + # Load model from Hugging Face + pretrained_path = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"], + filename=pretrained_weights[variant]["hf_hub_filename"]) + state_dict = torch.load(pretrained_path, map_location="cpu") + state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands) + model.load_state_dict(state_dict, strict=True) + except RuntimeError as e: + logger.error(f"Failed to load the pre-trained weights for {variant}.") + raise e + + assert encoder_only or "out_indices" not in kwargs, "out_indices provided for a MAE model." if encoder_only: default_out_indices = list(range(len(model.blocks))) out_indices = kwargs.pop("out_indices", default_out_indices) - model.feature_info = FeatureInfo(model.feature_info, out_indices) - model.encode_decode_forward = model.forward + def forward_filter_indices(*args, **kwargs): features = model.forward_features(*args, **kwargs) return [features[i] for i in out_indices] + model.forward = forward_filter_indices + model.out_indices = out_indices model.model_bands = model_bands model.pretrained_bands = pretrained_bands - padding = kwargs.get("padding", "none") - patch_size = kwargs.get("patch_size", 16) - if isinstance(patch_size, list): - patch_size = patch_size[-1] + return model - if padding != "none": - original_forward = model.forward - original_forward_features = model.forward_features - def pad_and_forward(forward_fn, patch_size, padding, *args, **kwargs): - inputs = pad_images(args[0], patch_size, padding) - return forward_fn(inputs, **kwargs) +@ TERRATORCH_BACKBONE_REGISTRY.register +def prithvi_eo_tiny( + pretrained: bool = False, # noqa: FBT001, FBT002 + bands: list[HLSBands] | None = None, + **kwargs, +) -> PrithviViT: - def forward_pad_images(*args, **kwargs): - return pad_and_forward(original_forward, patch_size, padding, *args, **kwargs) + return _create_prithvi("prithvi_eo_tiny", pretrained=pretrained, model_bands=bands, **kwargs) - def forward_features_pad_images(*args, **kwargs): - return pad_and_forward(original_forward_features, patch_size, padding, *args, **kwargs) - model.forward = forward_pad_images - model.forward_features = forward_features_pad_images +@ TERRATORCH_BACKBONE_REGISTRY.register +def prithvi_eo_v1_100( + pretrained: bool = False, # noqa: FBT001, FBT002 + bands: list[HLSBands] | None = None, + **kwargs, +) -> PrithviViT: - return model + return _create_prithvi("prithvi_eo_v1_100", pretrained=pretrained, model_bands=bands, **kwargs) -def create_prithvi_from_config( - model_name: str, +@ TERRATORCH_BACKBONE_REGISTRY.register +def prithvi_eo_v2_300( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - pretrained_bands = PRETRAINED_BANDS - if bands is None: - bands = pretrained_bands - logger.info( - f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\ - Pretrained patch_embed layer may be misaligned with current bands" - ) - - kwargs['num_frames'] = kwargs.pop('num_frames', 1) # Set num frames to 1 if not present - - model = _create_prithvi( - model_name, - pretrained=pretrained, - model_bands=bands, - pretrained_bands=pretrained_bands, - **kwargs, - ) - - return model + return _create_prithvi("prithvi_eo_v2_300", pretrained=pretrained, model_bands=bands, **kwargs) -@register_model -def prithvi_vit_tiny( + +@ TERRATORCH_BACKBONE_REGISTRY.register +def prithvi_eo_v2_600( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - logger.warning(f'The model prithvi_vit_tiny was renamed to prithvi_eo_tiny. ' - f'prithvi_vit_tiny will be removed in a future version.') + return _create_prithvi("prithvi_eo_v2_600", pretrained=pretrained, model_bands=bands, **kwargs) - return prithvi_eo_tiny(pretrained=pretrained, bands=bands, **kwargs) - -@register_model -def prithvi_eo_tiny( +@ TERRATORCH_BACKBONE_REGISTRY.register +def prithvi_eo_v2_300_tl( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_from_config("prithvi_eo_tiny", pretrained, bands, **kwargs) + return _create_prithvi("prithvi_eo_v2_300_tl", pretrained=pretrained, model_bands=bands, **kwargs) -@register_model -def prithvi_vit_100( +@ TERRATORCH_BACKBONE_REGISTRY.register +def prithvi_eo_v2_600_tl( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - logger.warning(f'The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. ' - f'prithvi_vit_100 will be removed in a future version.') + return _create_prithvi("prithvi_eo_v2_600_tl", pretrained=pretrained, model_bands=bands, **kwargs) - return prithvi_eo_v1_100(pretrained=pretrained, bands=bands, **kwargs) - -@register_model -def prithvi_eo_v1_100( +# TODO: Remove prithvi_vit_tiny and prithvi_vit_100 before version 1.0. +@ TERRATORCH_BACKBONE_REGISTRY.register +def prithvi_vit_tiny( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_from_config("prithvi_eo_v1_100", pretrained, bands, **kwargs) + warnings.warn(f"The model prithvi_vit_tiny was renamed to prithvi_eo_tiny. " + f"prithvi_vit_tiny will be removed in a future version.", FutureWarning) + return prithvi_eo_tiny(pretrained=pretrained, model_bands=bands, **kwargs) -@register_model -def prithvi_eo_v2_300( + +@ TERRATORCH_BACKBONE_REGISTRY.register +def prithvi_vit_100( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: + warnings.warn("The model prithvi_vit_100 was renamed to prithvi_eo_v1_100. " + "prithvi_vit_100 will be removed in a future version.", FutureWarning) - return create_prithvi_from_config("prithvi_eo_v2_300", pretrained, bands, **kwargs) + return prithvi_eo_v1_100(pretrained=pretrained, model_bands=bands, **kwargs) -@register_model -def prithvi_eo_v2_600( +# TODO: Remove timm_ errors before version v1.0. +@ TERRATORCH_BACKBONE_REGISTRY.register +def timm_prithvi_eo_v1_100( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, -) -> PrithviViT: +) -> None: + raise ValueError("The Prithvi models were moved to the terratorch registry. " + "Please remove the timm_ prefix from the model name.") - return create_prithvi_from_config("prithvi_eo_v2_600", pretrained, bands, **kwargs) +@ TERRATORCH_BACKBONE_REGISTRY.register +def timm_prithvi_eo_v2_300( + pretrained: bool = False, # noqa: FBT001, FBT002 + bands: list[HLSBands] | None = None, + **kwargs, +) -> None: + raise ValueError("The Prithvi models were moved to the terratorch registry. " + "Please remove the timm_ prefix from the model name.") -@register_model -def prithvi_eo_v2_300_tl( +@ TERRATORCH_BACKBONE_REGISTRY.register +def timm_prithvi_eo_v2_600( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, -) -> PrithviViT: +) -> None: + raise ValueError("The Prithvi models were moved to the terratorch registry. " + "Please remove the timm_ prefix from the model name.") - return create_prithvi_from_config("prithvi_eo_v2_300_tl", pretrained, bands, **kwargs) +@ TERRATORCH_BACKBONE_REGISTRY.register +def timm_prithvi_eo_v2_300_tl( + pretrained: bool = False, # noqa: FBT001, FBT002 + bands: list[HLSBands] | None = None, + **kwargs, +) -> None: + raise ValueError("The Prithvi models were moved to the terratorch registry. " + "Please remove the timm_ prefix from the model name.") -@register_model -def prithvi_eo_v2_600_tl( +@ TERRATORCH_BACKBONE_REGISTRY.register +def timm_prithvi_eo_v2_600_tl( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, -) -> PrithviViT: - - return create_prithvi_from_config("prithvi_eo_v2_600_tl", pretrained, bands, **kwargs) +) -> None: + raise ValueError("The Prithvi models were moved to the terratorch registry. " + "Please remove the timm_ prefix from the model name.") diff --git a/terratorch/models/backbones/select_patch_embed_weights.py b/terratorch/models/backbones/select_patch_embed_weights.py index 38dde626..9fdf9029 100644 --- a/terratorch/models/backbones/select_patch_embed_weights.py +++ b/terratorch/models/backbones/select_patch_embed_weights.py @@ -19,7 +19,7 @@ def patch_embed_weights_are_compatible(model_patch_embed: torch.Tensor, checkpoi return model_shape == checkpoint_shape def select_patch_embed_weights( - state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int | OpticalBands| SARBands], model_bands: list[HLSBands | int | OpticalBands| SARBands], custom_proj_key: str = None + state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int | OpticalBands| SARBands], model_bands: list[HLSBands | int | OpticalBands| SARBands], proj_key: str | None = None ) -> dict: """Filter out the patch embedding weights according to the bands being used. If a band exists in the pretrained_bands, but not in model_bands, drop it. @@ -31,39 +31,25 @@ def select_patch_embed_weights( model (nn.Module): Model to load the weights onto. pretrained_bands (list[HLSBands | int]): List of bands the model was pretrained on, in the correct order. model_bands (list[HLSBands | int]): List of bands the model is going to be finetuned on, in the correct order + proj_key (str, optional): Key to patch embedding projection weight in state_dict. Returns: dict: New state dict """ - if (type(pretrained_bands) == type(model_bands)) | (type(pretrained_bands) == int) | (type(model_bands) == int): + if (type(pretrained_bands) == type(model_bands)) | (type(pretrained_bands) == int) | (type(model_bands) == int): - if custom_proj_key is None: - _possible_keys_for_proj_weight = { - "patch_embed.proj.weight", - "module.patch_embed.proj.weight", - "patch_embed.projection.weight", - "module.patch_embed.projection.weight", - } - else: - _possible_keys_for_proj_weight = {custom_proj_key} - - patch_embed_proj_weight_key = state_dict.keys() & _possible_keys_for_proj_weight if (type(state_dict) in [collections.OrderedDict, dict]) else state_dict().keys() & _possible_keys_for_proj_weight - if len(patch_embed_proj_weight_key) == 0: - msg = "Could not find key for patch embed weight" - raise Exception(msg) - if len(patch_embed_proj_weight_key) > 1: - msg = "Too many matches for key for patch embed weight" - raise Exception(msg) - - # extract the single element from the set - if isinstance(patch_embed_proj_weight_key, tuple): - (patch_embed_proj_weight_key,) = patch_embed_proj_weight_key - elif isinstance(patch_embed_proj_weight_key, set): - patch_embed_proj_weight_key = list(patch_embed_proj_weight_key)[0] + if proj_key is None: + # Search for patch embedding weight in state dict + for key in state_dict.keys(): + if key.endswith('patch_embed.proj.weight') or key.endswith('patch_embed.projection.weight'): + proj_key = key + break + if proj_key is None or proj_key not in state_dict: + raise Exception("Could not find key for patch embed weight in state_dict.") - patch_embed_weight = state_dict[patch_embed_proj_weight_key] + patch_embed_weight = state_dict[proj_key] - temp_weight = model.state_dict()[patch_embed_proj_weight_key].clone() + temp_weight = model.state_dict()[proj_key].clone() # only do this if the patch size and tubelet size match. If not, start with random weights if patch_embed_weights_are_compatible(temp_weight, patch_embed_weight): @@ -80,6 +66,6 @@ def select_patch_embed_weights( stacklevel=1, ) - state_dict[patch_embed_proj_weight_key] = temp_weight + state_dict[proj_key] = temp_weight return state_dict diff --git a/terratorch/models/wxc_model_factory.py b/terratorch/models/wxc_model_factory.py index f446509a..f8d6fd19 100644 --- a/terratorch/models/wxc_model_factory.py +++ b/terratorch/models/wxc_model_factory.py @@ -61,6 +61,7 @@ def build_model( raise #remove parameters not meant for the backbone but for other parts of the model + logger.trace(kwargs) skip_connection = kwargs.pop('skip_connection') backbone = prithviwxc.PrithviWxC(**kwargs) diff --git a/terratorch/tasks/__init__.py b/terratorch/tasks/__init__.py index 790c10ec..782b0f08 100644 --- a/terratorch/tasks/__init__.py +++ b/terratorch/tasks/__init__.py @@ -1,3 +1,4 @@ +import logging from terratorch.tasks.classification_tasks import ClassificationTask from terratorch.tasks.regression_tasks import PixelwiseRegressionTask from terratorch.tasks.segmentation_tasks import SemanticSegmentationTask @@ -6,6 +7,8 @@ try: wxc_present = True from terratorch.tasks.wxc_downscaling_task import WxCDownscalingTask + from terratorch.tasks.wxc_task import WxCTask + logging.getLogger('terratorch').debug('wxc_downscaling found.') except ImportError as e: import logging logging.getLogger('terratorch').debug('wxc_downscaling not installed') @@ -21,4 +24,4 @@ ) if wxc_present: - __all__.__add__(("WxCDownscalingTask", )) + __all__.__add__(("WxCDownscalingTask", "WxCTask",)) diff --git a/terratorch/tasks/base_task.py b/terratorch/tasks/base_task.py index 9c7abcc7..c04d7f22 100644 --- a/terratorch/tasks/base_task.py +++ b/terratorch/tasks/base_task.py @@ -91,8 +91,9 @@ def on_validation_epoch_end(self) -> None: self.val_metrics.reset() def on_test_epoch_end(self) -> None: - self.log_dict(self.test_metrics.compute(), sync_dist=True) - self.test_metrics.reset() + for metrics in self.test_metrics: + self.log_dict(metrics.compute(), sync_dist=True) + metrics.reset() def _do_plot_samples(self, batch_index): if not self.plot_on_val: # dont plot if self.plot_on_val is 0 diff --git a/terratorch/tasks/classification_tasks.py b/terratorch/tasks/classification_tasks.py index ccbe3d25..60b33235 100644 --- a/terratorch/tasks/classification_tasks.py +++ b/terratorch/tasks/classification_tasks.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Any import logging import lightning @@ -34,6 +35,7 @@ class ClassificationTask(TerraTorchTask): - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - It provides mIoU with both Micro and Macro averaging + - Allows to evaluate on multiple test dataloaders .. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect @@ -64,6 +66,7 @@ def __init__( freeze_decoder: bool = False, # noqa: FBT002, FBT001 freeze_head: bool = False, # noqa: FBT002, FBT001 class_names: list[str] | None = None, + test_dataloaders_names: list[str] | None = None, lr_overrides: dict[str, float] | None = None, ) -> None: """Constructor @@ -101,6 +104,9 @@ def __init__( freeze_head (bool, optional): Whether to freeze the segmentation_head. Defaults to False. class_names (list[str] | None, optional): List of class names passed to metrics for better naming. Defaults to numeric ordering. + test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when + multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, + which assumes only one test dataloader is used. lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None. @@ -123,7 +129,9 @@ def __init__( self.model = model self.train_loss_handler = LossHandler(self.train_metrics.prefix) - self.test_loss_handler = LossHandler(self.test_metrics.prefix) + self.test_loss_handler: list[LossHandler] = [] + for metrics in self.test_metrics: + self.test_loss_handler.append(LossHandler(metrics.prefix)) self.val_loss_handler = LossHandler(self.val_metrics.prefix) self.monitor = f"{self.val_metrics.prefix}loss" @@ -193,7 +201,12 @@ def configure_metrics(self) -> None: ) self.train_metrics = metrics.clone(prefix="train/") self.val_metrics = metrics.clone(prefix="val/") - self.test_metrics = metrics.clone(prefix="test/") + if self.hparams["test_dataloaders_names"] is not None: + self.test_metrics = nn.ModuleList( + [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]] + ) + else: + self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")]) def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the train loss and additional metrics. @@ -247,10 +260,17 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None other_keys = batch.keys() - {"image", "label", "filename"} rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) - loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + if dataloader_idx >= len(self.test_loss_handler): + msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." + raise ValueError(msg) + loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss) + self.test_loss_handler[dataloader_idx].log_loss( + partial(self.log, add_dataloader_idx=False), # We don't need the dataloader idx as prefixes are different + loss_dict=loss, + batch_size=x.shape[0], + ) y_hat_hard = to_class_prediction(model_output) - self.test_metrics.update(y_hat_hard, y) + self.test_metrics[dataloader_idx].update(y_hat_hard, y) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities. diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index 3c9157f1..9849b0b2 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -1,6 +1,7 @@ """This module contains the regression task and its auxiliary classes.""" from collections.abc import Sequence +from functools import partial from typing import Any import logging @@ -130,7 +131,8 @@ class PixelwiseRegressionTask(TerraTorchTask): - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - - Allows the setting of optimizers in the constructor""" + - Allows the setting of optimizers in the constructor + - Allows to evaluate on multiple test dataloaders""" def __init__( self, @@ -154,6 +156,7 @@ def __init__( freeze_head: bool = False, # noqa: FBT001, FBT002 plot_on_val: bool | int = 10, tiled_inference_parameters: TiledInferenceParameters | None = None, + test_dataloaders_names: list[str] | None = None, lr_overrides: dict[str, float] | None = None, ) -> None: """Constructor @@ -190,6 +193,9 @@ def __init__( If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs. tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters used to determine if inference is done on the whole image or through tiling. + test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when + multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, + which assumes only one test dataloader is used. lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None. @@ -213,7 +219,9 @@ def __init__( self.model = model self.train_loss_handler = LossHandler(self.train_metrics.prefix) - self.test_loss_handler = LossHandler(self.test_metrics.prefix) + self.test_loss_handler: list[LossHandler] = [] + for metrics in self.test_metrics: + self.test_loss_handler.append(LossHandler(metrics.prefix)) self.val_loss_handler = LossHandler(self.val_metrics.prefix) self.monitor = f"{self.val_metrics.prefix}loss" self.plot_on_val = int(plot_on_val) @@ -260,7 +268,17 @@ def wrap_metrics_with_ignore_index(metrics): self.train_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="train/") self.val_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="val/") - self.test_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/") + if self.hparams["test_dataloaders_names"] is not None: + self.test_metrics = nn.ModuleList( + [ + MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix=f"test/{dl_name}/") + for dl_name in self.hparams["test_dataloaders_names"] + ] + ) + else: + self.test_metrics = nn.ModuleList( + [MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/")] + ) def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the train loss and additional metrics. @@ -338,10 +356,17 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None other_keys = batch.keys() - {"image", "mask", "filename"} rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) - loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + if dataloader_idx >= len(self.test_loss_handler): + msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." + raise ValueError(msg) + loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss) + self.test_loss_handler[dataloader_idx].log_loss( + partial(self.log, add_dataloader_idx=False), # We don't need the dataloader idx as prefixes are different + loss_dict=loss, + batch_size=x.shape[0], + ) y_hat = model_output.output - self.test_metrics.update(y_hat, y) + self.test_metrics[dataloader_idx].update(y_hat, y) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities. diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 4a689e0a..f8ab82af 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -270,11 +270,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None y_hat_hard = to_segmentation_prediction(model_output) self.test_metrics[dataloader_idx].update(y_hat_hard, y) - def on_test_epoch_end(self) -> None: - for metrics in self.test_metrics: - self.log_dict(metrics.compute(), sync_dist=True) - metrics.reset() - def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the validation loss and additional metrics. Args: diff --git a/terratorch/tasks/wxc_task.py b/terratorch/tasks/wxc_task.py index 87312a9a..3e4b4421 100644 --- a/terratorch/tasks/wxc_task.py +++ b/terratorch/tasks/wxc_task.py @@ -1,17 +1,19 @@ - - from torchgeo.trainers import BaseTask import torch.nn as nn import torch import logging logger = logging.getLogger(__name__) +from terratorch.registry import MODEL_FACTORY_REGISTRY + class WxCTask(BaseTask): - def __init__(self, model_factory, model_args: dict, mode, learning_rate=0.1): + def __init__(self, model_factory, model_args: dict, mode:str='train', learning_rate=0.1): if mode not in ['train', 'eval']: raise ValueError(f'mode {mode} is not supported. (train, eval)') self.model_args = model_args - self.model_factory = model_factory + + self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) + self.learning_rate = learning_rate super().__init__() @@ -34,4 +36,4 @@ def training_step(self, batch, batch_idx): def train_dataloader(self): return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True) - \ No newline at end of file + diff --git a/tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml b/tests/resources/configs/manufactured-finetune_prithvi_eo_v1_100.yaml similarity index 99% rename from tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml rename to tests/resources/configs/manufactured-finetune_prithvi_eo_v1_100.yaml index bb652415..49f8186b 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_eo_v1_100.yaml @@ -96,7 +96,7 @@ model: model_args: decoder: UperNetDecoder pretrained: false - backbone: prithvi_vit_100 + backbone: prithvi_eo_v1_100 #backbone_pretrained_cfg_overlay: #file: tests/all_ecos_random/version_0/checkpoints/epoch=0_state_dict.ckpt #tests/prithvi_vit_100.pt backbone_drop_path_rate: 0.3 diff --git a/tests/manufactured-finetune_prithvi_swin_B_segmentation.yaml b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml similarity index 100% rename from tests/manufactured-finetune_prithvi_swin_B_segmentation.yaml rename to tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml diff --git a/tests/resources/configs/manufactured-finetune_prithvi_vit_300.yaml b/tests/resources/configs/manufactured-finetune_prithvi_vit_300.yaml deleted file mode 100644 index 3e44a1c5..00000000 --- a/tests/resources/configs/manufactured-finetune_prithvi_vit_300.yaml +++ /dev/null @@ -1,150 +0,0 @@ -# lightning.pytorch==2.1.1 -seed_everything: 42 -trainer: - accelerator: cpu - strategy: auto - devices: auto - num_nodes: 1 - # precision: 16-mixed - logger: - class_path: TensorBoardLogger - init_args: - save_dir: tests/ - name: all_ecos_random - callbacks: - - class_path: RichProgressBar - - class_path: LearningRateMonitor - init_args: - logging_interval: epoch - - class_path: EarlyStopping - init_args: - monitor: val/loss - patience: 100 - max_epochs: 2 - check_val_every_n_epoch: 1 - log_every_n_steps: 20 - enable_checkpointing: true - default_root_dir: tests/ -data: - class_path: GenericNonGeoPixelwiseRegressionDataModule - init_args: - batch_size: 2 - num_workers: 4 - train_transform: - #- class_path: albumentations.HorizontalFlip - # init_args: - # p: 0.5 - #- class_path: albumentations.Rotate - # init_args: - # limit: 30 - # border_mode: 0 # cv2.BORDER_CONSTANT - # value: 0 - # # mask_value: 1 - # p: 0.5 - - class_path: ToTensorV2 - dataset_bands: - - 0 - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - - 1 - - 2 - - 3 - - 4 - output_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - rgb_indices: - - 2 - - 1 - - 0 - train_data_root: tests/resources/inputs - train_label_data_root: tests/resources/inputs - val_data_root: tests/resources/inputs - val_label_data_root: tests/resources/inputs - test_data_root: tests/resources/inputs - test_label_data_root: tests/resources/inputs - img_grep: "regression*input*.tif" - label_grep: "regression*label*.tif" - means: - - 547.36707 - - 898.5121 - - 1020.9082 - - 2665.5352 - - 2340.584 - - 1610.1407 - stds: - - 411.4701 - - 558.54065 - - 815.94025 - - 812.4403 - - 1113.7145 - - 1067.641 - no_label_replace: -1 - no_data_replace: 0 - -model: - class_path: terratorch.tasks.PixelwiseRegressionTask - init_args: - model_args: - decoder: UperNetDecoder - pretrained: false - backbone: prithvi_eo_v2_300 - # backbone_pretrained_cfg_overlay: - # file: tests/prithvi_vit_300.pt - backbone_drop_path_rate: 0.3 - # backbone_window_size: 8 - decoder_channels: 64 - num_frames: 1 - in_channels: 6 - bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - head_dropout: 0.5708022831486758 - head_final_act: torch.nn.ReLU - head_learned_upscale_layers: 2 - loss: rmse - #aux_heads: - # - name: aux_head - # decoder: IdentityDecoder - # decoder_args: - # decoder_out_index: 2 - # head_dropout: 0,5 - # head_channel_list: - # - 64 - # head_final_act: torch.nn.ReLU - #aux_loss: - # aux_head: 0.4 - ignore_index: -1 - freeze_backbone: true - freeze_decoder: false - model_factory: PrithviModelFactory - - # uncomment this block for tiled inference - # tiled_inference_parameters: - # h_crop: 224 - # h_stride: 192 - # w_crop: 224 - # w_stride: 192 - # average_patches: true -optimizer: - class_path: torch.optim.AdamW - init_args: - lr: 0.00013524680528283027 - weight_decay: 0.047782217873995426 -lr_scheduler: - class_path: ReduceLROnPlateau - init_args: - monitor: val/loss - diff --git a/tests/test_backbones.py b/tests/test_backbones.py index fd562adc..546250d9 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -35,7 +35,7 @@ def input_386(): return torch.ones((1, NUM_CHANNELS, 386, 386)) -@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) +@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"]) @pytest.mark.parametrize("test_input", ["input_224", "input_512"]) def test_can_create_backbones_from_timm(model_name, test_input, request): backbone = timm.create_model(model_name, pretrained=False) @@ -43,7 +43,7 @@ def test_can_create_backbones_from_timm(model_name, test_input, request): backbone(input_tensor) gc.collect() -@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) +@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"]) @pytest.mark.parametrize("test_input", ["input_224", "input_512"]) def test_can_create_backbones_from_timm_features_only(model_name, test_input, request): backbone = timm.create_model(model_name, pretrained=False, features_only=True) @@ -51,36 +51,37 @@ def test_can_create_backbones_from_timm_features_only(model_name, test_input, re backbone(input_tensor) gc.collect() -@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) +@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_L", "prithvi_swin_B"]) @pytest.mark.parametrize("prefix", ["", "timm_"]) def test_can_create_timm_backbones_from_registry(model_name, input_224, prefix): backbone = BACKBONE_REGISTRY.build(prefix+model_name, pretrained=False) backbone(input_224) gc.collect() + @pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"]) -def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal): - backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES) - backbone(input_224_multitemporal) +def test_can_create_backbones_from_registry(model_name, input_224): + backbone = BACKBONE_REGISTRY.build(model_name, pretrained=False) + backbone(input_224) gc.collect() + @pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"]) -def test_vit_models_non_divisible_input(model_name, input_non_divisible): - #padding 'none','constant', 'reflect', 'replicate' or 'circular' default is 'none' - backbone = timm.create_model(model_name, pretrained=False, features_only=True, num_frames=NUM_FRAMES, padding='constant') - backbone(input_non_divisible) +def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal): + backbone = BACKBONE_REGISTRY.build(model_name, pretrained=False, num_frames=NUM_FRAMES) + backbone(input_224_multitemporal) gc.collect() + @pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"]) @pytest.mark.parametrize("patch_size", [8, 16]) @pytest.mark.parametrize("patch_size_time", [1, 2, 4]) def test_vit_models_different_patch_tubelet_sizes(model_name, patch_size, patch_size_time, input_224_multitemporal): - backbone = timm.create_model( + backbone = BACKBONE_REGISTRY.build( model_name, pretrained=False, num_frames=NUM_FRAMES, patch_size=[patch_size_time, patch_size, patch_size], - features_only=True, ) embedding = backbone(input_224_multitemporal) processed_embedding = backbone.prepare_features_for_image_model(embedding) @@ -105,10 +106,9 @@ def test_vit_models_different_patch_tubelet_sizes(model_name, patch_size, patch_ gc.collect() @pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"]) def test_out_indices(model_name, input_224): - # out_indices = [2, 4, 8, 10] out_indices = (2, 4, 8, 10) - backbone = timm.create_model(model_name, pretrained=False, features_only=True, out_indices=out_indices) - assert backbone.feature_info.out_indices == out_indices + backbone = BACKBONE_REGISTRY.build(model_name, pretrained=False, out_indices=out_indices) + assert backbone.out_indices == out_indices output = backbone(input_224) full_output = backbone.forward_features(input_224) @@ -116,18 +116,8 @@ def test_out_indices(model_name, input_224): for filtered_index, full_index in enumerate(out_indices): assert torch.allclose(full_output[full_index], output[filtered_index]) gc.collect() -@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"]) -def test_out_indices_non_divisible(model_name, input_non_divisible): - out_indices = [2, 4, 8, 10] - backbone = timm.create_model(model_name, pretrained=False, features_only=True, num_frames=NUM_FRAMES, out_indices=out_indices, padding='constant') - assert backbone.feature_info.out_indices == tuple(out_indices) - output = backbone(input_non_divisible) - full_output = backbone.forward_features(input_non_divisible) - for filtered_index, full_index in enumerate(out_indices): - assert torch.allclose(full_output[full_index], output[filtered_index]) - gc.collect() @pytest.mark.parametrize("model_name", ["vit_base_patch16", "vit_large_patch16"]) def test_scale_mae(model_name): # out_indices = [2, 4, 8, 10] @@ -139,6 +129,8 @@ def test_scale_mae(model_name): assert len(output) == len(out_indices) gc.collect() + + @pytest.mark.parametrize("model_name", ["vit_base_patch16", "vit_large_patch16"]) @pytest.mark.parametrize("bands", [2, 4, 6]) def test_scale_mae_new_channels(model_name, bands): diff --git a/tests/test_finetune.py b/tests/test_finetune.py index 9c06e8da..c1652173 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -6,10 +6,11 @@ import torch from terratorch.cli_tools import build_lightning_cli +from terratorch.registry import BACKBONE_REGISTRY @pytest.fixture(autouse=True) def setup_and_cleanup(model_name): - model_instance = timm.create_model(model_name) + model_instance = BACKBONE_REGISTRY.build(model_name) state_dict = model_instance.state_dict() @@ -22,7 +23,7 @@ def setup_and_cleanup(model_name): if os.path.isdir(os.path.join("tests", "all_ecos_random")): shutil.rmtree(os.path.join("tests", "all_ecos_random")) -@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_eo_v2_300", "prithvi_eo_v2_600"]) +@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v2_600"]) @pytest.mark.parametrize("case", ["fit", "test", "validate"]) def test_finetune_multiple_backbones(model_name, case): command_list = [case, "-c", f"tests/resources/configs/manufactured-finetune_{model_name}.yaml"] diff --git a/tests/test_prithvi_vit.py b/tests/test_prithvi_vit.py index 659812f5..86316599 100644 --- a/tests/test_prithvi_vit.py +++ b/tests/test_prithvi_vit.py @@ -3,26 +3,25 @@ from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights +from terratorch.registry import BACKBONE_REGISTRY import gc @pytest.mark.parametrize("patch_size", [4, 8, 16]) @pytest.mark.parametrize("patch_size_time,num_frames", [(1, 1), (1, 2), (1, 3), (2, 2), (3,3)]) def test_prithvi_vit_patch_embed_loading_compatible(patch_size, patch_size_time, num_frames): - model = timm.create_model( + model = BACKBONE_REGISTRY.build( "prithvi_eo_v1_100", pretrained=False, num_frames=num_frames, patch_size=[patch_size_time, 16, 16], - features_only=True, ) - weights = timm.create_model( + weights = BACKBONE_REGISTRY.build( "prithvi_eo_v1_100", pretrained=False, num_frames=num_frames, patch_size=[patch_size_time, 16, 16], - features_only=True, ).state_dict() select_patch_embed_weights(weights, model, PRETRAINED_BANDS, PRETRAINED_BANDS) @@ -31,20 +30,18 @@ def test_prithvi_vit_patch_embed_loading_compatible(patch_size, patch_size_time, @pytest.mark.parametrize("patch_size_time,patch_size_time_other", [(1, 2), (2, 4)]) def test_prithvi_vit_patch_embed_loading_time_patch_size_other(patch_size_time,patch_size_time_other): - model = timm.create_model( + model = BACKBONE_REGISTRY.build( "prithvi_eo_v1_100", pretrained=False, num_frames=4, patch_size=[patch_size_time, 16, 16], - features_only=True, ) - weights = timm.create_model( + weights = BACKBONE_REGISTRY.build( "prithvi_eo_v1_100", pretrained=False, num_frames=4, patch_size=[patch_size_time_other, 16, 16], - features_only=True, ).state_dict() # assert warning produced @@ -55,20 +52,18 @@ def test_prithvi_vit_patch_embed_loading_time_patch_size_other(patch_size_time,p @pytest.mark.parametrize("patch_size,patch_size_other", [(2, 4), (4, 8), (16, 4)]) def test_prithvi_vit_patch_embed_loading_not_compatible_patch(patch_size, patch_size_other): - model = timm.create_model( + model = BACKBONE_REGISTRY.build( "prithvi_eo_v1_100", pretrained=False, num_frames=1, patch_size=patch_size, - features_only=True, ) - weights = timm.create_model( + weights = BACKBONE_REGISTRY.build( "prithvi_eo_v1_100", pretrained=False, num_frames=1, patch_size=patch_size_other, - features_only=True, ).state_dict() with pytest.warns(UserWarning):