Skip to content

Commit

Permalink
Merge branch 'IBM:main' into 201
Browse files Browse the repository at this point in the history
  • Loading branch information
romeokienzler authored Jan 6, 2025
2 parents 63edb29 + b730eba commit 97fbeed
Show file tree
Hide file tree
Showing 24 changed files with 2,499 additions and 75 deletions.
36 changes: 29 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,27 @@
TerraTorch is a library based on [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) and the [TorchGeo](https://github.com/microsoft/torchgeo) domain library
for geospatial data.

TerraTorch’s main purpose is to provide a flexible fine-tuning framework for Geospatial Foundation Models, which can be interacted with at different abstraction levels.
TerraTorch’s main purpose is to provide a flexible fine-tuning framework for Geospatial Foundation Models, which can be interacted with at different abstraction levels. The library provides:

The library provides:

- Easy access to open source pre-trained Geospatial Foundation Model backbones (e.g., [Prithvi](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M), [SatMAE](https://sustainlab-group.github.io/SatMAE/) and [ScaleMAE](https://github.com/bair-climate-initiative/scale-mae), other backbones available in the [timm](https://github.com/huggingface/pytorch-image-models) (Pytorch image models) or [SMP](https://github.com/qubvel/segmentation_models.pytorch) (Pytorch Segmentation models with pre-training backbones) packages, as well as fine-tuned models such as [granite-geospatial-biomass](https://huggingface.co/ibm-granite/granite-geospatial-biomass)
- Flexible trainers for Image Segmentation, Classification and Pixel Wise Regression fine-tuning tasks
- Launching of fine-tuning tasks through flexible configuration files
- Convenient modelling tools:
- Flexible trainers for Image Segmentation, Classification and Pixel Wise Regression fine-tuning tasks
- Model factories that allow to easily combine backbones and decoders for different tasks
- Ready-to-go datasets and datamodules that require only to point to your data with no need of creating new custom classes
- Launching of fine-tuning tasks through CLI and flexible configuration files, or via jupyter notebooks
- Easy access to:
- Open source pre-trained Geospatial Foundation Model backbones:
* [Prithvi](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M)
* [SatMAE](https://sustainlab-group.github.io/SatMAE/)
* [ScaleMAE](https://github.com/bair-climate-initiative/scale-mae)
* Satlas (as implemented in [TorchGeo](https://github.com/microsoft/torchgeo))
* DOFA (as implemented in [TorchGeo](https://github.com/microsoft/torchgeo))
* SSL4EO-L and SSL4EO-S12 models (as implemented in [TorchGeo](https://github.com/microsoft/torchgeo))
* Clay
- Backbones available in the [timm](https://github.com/huggingface/pytorch-image-models) (Pytorch image models)
- Decoders available in [SMP](https://github.com/qubvel/segmentation_models.pytorch) (Pytorch Segmentation models with pre-training backbones) and [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) packages
- Fine-tuned models such as [granite-geospatial-biomass](https://huggingface.co/ibm-granite/granite-geospatial-biomass)
- All GEO-Bench datasets and datamodules
- All [TorchGeo](https://github.com/microsoft/torchgeo) datasets and datamodules

## Install
### Pip
Expand All @@ -26,7 +40,15 @@ To get the most recent version of the main branch, install the library with `pip

TerraTorch requires gdal to be installed, which can be quite a complex process. If you don't have GDAL set up on your system, we reccomend using a conda environment and installing it with `conda install -c conda-forge gdal`.

To install as a developer (e.g. to extend the library) clone this repo, install dependencies using `pip install -r requirements/required.txt -r requirements/dev.txt` and run `pip install -e .`
To install as a developer (e.g. to extend the library):
```
git clone https://github.com/IBM/terratorch.git
cd terratorch
pip install -r requirements/required.txt -r requirements/dev.txt
conda install -c conda-forge gdal
pip install -e .
```

To install terratorch with partial (work in development) support for Weather Foundation Models, `pip install -e .[wxc]`, which currently works just for `Python >= 3.11`.

## Quick start
Expand Down
5 changes: 4 additions & 1 deletion contribution_process.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ If you want to contribute to this project, there are many valuable ways in doing
1. Use / test TerraTorch and create an [Issue](https://github.com/IBM/terratorch/issues) if something is not working properly or if you have an idea for a feature request.
1. Pick an [Issue](https://github.com/IBM/terratorch/issues) and start contributing

Contributions are welcome as pull requests on a [fork](https://github.com/IBM/terratorch/fork) of this project. Ideally, pull requests are backed by an [Issue](https://github.com/IBM/terratorch/issues). You can also tag the [code owners](https://github.com/IBM/terratorch/blob/main/CODEOWNERS) in the issue before you start, so we can talk about the details (in case you can't join one of the community calls).
Contributions are welcome as pull requests on a [fork](https://github.com/IBM/terratorch/fork) of this project. Ideally, pull requests are backed by an [Issue](https://github.com/IBM/terratorch/issues). You can also tag the [code owners](https://github.com/IBM/terratorch/blob/main/CODEOWNERS) in the issue before you start, so we can talk about the details (in case you can't join one of the community calls).

After or during implementation on your branch, please create a PR to main. During development, please mark this PR as DRAFT and prefix with '[WIP]'
If you want us to merge the PR, remove 'draft' and '[WIP]'. Before that, please make sure that all tests are passing. Unit tests are automatically run on GitHub on the branch as well. The TerraTorch committers will review your code and will run integrations tests on our GPU cluster before we merge to main.
194 changes: 194 additions & 0 deletions examples/confs/dofa_sen1floods11_fcn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
num_sanity_val_steps: 0
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
# precision: bf16
logger:
class_path: TensorBoardLogger
init_args:
save_dir: /dccstor/geofm-finetuning/carlosgomes/torchgeo_floods
name: dofa
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch


max_epochs: 50
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: true
default_root_dir: /dccstor/geofm-finetuning/carlosgomes/torchgeo_floods
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 16
num_workers: 8
constant_scale: 0.0001
dataset_bands:
- COASTAL_AEROSOL
- BLUE
- GREEN
- RED
- RED_EDGE_1
- RED_EDGE_2
- RED_EDGE_3
- NIR_BROAD
- NIR_NARROW
- WATER_VAPOR
- CIRRUS
- SWIR_1
- SWIR_2
output_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
train_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/
train_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
val_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/
val_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
test_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2Hand/
test_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
# these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files
train_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_train_data_S2.txt
test_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_test_data_S2.txt
val_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_valid_data_S2.txt
img_grep: "*_S2Hand.tif"
label_grep: "*_LabelHand.tif"
no_label_replace: -1
no_data_replace: 0
means:
- 0.1412956
- 0.13795798
- 0.12353792
- 0.30902815
- 0.2044958
- 0.11912015
stds:
- 0.07406382
- 0.07370365
- 0.08692279
- 0.11798815
- 0.09772074
- 0.07659938
num_classes: 2
# train_transform:
# - class_path: albumentations.RandomCrop
# init_args:
# height: 224
# width: 224
# - class_path: albumentations.HorizontalFlip
# init_args:
# p: 0.5
# - class_path: ToTensorV2
# val_transform:
# - class_path: albumentations.RandomCrop
# init_args:
# height: 224
# width: 224
# - class_path: ToTensorV2
# test_transform:
# - class_path: albumentations.CenterCrop
# init_args:
# height: 224
# width: 224
# - class_path: ToTensorV2



# class_path: terratorch.datamodules.sen1floods11.Sen1Floods11NonGeoDataModule
# init_args:
# batch_size: 8
# num_workers: 8
# train_aug:
# - class_path: albumentations.RandomCrop
# init_args:
# height: 224
# width: 224
# - class_path: albumentations.HorizontalFlip
# init_args:
# p: 0.5
# - class_path: ToTensorV2
# val_aug:
# - class_path: albumentations.RandomCrop
# init_args:
# height: 224
# width: 224
# - class_path: ToTensorV2

# dict_kwargs:
# data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/
# bands:
# - 1
# - 2
# - 3
# - 8
# - 11
# - 12

model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
decoder: FCNDecoder
backbone_pretrained: True
backbone_img_size: 512
backbone: dofa_large_patch16_224
# backbone_pretrain_img_size: 512
# decoder_scale_modules: True
# decoder_in_channels: 1024
decoder_channels: 256
# backbone_in_channels: 6
backbone_model_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
# num_frames: 1
num_classes: 2
head_dropout: 0.1
head_channel_list:
- 256
necks:
- name: SelectIndices
indices:
- -1
- name: ReshapeTokensToImage
loss: ce

ignore_index: -1
class_weights:
- 0.3
- 0.7
freeze_backbone: false
freeze_decoder: false
model_factory: EncoderDecoderFactory
tiled_inference_parameters:
h_crop: 224
h_stride: 196
w_crop: 224
w_stride: 196
average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 6.e-5
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss
Loading

0 comments on commit 97fbeed

Please sign in to comment.