Skip to content

Commit

Permalink
merges with main
Browse files Browse the repository at this point in the history
  • Loading branch information
PedroConrado committed Dec 4, 2024
2 parents 5f72df6 + 7586c66 commit 1964f78
Show file tree
Hide file tree
Showing 28 changed files with 5,136 additions and 32 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ venv/*
examples/notebooks/config.yaml
examples/notebooks/wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc
tests/all_ecos_random/*
examples/**/*tif*
**/climatology/*
**/lightning_logs/*
**/merra-2/*
Expand Down
3 changes: 1 addition & 2 deletions docs/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ model_args = dict(
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
necks=[{"name": "SelectIndices", "indices": -1},
necks=[{"name": "SelectIndices", "indices": [-1]},
{"name": "ReshapeTokensToImage"}],
num_classes=4,
backbone_pretrained=True,
backbone_num_frames=1,
decoder_channels=128,
Expand Down
161 changes: 161 additions & 0 deletions examples/confs/multimae_sen1floods11.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: output
name: multimae_sen1floods11
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 40

max_epochs: 2
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: true
default_root_dir: output/multimae_sen1floods11/

data:
class_path: GenericMultiModalDataModule
init_args:
task: 'segmentation'
batch_size: 4
num_workers: 0
modalities:
- S2L2A
- S1
- LULC
rgb_modality: S2L2A # If not provided, uses first modality
rgb_indices:
- 3
- 2
- 1

train_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand
train_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
val_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand
val_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
test_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand
test_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand

train_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_train.txt
val_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_valid.txt
test_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_test.txt

allow_substring_file_names: True
image_grep:
S2L2A: "*_S2L2AHand.tif"
S1: "*_S1Hand.tif"
LULC: "*_LULCHand.npy"
label_grep: "*_LabelHand.tif"
no_label_replace: -1
no_data_replace: 0

means:
S2L2A:
- 1793.243
- 1924.863
- 2184.553
- 2340.936
- 2671.402
- 3240.082
- 3468.412
- 3563.244
- 3627.704
- 3711.071
- 3416.714
- 2849.625
S1:
- -12.577
- -20.265

stds:
S2L2A:
- 1160.144
- 1201.092
- 1219.943
- 1397.225
- 1400.035
- 1373.136
- 1429.17
- 1485.025
- 1447.836
- 1652.703
- 1471.002
- 1365.30
S1:
- 5.179
- 5.872

num_classes: 2

train_transform:
- class_path: albumentations.RandomCrop
init_args:
height: 224
width: 224
- class_path: albumentations.D4
- class_path: ToTensorV2


model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_factory: EncoderDecoderFactory
model_args:
backbone_pretrained: false
backbone: multimae_base
backbone_input_adapters:
- S1
- S2L2A
- LULC
decoder: FCNDecoder # UperNetDecoder
decoder_num_convs: 4 # only for FCNDecoder
# decoder_scale_modules: True # only for UperNetDecoder
decoder_channels: 256
num_classes: 2
head_dropout: 0.1
head_channel_list:
- 256
loss: ce
ignore_index: -1
class_weights:
- 0.3
- 0.7
class_names:
- Others
- Flood
freeze_backbone: false
freeze_decoder: false

optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 6.e-5
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

173 changes: 173 additions & 0 deletions examples/confs/multimodal_prithvi_sen1floods11.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: output
name: multimodal_prithvi_sen1floods11
version: test_best
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 40

max_epochs: 100
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: True
default_root_dir: output/multimodal_prithvi_sen1floods11/

data:
class_path: GenericMultiModalDataModule
init_args:
task: 'segmentation'
batch_size: 16
num_workers: 4
modalities: # Define names of modalities
- S2L2A
- S1
rgb_modality: S2L2A # If not provided, uses first modality
rgb_indices:
- 3
- 2
- 1

# Data roots are defined as dicts with modalities as keys
train_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
train_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
val_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
val_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
test_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
test_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand

train_split: data/sen1floods11/splits/splits/flood_handlabeled/flood_train_data.txt
val_split: data/sen1floods11/splits/splits/flood_handlabeled/flood_valid_data.txt
test_split: data/sen1floods11/splits/splits/flood_handlabeled/flood_test_data.txt

allow_substring_file_names: True
image_grep:
S2L2A: "*_S2L2AHand.tif"
S1: "*_S1Hand.tif"
label_grep: "*_LabelHand.tif"
no_label_replace: -1
no_data_replace: 0
concat_bands: true # Concatenate modalities along band dim for single-modal models like Prithvi

# Define standardization values as dicts (no scaling if modality is not included)
means:
S2L2A:
- 1793.243
- 1924.863
- 2184.553
- 2340.936
- 2671.402
- 3240.082
- 3468.412
- 3563.244
- 3627.704
- 3711.071
- 3416.714
- 2849.625
S1:
- -12.577
- -20.265

stds:
S2L2A:
- 1160.144
- 1201.092
- 1219.943
- 1397.225
- 1400.035
- 1373.136
- 1429.17
- 1485.025
- 1447.836
- 1652.703
- 1471.002
- 1365.30
S1:
- 5.179
- 5.872

num_classes: 2

# Transforms are shared between all image modalities (e.g. same crop area)
train_transform:
- class_path: albumentations.RandomCrop
init_args:
height: 224
width: 224
- class_path: albumentations.D4
- class_path: ToTensorV2


model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_factory: EncoderDecoderFactory
model_args:
backbone: prithvi_vit_100
backbone_pretrained: false
backbone_bands:
- COASTAL_AEROSOL
- BLUE
- GREEN
- RED
- RED_EDGE_1
- RED_EDGE_2
- RED_EDGE_3
- NIR_BROAD
- NIR_NARROW
- CIRRUS
- SWIR_1
- SWIR_2
- VV
- VH
decoder: FCNDecoder # FCNDecoder
decoder_num_convs: 4 # only for FCNDecoder
# decoder_scale_modules: True # only for UperNetDecoder
decoder_channels: 256
num_classes: 2
head_dropout: 0.1
head_channel_list:
- 256

loss: dice
ignore_index: -1
class_weights:
- 0.3
- 0.7
class_names:
- Others
- Flood
freeze_backbone: false
freeze_decoder: false

optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 6.e-5
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

18 changes: 11 additions & 7 deletions examples/scripts/WxCTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,17 @@
#


config = get_config('../confs/granite-wxc-merra2-downscale-config.yaml')
config.download_path = './'
config = get_config('../confs/granite-wxc-merra2-downscale-large-config.yaml')
download_path = os.getcwd()
config.download_path = download_path

config.data.data_path_surface = os.path.join(config.download_path,'merra-2')
config.data.data_path_vertical = os.path.join(config.download_path, 'merra-2')
config.data.climatology_path_surface = os.path.join(config.download_path,'climatology')
config.data.climatology_path_vertical = os.path.join(config.download_path,'climatology')
config.data.data_path_surface = os.path.join(download_path,'merra-2')
config.data.data_path_vertical = os.path.join(download_path, 'merra-2')
config.data.climatology_path_surface = os.path.join(download_path,'climatology')
config.data.climatology_path_vertical = os.path.join(download_path,'climatology')

extra_kwargs = config.model.init_args["extra_kwargs"]
model_args = config.model.init_args["model_args"]

config.model.input_scalers_surface_path = os.path.join(config.download_path,'climatology/musigma_surface.nc')
config.model.input_scalers_vertical_path = os.path.join(config.download_path,'climatology/musigma_vertical.nc')
Expand Down Expand Up @@ -130,7 +134,7 @@

print("This is our config:")

task = WxCDownscalingTask(model_args = {}, model_factory = 'WxCModelFactory', model_config=config, optimizer='AdamW', optimizer_hparams={'weight_decay': 0.05})
task = WxCDownscalingTask(model_args = model_args, model_factory = 'WxCModelFactory',extra_kwargs=extra_kwargs, model_config=config, optimizer='AdamW', optimizer_hparams={'weight_decay': 0.05})


#
Expand Down
Loading

0 comments on commit 1964f78

Please sign in to comment.