diff --git a/examples/confs/wxc-gravity-wave-ccc.yaml b/examples/confs/wxc-gravity-wave-ccc.yaml new file mode 100644 index 00000000..52a4af3a --- /dev/null +++ b/examples/confs/wxc-gravity-wave-ccc.yaml @@ -0,0 +1,85 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: + name: fire_scars + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 40 + + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: + +# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars +data: + class_path: terratorch.datamodules.era5.ERA5DataModule + init_args: + train_data_path: /dccstor/terratorch/users/rkie/gitco/terratorch + valid_data_path: /dccstor/terratorch/users/rkie/gitco/terratorch + file_glob_pattern: "wxc_input_u_v_t_p_output_theta_uw_vw_*.nc" + +model: + class_path: WxCTask + init_args: + model_args: + in_channels: 1280 + input_size_time: 1 + n_lats_px: 64 + n_lons_px: 128 + patch_size_px: [2, 2] + mask_unit_size_px: [8, 16] + mask_ratio_inputs: 0.5 + embed_dim: 2560 + n_blocks_encoder: 12 + n_blocks_decoder: 2 + mlp_multiplier: 4 + n_heads: 16 + dropout: 0.0 + drop_path: 0.05 + parameter_dropout: 0.0 + residual: none + masking_mode: both + decoder_shifting: False + positional_encoding: absolute + checkpoint_encoder: [3, 6, 9, 12, 15, 18, 21, 24] + checkpoint_decoder: [1, 3] + in_channels_static: 3 + input_scalers_mu: torch.tensor([0] * 1280) + input_scalers_sigma: torch.tensor([1] * 1280) + input_scalers_epsilon: 0 + static_input_scalers_mu: torch.tensor([0] * 3) + static_input_scalers_sigma: torch.tensor([1] * 3) + static_input_scalers_epsilon: 0 + output_scalers: torch.tensor([0] * 1280) + backbone_weights: magnet-flux-uvtp122-epoch-99-loss-0.1022.pt + backbone: prithviwxc + aux_decoders: unetpincer + skip_connection: True + model_factory: WxCModelFactory + mode: eval +optimizer: + class_path: torch.optim.Adam + init_args: + lr: 1.5e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss diff --git a/examples/confs/wxc-gravity-wave.yaml b/examples/confs/wxc-gravity-wave.yaml new file mode 100644 index 00000000..b504df0e --- /dev/null +++ b/examples/confs/wxc-gravity-wave.yaml @@ -0,0 +1,81 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: + name: fire_scars + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 40 + + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: + +# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars +data: + class_path: terratorch.datamodules.era5.ERA5DataModule + +model: + class_path: WxCTask + init_args: + model_args: + in_channels: 1280 + input_size_time: 1 + n_lats_px: 64 + n_lons_px: 128 + patch_size_px: [2, 2] + mask_unit_size_px: [8, 16] + mask_ratio_inputs: 0.5 + embed_dim: 2560 + n_blocks_encoder: 12 + n_blocks_decoder: 2 + mlp_multiplier: 4 + n_heads: 16 + dropout: 0.0 + drop_path: 0.05 + parameter_dropout: 0.0 + residual: none + masking_mode: both + decoder_shifting: False + positional_encoding: absolute + checkpoint_encoder: [3, 6, 9, 12, 15, 18, 21, 24] + checkpoint_decoder: [1, 3] + in_channels_static: 3 + input_scalers_mu: torch.tensor([0] * 1280) + input_scalers_sigma: torch.tensor([1] * 1280) + input_scalers_epsilon: 0 + static_input_scalers_mu: torch.tensor([0] * 3) + static_input_scalers_sigma: torch.tensor([1] * 3) + static_input_scalers_epsilon: 0 + output_scalers: torch.tensor([0] * 1280) + backbone_weights: magnet-flux-uvtp122-epoch-99-loss-0.1022.pt + backbone: prithviwxc + aux_decoders: unetpincer + skip_connection: True + model_factory: WxCModelFactory + mode: eval +optimizer: + class_path: torch.optim.Adam + init_args: + lr: 1.5e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index a00f2ad3..21f22c40 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -20,6 +20,9 @@ import rasterio import torch +import random +import string + # Allows classes to be referenced using only the class name import torchgeo.datamodules import yaml @@ -153,10 +156,15 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batc if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) - pred_batch, filename_batch = prediction - - for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False): - save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype) + if isinstance(prediction, torch.Tensor): + filename_batch = ''.join(random.choices(string.ascii_letters + string.digits, k=8)) + torch.save(prediction, os.path.join(output_dir, f"{filename_batch}.pt")) + elif isinstance(prediction, tuple): + pred_batch, filename_batch = prediction + for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False): + save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype) + else: + raise TypeError(f"Unknown type for prediction{type(prediction)}") def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # noqa: ARG002 # this will create N (num processes) files in `output_dir` each containing diff --git a/terratorch/datamodules/__init__.py b/terratorch/datamodules/__init__.py index f97c75fe..bef29be2 100644 --- a/terratorch/datamodules/__init__.py +++ b/terratorch/datamodules/__init__.py @@ -27,6 +27,7 @@ from terratorch.datamodules.multi_temporal_crop_classification import MultiTemporalCropClassificationDataModule from terratorch.datamodules.open_sentinel_map import OpenSentinelMapDataModule from terratorch.datamodules.pastis import PASTISDataModule +from terratorch.datamodules.era5 import ERA5DataModule try: wxc_present = True diff --git a/terratorch/models/wxc_model_factory.py b/terratorch/models/wxc_model_factory.py index f446509a..f8d6fd19 100644 --- a/terratorch/models/wxc_model_factory.py +++ b/terratorch/models/wxc_model_factory.py @@ -61,6 +61,7 @@ def build_model( raise #remove parameters not meant for the backbone but for other parts of the model + logger.trace(kwargs) skip_connection = kwargs.pop('skip_connection') backbone = prithviwxc.PrithviWxC(**kwargs) diff --git a/terratorch/tasks/__init__.py b/terratorch/tasks/__init__.py index 790c10ec..782b0f08 100644 --- a/terratorch/tasks/__init__.py +++ b/terratorch/tasks/__init__.py @@ -1,3 +1,4 @@ +import logging from terratorch.tasks.classification_tasks import ClassificationTask from terratorch.tasks.regression_tasks import PixelwiseRegressionTask from terratorch.tasks.segmentation_tasks import SemanticSegmentationTask @@ -6,6 +7,8 @@ try: wxc_present = True from terratorch.tasks.wxc_downscaling_task import WxCDownscalingTask + from terratorch.tasks.wxc_task import WxCTask + logging.getLogger('terratorch').debug('wxc_downscaling found.') except ImportError as e: import logging logging.getLogger('terratorch').debug('wxc_downscaling not installed') @@ -21,4 +24,4 @@ ) if wxc_present: - __all__.__add__(("WxCDownscalingTask", )) + __all__.__add__(("WxCDownscalingTask", "WxCTask",)) diff --git a/terratorch/tasks/wxc_task.py b/terratorch/tasks/wxc_task.py index 87312a9a..3e4b4421 100644 --- a/terratorch/tasks/wxc_task.py +++ b/terratorch/tasks/wxc_task.py @@ -1,17 +1,19 @@ - - from torchgeo.trainers import BaseTask import torch.nn as nn import torch import logging logger = logging.getLogger(__name__) +from terratorch.registry import MODEL_FACTORY_REGISTRY + class WxCTask(BaseTask): - def __init__(self, model_factory, model_args: dict, mode, learning_rate=0.1): + def __init__(self, model_factory, model_args: dict, mode:str='train', learning_rate=0.1): if mode not in ['train', 'eval']: raise ValueError(f'mode {mode} is not supported. (train, eval)') self.model_args = model_args - self.model_factory = model_factory + + self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) + self.learning_rate = learning_rate super().__init__() @@ -34,4 +36,4 @@ def training_step(self, batch, batch_idx): def train_dataloader(self): return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True) - \ No newline at end of file +