Skip to content

Commit

Permalink
merging
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Jan 6, 2025
2 parents 1c409e8 + b730eba commit 0d79f8e
Show file tree
Hide file tree
Showing 23 changed files with 2,495 additions and 72 deletions.
36 changes: 29 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,27 @@
TerraTorch is a library based on [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) and the [TorchGeo](https://github.com/microsoft/torchgeo) domain library
for geospatial data.

TerraTorch’s main purpose is to provide a flexible fine-tuning framework for Geospatial Foundation Models, which can be interacted with at different abstraction levels.
TerraTorch’s main purpose is to provide a flexible fine-tuning framework for Geospatial Foundation Models, which can be interacted with at different abstraction levels. The library provides:

The library provides:

- Easy access to open source pre-trained Geospatial Foundation Model backbones (e.g., [Prithvi](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M), [SatMAE](https://sustainlab-group.github.io/SatMAE/) and [ScaleMAE](https://github.com/bair-climate-initiative/scale-mae), other backbones available in the [timm](https://github.com/huggingface/pytorch-image-models) (Pytorch image models) or [SMP](https://github.com/qubvel/segmentation_models.pytorch) (Pytorch Segmentation models with pre-training backbones) packages, as well as fine-tuned models such as [granite-geospatial-biomass](https://huggingface.co/ibm-granite/granite-geospatial-biomass)
- Flexible trainers for Image Segmentation, Classification and Pixel Wise Regression fine-tuning tasks
- Launching of fine-tuning tasks through flexible configuration files
- Convenient modelling tools:
- Flexible trainers for Image Segmentation, Classification and Pixel Wise Regression fine-tuning tasks
- Model factories that allow to easily combine backbones and decoders for different tasks
- Ready-to-go datasets and datamodules that require only to point to your data with no need of creating new custom classes
- Launching of fine-tuning tasks through CLI and flexible configuration files, or via jupyter notebooks
- Easy access to:
- Open source pre-trained Geospatial Foundation Model backbones:
* [Prithvi](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M)
* [SatMAE](https://sustainlab-group.github.io/SatMAE/)
* [ScaleMAE](https://github.com/bair-climate-initiative/scale-mae)
* Satlas (as implemented in [TorchGeo](https://github.com/microsoft/torchgeo))
* DOFA (as implemented in [TorchGeo](https://github.com/microsoft/torchgeo))
* SSL4EO-L and SSL4EO-S12 models (as implemented in [TorchGeo](https://github.com/microsoft/torchgeo))
* Clay
- Backbones available in the [timm](https://github.com/huggingface/pytorch-image-models) (Pytorch image models)
- Decoders available in [SMP](https://github.com/qubvel/segmentation_models.pytorch) (Pytorch Segmentation models with pre-training backbones) and [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) packages
- Fine-tuned models such as [granite-geospatial-biomass](https://huggingface.co/ibm-granite/granite-geospatial-biomass)
- All GEO-Bench datasets and datamodules
- All [TorchGeo](https://github.com/microsoft/torchgeo) datasets and datamodules

## Install
### Pip
Expand All @@ -26,7 +40,15 @@ To get the most recent version of the main branch, install the library with `pip

TerraTorch requires gdal to be installed, which can be quite a complex process. If you don't have GDAL set up on your system, we reccomend using a conda environment and installing it with `conda install -c conda-forge gdal`.

To install as a developer (e.g. to extend the library) clone this repo, install dependencies using `pip install -r requirements/required.txt -r requirements/dev.txt` and run `pip install -e .`
To install as a developer (e.g. to extend the library):
```
git clone https://github.com/IBM/terratorch.git
cd terratorch
pip install -r requirements/required.txt -r requirements/dev.txt
conda install -c conda-forge gdal
pip install -e .
```

To install terratorch with partial (work in development) support for Weather Foundation Models, `pip install -e .[wxc]`, which currently works just for `Python >= 3.11`.

## Quick start
Expand Down
194 changes: 194 additions & 0 deletions examples/confs/dofa_sen1floods11_fcn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
num_sanity_val_steps: 0
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
# precision: bf16
logger:
class_path: TensorBoardLogger
init_args:
save_dir: /dccstor/geofm-finetuning/carlosgomes/torchgeo_floods
name: dofa
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch


max_epochs: 50
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: true
default_root_dir: /dccstor/geofm-finetuning/carlosgomes/torchgeo_floods
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 16
num_workers: 8
constant_scale: 0.0001
dataset_bands:
- COASTAL_AEROSOL
- BLUE
- GREEN
- RED
- RED_EDGE_1
- RED_EDGE_2
- RED_EDGE_3
- NIR_BROAD
- NIR_NARROW
- WATER_VAPOR
- CIRRUS
- SWIR_1
- SWIR_2
output_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
train_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/
train_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
val_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/
val_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
test_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/
test_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
# these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files
train_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_train_data_S2.txt
test_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_test_data_S2.txt
val_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_valid_data_S2.txt
img_grep: "*_S2Hand.tif"
label_grep: "*_LabelHand.tif"
no_label_replace: -1
no_data_replace: 0
means:
- 0.1412956
- 0.13795798
- 0.12353792
- 0.30902815
- 0.2044958
- 0.11912015
stds:
- 0.07406382
- 0.07370365
- 0.08692279
- 0.11798815
- 0.09772074
- 0.07659938
num_classes: 2
# train_transform:
# - class_path: albumentations.RandomCrop
# init_args:
# height: 224
# width: 224
# - class_path: albumentations.HorizontalFlip
# init_args:
# p: 0.5
# - class_path: ToTensorV2
# val_transform:
# - class_path: albumentations.RandomCrop
# init_args:
# height: 224
# width: 224
# - class_path: ToTensorV2
# test_transform:
# - class_path: albumentations.CenterCrop
# init_args:
# height: 224
# width: 224
# - class_path: ToTensorV2



# class_path: terratorch.datamodules.sen1floods11.Sen1Floods11NonGeoDataModule
# init_args:
# batch_size: 8
# num_workers: 8
# train_aug:
# - class_path: albumentations.RandomCrop
# init_args:
# height: 224
# width: 224
# - class_path: albumentations.HorizontalFlip
# init_args:
# p: 0.5
# - class_path: ToTensorV2
# val_aug:
# - class_path: albumentations.RandomCrop
# init_args:
# height: 224
# width: 224
# - class_path: ToTensorV2

# dict_kwargs:
# data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/
# bands:
# - 1
# - 2
# - 3
# - 8
# - 11
# - 12

model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
decoder: FCNDecoder
backbone_pretrained: True
backbone_img_size: 512
backbone: dofa_large_patch16_224
# backbone_pretrain_img_size: 512
# decoder_scale_modules: True
# decoder_in_channels: 1024
decoder_channels: 256
# backbone_in_channels: 6
backbone_model_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
# num_frames: 1
num_classes: 2
head_dropout: 0.1
head_channel_list:
- 256
necks:
- name: SelectIndices
indices:
- -1
- name: ReshapeTokensToImage
loss: ce

ignore_index: -1
class_weights:
- 0.3
- 0.7
freeze_backbone: false
freeze_decoder: false
model_factory: EncoderDecoderFactory
tiled_inference_parameters:
h_crop: 224
h_stride: 196
w_crop: 224
w_stride: 196
average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 6.e-5
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss
Loading

0 comments on commit 0d79f8e

Please sign in to comment.