Skip to content

Commit

Permalink
Merge pull request #380 from IBM/romeo_201
Browse files Browse the repository at this point in the history
add cli support for wxc gravity wave
  • Loading branch information
Joao-L-S-Almeida authored Jan 27, 2025
2 parents e1d9ddb + 6bc6293 commit 375cc4a
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 10 deletions.
85 changes: 85 additions & 0 deletions examples/confs/wxc-gravity-wave-ccc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: <path>
name: fire_scars
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 40

max_epochs: 200
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: true
default_root_dir: <path>

# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars
data:
class_path: terratorch.datamodules.era5.ERA5DataModule
init_args:
train_data_path: /dccstor/terratorch/users/rkie/gitco/terratorch
valid_data_path: /dccstor/terratorch/users/rkie/gitco/terratorch
file_glob_pattern: "wxc_input_u_v_t_p_output_theta_uw_vw_*.nc"

model:
class_path: WxCTask
init_args:
model_args:
in_channels: 1280
input_size_time: 1
n_lats_px: 64
n_lons_px: 128
patch_size_px: [2, 2]
mask_unit_size_px: [8, 16]
mask_ratio_inputs: 0.5
embed_dim: 2560
n_blocks_encoder: 12
n_blocks_decoder: 2
mlp_multiplier: 4
n_heads: 16
dropout: 0.0
drop_path: 0.05
parameter_dropout: 0.0
residual: none
masking_mode: both
decoder_shifting: False
positional_encoding: absolute
checkpoint_encoder: [3, 6, 9, 12, 15, 18, 21, 24]
checkpoint_decoder: [1, 3]
in_channels_static: 3
input_scalers_mu: torch.tensor([0] * 1280)
input_scalers_sigma: torch.tensor([1] * 1280)
input_scalers_epsilon: 0
static_input_scalers_mu: torch.tensor([0] * 3)
static_input_scalers_sigma: torch.tensor([1] * 3)
static_input_scalers_epsilon: 0
output_scalers: torch.tensor([0] * 1280)
backbone_weights: magnet-flux-uvtp122-epoch-99-loss-0.1022.pt
backbone: prithviwxc
aux_decoders: unetpincer
skip_connection: True
model_factory: WxCModelFactory
mode: eval
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 1.5e-5
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss
81 changes: 81 additions & 0 deletions examples/confs/wxc-gravity-wave.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: <path>
name: fire_scars
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 40

max_epochs: 200
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: true
default_root_dir: <path>

# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars
data:
class_path: terratorch.datamodules.era5.ERA5DataModule

model:
class_path: WxCTask
init_args:
model_args:
in_channels: 1280
input_size_time: 1
n_lats_px: 64
n_lons_px: 128
patch_size_px: [2, 2]
mask_unit_size_px: [8, 16]
mask_ratio_inputs: 0.5
embed_dim: 2560
n_blocks_encoder: 12
n_blocks_decoder: 2
mlp_multiplier: 4
n_heads: 16
dropout: 0.0
drop_path: 0.05
parameter_dropout: 0.0
residual: none
masking_mode: both
decoder_shifting: False
positional_encoding: absolute
checkpoint_encoder: [3, 6, 9, 12, 15, 18, 21, 24]
checkpoint_decoder: [1, 3]
in_channels_static: 3
input_scalers_mu: torch.tensor([0] * 1280)
input_scalers_sigma: torch.tensor([1] * 1280)
input_scalers_epsilon: 0
static_input_scalers_mu: torch.tensor([0] * 3)
static_input_scalers_sigma: torch.tensor([1] * 3)
static_input_scalers_epsilon: 0
output_scalers: torch.tensor([0] * 1280)
backbone_weights: magnet-flux-uvtp122-epoch-99-loss-0.1022.pt
backbone: prithviwxc
aux_decoders: unetpincer
skip_connection: True
model_factory: WxCModelFactory
mode: eval
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 1.5e-5
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss
16 changes: 12 additions & 4 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import rasterio
import torch

import random
import string

# Allows classes to be referenced using only the class name
import torchgeo.datamodules
import yaml
Expand Down Expand Up @@ -153,10 +156,15 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batc
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)

pred_batch, filename_batch = prediction

for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False):
save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype)
if isinstance(prediction, torch.Tensor):
filename_batch = ''.join(random.choices(string.ascii_letters + string.digits, k=8))
torch.save(prediction, os.path.join(output_dir, f"{filename_batch}.pt"))
elif isinstance(prediction, tuple):
pred_batch, filename_batch = prediction
for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False):
save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype)
else:
raise TypeError(f"Unknown type for prediction{type(prediction)}")

def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # noqa: ARG002
# this will create N (num processes) files in `output_dir` each containing
Expand Down
1 change: 1 addition & 0 deletions terratorch/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from terratorch.datamodules.multi_temporal_crop_classification import MultiTemporalCropClassificationDataModule
from terratorch.datamodules.open_sentinel_map import OpenSentinelMapDataModule
from terratorch.datamodules.pastis import PASTISDataModule
from terratorch.datamodules.era5 import ERA5DataModule

try:
wxc_present = True
Expand Down
1 change: 1 addition & 0 deletions terratorch/models/wxc_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def build_model(
raise

#remove parameters not meant for the backbone but for other parts of the model
logger.trace(kwargs)
skip_connection = kwargs.pop('skip_connection')

backbone = prithviwxc.PrithviWxC(**kwargs)
Expand Down
5 changes: 4 additions & 1 deletion terratorch/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from terratorch.tasks.classification_tasks import ClassificationTask
from terratorch.tasks.regression_tasks import PixelwiseRegressionTask
from terratorch.tasks.segmentation_tasks import SemanticSegmentationTask
Expand All @@ -6,6 +7,8 @@
try:
wxc_present = True
from terratorch.tasks.wxc_downscaling_task import WxCDownscalingTask
from terratorch.tasks.wxc_task import WxCTask
logging.getLogger('terratorch').debug('wxc_downscaling found.')
except ImportError as e:
import logging
logging.getLogger('terratorch').debug('wxc_downscaling not installed')
Expand All @@ -21,4 +24,4 @@
)

if wxc_present:
__all__.__add__(("WxCDownscalingTask", ))
__all__.__add__(("WxCDownscalingTask", "WxCTask",))
12 changes: 7 additions & 5 deletions terratorch/tasks/wxc_task.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@


from torchgeo.trainers import BaseTask
import torch.nn as nn
import torch
import logging
logger = logging.getLogger(__name__)

from terratorch.registry import MODEL_FACTORY_REGISTRY

class WxCTask(BaseTask):
def __init__(self, model_factory, model_args: dict, mode, learning_rate=0.1):
def __init__(self, model_factory, model_args: dict, mode:str='train', learning_rate=0.1):
if mode not in ['train', 'eval']:
raise ValueError(f'mode {mode} is not supported. (train, eval)')
self.model_args = model_args
self.model_factory = model_factory

self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

self.learning_rate = learning_rate
super().__init__()

Expand All @@ -34,4 +36,4 @@ def training_step(self, batch, batch_idx):

def train_dataloader(self):
return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)


0 comments on commit 375cc4a

Please sign in to comment.