Skip to content

Commit

Permalink
Merge with main
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Jan 27, 2025
2 parents 8493e34 + 9140dc1 commit b5c0c0e
Show file tree
Hide file tree
Showing 47 changed files with 788 additions and 1,010 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
timeout-minutes: 30
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12.7"]

steps:
- name: Clone repo
Expand All @@ -28,7 +28,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements/required.txt -r requirements/test.txt -r requirements/optional.txt
#pip install -r requirements/required.txt -r requirements/test.txt -r requirements/optional.txt
pip install -r requirements/test.txt -r requirements/optional.txt
pip install -e .[torchgeo]
pip install git+https://github.com/NASA-IMPACT/Prithvi-WxC.git
pip install git+https://github.com/IBM/granite-wxc.git
- name: List pip dependencies
Expand Down
109 changes: 48 additions & 61 deletions docs/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,27 @@ In the simplest case, we might only want access a backbone and code all the rest
from terratorch import BACKBONE_REGISTRY

# find available prithvi models
print([model_name for model_name in BACKBONE_REGISTRY if "prithvi" in model_name])
>>> ['timm_prithvi_eo_tiny', 'timm_prithvi_eo_v1_100', 'timm_prithvi_eo_v2_300', 'timm_prithvi_eo_v2_300_tl', 'timm_prithvi_eo_v2_600',
'timm_prithvi_eo_v2_600_tl', 'timm_prithvi_swin_B', 'timm_prithvi_swin_L', 'timm_prithvi_vit_100', 'timm_prithvi_vit_tiny']
print([model_name for model_name in BACKBONE_REGISTRY if "terratorch_prithvi" in model_name])
>>> ['terratorch_prithvi_eo_tiny', 'terratorch_prithvi_eo_v1_100', 'terratorch_prithvi_eo_v2_300', 'terratorch_prithvi_eo_v2_600', 'terratorch_prithvi_eo_v2_300_tl', 'terratorch_prithvi_eo_v2_600_tl']

# show all models with list(BACKBONE_REGISTRY)

# check a model is in the registry
"timm_prithvi_swin_B" in BACKBONE_REGISTRY
"terratorch_prithvi_eo_v2_300" in BACKBONE_REGISTRY
>>> True

# without the prefix, all internal registries will be searched until the first match is found
"prithvi_swin_B" in BACKBONE_REGISTRY
"prithvi_eo_v1_100" in BACKBONE_REGISTRY
>>> True

# instantiate your desired model
# the backbone registry prefix (in this case 'timm') is optional
# in this case, the underlying registry is timm, so we can pass timm arguments to it
model = BACKBONE_REGISTRY.build("prithvi_eo_v1_100", num_frames=1, pretrained=True)
# the backbone registry prefix (e.g. `terratorch` or `timm`) is optional
# in this case, the underlying registry is terratorch.
model = BACKBONE_REGISTRY.build("prithvi_eo_v1_100", pretrained=True)

# instantiate your model with more options, for instance, passing weights of your own through timm
# instantiate your model with more options, for instance, passing weights from your own file
model = BACKBONE_REGISTRY.build(
"prithvi_vit_100", num_frames=1, pretrained=True, pretrained_cfg_overlay={"file": "<path to weights>"}
"prithvi_eo_v2_300", num_frames=1, ckpt_path='path/to/model.pt'
)
# Rest of your PyTorch / PyTorchLightning code

Expand Down Expand Up @@ -68,25 +67,25 @@ model_factory = EncoderDecoderFactory()
# Parameters prefixed with decoder_ get passed to the decoder
# Parameters prefixed with head_ get passed to the head

model = model_factory.build_model(task="segmentation",
backbone="prithvi_vit_100",
decoder="FCNDecoder",
backbone_bands=[
HLSBands.BLUE,
HLSBands.GREEN,
HLSBands.RED,
HLSBands.NIR_NARROW,
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
necks=[{"name": "SelectIndices", "indices": [-1]},
{"name": "ReshapeTokensToImage"}],
num_classes=4,
backbone_pretrained=True,
backbone_num_frames=1,
decoder_channels=128,
head_dropout=0.2
)
model = model_factory.build_model(
task="segmentation",
backbone="prithvi_eo_v2_300",
backbone_pretrained=True,
backbone_bands=[
HLSBands.BLUE,
HLSBands.GREEN,
HLSBands.RED,
HLSBands.NIR_NARROW,
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
necks=[{"name": "SelectIndices", "indices": [-1]},
{"name": "ReshapeTokensToImage"}],
decoder="FCNDecoder",
decoder_channels=128,
head_dropout=0.1,
num_classes=4,
)

# Rest of your PyTorch / PyTorchLightning code
.
Expand All @@ -102,8 +101,9 @@ At the highest level of abstraction, you can directly obtain a LightningModule r
```python title="Building a full Pixel-Wise Regression task"

model_args = dict(
backbone="prithvi_vit_100",
decoder="FCNDecoder",
backbone="prithvi_eo_v2_300",
backbone_pretrained=True,
backbone_num_frames=1,
backbone_bands=[
HLSBands.BLUE,
HLSBands.GREEN,
Expand All @@ -114,10 +114,9 @@ model_args = dict(
],
necks=[{"name": "SelectIndices", "indices": [-1]},
{"name": "ReshapeTokensToImage"}],
backbone_pretrained=True,
backbone_num_frames=1,
decoder="FCNDecoder",
decoder_channels=128,
head_dropout=0.2
head_dropout=0.1
)

task = PixelwiseRegressionTask(
Expand Down Expand Up @@ -175,54 +174,42 @@ data:
model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_factory: EncoderDecoderFactory
model_args:
decoder: UperNetDecoder
backbone: prithvi_eo_v2_300
backbone_img_size: 512
backbone_pretrained: True
backbone: prithvi_vit_100
backbone_pretrain_img_size: 512
decoder_scale_modules: True
decoder_channels: 256
backbone_in_channels: 6
backbone_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
num_classes: 2
head_dropout: 0.1
head_channel_list:
- 256
post_backbone_ops:
necks:
- name: SelectIndices
indices:
- 5
- 11
- 17
- 23
indices: [5, 11, 17, 23]
- name: ReshapeTokensToImage
loss: ce

- name: LearnedInterpolateToPyramidal
decoder: UperNetDecoder
decoder_channels: 256
head_channel_list: [256]
head_dropout: 0.1
num_classes: 2
loss: dice
ignore_index: -1
class_weights:
- 0.3
- 0.7
freeze_backbone: false
freeze_decoder: false
model_factory: EncoderDecoderFactory
freeze_decoder: false
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 6.e-5
weight_decay: 0.05
lr: 1.e-4
weight_decay: 0.1
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss


```

To run this training task using the YAML, simply execute:
Expand Down
18 changes: 9 additions & 9 deletions docs/registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@ To create the desired instance, registries expose a `build` method, which accept
from terratorch import BACKBONE_REGISTRY

# find available prithvi models
print([model_name for model_name in BACKBONE_REGISTRY if "prithvi" in model_name])
>>> ['timm_prithvi_swin_B', 'timm_prithvi_swin_L', 'timm_prithvi_vit_100', 'timm_prithvi_vit_300', 'timm_prithvi_vit_tiny']
print([model_name for model_name in BACKBONE_REGISTRY if "terratorch_prithvi" in model_name])
>>> ['terratorch_prithvi_eo_tiny', 'terratorch_prithvi_eo_v1_100', 'terratorch_prithvi_eo_v2_300', 'terratorch_prithvi_eo_v2_600', 'terratorch_prithvi_eo_v2_300_tl', 'terratorch_prithvi_eo_v2_600_tl']

# show all models with list(BACKBONE_REGISTRY)

# check a model is in the registry
"timm_prithvi_swin_B" in BACKBONE_REGISTRY
"terratorch_prithvi_eo_v2_300" in BACKBONE_REGISTRY
>>> True

# without the prefix, all internal registries will be searched until the first match is found
"prithvi_swin_B" in BACKBONE_REGISTRY
"prithvi_eo_v1_100" in BACKBONE_REGISTRY
>>> True

# instantiate your desired model
# the backbone registry prefix (in this case 'timm') is optional
# in this case, the underlying registry is timm, so we can pass timm arguments to it
model = BACKBONE_REGISTRY.build("prithvi_vit_100", num_frames=1, pretrained=True)
# the backbone registry prefix (e.g. `terratorch` or `timm`) is optional
# in this case, the underlying registry is terratorch.
model = BACKBONE_REGISTRY.build("prithvi_eo_v1_100", pretrained=True)

# instantiate your model with more options, for instance, passing weights of your own through timm
# instantiate your model with more options, for instance, passing weights from your own file
model = BACKBONE_REGISTRY.build(
"prithvi_vit_100", num_frames=1, pretrained=True, pretrained_cfg_overlay={"file": "<path to weights>"}
"prithvi_eo_v2_300", num_frames=1, ckpt_path='path/to/model.pt'
)
# Rest of your PyTorch / PyTorchLightning code

Expand Down
29 changes: 17 additions & 12 deletions examples/confs/burn_scars.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:

# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars
data:
class_path: GenericNonGeoSegmentationDataModule
class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule
init_args:
batch_size: 4
num_workers: 8
Expand All @@ -56,9 +56,7 @@ data:
init_args:
height: 224
width: 224
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.D4
- class_path: ToTensorV2
no_data_replace: 0
no_label_replace: -1
Expand Down Expand Up @@ -89,35 +87,42 @@ data:
model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_factory: EncoderDecoderFactory
model_args:
decoder: FCNDecoder
backbone: prithvi_eo_v2_300
backbone_pretrained: true
backbone: prithvi_vit_100
decoder_channels: 256
backbone_drop_path: 0.1
backbone_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_classes: 2
necks:
- name: SelectIndices
# indices: [2, 5, 8, 11] # 100M models
indices: [5, 11, 17, 23] # 300M models
# indices: [7, 15, 23, 31] # 600M models
- name: ReshapeTokensToImage
- name: LearnedInterpolateToPyramidal
decoder: UNetDecoder
decoder_channels: [512, 256, 128, 64]
head_channel_list: [256]
head_dropout: 0.1
decoder_num_convs: 4
head_channel_list:
- 256
num_classes: 2
loss: dice
plot_on_val: 10
ignore_index: -1
freeze_backbone: false
freeze_decoder: false
model_factory: EncoderDecoderFactory
tiled_inference_parameters:
h_crop: 512
h_stride: 496
w_crop: 512
w_stride: 496
average_patches: true

optimizer:
class_path: torch.optim.Adam
init_args:
Expand Down
2 changes: 1 addition & 1 deletion examples/confs/burnscars_smp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ trainer:
default_root_dir: output/BurnScars

data:
class_path: GenericNonGeoSegmentationDataModule
class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule
init_args:
batch_size: 4
num_workers: 8
Expand Down
2 changes: 1 addition & 1 deletion examples/confs/eurosat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ model:
model_args:
decoder: IdentityDecoder
backbone_pretrained: true
backbone: prithvi_vit_100
backbone: prithvi_eo_v1_300
head_dim_list:
- 384
- 128
Expand Down
2 changes: 1 addition & 1 deletion examples/confs/forestnet_timm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ trainer:
default_root_dir: output/ForestNet

data:
class_path: GenericNonGeoClassificationDataModule
class_path: terratorch.datamodules.GenericNonGeoClassificationDataModule
init_args:
batch_size: 16
num_workers: 8
Expand Down
2 changes: 1 addition & 1 deletion examples/confs/multi_temporal_crop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ model:
model_args:
decoder: FCNDecoder
backbone_pretrained: true
backbone: prithvi_vit_100
backbone: prithvi_eo_v2_300
backbone_in_channels: 6
rescale: False
backbone_bands:
Expand Down
2 changes: 1 addition & 1 deletion examples/confs/multimae_sen1floods11.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ trainer:
default_root_dir: output/multimae_sen1floods11/

data:
class_path: GenericMultiModalDataModule
class_path: terratorch.datamodules.GenericMultiModalDataModule
init_args:
task: 'segmentation'
batch_size: 4
Expand Down
16 changes: 4 additions & 12 deletions examples/confs/multimodal_prithvi_sen1floods11.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
default_root_dir: output/multimodal_prithvi_sen1floods11/

data:
class_path: GenericMultiModalDataModule
class_path: terratorch.datamodules.GenericMultiModalDataModule
init_args:
task: 'segmentation'
batch_size: 16
Expand Down Expand Up @@ -124,7 +124,7 @@ model:
init_args:
model_factory: EncoderDecoderFactory
model_args:
backbone: prithvi_vit_100
backbone: prithvi_eo_v2_300
backbone_pretrained: false
backbone_bands:
- COASTAL_AEROSOL
Expand All @@ -141,23 +141,15 @@ model:
- SWIR_2
- VV
- VH
decoder: FCNDecoder # FCNDecoder
decoder_num_convs: 4 # only for FCNDecoder
# decoder_scale_modules: True # only for UperNetDecoder
decoder: FCNDecoder
decoder_num_convs: 4
decoder_channels: 256
num_classes: 2
head_dropout: 0.1
head_channel_list:
- 256

loss: dice
ignore_index: -1
class_weights:
- 0.3
- 0.7
class_names:
- Others
- Flood
freeze_backbone: false
freeze_decoder: false

Expand Down
Loading

0 comments on commit b5c0c0e

Please sign in to comment.