Skip to content

Commit

Permalink
Merge branch 'main' into custom_modules
Browse files Browse the repository at this point in the history
  • Loading branch information
romeokienzler authored Dec 4, 2024
2 parents b4c602d + 23944e3 commit 55acdb9
Show file tree
Hide file tree
Showing 26 changed files with 5,043 additions and 130 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10","3.11"]
python-version: ["3.10","3.11","3.12"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
timeout-minutes: 20
strategy:
matrix:
python-version: ["3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- name: Clone repo
Expand Down
3 changes: 1 addition & 2 deletions docs/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ model_args = dict(
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
necks=[{"name": "SelectIndices", "indices": -1},
necks=[{"name": "SelectIndices", "indices": [-1]},
{"name": "ReshapeTokensToImage"}],
num_classes=4,
backbone_pretrained=True,
backbone_num_frames=1,
decoder_channels=128,
Expand Down
161 changes: 161 additions & 0 deletions examples/confs/multimae_sen1floods11.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: output
name: multimae_sen1floods11
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 40

max_epochs: 2
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: true
default_root_dir: output/multimae_sen1floods11/

data:
class_path: GenericMultiModalDataModule
init_args:
task: 'segmentation'
batch_size: 4
num_workers: 0
modalities:
- S2L2A
- S1
- LULC
rgb_modality: S2L2A # If not provided, uses first modality
rgb_indices:
- 3
- 2
- 1

train_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand
train_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
val_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand
val_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
test_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand
test_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand

train_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_train.txt
val_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_valid.txt
test_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_test.txt

allow_substring_file_names: True
image_grep:
S2L2A: "*_S2L2AHand.tif"
S1: "*_S1Hand.tif"
LULC: "*_LULCHand.npy"
label_grep: "*_LabelHand.tif"
no_label_replace: -1
no_data_replace: 0

means:
S2L2A:
- 1793.243
- 1924.863
- 2184.553
- 2340.936
- 2671.402
- 3240.082
- 3468.412
- 3563.244
- 3627.704
- 3711.071
- 3416.714
- 2849.625
S1:
- -12.577
- -20.265

stds:
S2L2A:
- 1160.144
- 1201.092
- 1219.943
- 1397.225
- 1400.035
- 1373.136
- 1429.17
- 1485.025
- 1447.836
- 1652.703
- 1471.002
- 1365.30
S1:
- 5.179
- 5.872

num_classes: 2

train_transform:
- class_path: albumentations.RandomCrop
init_args:
height: 224
width: 224
- class_path: albumentations.D4
- class_path: ToTensorV2


model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_factory: EncoderDecoderFactory
model_args:
backbone_pretrained: false
backbone: multimae_base
backbone_input_adapters:
- S1
- S2L2A
- LULC
decoder: FCNDecoder # UperNetDecoder
decoder_num_convs: 4 # only for FCNDecoder
# decoder_scale_modules: True # only for UperNetDecoder
decoder_channels: 256
num_classes: 2
head_dropout: 0.1
head_channel_list:
- 256
loss: ce
ignore_index: -1
class_weights:
- 0.3
- 0.7
class_names:
- Others
- Flood
freeze_backbone: false
freeze_decoder: false

optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 6.e-5
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

173 changes: 173 additions & 0 deletions examples/confs/multimodal_prithvi_sen1floods11.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: output
name: multimodal_prithvi_sen1floods11
version: test_best
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 40

max_epochs: 100
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: True
default_root_dir: output/multimodal_prithvi_sen1floods11/

data:
class_path: GenericMultiModalDataModule
init_args:
task: 'segmentation'
batch_size: 16
num_workers: 4
modalities: # Define names of modalities
- S2L2A
- S1
rgb_modality: S2L2A # If not provided, uses first modality
rgb_indices:
- 3
- 2
- 1

# Data roots are defined as dicts with modalities as keys
train_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
train_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
val_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
val_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
test_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
test_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand

train_split: data/sen1floods11/splits/splits/flood_handlabeled/flood_train_data.txt
val_split: data/sen1floods11/splits/splits/flood_handlabeled/flood_valid_data.txt
test_split: data/sen1floods11/splits/splits/flood_handlabeled/flood_test_data.txt

allow_substring_file_names: True
image_grep:
S2L2A: "*_S2L2AHand.tif"
S1: "*_S1Hand.tif"
label_grep: "*_LabelHand.tif"
no_label_replace: -1
no_data_replace: 0
concat_bands: true # Concatenate modalities along band dim for single-modal models like Prithvi

# Define standardization values as dicts (no scaling if modality is not included)
means:
S2L2A:
- 1793.243
- 1924.863
- 2184.553
- 2340.936
- 2671.402
- 3240.082
- 3468.412
- 3563.244
- 3627.704
- 3711.071
- 3416.714
- 2849.625
S1:
- -12.577
- -20.265

stds:
S2L2A:
- 1160.144
- 1201.092
- 1219.943
- 1397.225
- 1400.035
- 1373.136
- 1429.17
- 1485.025
- 1447.836
- 1652.703
- 1471.002
- 1365.30
S1:
- 5.179
- 5.872

num_classes: 2

# Transforms are shared between all image modalities (e.g. same crop area)
train_transform:
- class_path: albumentations.RandomCrop
init_args:
height: 224
width: 224
- class_path: albumentations.D4
- class_path: ToTensorV2


model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_factory: EncoderDecoderFactory
model_args:
backbone: prithvi_vit_100
backbone_pretrained: false
backbone_bands:
- COASTAL_AEROSOL
- BLUE
- GREEN
- RED
- RED_EDGE_1
- RED_EDGE_2
- RED_EDGE_3
- NIR_BROAD
- NIR_NARROW
- CIRRUS
- SWIR_1
- SWIR_2
- VV
- VH
decoder: FCNDecoder # FCNDecoder
decoder_num_convs: 4 # only for FCNDecoder
# decoder_scale_modules: True # only for UperNetDecoder
decoder_channels: 256
num_classes: 2
head_dropout: 0.1
head_channel_list:
- 256

loss: dice
ignore_index: -1
class_weights:
- 0.3
- 0.7
class_names:
- Others
- Flood
freeze_backbone: false
freeze_decoder: false

optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 6.e-5
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

Loading

0 comments on commit 55acdb9

Please sign in to comment.