diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml index 96812e26..eb7c3c8d 100644 --- a/.github/dependabot.yaml +++ b/.github/dependabot.yaml @@ -10,6 +10,7 @@ updates: directory: "/" schedule: interval: "daily" + target-branch: "maintenance" groups: # torchvision pins torch, must update in unison torch: @@ -21,4 +22,4 @@ updates: - dependency-name: "setuptools" update-types: ["version-update:semver-patch"] # segmentation-models-pytorch pins timm, must update in unison - - dependency-name: "timm" \ No newline at end of file + - dependency-name: "timm" diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 2d30dfd0..51802bed 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10","3.11"] + python-version: ["3.10","3.11","3.12"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 9ae3fdab..7a1538fb 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -12,32 +12,10 @@ on: jobs: build: runs-on: ubuntu-latest + timeout-minutes: 20 strategy: matrix: - python-version: ["3.10", "3.11"] - - steps: - - name: Clone repo - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements/required.txt -r requirements/test.txt -r requirements/optional.txt - - name: List pip dependencies - run: pip list - - name: Test with pytest - run: | - pytest -s tests - build_weather: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.11"] + python-version: ["3.10", "3.11", "3.12"] steps: - name: Clone repo @@ -58,3 +36,30 @@ jobs: - name: Test with pytest run: | pytest -s tests + +# build_weather: +# runs-on: ubuntu-latest +# strategy: +# matrix: +# python-version: ["3.11"] + +# steps: +# - name: Clone repo +# uses: actions/checkout@v3 +# - name: Set up Python ${{ matrix.python-version }} +# uses: actions/setup-python@v4 +# with: +# python-version: ${{ matrix.python-version }} +# cache: 'pip' +# - name: Install dependencies +# run: | +# python -m pip install --upgrade pip +# pip install -r requirements/required.txt -r requirements/test.txt -r requirements/optional.txt +# pip install git+https://github.com/NASA-IMPACT/Prithvi-WxC.git +# pip install git+https://github.com/IBM/granite-wxc.git +# - name: List pip dependencies +# run: pip list +# - name: Test with pytest +# run: | +# pytest -s tests + diff --git a/.gitignore b/.gitignore index 4f1c0356..a88104df 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,12 @@ dist/* **/*pth .venv/* venv/* +examples/notebooks/config.yaml +examples/notebooks/wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc tests/all_ecos_random/* +**/climatology/* +**/lightning_logs/* +**/merra-2/* +**/*.bin +*.stdout +*.log diff --git a/examples/confs/granite-wxc-merra2-downscale-large-config.yaml b/examples/confs/granite-wxc-merra2-downscale-large-config.yaml new file mode 100644 index 00000000..b78aab09 --- /dev/null +++ b/examples/confs/granite-wxc-merra2-downscale-large-config.yaml @@ -0,0 +1,245 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: True # will use tensorboardlogger + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: examples/ + #num_epochs: 400 + #dl_num_workers: 19 + #dl_prefetch_size: 1 + #learning_rate: 0.0001 + #limit_steps_train: 250 + #limit_steps_valid: 25 + #min_lr: 0.00001 + #max_lr: 0.0002 + #warm_up_steps: 0 + #mask_unit_size: + # - 15 + # - 16 + #mask_ratio_inputs: 0.0 + #mask_ratio_targets: 0.0 + #max_batch_size: 16 + + #path_experiment: experiment + #loss: rmse + #freeze_backbone: True + #freeze_decoder: False + #backbone_prefix: encoder. + #finetune_w_static: True + #strict_matching: true +data: + class_path: terratorch.datamodules.Merra2DownscaleNonGeoDataModule + init_args: + transforms_fn: granitewxc.utils.data._get_transforms + # Input variables definition + input_surface_vars: + - EFLUX + - GWETROOT + - HFLUX + - LAI + - LWGAB # surface absorbed longwave radiation + - LWGEM # longwave flux emitted from surface + - LWTUP # upwelling longwave flux at toa + - PS # surface pressure + - QV2M # 2-meter specific humidity + - SLP # sea level pressure + - SWGNT # surface net downward shortwave flux + - SWTNT # toa net downward shortwave flux + - T2M # near surface temperature + - TQI # total precipitable ice water + - TQL # total precipitable liquid water + - TQV # total precipitable water vapor + - TS # surface skin temperature + - U10M # 10m eastward wind + - V10M # 10m northward wind + - Z0M # surface roughness + input_static_surface_vars: [FRACI, FRLAND, FROCEAN, PHIS] + input_vertical_vars: + - CLOUD # cloud feraction for radiation + - H # geopotential/ mid layer heights + - OMEGA # vertical pressure velocity + - PL # mid level pressure + - QI # mass fraction of clous ice water + - QL # mass fraction of cloud liquid water + - QV # specific humidity + - T # tempertaure + - U # eastward wind + - V # northward wind + # (model level/ml ~ pressure level/hPa) + # 52ml ~ 562.5hPa, 56ml ~ 700hPa, 63 ml ~ 850hPa + input_levels: [34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 53.0, 56.0, 63.0, 68.0, 72.0] + ## remove: n_input_timestamps: 1 + # Output variables definition + output_vars: + - T2M # near surface temperature + + n_input_timestamps: 2 + + # Data transformations + # Initial crop before any other processing + crop_lat: [0, 1] + # crop_lon: [0, 0] + # coarsening of target -- applied after crop + input_size_lat: 60 # 6x coarsening + input_size_lon: 96 # 6x coarsening + apply_smoothen: True + dict_kwargs: + data_path_surface: examples/scripts/merra-2/ + data_path_vertical: examples/scripts/merra-2/ + climatology_path_surface: examples/scripts/climatology/ + climatology_path_vertical: examples/scripts/climatology/ + time_range: + - '2020-01-01' + - '2020-01-02' +model: + class_path: WxCDownscalingTask + init_args: + model_args: + checkpoint_path: examples/pytorch_model.bin + num_static_channels: 7 + embed_dim: 2560 + token_size: + - 1 + - 1 + n_blocks_encoder: 12 + mlp_multiplier: 4 + n_heads: 16 + dropout_rate: 0.0 + drop_path: 0.05 + mask_unit_size: + - 15 + - 16 + residual_connection: True + model_factory: WxCModelFactory + extra_kwargs: + encoder_decoder_kernel_size_per_stage: [[3], [3]] # Optional, default = 3 for conv_tanspose [[3], [2]] + output_vars: + - T2M # near surface temperature + type: merra2 + input_scalers_surface_path: examples/scripts/climatology/musigma_surface.nc + input_scalers_vertical_path: examples/scripts/climatology/musigma_vertical.nc + output_scalers_surface_path: examples/scripts/climatology/anomaly_variance_surface.nc + output_scalers_vertical_path: examples/scripts/climatology/anomaly_variance_vertical.nc + input_levels: [34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 53.0, 56.0, 63.0, 68.0, 72.0] + downscaling_patch_size: [2, 2] + n_input_timestamps: 2 + downscaling_embed_dim: 256 + encoder_decoder_conv_channels: 128 + encoder_decoder_scale_per_stage: [[2], [3]] # First list determines before/after backbone + encoder_decoder_upsampling_mode: pixel_shuffle # ['nearest', 'bilinear', 'pixel_shuffle', 'conv_transpose'] + encoder_shift: False + drop_path: 0.05 + encoder_decoder_type: 'conv' # ['conv', 'transformer'] + input_size_lat: 60 # 6x coarsening + input_size_lon: 96 # 6x coarsening + freeze_backbone: True + freeze_decoder: False + data_path_surface: examples/scripts/merra-2/ + data_path_vertical: examples/scripts/merra-2/ + climatology_path_surface: examples/scripts/climatology/ + climatology_path_vertical: examples/scripts/climatology/ + residual: climate + model_config: + num_epochs: 200 + limit_steps_train: 250 + limit_steps_valid: 50 + batch_size: 1 #16 + learning_rate: 0.0001 + min_lr: 0.00001 + dl_num_workers: 19 + dl_prefetch_size: 1 + path_experiment: experiment/ + warm_up_steps: 0 + mask_ratio_inputs: 0.0 + mask_ratio_targets: 0.0 # Accepted values: temporal, climate, none + job_id: inference-test + model_config: + num_static_channels: 7 + embed_dim: 2560 + token_size: + - 1 + - 1 + n_blocks_encoder: 12 + mlp_multiplier: 4 + n_heads: 16 + dropout_rate: 0.0 + #drop_path: 0.05 + residual: True + data_config: + surface_vars: + - EFLUX + - GWETROOT + - HFLUX + - LAI + - LWGAB # surface absorbed longwave radiation + - LWGEM # longwave flux emitted from surface + - LWTUP # upwelling longwave flux at toa + - PS # surface pressure + - QV2M # 2-meter specific humidity + - SLP # sea level pressure + - SWGNT # surface net downward shortwave flux + - SWTNT # toa net downward shortwave flux + - T2M # near surface temperature + - TQI # total precipitable ice water + - TQL # total precipitable liquid water + - TQV # total precipitable water vapor + - TS # surface skin temperature + - U10M # 10m eastward wind + - V10M # 10m northward wind + - Z0M # surface roughness + static_surface_vars: [FRACI, FRLAND, FROCEAN, PHIS] + vertical_vars: + - CLOUD # cloud feraction for radiation + - H # geopotential/ mid layer heights + - OMEGA # vertical pressure velocity + - PL # mid level pressure + - QI # mass fraction of clous ice water + - QL # mass fraction of cloud liquid water + - QV # specific humidity + - T # tempertaure + - U # eastward wind + - V # northward wind + # (model level/ml ~ pressure level/hPa) + # 52ml ~ 562.5hPa, 56ml ~ 700hPa, 63 ml ~ 850hPa + # levels: [34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 53.0, 56.0, 63.0, 68.0, 72.0] + ## remove: n_input_aimestamps: 1 + # Output variables definition + #residual_connection: True + #encoder_shift: False + + #downscaling_patch_size: [2, 2] + #downscaling_embed_dim: 256 + #encoder_decoder_type: 'conv' # ['conv', 'transformer'] + #encoder_decoder_upsampling_mode: pixel_shuffle # ['nearest', 'bilinear', 'pixel_shuffle', 'conv_transpose'] + #encoder_decoder_kernel_size_per_stage: [[3], [3]] # Optional, default = 3 for conv_tanspose [[3], [2]] + #encoder_decoder_scale_per_stage: [[2], [3]] # First list determines before/after backbone + #encoder_decoder_conv_channels: 128 + #freeze_backbone: True + #freeze_decoder: False + #ignore_index: -1 + #loss: rmse +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 6.e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + + + diff --git a/examples/confs/granite-wxc-merra2-downscale-small-config.yaml b/examples/confs/granite-wxc-merra2-downscale-small-config.yaml index a60d0fe2..171bfa02 100644 --- a/examples/confs/granite-wxc-merra2-downscale-small-config.yaml +++ b/examples/confs/granite-wxc-merra2-downscale-small-config.yaml @@ -1,112 +1,244 @@ -data: - type: merra2 - - # Input variables definition - input_surface_vars: - - EFLUX - - GWETROOT - - HFLUX - - LAI - - LWGAB # surface absorbed longwave radiation - - LWGEM # longwave flux emitted from surface - - LWTUP # upwelling longwave flux at toa - - PS # surface pressure - - QV2M # 2-meter specific humidity - - SLP # sea level pressure - - SWGNT # surface net downward shortwave flux - - SWTNT # toa net downward shortwave flux - - T2M # near surface temperature - - TQI # total precipitable ice water - - TQL # total precipitable liquid water - - TQV # total precipitable water vapor - - TS # surface skin temperature - - U10M # 10m eastward wind - - V10M # 10m northward wind - - Z0M # surface roughness - input_static_surface_vars: [FRACI, FRLAND, FROCEAN, PHIS] - input_vertical_vars: - - CLOUD # cloud feraction for radiation - - H # geopotential/ mid layer heights - - OMEGA # vertical pressure velocity - - PL # mid level pressure - - QI # mass fraction of clous ice water - - QL # mass fraction of cloud liquid water - - QV # specific humidity - - T # tempertaure - - U # eastward wind - - V # northward wind - # (model level/ml ~ pressure level/hPa) - # 52ml ~ 562.5hPa, 56ml ~ 700hPa, 63 ml ~ 850hPa - input_levels: [34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 53.0, 56.0, 63.0, 68.0, 72.0] - ## remove: n_input_timestamps: 1 - # Output variables definition - output_vars: - - T2M # near surface temperature +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: True # will use tensorboardlogger + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: examples/ + #num_epochs: 400 + #dl_num_workers: 19 + #dl_prefetch_size: 1 + #learning_rate: 0.0001 + #limit_steps_train: 250 + #limit_steps_valid: 25 + #min_lr: 0.00001 + #max_lr: 0.0002 + #warm_up_steps: 0 + #mask_unit_size: + # - 15 + # - 16 + #mask_ratio_inputs: 0.0 + #mask_ratio_targets: 0.0 + #max_batch_size: 16 - n_input_timestamps: 2 + #path_experiment: experiment + #loss: rmse + #freeze_backbone: True + #freeze_decoder: False + #backbone_prefix: encoder. + #finetune_w_static: True + #strict_matching: true +data: + class_path: terratorch.datamodules.Merra2DownscaleNonGeoDataModule + init_args: + transforms_fn: granitewxc.utils.data._get_transforms + # Input variables definition + input_surface_vars: + - EFLUX + - GWETROOT + - HFLUX + - LAI + - LWGAB # surface absorbed longwave radiation + - LWGEM # longwave flux emitted from surface + - LWTUP # upwelling longwave flux at toa + - PS # surface pressure + - QV2M # 2-meter specific humidity + - SLP # sea level pressure + - SWGNT # surface net downward shortwave flux + - SWTNT # toa net downward shortwave flux + - T2M # near surface temperature + - TQI # total precipitable ice water + - TQL # total precipitable liquid water + - TQV # total precipitable water vapor + - TS # surface skin temperature + - U10M # 10m eastward wind + - V10M # 10m northward wind + - Z0M # surface roughness + input_static_surface_vars: [FRACI, FRLAND, FROCEAN, PHIS] + input_vertical_vars: + - CLOUD # cloud feraction for radiation + - H # geopotential/ mid layer heights + - OMEGA # vertical pressure velocity + - PL # mid level pressure + - QI # mass fraction of clous ice water + - QL # mass fraction of cloud liquid water + - QV # specific humidity + - T # tempertaure + - U # eastward wind + - V # northward wind + # (model level/ml ~ pressure level/hPa) + # 52ml ~ 562.5hPa, 56ml ~ 700hPa, 63 ml ~ 850hPa + input_levels: [34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 53.0, 56.0, 63.0, 68.0, 72.0] + ## remove: n_input_timestamps: 1 + # Output variables definition + output_vars: + - T2M # near surface temperature - # Data transformations - # Initial crop before any other processing - crop_lat: [0, 1] - # crop_lon: [0, 0] - # coarsening of target -- applied after crop - input_size_lat: 60 # 6x coarsening - input_size_lon: 96 # 6x coarsening - apply_smoothen: True + n_input_timestamps: 2 + # Data transformations + # Initial crop before any other processing + crop_lat: [0, 1] + # crop_lon: [0, 0] + # coarsening of target -- applied after crop + input_size_lat: 60 # 6x coarsening + input_size_lon: 96 # 6x coarsening + apply_smoothen: True + dict_kwargs: + data_path_surface: /home/jalmeida/Projetos/terratorch/examples/scripts/merra-2/ + data_path_vertical: /home/jalmeida/Projetos/terratorch/examples/scripts/merra-2/ + climatology_path_surface: /home/jalmeida/Projetos/terratorch/examples/scripts/climatology/ + climatology_path_vertical: /home/jalmeida/Projetos/terratorch/examples/scripts/climatology/ + time_range: + - '2020-01-01' + - '2020-01-02' model: - # Platform independent config - num_static_channels: 7 - embed_dim: 256 #0 - token_size: + class_path: WxCDownscalingTask + init_args: + model_args: + #checkpoint_path: examples/pytorch_model_.bin + num_static_channels: 7 + embed_dim: 256 #0 + token_size: - 1 - 1 - n_blocks_encoder: 6 #12 - mlp_multiplier: 4 - n_heads: 8 #16 - dropout_rate: 0.0 - drop_path: 0.05 - - # Accepted values: temporal, climate, none - residual: climate - - residual_connection: True - encoder_shift: False + n_blocks_encoder: 6 #12 + mlp_multiplier: 4 + n_heads: 16 + dropout_rate: 0.0 + drop_path: 0.05 + mask_unit_size: + - 15 + - 16 + residual_connection: True + model_factory: WxCModelFactory + extra_kwargs: + encoder_decoder_kernel_size_per_stage: [[3], [3]] # Optional, default = 3 for conv_tanspose [[3], [2]] + output_vars: + - T2M # near surface temperature + type: merra2 + input_scalers_surface_path: /home/jalmeida/Projetos/terratorch/examples/scripts/climatology/musigma_surface.nc + input_scalers_vertical_path: /home/jalmeida/Projetos/terratorch/examples/scripts/climatology/musigma_vertical.nc + output_scalers_surface_path: /home/jalmeida/Projetos/terratorch/examples/scripts/climatology/anomaly_variance_surface.nc + output_scalers_vertical_path: /home/jalmeida/Projetos/terratorch/examples/scripts/climatology/anomaly_variance_vertical.nc + input_levels: [34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 53.0, 56.0, 63.0, 68.0, 72.0] + downscaling_patch_size: [2, 2] + n_input_timestamps: 2 + downscaling_embed_dim: 256 + encoder_decoder_conv_channels: 128 + encoder_decoder_scale_per_stage: [[2], [3]] # First list determines before/after backbone + encoder_decoder_upsampling_mode: pixel_shuffle # ['nearest', 'bilinear', 'pixel_shuffle', 'conv_transpose'] + encoder_shift: False + drop_path: 0.05 + encoder_decoder_type: 'conv' # ['conv', 'transformer'] + input_size_lat: 60 # 6x coarsening + input_size_lon: 96 # 6x coarsening + freeze_backbone: True + freeze_decoder: False + data_path_surface: /home/jalmeida/Projetos/terratorch/examples/scripts/merra-2/ + data_path_vertical: /home/jalmeida/Projetos/terratorch/examples/scripts/merra-2/ + climatology_path_surface: /home/jalmeida/Projetos/terratorch/examples/scripts/climatology/ + climatology_path_vertical: /home/jalmeida/Projetos/terratorch/examples/scripts/climatology/ + model_config: + num_epochs: 200 + limit_steps_train: 250 + limit_steps_valid: 50 + batch_size: 16 + learning_rate: 0.0001 + min_lr: 0.00001 + dl_num_workers: 19 + dl_prefetch_size: 1 + path_experiment: experiment/ + warm_up_steps: 0 + mask_ratio_inputs: 0.0 + mask_ratio_targets: 0.0 # Accepted values: temporal, climate, none + job_id: inference-test + model_config: + num_static_channels: 7 + embed_dim: 256 #0 + token_size: + - 1 + - 1 + n_blocks_encoder: 6 #12 + mlp_multiplier: 4 + n_heads: 8 #16 + dropout_rate: 0.0 + #drop_path: 0.05 + residual: True + data_config: + surface_vars: + - EFLUX + - GWETROOT + - HFLUX + - LAI + - LWGAB # surface absorbed longwave radiation + - LWGEM # longwave flux emitted from surface + - LWTUP # upwelling longwave flux at toa + - PS # surface pressure + - QV2M # 2-meter specific humidity + - SLP # sea level pressure + - SWGNT # surface net downward shortwave flux + - SWTNT # toa net downward shortwave flux + - T2M # near surface temperature + - TQI # total precipitable ice water + - TQL # total precipitable liquid water + - TQV # total precipitable water vapor + - TS # surface skin temperature + - U10M # 10m eastward wind + - V10M # 10m northward wind + - Z0M # surface roughness + static_surface_vars: [FRACI, FRLAND, FROCEAN, PHIS] + vertical_vars: + - CLOUD # cloud feraction for radiation + - H # geopotential/ mid layer heights + - OMEGA # vertical pressure velocity + - PL # mid level pressure + - QI # mass fraction of clous ice water + - QL # mass fraction of cloud liquid water + - QV # specific humidity + - T # tempertaure + - U # eastward wind + - V # northward wind + # (model level/ml ~ pressure level/hPa) + # 52ml ~ 562.5hPa, 56ml ~ 700hPa, 63 ml ~ 850hPa + # levels: [34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 53.0, 56.0, 63.0, 68.0, 72.0] + ## remove: n_input_aimestamps: 1 + # Output variables definition + #residual_connection: True + #encoder_shift: False - downscaling_patch_size: [2, 2] - downscaling_embed_dim: 256 - encoder_decoder_type: 'conv' # ['conv', 'transformer'] - encoder_decoder_upsampling_mode: pixel_shuffle # ['nearest', 'bilinear', 'pixel_shuffle', 'conv_transpose'] - encoder_decoder_kernel_size_per_stage: [[3], [3]] # Optional, default = 3 for conv_tanspose [[3], [2]] - encoder_decoder_scale_per_stage: [[2], [3]] # First list determines before/after backbone - encoder_decoder_conv_channels: 128 - freeze_backbone: True - freeze_decoder: False - ignore_index: -1 + #downscaling_patch_size: [2, 2] + #downscaling_embed_dim: 256 + #encoder_decoder_type: 'conv' # ['conv', 'transformer'] + #encoder_decoder_upsampling_mode: pixel_shuffle # ['nearest', 'bilinear', 'pixel_shuffle', 'conv_transpose'] + #encoder_decoder_kernel_size_per_stage: [[3], [3]] # Optional, default = 3 for conv_tanspose [[3], [2]] + #encoder_decoder_scale_per_stage: [[2], [3]] # First list determines before/after backbone + #encoder_decoder_conv_channels: 128 + #freeze_backbone: True + #freeze_decoder: False + #ignore_index: -1 + #loss: rmse +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 6.e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss -job_id: inference-test -batch_size: 1 -num_epochs: 400 -dl_num_workers: 19 -dl_prefetch_size: 1 -learning_rate: 0.0001 -limit_steps_train: 250 -limit_steps_valid: 25 -min_lr: 0.00001 -max_lr: 0.0002 -warm_up_steps: 0 -mask_unit_size: - - 15 - - 16 -mask_ratio_inputs: 0.0 -mask_ratio_targets: 0.0 -max_batch_size: 16 -path_experiment: experiment -loss: rmse -freeze_backbone: True -freeze_decoder: False -backbone_prefix: encoder. -finetune_w_static: True -strict_matching: true diff --git a/examples/confs/sen1floods11_vit_peft.yaml b/examples/confs/sen1floods11_vit_peft.yaml new file mode 100644 index 00000000..3404e12b --- /dev/null +++ b/examples/confs/sen1floods11_vit_peft.yaml @@ -0,0 +1,144 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: True # will use tensorboardlogger + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: +data: + class_path: GenericNonGeoSegmentationDataModule + init_args: + batch_size: 16 + num_workers: 8 + constant_scale: 0.0001 + dataset_bands: + - COASTAL_AEROSOL + - BLUE + - GREEN + - RED + - RED_EDGE_1 + - RED_EDGE_2 + - RED_EDGE_3 + - NIR_BROAD + - NIR_NARROW + - WATER_VAPOR + - CIRRUS + - SWIR_1 + - SWIR_2 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: /v1.1/data/flood_events/HandLabeled/S2Hand/ + train_label_data_root: /v1.1/data/flood_events/HandLabeled/LabelHand + val_data_root: /v1.1/data/flood_events/HandLabeled/S2Hand/ + val_label_data_root: /v1.1/data/flood_events/HandLabeled/LabelHand + test_data_root: /v1.1/data/flood_events/HandLabeled/S2Hand/ + test_label_data_root: /v1.1/data/flood_events/HandLabeled/LabelHand + # these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files + train_split: /v1.1/splits/flood_handlabeled/flood_train_data.txt + test_split: /v1.1/splits/flood_handlabeled/flood_test_data.txt + val_split: /v1.1/splits/flood_handlabeled/flood_valid_data.txt + img_grep: "*_S2Hand.tif" + label_grep: "*_LabelHand.tif" + no_label_replace: -1 + no_data_replace: 0 + means: + - 0.1412956 + - 0.13795798 + - 0.12353792 + - 0.30902815 + - 0.2044958 + - 0.11912015 + stds: + - 0.07406382 + - 0.07370365 + - 0.08692279 + - 0.11798815 + - 0.09772074 + - 0.07659938 + num_classes: 2 + +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_args: + decoder: FCNDecoder + backbone_pretrained: true + backbone: prithvi_vit_100 + decoder_channels: 256 + backbone_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_classes: 2 + head_dropout: 0.1 + decoder_num_convs: 4 + head_channel_list: + - 256 + necks: + - name: SelectIndices + indices: + - -1 + - name: ReshapeTokensToImage + peft_config: + method: LORA + replace_qkv: qkv # As we want to apply LoRA separately and only to Q and V, we need to separate the matrix. + peft_config_kwargs: + target_modules: + - qkv.q_linear + - qkv.v_linear + - mlp.fc1 + - mlp.fc2 + loss: ce + aux_heads: + - name: aux_head + decoder: FCNDecoder + decoder_args: + decoder_channels: 256 + decoder_in_index: -1 + decoder_num_convs: 2 + head_dropout: 0.1 + # head_channel_list: + # - 64 + aux_loss: + aux_head: 1.0 + ignore_index: -1 + class_weights: + - 0.3 + - 0.7 + freeze_backbone: false + freeze_decoder: false + model_factory: EncoderDecoderFactory +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 6.e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss diff --git a/examples/confs/smp_model_factory.yaml b/examples/confs/smp_model_factory.yaml new file mode 100644 index 00000000..7d9a1d53 --- /dev/null +++ b/examples/confs/smp_model_factory.yaml @@ -0,0 +1,93 @@ +benchmark_suffix: smp_test +experiment_name: smp_test +backbone: + backbone: resnet18 + backbone_args: + pretrained: False + output_stride: 2 + smp_decoder_channels: 512 + smp_encoder_depth: 5 + + # backbone: swin3d.swin3d_backbone.Swin3dBackbone + # backbone_args: + # pretrained: False + # output_stride: 2 + # out_channels: + # - 192 + # - 384 + # - 768 + # - 768 + # smp_decoder_channels: 768 + # smp_encoder_depth: 5 + + +tasks: + - name: cashew + type: segmentation + loss: ce + model_factory: SMPModelFactory + bands: + - RED + - GREEN + - BLUE + num_classes: 7 + max_epochs: 60 + direction: max + datamodule: + class_path: terratorch.datamodules.MBeninSmallHolderCashewsNonGeoDataModule + init_args: + batch_size: 16 + num_workers: 4 + train_transform: + - class_path: albumentations.Resize + init_args: + always_apply: True + height: 224 + width: 224 + - class_path: ToTensorV2 + test_transform: + - class_path: albumentations.Resize + init_args: + always_apply: True + height: 224 + width: 224 + - class_path: ToTensorV2 + val_transform: + - class_path: albumentations.Resize + init_args: + height: 224 + width: 224 + - class_path: ToTensorV2 + data_root: "/dccstor/geofm-finetuning/geobench/segmentation_v1.0" + bands: + - "RED" + - "GREEN" + - "BLUE" + decoder: IdentityDecoder + decoder_args: + channels: 128 + metric: val/Multiclass Jaccard Index + +n_trials: 16 +save_models: False +storage_uri: /path/to/storage +optimization_space: + model: + - DeepLabV3 + lr: + min: 6e-5 + max: 1e-3 + type: real + log: true + batch_size: + - 8 + - 16 + - 32 + decoder_channels: + - 32 + - 64 + - 128 + head_dropout: + min: 0.2 + max: 0.8 + type: real \ No newline at end of file diff --git a/examples/notebooks/WxCTutorial.ipynb b/examples/notebooks/WxCTutorialDownscaling.ipynb similarity index 99% rename from examples/notebooks/WxCTutorial.ipynb rename to examples/notebooks/WxCTutorialDownscaling.ipynb index 251b1075..b779b7aa 100644 --- a/examples/notebooks/WxCTutorial.ipynb +++ b/examples/notebooks/WxCTutorialDownscaling.ipynb @@ -24,7 +24,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -U ../../." + "!pip install -U -e ../../." ] }, { @@ -317,7 +317,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/examples/notebooks/WxCTutorialGravityWave.ipynb b/examples/notebooks/WxCTutorialGravityWave.ipynb new file mode 100644 index 00000000..b12f3f2a --- /dev/null +++ b/examples/notebooks/WxCTutorialGravityWave.ipynb @@ -0,0 +1,278 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prithvi WxC Gravity Wave: Model Fine Tuning and Inference using TerraTorch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -U git+https://github.com/romeokienzler/terratorch.git@201\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -U albumentations # fix until https://github.com/IBM/terratorch/issues/164 is solved" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -U git+https://github.com/romeokienzler/gravity-wave-finetuning.git\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install huggingface_hub" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/home/romeokienzler/.local/lib/python3.12/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n", + " @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)\n" + ] + } + ], + "source": [ + "import terratorch # this import is needed to initialize TT's factories\n", + "from lightning.pytorch import Trainer\n", + "import os\n", + "import torch\n", + "from huggingface_hub import hf_hub_download, snapshot_download\n", + "from terratorch.models.wxc_model_factory import WxCModelFactory\n", + "import torch.distributed as dist" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "os.environ['MASTER_ADDR'] = 'localhost'\n", + "os.environ['MASTER_PORT'] = '12355' \n", + "\n", + "if dist.is_initialized():\n", + " dist.destroy_process_group()\n", + "\n", + "dist.init_process_group(\n", + " backend='gloo',\n", + " init_method='env://', \n", + " rank=0,\n", + " world_size=1\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'config.yaml'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/Gravity_wave_Parameterization\",\n", + " filename=f\"magnet-flux-uvtp122-epoch-99-loss-0.1022.pt\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/Gravity_wave_Parameterization\",\n", + " filename=f\"config.yaml\",\n", + " local_dir=\".\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/Gravity_wave_Parameterization\",\n", + " repo_type='dataset',\n", + " filename=f\"wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc\",\n", + " local_dir=\".\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading weights from magnet-flux-uvtp122-epoch-99-loss-0.1022.pt\n", + "Loaded weights\n" + ] + } + ], + "source": [ + "from prithviwxc.gravitywave.datamodule import ERA5DataModule\n", + "from terratorch.tasks.wxc_gravity_wave_task import WxCGravityWaveTask\n", + "task = WxCGravityWaveTask(WxCModelFactory())" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: GPU available: False, used: False\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n" + ] + }, + { + "data": { + "text/plain": [ + "prithviwxc.gravitywave.datamodule.ERA5DataModule" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer = Trainer(\n", + " max_epochs=1,\n", + ")\n", + "dm = ERA5DataModule(train_data_path='.', valid_data_path='.')\n", + "type(dm)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predicting DataLoader 0: 19%|█████████████████████████████████▏ | 9/47 [1:06:22<4:40:15, 0.00it/s]" + ] + } + ], + "source": [ + "results = trainer.predict(model=task, datamodule=dm, return_predictions=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dm.setup(stage='predict')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results = trainer.train(model=task, datamodule=dm, return_predictions=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dist.destroy_process_group()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/scripts/WxCTrain.py b/examples/scripts/WxCTrain.py index 8490e62e..5b67755c 100644 --- a/examples/scripts/WxCTrain.py +++ b/examples/scripts/WxCTrain.py @@ -36,7 +36,7 @@ # -config = get_config('../confs/granite-wxc-merra2-downscale-small-config.yaml') +config = get_config('../confs/granite-wxc-merra2-downscale-config.yaml') config.download_path = './' config.data.data_path_surface = os.path.join(config.download_path,'merra-2') diff --git a/examples/scripts/WxC_downloader.py b/examples/scripts/WxC_downloader.py new file mode 100644 index 00000000..7ea8dd70 --- /dev/null +++ b/examples/scripts/WxC_downloader.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python +# coding: utf-8 + +# # Prithvi WxC Downscaling: Model Inference using TerraTorch +# This notebook is a walk through to use a finetuned downscaling model to generate inferences using TerraTorch. We show how to initalize the model, load weights, and use the model for inference using TerraTorch. +# +# Note to set up your environment by running the following cells. (We recommend to run this notebook in an empty pyton 3.11 environment) +# (e.g., +# python3.11 -m venv .venv +# source .venv/bin/activate +# ) +# +# We assume that you've cloned terratorch with: +# git clone https://github.com/IBM/terratorch.git +# And you run this notebook from terratorch/examples/notebooks + +# + +import terratorch # this import is needed to initialize TT's factories +from lightning.pytorch import Trainer +import os, glob +from granitewxc.utils.config import get_config +import torch +from terratorch.tasks.wxc_downscaling_task import WxCDownscalingTask +from terratorch.datamodules.merra2_downscale import Merra2DownscaleNonGeoDataModule +from granitewxc.utils.data import _get_transforms +from huggingface_hub import hf_hub_download, snapshot_download + +files = glob.glob("merra-2/*") + +if not len(files): + + snapshot_download( + repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", + allow_patterns="merra-2/MERRA2_sfc_2020010[1].nc", + local_dir=".", + ) + + snapshot_download( + repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", + allow_patterns="merra-2/MERRA_pres_2020010[1].nc", + local_dir=".", + ) + + +# +# ## Climatology +# +# The PrithviWxC model was trained to calculate the output by producing a perturbation to the climatology at the target time. This mode of operation is set via the residual=climate option. This was chosen as climatology is typically a strong prior for long-range prediction. When using the residual=climate option, we have to provide the dataloader with the path of the climatology data. +# + +# + +files = glob.glob("climatology/*") + +if not len(files): + snapshot_download( + repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", + allow_patterns="climatology/climate_surface_doy00[1]*.nc", + local_dir=".", + ) + + snapshot_download( + repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", + allow_patterns="climatology/climate_vertical_doy00[1]*.nc", + local_dir=".", + ) + + hf_hub_download( + repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", + filename=f"climatology/anomaly_variance_surface.nc", + local_dir=".", + ) + + hf_hub_download( + repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", + filename=f"climatology/anomaly_variance_vertical.nc", + local_dir=".", + ) + + hf_hub_download( + repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", + filename=f"climatology/musigma_surface.nc", + local_dir=".", + ) + + hf_hub_download( + repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", + filename=f"climatology/musigma_vertical.nc", + local_dir=".", + ) + +# + + diff --git a/examples/scripts/sen1floods11_vit_fit.sh b/examples/scripts/sen1floods11_vit_fit.sh new file mode 100755 index 00000000..ff6f05c2 --- /dev/null +++ b/examples/scripts/sen1floods11_vit_fit.sh @@ -0,0 +1,53 @@ +#!/bin/sh + +# We assue following environ variables are available from Slurm: +# SLURM_JOB_NUM_NODES - Number of nodes +# SLURM_NTASKS_PER_NODE - Nuber of GPUs (tasks) per node + +cat < +# where +# script-name: Name of script to run in parallel (e.g. sen1floods11_vit_fit.sh) +# num-nodes: Number of nodes (1, 2, ...) +# num-gpus Number of GPUs in a node (1, 2, 3, or 4) +script_name=$1 +num_nodes=$2 +num_gpus=$3 +logfile="log/slurm-%j.out" +if [ $# -ne 3 ]; then + echo "Usage: %0 " + exit 1 +fi +# Used to set CUDA_VISIBLE_DEVICES so that all GPUs on a node can be accessible from all tasks on the node +devices=("" "0" "0,1" "0,1,2" "0,1,2,3") + +sbatch -o ${logfile} -e ${logfile} <=6.5 pytest-cov==4.1.0 -pytest==8.2.2 \ No newline at end of file +pytest==8.3.3 \ No newline at end of file diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 183fa81c..cae53f19 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -3,7 +3,7 @@ import importlib.util import itertools import json -import logging # noqa: I001 +import logging import os import shutil import sys @@ -26,10 +26,10 @@ from albumentations.pytorch import ToTensorV2 # noqa: F401 from jsonargparse import set_dumper from lightning.fabric.utilities.cloud_io import get_filesystem -from lightning.fabric.utilities.types import _PATH # noqa: F401 -from lightning.pytorch import LightningDataModule, LightningModule, Trainer # noqa: F401 +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.callbacks import BasePredictionWriter, ModelCheckpoint, RichProgressBar -from lightning.pytorch.cli import ArgsType, LightningArgumentParser, LightningCLI, SaveConfigCallback # noqa: F401 +from lightning.pytorch.cli import ArgsType, LightningArgumentParser, LightningCLI, SaveConfigCallback from torchgeo.trainers import BaseTask import terratorch.datamodules @@ -56,7 +56,7 @@ SemanticSegmentationTask, # noqa: F401 ) -CUSTOM_MODULES_DIR_NAME = "custom_modules" +logger = logging.getLogger(__name__) def flatten(list_of_lists): return list(itertools.chain.from_iterable(list_of_lists)) @@ -74,6 +74,7 @@ def is_one_band(img): def write_tiff(img_wrt, filename, metadata): + with rasterio.open(filename, "w", **metadata) as dest: if is_one_band(img_wrt): img_wrt = img_wrt[None] @@ -87,7 +88,7 @@ def save_prediction(prediction, input_file_name, out_dir, dtype:str="int16"): mask, metadata = open_tiff(input_file_name) mask = np.where(mask == metadata["nodata"], 1, 0) mask = np.max(mask, axis=0) - result = np.where(mask == 1, -1, prediction) + result = np.where(mask == 1, -1, prediction.detach().cpu()) ##### Save file to disk metadata["count"] = 1 @@ -97,9 +98,32 @@ def save_prediction(prediction, input_file_name, out_dir, dtype:str="int16"): file_name = os.path.basename(input_file_name) file_name_no_ext = os.path.splitext(file_name)[0] out_file_name = file_name_no_ext + "_pred.tif" - logging.info(f"Saving output to {out_file_name} ...") + logger.info(f"Saving output to {out_file_name} ...") write_tiff(result, os.path.join(out_dir, out_file_name), metadata) +def import_custom_modules(custom_modules_path:None | Path | str =None) -> None: + + if custom_modules_path: + + custom_modules_path = Path(custom_modules_path) + + if custom_modules_path.is_dir(): + + # Add 'custom_modules' folder to sys.path + workdir = custom_modules_path.parents[0] + module_dir = custom_modules_path.name + + sys.path.append(workdir) + + try: + importlib.import_module(module_dir) + logger.info(f"Found {custom_modules_path}") + except ImportError: + raise ImportError(f"It was not possible to import modules from {custom_modules_path}.") + else: + raise ValueError(f"Modules path {custom_modules_path} isn't a directory. Check if you have defined it properly.") + else: + logger.info("No custom module is being used.") class CustomWriter(BasePredictionWriter): """Callback class to write geospatial data to file.""" @@ -110,6 +134,29 @@ def __init__(self, output_dir: str | None = None, write_interval: str = "epoch") self.output_dir = output_dir + def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx): # noqa: ARG002 + # this will create N (num processes) files in `output_dir` each containing + # the predictions of it's respective rank + + # by default take self.output_dir. If None, look for one in trainer + if self.output_dir is None: + try: + output_dir = trainer.predict_output_dir + except AttributeError as err: + msg = "Output directory must be passed to CustomWriter constructor or the `predict_output_dir`\ + attribute must be present in the trainer." + raise Exception(msg) from err + else: + output_dir = self.output_dir + + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + pred_batch, filename_batch = prediction + + for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False): + save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype) + def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # noqa: ARG002 # this will create N (num processes) files in `output_dir` each containing # the predictions of it's respective rank @@ -182,8 +229,8 @@ def __init__( # Preparing information to save config file to log dir config_dict = config.as_dict() self.config_path_original = str(config_dict["config"][0]) - _, self.config_file_original = os.path.split(self.config_path_original) - + _, self.config_file_original = os.path.split(self.config_path_original) + self.deploy_config_file = config_dict["deploy_config_file"] def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: @@ -302,6 +349,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.add_argument("--predict_output_dir", default=None) parser.add_argument("--out_dtype", default="int16") parser.add_argument("--deploy_config_file", type=bool, default=True) + parser.add_argument("--custom_modules_path", type=str, default=None) # parser.set_defaults({"trainer.enable_checkpointing": False}) @@ -320,6 +368,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.link_arguments("ModelCheckpoint.dirpath", "StateDictModelCheckpoint.dirpath") def instantiate_classes(self) -> None: + super().instantiate_classes() # get the predict_output_dir. Depending on the value of run, it may be in the subcommand try: @@ -328,13 +377,22 @@ def instantiate_classes(self) -> None: config = self.config if hasattr(config, "predict_output_dir"): self.trainer.predict_output_dir = config.predict_output_dir - + if hasattr(config, "out_dtype"): self.trainer.out_dtype = config.out_dtype if hasattr(config, "deploy_config_file"): self.trainer.deploy_config = config.deploy_config_file + # Custom modules path + if hasattr(self.config.fit, "custom_modules_path"): + + custom_modules_path = self.config.fit.custom_modules_path + else: + default_path = Path(".") / "custom_modules" + custom_modules_path = os.environ.get("TERRATORCH_CUSTOM_MODULE_PATH", default_path) + + import_custom_modules(custom_modules_path) def build_lightning_cli( args: ArgsType = None, @@ -363,15 +421,6 @@ def build_lightning_cli( stacklevel=1, ) - # import any custom modules - current_working_dir = os.getcwd() - custom_modules_path = os.path.join(current_working_dir, CUSTOM_MODULES_DIR_NAME) - if os.path.exists(custom_modules_path) and os.path.isdir(custom_modules_path): - # Add 'custom_modules' folder to sys.path - sys.path.append(os.getcwd()) - logging.info(f"Found {CUSTOM_MODULES_DIR_NAME}") - importlib.import_module(CUSTOM_MODULES_DIR_NAME) - return MyLightningCLI( model_class=BaseTask, subclass_mode_model=True, @@ -381,7 +430,7 @@ def build_lightning_cli( save_config_kwargs={"overwrite": True}, args=args, # save only state_dict as well as full state. Only state_dict will be used for exporting the model - trainer_defaults={"callbacks": [CustomWriter(write_interval="epoch")]}, + trainer_defaults={"callbacks": [CustomWriter(write_interval="batch")]}, run=run, ) @@ -481,6 +530,7 @@ def inference_on_dir(self, data_root: Path | None = None) -> tuple[torch.Tensor, A tuple with a torch tensor with all predictions and a list of corresponding input file paths """ + if data_root: self.datamodule.predict_root = data_root predictions = self.trainer.predict(model=self.model, datamodule=self.datamodule, return_predictions=True) diff --git a/terratorch/datamodules/__init__.py b/terratorch/datamodules/__init__.py index b75da89b..998eca33 100644 --- a/terratorch/datamodules/__init__.py +++ b/terratorch/datamodules/__init__.py @@ -28,6 +28,13 @@ from terratorch.datamodules.open_sentinel_map import OpenSentinelMapDataModule from terratorch.datamodules.pastis import PASTISDataModule +try: + wxc_present = True + from terratorch.datamodules.merra2_downscale import Merra2DownscaleNonGeoDataModule +except ImportError as e: + print('wxc_downscaling not installed') + wxc_present = False + # GenericNonGeoRegressionDataModule, from terratorch.datamodules.sen1floods11 import Sen1Floods11NonGeoDataModule from terratorch.datamodules.sen4agrinet import Sen4AgriNetDataModule @@ -74,7 +81,11 @@ "MPv4gerSegNonGeoDataModule", "MSACropTypeNonGeoDataModule", "MNeonTreeNonGeoDataModule", + "OpenEarthMapModule" "OpenSentinelMapDataModule", "PASTISDataModule", "Sen4AgriNetDataModule" ) + +if wxc_present: + __all__.__add__(("Merra2DownscaleNonGeoDataModule", )) \ No newline at end of file diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index 698262e5..52586025 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -526,6 +526,7 @@ def setup(self, stage: str) -> None: if stage in ["predict"] and self.predict_root: self.predict_dataset = self.dataset_class( self.predict_root, + image_grep=self.img_grep, dataset_bands=self.predict_dataset_bands, output_bands=self.predict_output_bands, constant_scale=self.constant_scale, diff --git a/terratorch/datamodules/merra2_downscale.py b/terratorch/datamodules/merra2_downscale.py index e775e76c..f428d573 100644 --- a/terratorch/datamodules/merra2_downscale.py +++ b/terratorch/datamodules/merra2_downscale.py @@ -1,15 +1,27 @@ from torch._tensor import Tensor from granitewxc.datasets.merra2 import Merra2DownscaleDataset from torchgeo.datamodules import NonGeoDataModule -from typing import Any +from typing import Any, Callable from granitewxc.datasets.merra2 import Merra2DownscaleDataset +from granitewxc.utils.config import ExperimentConfig from torch.utils.data.dataloader import DataLoader from torch._tensor import Tensor +from typing import Callable class Merra2DownscaleDatasetTerraTorch(Merra2DownscaleDataset): def __init__(self, split : str, **kwargs): + # Strict variables + crop_lat = kwargs.pop("crop_lat") + input_size_lat = kwargs.pop("input_size_lat") + input_size_lon = kwargs.pop("input_size_lon") + apply_smoothen = kwargs.pop("apply_smoothen") + super().__init__(**kwargs) self.split = split + self.crop_lat = crop_lat + self.input_size_lat = input_size_lat + self.input_size_lon = input_size_lon + self.apply_smoothen = apply_smoothen def __getitem__(self, index) -> dict[Tensor | int]: batch = super().__getitem__(index) @@ -22,8 +34,46 @@ def __getitem__(self, index) -> dict[Tensor | int]: class Merra2DownscaleNonGeoDataModule(NonGeoDataModule): - def __init__(self, **kwargs: Any) -> None: - super().__init__(Merra2DownscaleDatasetTerraTorch, **kwargs) + def __init__(self, input_surface_vars: list[int | tuple[int, int] | str] | None = None, + input_static_surface_vars: list[int | tuple[int, int] | str] | None = None, + input_vertical_vars: list[int | tuple[int, int] | str] | None = None, + input_levels: list[float] = None, + output_vars: list[int | tuple[int, int] | str] | None = None, + n_input_timestamps: int = 2, + crop_lat: list[int] = None, + crop_lon: list[int] = None, + input_size_lat: int = 60, + input_size_lon: int = 60, + apply_smoothen: bool = True, + transforms_fn: Callable = None, + **kwargs: Any) -> None: + + + config = ExperimentConfig.from_dict({ + 'data':{ + 'crop_lat':crop_lat, + 'crop_lon': crop_lon, + 'apply_smoothen': apply_smoothen, + 'input_size_lat': input_size_lat, + 'input_size_lon': input_size_lon + }, + 'model': {} + }) + + super().__init__(Merra2DownscaleDatasetTerraTorch, + input_surface_vars=input_surface_vars, + input_static_surface_vars=input_static_surface_vars, + input_vertical_vars=input_vertical_vars, + input_levels=input_levels, + output_vars=output_vars, + n_input_timestamps=n_input_timestamps, + crop_lat=crop_lat, + input_size_lat=input_size_lat, + input_size_lon=input_size_lon, + apply_smoothen=apply_smoothen, + transforms=transforms_fn(config), + **kwargs) + self.aug = lambda x: x def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: diff --git a/terratorch/datamodules/openearthmap.py b/terratorch/datamodules/openearthmap.py new file mode 100644 index 00000000..613a6425 --- /dev/null +++ b/terratorch/datamodules/openearthmap.py @@ -0,0 +1,58 @@ +from typing import Any +import torch + +import albumentations as A +import kornia.augmentation as K +from torchgeo.datamodules import NonGeoDataModule +from torchgeo.transforms import AugmentationSequential +from terratorch.datasets import OpenEarthMapNonGeo +from terratorch.datamodules.utils import wrap_in_compose_is_list + +MEANS = { + "BLUE": 116.628328, + "GREEN": 119.65935, + "RED": 113.385309 +} + +STDS = { + "BLUE": 44.668890717415586, + "GREEN": 48.282311849967364, + "RED": 54.19692448815262, +} + +class OpenEarthMapNonGeoDataModule(NonGeoDataModule): + def __init__( + self, + batch_size: int = 8, + num_workers: int = 0, + data_root: str = "./", + train_transform: A.Compose | None | list[A.BasicTransform] = None, + val_transform: A.Compose | None | list[A.BasicTransform] = None, + test_transform: A.Compose | None | list[A.BasicTransform] = None, + aug: AugmentationSequential = None, + **kwargs: Any + ) -> None: + super().__init__(OpenEarthMapNonGeo, batch_size, num_workers, **kwargs) + + bands = kwargs.get("bands", OpenEarthMapNonGeo.all_band_names) + self.means = torch.tensor([MEANS[b] for b in bands]) + self.stds = torch.tensor([STDS[b] for b in bands]) + self.train_transform = wrap_in_compose_is_list(train_transform) + self.val_transform = wrap_in_compose_is_list(val_transform) + self.test_transform = wrap_in_compose_is_list(test_transform) + self.data_root = data_root + self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug + + def setup(self, stage: str) -> None: + if stage in ["fit"]: + self.train_dataset = self.dataset_class( + split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs + ) + if stage in ["fit", "validate"]: + self.val_dataset = self.dataset_class( + split="val", data_root=self.data_root, transform=self.val_transform, **self.kwargs + ) + if stage in ["test"]: + self.test_dataset = self.dataset_class( + split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs + ) \ No newline at end of file diff --git a/terratorch/datamodules/sen1floods11.py b/terratorch/datamodules/sen1floods11.py index e849abee..d2699076 100644 --- a/terratorch/datamodules/sen1floods11.py +++ b/terratorch/datamodules/sen1floods11.py @@ -70,7 +70,7 @@ def __init__( means = [MEANS[b] for b in bands] stds = [STDS[b] for b in bands] - self.bands = bands, + self.bands = bands self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) diff --git a/terratorch/datasets/__init__.py b/terratorch/datasets/__init__.py index 41cd2fa7..19dd3b61 100644 --- a/terratorch/datasets/__init__.py +++ b/terratorch/datasets/__init__.py @@ -43,9 +43,14 @@ # TorchGeo RasterDatasets from terratorch.datasets.wsf import WSF2019, WSFEvolution + +# miscellaneous datasets +from terratorch.datasets.openearthmap import OpenEarthMapNonGeo + # Generic Classification Dataset from terratorch.datasets.sen4map import Sen4MapDatasetMonthlyComposites + __all__ = ( "GenericNonGeoSegmentationDataset", "GenericNonGeoPixelwiseRegressionDataset", @@ -82,4 +87,5 @@ "WSFEvolution", "HLSL30", "HLSS30", + "OpenEarthMapNonGeo" ) diff --git a/terratorch/datasets/openearthmap.py b/terratorch/datasets/openearthmap.py new file mode 100644 index 00000000..bad3bbcd --- /dev/null +++ b/terratorch/datasets/openearthmap.py @@ -0,0 +1,114 @@ +import numpy as np +from collections.abc import Sequence +import matplotlib.pyplot as plt +import torch +import rasterio +from pathlib import Path + +import albumentations as A + +from torchgeo.datasets import NonGeoDataset +from terratorch.datasets.utils import to_tensor + + + +class OpenEarthMapNonGeo(NonGeoDataset): + + all_band_names = ("BLUE","GREEN","RED") + + rgb_bands = ("RED","GREEN","BLUE") + + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + + def __init__(self, data_root: str, + bands: Sequence[str] = BAND_SETS["all"], + transform: A.Compose | None = None, + split="train", + crop_size: int = 256, + random_crop: bool = True) -> None: + super().__init__() + if split not in ["train", "test", "val"]: + msg = "Split must be one of train, test, val." + raise Exception(msg) + + self.transform = transform if transform else lambda **batch: to_tensor(batch, transpose=False) + self.split = split + self.data_root = data_root + + # images in openearthmap are not all 1024x1024 and must be cropped + self.crop_size = crop_size + self.random_crop = random_crop + + assert self.crop_size > 0, "Crop size must be greater than 0" + + self.image_files = self._get_file_paths(Path(self.data_root, f"{split}.txt")) + + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: + image_path, label_path = self.image_files[index] + + with rasterio.open(image_path) as src: + image = src.read() + with rasterio.open(label_path) as src: + mask = src.read() + + # some images in the dataset are not perfect squares + # cropping to fit to the prepare_features_for_image_model call + if self.random_crop: + image, mask = self._random_crop(image, mask) + else: + image, mask = self._center_crop(image, mask) + + output = { + "image": image.astype(np.float32), + "mask": mask + } + + output = self.transform(**output) + output['mask'] = output['mask'].long() + + return output + + def _parse_file_name(self, file_name: str): + underscore_pos = file_name.rfind('_') + folder_name = file_name[:underscore_pos] + region_path = Path(self.data_root, folder_name) + image_path = Path(region_path, "images", file_name) + label_path = Path(region_path, "labels", file_name) + return image_path, label_path + + def _get_file_paths(self, text_file_path: str): + with open(text_file_path, 'r') as file: + lines = file.readlines() + file_paths = [self._parse_file_name(line.strip()) for line in lines] + return file_paths + + def __len__(self): + return len(self.image_files) + + def _random_crop(self, image, mask): + h, w = image.shape[1:] + top = np.random.randint(0, h - self.crop_size) + left = np.random.randint(0, w - self.crop_size) + + image = image[:, top: top + self.crop_size, left: left + self.crop_size] + mask = mask[:, top: top + self.crop_size, left: left + self.crop_size] + + return image, mask + + def _center_crop(self, image, mask): + h, w = image.shape[1:] + top = (h - self.crop_size) // 2 + left = (w - self.crop_size) // 2 + + image = image[:, top: top + self.crop_size, left: left + self.crop_size] + mask = mask[:, top: top + self.crop_size, left: left + self.crop_size] + + return image, mask + + def plot(self, arg, suptitle: str | None = None) -> None: + pass + + def plot_sample(self, sample, prediction=None, suptitle: str | None = None, class_names=None): + pass + + \ No newline at end of file diff --git a/terratorch/datasets/sen4map.py b/terratorch/datasets/sen4map.py index 73ecb21f..89fee9d2 100644 --- a/terratorch/datasets/sen4map.py +++ b/terratorch/datasets/sen4map.py @@ -14,7 +14,7 @@ class Sen4MapDatasetMonthlyComposites(Dataset): # This dictionary maps the LUCAS classes to Land-cover classes. - land_use_classification_map={'A10':0, 'A11':0, 'A12':0, 'A13':0, + land_cover_classification_map={'A10':0, 'A11':0, 'A12':0, 'A13':0, 'A20':0, 'A21':0, 'A30':0, 'A22':1, 'F10':1, 'F20':1, 'F30':1, 'F40':1, @@ -45,12 +45,12 @@ class Sen4MapDatasetMonthlyComposites(Dataset): crop_classification_map = { "B11":0, "B12":0, "B13":0, "B14":0, "B15":0, "B16":0, "B17":0, "B18":0, "B19":0, # Cereals "B21":1, "B22":1, "B23":1, # Root Crops - "B34":2, "B35":2, "B36":2, "B37":2, # Nonpermanent Industrial Crops - "B31":3, "B32":3, "B33":3, "B41":3, "B42":3, "B43":3, "B44":3, "B45":3, # Dry Pulses, Vegetables and Flowers + "B31":2, "B32":2, "B33":2, "B34":2, "B35":2, "B36":2, "B37":2, # Nonpermanent Industrial Crops + "B41":3, "B42":3, "B43":3, "B44":3, "B45":3, # Dry Pulses, Vegetables and Flowers "B51":4, "B52":4, "B53":4, "B54":4, # Fodder Crops "F10":5, "F20":5, "F30":5, "F40":5, # Bareland "B71":6, "B72":6, "B73":6, "B74":6, "B75":6, "B76":6, "B77":6, - "B81":6, "B82":6, "B83":6, "B84":6, "C10":6, "C20":6, "C30":6, "D10":6, "D20":6, # Woodland and Shrubland + "B81":6, "B82":6, "B83":6, "B84":6, "C10":6, "C21":6, "C22":6, "C23":6, "C31":6, "C32":6, "C33":6, "D10":6, "D20":6, # Woodland and Shrubland "B55":7, "E10":7, "E20":7, "E30":7, # Grassland } @@ -68,10 +68,11 @@ def __init__( reverse_tile = False, reverse_tile_size = 3, save_keys_path = None, - classification_map = "land-use" + classification_map = "land-cover" ): self.h5data = h5py_file_object if h5data_keys is None: + if classification_map == "crops": print(f"Crop classification task chosen but no keys supplied. Will fail unless dataset hdf5 files have been filtered. Either filter dataset files or create a filtered set of keys.") self.h5data_keys = list(self.h5data.keys()) if save_keys_path is not None: with open(save_keys_path, "wb") as file: @@ -87,7 +88,7 @@ def __init__( self.input_channels = [dataset_bands.index(band_ind) for band_ind in input_bands if band_ind in dataset_bands] else: self.input_channels = None - classification_maps = {"land-use": Sen4MapDatasetMonthlyComposites.land_use_classification_map, + classification_maps = {"land-cover": Sen4MapDatasetMonthlyComposites.land_cover_classification_map, "crops": Sen4MapDatasetMonthlyComposites.crop_classification_map} if classification_map not in classification_maps.keys(): raise ValueError(f"Provided classification_map of: {classification_map}, is not from the list of valid ones: {classification_maps}") diff --git a/terratorch/datasets/utils.py b/terratorch/datasets/utils.py index befc48a6..d647f78b 100644 --- a/terratorch/datasets/utils.py +++ b/terratorch/datasets/utils.py @@ -87,13 +87,13 @@ def _split_filter_function(file_name, valid_files: list[str], ignore_extensions= return False -def to_tensor(d): +def to_tensor(d, transpose=True): new_dict = {} for k, v in d.items(): if not isinstance(v, np.ndarray): new_dict[k] = v else: - if k == "image": + if k == "image" and transpose: v = np.moveaxis(v, -1, 0) new_dict[k] = torch.from_numpy(v) return new_dict diff --git a/terratorch/models/__init__.py b/terratorch/models/__init__.py index 1fd5224c..181a8930 100644 --- a/terratorch/models/__init__.py +++ b/terratorch/models/__init__.py @@ -1,11 +1,13 @@ # Copyright contributors to the Terratorch project + import logging import terratorch.models.necks # register necks # noqa: F401 from terratorch.models.encoder_decoder_factory import EncoderDecoderFactory from terratorch.models.generic_unet_model_factory import GenericUnetModelFactory from terratorch.models.prithvi_model_factory import PrithviModelFactory +from terratorch.models.clay_model_factory import ClayModelFactory from terratorch.models.satmae_model_factory import SatMAEModelFactory from terratorch.models.smp_model_factory import SMPModelFactory from terratorch.models.timm_model_factory import TimmModelFactory diff --git a/terratorch/models/backbones/__init__.py b/terratorch/models/backbones/__init__.py index da4dba5d..48b1eb37 100644 --- a/terratorch/models/backbones/__init__.py +++ b/terratorch/models/backbones/__init__.py @@ -3,6 +3,7 @@ # import so they get registered import terratorch.models.backbones.prithvi_swin import terratorch.models.backbones.prithvi_vit +import terratorch.models.backbones.clay_v1 import terratorch.models.backbones.scalemae from terratorch.models.backbones.unet import UNet diff --git a/terratorch/models/backbones/clay_v1/__init__.py b/terratorch/models/backbones/clay_v1/__init__.py new file mode 100644 index 00000000..8ef0f248 --- /dev/null +++ b/terratorch/models/backbones/clay_v1/__init__.py @@ -0,0 +1,3 @@ +from terratorch.models.backbones.clay_v1.embedder import * +from terratorch.models.backbones.clay_v1.modules import * +from terratorch.models.backbones.clay_v1.utils import * \ No newline at end of file diff --git a/terratorch/models/backbones/clay_v1/embedder.py b/terratorch/models/backbones/clay_v1/embedder.py new file mode 100644 index 00000000..6ffdb1fb --- /dev/null +++ b/terratorch/models/backbones/clay_v1/embedder.py @@ -0,0 +1,174 @@ +import re +import warnings + +import numpy as np +import torch +from torch import nn, Tensor +import torch +from timm.models import FeatureInfo +from timm.models._builder import build_model_with_cfg +from timm.models._registry import generate_default_cfgs, register_model + +from terratorch.models.backbones.clay_v1.modules import EmbeddingEncoder, Datacuber + + +default_cfgs = generate_default_cfgs( + { + "clay_v1_base": { + "hf_hub_id": "made-with-clay/Clay", + "hf_hub_filename": "clay-v1-base.ckpt" + } + } +) + + +class Embedder(nn.Module): + default_out_indices = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) + + def __init__( + self, + img_size=256, + num_frames=1, + ckpt_path=None, + bands=["blue", "green", "red", "nir", "swir16", "swir22"], + out_indices: tuple[int] = default_out_indices, + **kwargs, + ): + super().__init__() + self.feature_info = [] + self.img_size = img_size + self.num_frames = num_frames + self.bands = bands + self.out_indices = out_indices + + self.datacuber = Datacuber(bands=bands) + + # TODO: add support for various clay versions + self.clay_encoder = ( + EmbeddingEncoder( # Default parameters for the Clay base model + img_size=img_size, + patch_size=8, + dim=768, + depth=12, + heads=12, + dim_head=64, + mlp_ratio=4.0, + ) + ) + + # for use in features list. + for i in range(12): + self.feature_info.append({"num_chs": 768, "reduction": 1, "module": f"blocks.{i}"}) + + # assuming this is used to fine tune a network on top of the embeddings + + if ckpt_path: + self.load_clay_weights(ckpt_path) + + def load_clay_weights(self, ckpt_path): + "Load the weights from the Clay model encoder." + ckpt = torch.load(ckpt_path) + state_dict = ckpt.get("state_dict") + state_dict = { + re.sub(r"^model\.encoder\.", "", name): param + for name, param in state_dict.items() + if name.startswith("model.encoder") + } + + with torch.no_grad(): + for name, param in self.clay_encoder.named_parameters(): + if name in state_dict and param.size() == state_dict[name].size(): + param.data.copy_(state_dict[name]) # Copy the weights + else: + print( + f"No matching parameter for {name} with size {param.size()}") + + for param in self.clay_encoder.parameters(): + param.requires_grad = False + + self.clay_encoder.eval() + + @staticmethod + def transform_state_dict(state_dict, model): + state_dict = state_dict.get("state_dict") + state_dict = { + re.sub(r"^model\.encoder\.", "clay_encoder.", name): param + for name, param in state_dict.items() + if name.startswith("model.encoder") + } + return state_dict + + def forward_features( + self, + x: torch.Tensor, + time: torch.Tensor | None = None, + latlon: torch.Tensor | None = None, + waves: torch.Tensor | None = None, + gsd: float | None = None, + ): + datacube = self.datacuber(x=x, time=time, latlon=latlon, waves=waves, gsd=gsd) + embeddings = self.clay_encoder(datacube) + + return [embeddings[i] for i in self.out_indices] + + def fake_datacube(self): + "Generate a fake datacube for model export." + dummy_datacube = { + "pixels": torch.randn(2, 3, self.img_size, self.img_size), + "time": torch.randn(2, 4), + "latlon": torch.randn(2, 4), + "waves": torch.randn(3), + "gsd": torch.randn(1), + } + dummy_datacube = {k: v + for k, v in dummy_datacube.items()} + return dummy_datacube + + def prepare_features_for_image_model(self, features: list[Tensor]) -> list[Tensor]: + x_no_token = features[-1][:, 1:, :] + encoded = x_no_token.permute(0, 2, 1).reshape( + x_no_token.shape[0], + -1, + int(np.sqrt(x_no_token.shape[1] // self.num_frames)), + int(np.sqrt(x_no_token.shape[1] // self.num_frames)), + ) + + # return as list for features list compatibility + return [encoded] + + +def _make_clay( + variant: str, + pretrained: bool, + **kwargs +): + encoder_only = kwargs.pop("features_only", False) + model = build_model_with_cfg( + model_cls=Embedder, + variant=variant, + pretrained=pretrained, + pretrained_strict=True, + pretrained_filter_fn=Embedder.transform_state_dict, + **kwargs, + ) + if encoder_only: + out_indices = kwargs.pop("out_indices", model.default_out_indices) + model.feature_info = FeatureInfo(model.feature_info, out_indices) + model.model_bands = kwargs.get("model_bands") + + # TODO: split features according to typical TIMM outputs + model.forward = model.forward_features + model.pretrained_bands = kwargs.get("pretrained_bands") + return model + + +@register_model +def clay_v1_base( + pretrained: bool = False, + **kwargs, +) -> Embedder: + return _make_clay( + "clay_v1_base", + pretrained=pretrained, + **kwargs + ) diff --git a/terratorch/models/backbones/clay_v1/modules.py b/terratorch/models/backbones/clay_v1/modules.py new file mode 100644 index 00000000..53d18022 --- /dev/null +++ b/terratorch/models/backbones/clay_v1/modules.py @@ -0,0 +1,533 @@ +import math +import os +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import nn, Tensor +from timm.layers import use_fused_attn + +from terratorch.models.backbones.clay_v1.utils import posemb_sincos_1d, posemb_sincos_2d_with_gsd + +os.environ["TORCH_CUDNN_V8_API_DISABLED"] = "1" + +# central wavelengths of pretrained model +WAVELENGTHS = { + "blue": 0.493, + "green": 0.56, + "red": 0.665, + "rededge1": 0.704, + "rededge2": 0.74, + "rededge3": 0.783, + "nir": 0.842, + "nir08": 0.865, + "swir16": 1.61, + "swir22": 2.19, +} + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, dim), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) + + self.attend = nn.Softmax(dim=-1) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x): + x = self.norm(x) + + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + + if use_fused_attn(): + out = F.scaled_dot_product_attention(q, k, v) + else: + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.attend(dots) + out = torch.matmul(attn, v) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim, heads=heads, dim_head=dim_head), + FeedForward(dim, mlp_dim) + ])) + + def forward(self, x) -> list[torch.Tensor]: + out = [] + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + out.append(x.clone()) + x = self.norm(x) + out[-1] = x.clone() + return out + + +class Encoder(nn.Module): + def __init__( + self, + mask_ratio, + patch_size, + shuffle, + dim, + depth, + heads, + dim_head, + mlp_ratio, + ): + super().__init__() + self.mask_ratio = mask_ratio + self.patch_size = patch_size + self.shuffle = shuffle + self.dim = dim + self.cls_token = nn.Parameter(torch.randn(1, 1, dim) * 0.02) + + self.patch_embedding = DynamicEmbedding( + wave_dim=128, + num_latent_tokens=128, + patch_size=patch_size, + embed_dim=dim, + is_decoder=False, + ) + + self.transformer = Transformer( + dim=dim, + depth=depth, + heads=heads, + dim_head=dim_head, + mlp_dim=int(dim * mlp_ratio), + ) + + def to_patch_embed(self, cube, waves): + """Split the input cube into patches & create embeddings per patch""" + patches, waves_encoded = self.patch_embedding(cube, waves) # [B L D] + return patches, waves_encoded # ([B L D], [N D]) + + def add_encodings(self, patches, time, latlon, gsd): + """Add position encoding to the patches""" + B, L, D = patches.shape + + grid_size = int(math.sqrt(L)) + self.num_patches = grid_size**2 + + pos_encoding = ( + posemb_sincos_2d_with_gsd( + h=grid_size, + w=grid_size, + dim=(self.dim - 8), + gsd=gsd, + ) + .to(patches.device) + .detach() + ) # [L (D - 8)] + + time_latlon = torch.hstack((time, latlon)).to( + patches.device).detach() # [B 8] + + pos_encoding = repeat( + pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)] + time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8] + pos_metadata_encoding = torch.cat( + (pos_encoding, time_latlon), dim=-1 + ) # [B L D] + + # [B L D] + [B L D] -> [B L D] + patches = patches + pos_metadata_encoding + return patches # [B L D] + + def mask_out(self, patches): + """ + Mask out patches randomly by shuffling the patches & masking out the + first N patches + + Parameters + ---------- + patches : torch.Tensor A tensor of shape (B, L, D) + + Returns + ------- + unmasked_patches : torch.Tensor + A tensor of shape (B, L:(1 - mask_ratio), D) containing the + embeddings of the unmasked patches. + unmasked_indices : torch.Tensor + A tensor of shape (B, (1 - mask_ratio)) containing the indices of + the unmasked patches. + masked_indices : torch.Tensor + A tensor of shape (B, mask_ratio) containing the indices of the + masked patches. + masked_matrix : torch.Tensor + A tensor of shape (B, L) containing the mask matrix, 1 indicates a masked + patch & 0 indicates an unmasked patch. + """ + B, L, D = patches.shape + # assert ( + # L == self.num_patches + # ), f"Expected {self.num_patches} patches, got {L} patches." + + if self.shuffle: # Shuffle the patches + noise = torch.randn((B, L), device=patches.device) # [B L] + else: # Don't shuffle, useful for interpolation & inspection of embeddings + noise = rearrange( + torch.arange(B * L, device=patches.device), "(B L) -> B L", B=B, L=L + ) + + random_indices = torch.argsort(noise, dim=-1) # [B L] + reverse_indices = torch.argsort(random_indices, dim=-1) # [B L] + + num_masked_patches = int( + self.mask_ratio * self.num_patches + ) # Number of patches to be masked out + masked_indices, unmasked_indices = ( + random_indices[:, :num_masked_patches], # [B mask_ratio * L] + random_indices[:, num_masked_patches:], # [B (1 - mask_ratio) * L] + ) + + # create a mask of shape B L, where 1 indicates a masked patch + # and 0 indicates an unmasked patch + masked_matrix = torch.zeros((B, L), device=patches.device) # [B L] = 0 + masked_matrix[:, :num_masked_patches] = 1 # [B mask_ratio * L] = 1 + masked_matrix = torch.gather( + masked_matrix, dim=1, index=reverse_indices + ) # [B L] -> [B L] - reorder the patches + + # mask out the patches + batch_indices = rearrange( + torch.arange(B, device=patches.device), "B -> B 1" + ) # [B 1] + unmasked_patches = patches[ + batch_indices, unmasked_indices, : + ] # [B L:(1 - mask_ratio) D] + _ = patches[batch_indices, masked_indices, :] # [B L:mask_ratio D] + + return ( + unmasked_patches, + unmasked_indices, + masked_indices, + masked_matrix, + ) # [B L:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B L] + + def forward(self, datacube): + cube, time, latlon, gsd, waves = ( + datacube["pixels"], # [B C H W] + datacube["time"], # [B 2] + datacube["latlon"], # [B 2] + datacube["gsd"], # 1 + datacube["waves"], # [N] + ) # [B C H W] + + B, C, H, W = cube.shape + + patches, waves_encoded = self.to_patch_embed( + cube, waves + ) # [B L D] - patchify & create embeddings per patch + patches = self.add_encodings( + patches, + time, + latlon, + gsd, + ) # [B L D] - add position encoding to the embeddings + + # mask out patches + ( + unmasked_patches, + unmasked_indices, + masked_indices, + masked_matrix, + ) = self.mask_out( + patches + ) # [B L:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B L] + + # Add class tokens + cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] + unmasked_patches = torch.cat( + (cls_tokens, unmasked_patches), dim=1 + ) # [B (1 + L) D] + + # pass the unmasked patches through the transformer + encoded_unmasked_patches = self.transformer( + unmasked_patches + ) # [B ((1 + L)):(1 - mask_ratio)) D] + + return ( + encoded_unmasked_patches, + unmasked_indices, + masked_indices, + masked_matrix, + ) # [B ((1 + L):(1 - mask_ratio)) D], [(1-mask_ratio)], [mask_ratio], [B L] + + +class EmbeddingEncoder(Encoder): + """Clay Encoder without mask and shuffle.""" + + def __init__( # noqa: PLR0913 + self, + img_size, + patch_size, + dim, + depth, + heads, + dim_head, + mlp_ratio, + ): + super().__init__( + mask_ratio=0.0, + shuffle=False, + patch_size=patch_size, + dim=dim, + depth=depth, + heads=heads, + dim_head=dim_head, + mlp_ratio=mlp_ratio, + ) + self.img_size = img_size + + # Using fixed grid size for inference + self.grid_size = img_size // patch_size + self.num_patches = self.grid_size**2 + + def add_encodings(self, patches, time, latlon, gsd): + """Add position encoding to the patches""" + B, L, D = patches.shape + + grid_size = self.grid_size + + pos_encoding = ( + posemb_sincos_2d_with_gsd( + h=grid_size, + w=grid_size, + dim=(self.dim - 8), + gsd=gsd, + ) + .to(patches.device) + .detach() + ) # [L (D - 8)] + + time_latlon = torch.hstack((time, latlon)).to( + patches.device).detach() # [B 8] + + pos_encoding = repeat( + pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)] + time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8] + pos_metadata_encoding = torch.cat( + (pos_encoding, time_latlon), dim=-1 + ) # [B L D] + + # [B L D] + [B L D] -> [B L D] + patches = patches + pos_metadata_encoding + return patches # [B L D] + + # def forward(self, cube, time, latlon, waves, gsd): + def forward(self, datacube): + cube, time, latlon, gsd, waves = ( + datacube["pixels"], # [B C H W] + datacube["time"], # [B 2] + datacube["latlon"], # [B 2] + datacube["gsd"], # 1 + datacube["waves"], # [N] + ) # [B C H W] + B, C, H, W = cube.shape + + patches, _ = self.to_patch_embed( + cube, waves + ) # [B L D] - patchify & create embeddings per patch + + # Add time & latlon as encoding to patches + patches = self.add_encodings( + patches, + time, + latlon, + gsd, + ) # [B L D] - add position encoding to the embeddings + + # Add class tokens + cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] + patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D] + + # pass the patches through the transformer + patches = self.transformer(patches) # list of [B (1 + L) D] + + # # remove the cls token + # embeddings = patches[:, 1: , :] # [B L D] + + return patches # list [B (1 + L) D] + + +class FCBlock(nn.Module): + def __init__(self, size): + super().__init__() + self.l1 = nn.Linear(size, size) + self.l2 = nn.Linear(size, size) + + def forward(self, x): + y = F.gelu(self.l1(x)) + y = F.gelu(self.l2(y)) + return x + y + + +class WavesTransformer(nn.Module): + def __init__( # noqa: PLR0913 + self, + wave_dim, + output_dim, + num_latent_tokens, + embed_dim, + is_decoder, + num_heads=4, + num_layers=1, + ): + super().__init__() + self.num_latent_tokens = num_latent_tokens + self.is_decoder = is_decoder + layer = nn.TransformerEncoderLayer( + d_model=wave_dim, + nhead=num_heads, + activation="gelu", + dropout=0, + norm_first=False, + batch_first=False, + ) + self.encoder = nn.TransformerEncoder(layer, num_layers) + + self.fc_weight = nn.Linear(wave_dim, output_dim) + self.fc_bias = None if self.is_decoder else nn.Linear( + wave_dim, embed_dim) + + self.weight_tokens = nn.Parameter( + torch.randn(self.num_latent_tokens, wave_dim) * 0.02 + ) + self.bias_token = nn.Parameter(torch.randn(1, wave_dim) * 0.02) + + def forward(self, x): + x = torch.cat([self.weight_tokens, x, self.bias_token], dim=0) + out = self.encoder(x) + weights = self.fc_weight( + out[self.num_latent_tokens: -1] + x[self.num_latent_tokens: -1] + ) + bias = None if self.is_decoder else self.fc_bias(out[-1]) + return weights, bias + + +class DynamicEmbedding(nn.Module): + def __init__( + self, + wave_dim, + num_latent_tokens, + patch_size, + embed_dim, + is_decoder=False, + ): + super().__init__() + self.wave_dim = wave_dim + self.num_latent_tokens = num_latent_tokens + self.patch_size = patch_size + self.embed_dim = embed_dim + self.is_decoder = is_decoder + self.output_dim = (patch_size**2) * embed_dim + + self.weight_generator = WavesTransformer( + wave_dim, + self.output_dim, + self.num_latent_tokens, + self.embed_dim, + is_decoder, + ) + self.fclayer = FCBlock(self.wave_dim) + + self.initialize_weights() + + def forward(self, batch, waves): + waves = posemb_sincos_1d(waves, self.wave_dim) + waves = self.fclayer(waves) + weight, bias = self.weight_generator(waves) + + if self.is_decoder: + dynamic_weight = rearrange( + weight, + "cin (k1 k2 cout) -> (cin k1 k2) cout", + k1=self.patch_size, + k2=self.patch_size, + cout=self.embed_dim, + ) + if bias is not None: + bias = rearrange(bias, "b -> (b)") + dynamic_out = F.linear(batch, dynamic_weight * 0.02, bias=bias) + x = dynamic_out + else: + dynamic_weight = rearrange( + weight, + "cin (cout k1 k2) -> cout cin k1 k2", + k1=self.patch_size, + k2=self.patch_size, + ) + if bias is not None: + bias = rearrange(bias, "b -> (b)") + dynamic_out = F.conv2d( + batch, dynamic_weight * 0.02, bias=bias, stride=self.patch_size + ) + x = rearrange(dynamic_out, "b c h w -> b (h w) c") + + return x, waves + + def initialize_weights(self): + # Initialize weights using Xavier initialization + for m in self.modules(): + if isinstance(m, (nn.Linear, nn.Conv2d)): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class Datacuber(nn.Module): + def __init__(self, bands=None) -> None: + super().__init__() + self.bands = bands + + def forward( + self, + x: torch.Tensor, + time: torch.Tensor | None = None, + latlon: torch.Tensor | None = None, + waves: torch.Tensor | None = None, + gsd: float | None = None, + ) -> dict[str, torch.Tensor | float]: + datacube: dict[str, torch.Tensor | float] = {} + datacube["pixels"] = x + datacube["time"] = torch.zeros((x.shape[0], 4), device=x.device) if time is None else time + datacube["latlon"] = torch.zeros((x.shape[0], 4), device=x.device) if latlon is None else latlon + datacube["gsd"] = 1.0 if gsd is None else gsd + datacube["waves"] = self._parse_wavelengths(self.bands, x.shape[1]).to(x.device) if waves is None else waves + return datacube + + def _parse_wavelengths(self, bands, channels): + if bands is not None and all([_ in WAVELENGTHS for _ in bands]): + return torch.tensor([WAVELENGTHS[_] for _ in bands]) + else: + return torch.zeros(channels) diff --git a/terratorch/models/backbones/clay_v1/utils.py b/terratorch/models/backbones/clay_v1/utils.py new file mode 100644 index 00000000..d16abf75 --- /dev/null +++ b/terratorch/models/backbones/clay_v1/utils.py @@ -0,0 +1,42 @@ +import torch + + +def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" + omega = torch.arange(dim // 4) / (dim // 4 - 1) + omega = 1.0 / (temperature**omega) + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return pe.type(dtype) + + +def posemb_sincos_2d_with_gsd( + h, w, dim, gsd=1.0, temperature: int = 10000, dtype=torch.float32 +): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" + + omega = torch.arange(dim // 4) / (dim // 4 - 1) + omega = 1.0 / (temperature ** (2 * omega / dim)) * \ + (gsd / 1.0) # Adjusted for g + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return pe.type(dtype) + + +def posemb_sincos_1d(pos, dim, temperature: int = 10000, dtype=torch.float32): + assert dim % 2 == 0, "Feature dimension must be a multiple of 2 for sincos embedding" + pos = torch.arange(pos) if isinstance(pos, int) else pos + + omega = torch.arange(dim // 2).to(pos) / (dim // 2 - 1) + omega = 1.0 / (temperature**omega) + + scaled_pos = pos[:, None] * omega[None, :] + pe = torch.cat((scaled_pos.sin(), scaled_pos.cos()), dim=1) + + return pe.type(dtype) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py new file mode 100644 index 00000000..e8483207 --- /dev/null +++ b/terratorch/models/backbones/prithvi_mae.py @@ -0,0 +1,736 @@ +# Copyright (c) IBM Corp. 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# transformers: https://github.com/huggingface/transformers +# -------------------------------------------------------- + +from functools import partial +from typing import List, Tuple + +import logging +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from timm.layers import to_2tuple +from timm.models.vision_transformer import Block + + +def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): + """ + Create 3D sin/cos positional embeddings. + + Args: + embed_dim (int): + Embedding dimension. + grid_size (tuple[int, int, int] | list[int]): + The grid depth, height and width. + add_cls_token (bool, *optional*, defaults to False): + Whether or not to add a classification (CLS) token. + + Returns: + (`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or + (1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token) + """ + + assert embed_dim % 16 == 0 + + t_size, h_size, w_size = grid_size + + w_embed_dim = embed_dim // 16 * 6 + h_embed_dim = embed_dim // 16 * 6 + t_embed_dim = embed_dim // 16 * 4 + + w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) + h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) + t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) + + w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) + h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) + t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) + + pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) + + if add_cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor): + """ This is the torch version of *get_1d_sincos_pos_embed_from_grid()*. However, + it was modified to cast omega values to pos.dtype which must be float (and not int as in + regular positional embeddings). This was required in order to allow for native FSDP mixed + precision support: modify omega to appropriate dtype (pos carries the correct float dtype), + instead of manually forcing float32. + + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) - must be float dtype! + out: (M, D) + """ + assert embed_dim % 2 == 0 + assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16] + + omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + + return emb + + +def _init_weights(module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class PatchEmbed(nn.Module): + """3D version of timm.models.vision_transformer.PatchEmbed""" + def __init__( + self, + input_size: Tuple[int, int, int] = (1, 224, 224), + patch_size: Tuple[int, int, int] = (1, 16, 16), + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: nn.Module | None = None, + flatten: bool = True, + bias: bool = True, + ): + super().__init__() + self.input_size = input_size + self.patch_size = patch_size + self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)] + self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] + self.flatten = flatten + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, T, H, W = x.shape + + if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1: + logging.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." + f"The border will be ignored, add backbone_padding for pixel-wise tasks.") + + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C + x = self.norm(x) + return x + + +class TemporalEncoder(nn.Module): + def __init__(self, embed_dim: int, trainable_scale: bool = False): + super().__init__() + self.embed_dim = embed_dim + self.year_embed_dim = embed_dim // 2 + self.julian_day_embed_dim = embed_dim - self.year_embed_dim + + # If trainable, initialize scale with small number + if trainable_scale: + self.scale = nn.Parameter(torch.full((1,), 0.1)) + else: + self.register_buffer('scale', torch.ones(1)) + + def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None): + """ + temporal_coords: year and day-of-year info with shape (B, T, 2). + tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be + repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim). + """ + shape = temporal_coords.shape[:2] + (-1,) # B, T, -1 + + year = _get_1d_sincos_embed_from_grid_torch( + self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape) + julian_day = _get_1d_sincos_embed_from_grid_torch( + self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape) + + embedding = self.scale * torch.cat([year, julian_day], dim=-1) + + if tokens_per_frame is not None: + embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1) + + return embedding # B, T*tokens_per_frame, embed_dim + + +class LocationEncoder(nn.Module): + def __init__(self, embed_dim: int, trainable_scale: bool = False): + super().__init__() + self.embed_dim = embed_dim + self.lat_embed_dim = embed_dim // 2 + self.lon_embed_dim = embed_dim - self.lat_embed_dim + + # If trainable, initialize scale with small number + if trainable_scale: + self.scale = nn.Parameter(torch.full((1,), 0.1)) + else: + self.register_buffer('scale', torch.ones(1)) + + def forward(self, location_coords: torch.Tensor): + """ + location_coords: lat and lon info with shape (B, 2). + """ + shape = location_coords.shape[:1] + (1, -1) # B, 1, -1 + + lat = _get_1d_sincos_embed_from_grid_torch( + self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape) + lon = _get_1d_sincos_embed_from_grid_torch( + self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape) + + embedding = self.scale * torch.cat([lat, lon], dim=-1) + + return embedding # B, 1, embed_dim + + +class PrithviViT(nn.Module): + """ Prithvi ViT Encoder""" + def __init__(self, + img_size: int | Tuple[int, int] = 224, + patch_size: int | Tuple[int, int, int] = (1, 16, 16), + num_frames: int = 1, + in_chans: int = 3, + embed_dim: int = 1024, + depth: int = 24, + num_heads: int = 16, + mlp_ratio: float = 4., + norm_layer: nn.Module = nn.LayerNorm, + coords_encoding: List[str] | None = None, + coords_scale_learn: bool = False, + encoder_only: bool = True, # needed for timm + ** kwargs, + ): + super().__init__() + + self.feature_info = [] + self.encoder_only = encoder_only + self.in_chans = in_chans + self.num_frames = num_frames + self.embed_dim = embed_dim + self.img_size = to_2tuple(img_size) + if isinstance(patch_size, int): + patch_size = (1, patch_size, patch_size) + + # 3D patch embedding + self.patch_embed = PatchEmbed( + input_size=(num_frames,) + self.img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + + # Optional temporal and location embedding + coords_encoding = coords_encoding or [] + self.temporal_encoding = 'time' in coords_encoding + self.location_encoding = 'location' in coords_encoding + if self.temporal_encoding: + assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}" + self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn) + if self.location_encoding: + self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)) + + # Transformer layers + self.blocks = [] + for i in range(depth): + self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)) + self.feature_info.append( + {"num_chs": embed_dim * self.patch_embed.grid_size[0], "reduction": 1, "module": f"blocks.{i}"} + ) + self.blocks = nn.ModuleList(self.blocks) + + self.norm = norm_layer(embed_dim) + + self.initialize_weights() + + def initialize_weights(self): + # initialize (and freeze) position embeddings by sin-cos embedding + pos_embed = get_3d_sincos_pos_embed( + self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.cls_token, std=0.02) + self.apply(_init_weights) + + def random_masking(self, sequence, mask_ratio, noise=None): + """ + Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random + noise. + + Args: + sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`) + mask_ratio (float): mask ratio to use. + noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is + mainly used for testing purposes to control randomness and maintain the reproducibility + """ + batch_size, seq_length, dim = sequence.shape + len_keep = int(seq_length * (1 - mask_ratio)) + + if noise is None: + noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([batch_size, seq_length], device=sequence.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return sequence_unmasked, mask, ids_restore + + def _get_pos_embed(self, x): + t, h, w = x.shape[-3:] + + pos_embed = torch.from_numpy(get_3d_sincos_pos_embed( + self.embed_dim, + ( + t // self.patch_embed.patch_size[0], + h // self.patch_embed.patch_size[1], + w // self.patch_embed.patch_size[2], + ), + add_cls_token=True, + )).float().unsqueeze(0).to(x) + + return pos_embed + + + def forward( + self, x: torch.Tensor, + temporal_coords: None | torch.Tensor = None, + location_coords: None | torch.Tensor = None, + mask_ratio=0.75 + ): + if x.shape[-3:] != self.patch_embed.input_size: + # changed input size + pos_embed = self._get_pos_embed(x) + else: + pos_embed = self.pos_embed + + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + pos_embed[:, 1:, :] + + if self.temporal_encoding and temporal_coords is not None: + num_tokens_per_frame = x.shape[1] // self.num_frames + temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) + x = x + temporal_encoding + if self.location_encoding and location_coords is not None: + location_encoding = self.location_embed_enc(location_coords) + x = x + location_encoding + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking(x, mask_ratio) + + # append cls token + cls_token = self.cls_token + pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for block in self.blocks: + x = block(x) + x = self.norm(x) + + return x, mask, ids_restore + + def forward_features( + self, + x: torch.Tensor, + temporal_coords: None | torch.Tensor = None, + location_coords: None | torch.Tensor = None, + ) -> list[torch.Tensor]: + if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: + # add time dim + x = x.unsqueeze(2) + + if x.shape[-3:] != self.patch_embed.input_size: + pos_embed = self._get_pos_embed(x) + else: + pos_embed = self.pos_embed + + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + pos_embed[:, 1:, :] + + if self.temporal_encoding and temporal_coords is not None: + num_tokens_per_frame = x.shape[1] // self.num_frames + temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) + x = x + temporal_encoding + if self.location_encoding and location_coords is not None: + location_encoding = self.location_embed_enc(location_coords) + x = x + location_encoding + + # append cls token + cls_token = self.cls_token + pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + out = [] + for block in self.blocks: + x = block(x) + out.append(x.clone()) + + x = self.norm(x) + out[-1] = x + return out + + def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]: + out = [] + effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0] + for x in features: + x_no_token = x[:, 1:, :] + number_of_tokens = x_no_token.shape[1] + tokens_per_timestep = number_of_tokens // effective_time_dim + h = int(np.sqrt(tokens_per_timestep)) + encoded = rearrange( + x_no_token, + "batch (t h w) e -> batch (t e) h w", + e=self.embed_dim, + t=effective_time_dim, + h=h, + ) + out.append(encoded) + return out + + +class MAEDecoder(nn.Module): + """ Transformer Decoder used in the Prithvi MAE""" + def __init__(self, + patch_size: int | Tuple[int, int, int] = (1, 16, 16), + grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14), + in_chans: int = 3, + encoder_embed_dim: int = 1024, + decoder_embed_dim: int = 512, + depth: int = 8, + num_heads: int = 16, + mlp_ratio: float = 4., + norm_layer: nn.Module = nn.LayerNorm, + coords_encoding: List[str] | None = None, + coords_scale_learn: bool = False, + ): + super().__init__() + + self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) + self.decoder_embed_dim = decoder_embed_dim + self.grid_size = grid_size + if isinstance(patch_size, int): + patch_size = (1, patch_size, patch_size) + self.patch_size = patch_size + self.num_frames = self.grid_size[0] * patch_size[0] + num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] + + # Optional temporal and location embedding + coords_encoding = coords_encoding or [] + self.temporal_encoding = 'time' in coords_encoding + self.location_encoding = 'location' in coords_encoding + if self.temporal_encoding: + self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn) + if self.location_encoding: + self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + + self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim)) + + self.decoder_blocks = nn.ModuleList( + [Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)] + ) + + self.decoder_norm = norm_layer(decoder_embed_dim) + self.decoder_pred = nn.Linear(decoder_embed_dim, + patch_size[0] * patch_size[1] * patch_size[2] * in_chans, + bias=True) + + self.initialize_weights() + + def initialize_weights(self): + # initialize (and freeze) position embeddings by sin-cos embedding + decoder_pos_embed = get_3d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True + ) + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.mask_token, std=0.02) + self.apply(_init_weights) + + def forward( + self, + hidden_states: torch.Tensor, + ids_restore: torch.Tensor, + temporal_coords: None | torch.Tensor = None, + location_coords: None | torch.Tensor = None, + input_size: list[int] = None, + ): + # embed tokens + x = self.decoder_embed(hidden_states) + + t, h, w = input_size[-3:] + decoder_pos_embed = torch.from_numpy( + get_3d_sincos_pos_embed( + self.decoder_embed_dim, + ( + t // self.patch_size[0], + h // self.patch_size[1], + w // self.patch_size[2], + ), + add_cls_token=True, + ) + ).to(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + # unshuffle + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device)) + x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token + # add pos embed + x = x + decoder_pos_embed + + # remove cls token + x_ = x[:, 1:, :] + + if self.temporal_encoding and temporal_coords is not None: + num_tokens_per_frame = x_.shape[1] // self.num_frames + temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame) + # Add temporal encoding w/o cls token + x_ = x_ + temporal_encoding + if self.location_encoding and location_coords is not None: + location_encoding = self.location_embed_dec(location_coords) + # Add location encoding w/o cls token + x_ = x_ + location_encoding + + # append cls token + x = torch.cat([x[:, :1, :], x_], dim=1) + + # apply Transformer layers (blocks) + for block in self.decoder_blocks: + x = block(x) + x = self.decoder_norm(x) + + # predictor projection + pred = self.decoder_pred(x) + + # remove cls token + pred = pred[:, 1:, :] + + return pred + + +class PrithviMAE(nn.Module): + """ Prithvi Masked Autoencoder""" + + def __init__(self, + img_size: int | Tuple[int, int] = 224, + patch_size: int | Tuple[int, int, int] = (1, 16, 16), + num_frames: int = 4, + in_chans: int = 6, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + mlp_ratio: float = 4., + norm_layer: nn.Module = nn.LayerNorm, + norm_pix_loss: bool = False, + coords_encoding: List[str] | None = None, + coords_scale_learn: bool = False, + encoder_only: bool = False, + **kwargs, + ): + super().__init__() + + self.encoder = PrithviViT( + img_size=img_size, + num_frames=num_frames, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + coords_encoding=coords_encoding, + coords_scale_learn=coords_scale_learn, + ) + + self.encoder_only = encoder_only + + if not encoder_only: + self.decoder = MAEDecoder( + patch_size=patch_size, + grid_size=self.encoder.patch_embed.grid_size, + in_chans=in_chans, + encoder_embed_dim=embed_dim, + decoder_embed_dim=decoder_embed_dim, + depth=decoder_depth, + num_heads=decoder_num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + coords_encoding=coords_encoding, + coords_scale_learn=coords_scale_learn, + ) + else: + self.decoder = nn.Identity() + + self.norm_pix_loss = norm_pix_loss + + def patchify(self, pixel_values): + """ + Args: + pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`): + Pixel values. + + Returns: + torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: + Patchified pixel values. + """ + patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size + num_channels = self.encoder.in_chans + + # patchify + patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)', + c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w) + + + return patchified_pixel_values + + def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None): + """ + Args: + patchified_pixel_values (`torch.FloatTensor` of shape + `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: + Patchified pixel values. + image_size (`Tuple[int, int]`, *optional*): + Original image size. + + Returns: + `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: + Pixel values. + """ + patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size + image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size + original_height, original_width = image_size + num_patches_h = original_height // patch_size_h + num_patches_w = original_width // patch_size_w + num_channels = self.encoder.in_chans + + pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)', + c=num_channels, h=num_patches_h, w=num_patches_w, + s=patch_size_t, p=patch_size_h, q=patch_size_w) + return pixel_values + + def forward_loss(self, pixel_values, pred, mask): + """ + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`): + Pixel values. + pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: + Predicted pixel values. + mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Tensor indicating which patches are masked (1) and which are not (0). + + Returns: + `torch.FloatTensor`: Pixel reconstruction loss. + """ + target = self.patchify(pixel_values) + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.0e-6) ** 0.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + return loss + + def forward( + self, + pixel_values: torch.Tensor, + temporal_coords: None | torch.Tensor = None, + location_coords: None | torch.Tensor = None, + mask_ratio: float = 0.75 + ): + if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1: + # add time dim + pixel_values = pixel_values.unsqueeze(2) + + latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio) + pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape) + loss = self.forward_loss(pixel_values, pred, mask) + return loss, pred, mask + + def forward_features( + self, + x: torch.Tensor, + temporal_coords: None | torch.Tensor = None, + location_coords: None | torch.Tensor = None, + ) -> List[torch.Tensor]: + return self.encoder.forward_features(x, temporal_coords, location_coords) diff --git a/terratorch/models/backbones/prithvi_swin.py b/terratorch/models/backbones/prithvi_swin.py index dce4eede..d58c8fd6 100644 --- a/terratorch/models/backbones/prithvi_swin.py +++ b/terratorch/models/backbones/prithvi_swin.py @@ -188,15 +188,30 @@ def _create_swin_mmseg_transformer( def checkpoint_filter_wrapper_fn(state_dict, model): return checkpoint_filter_fn(state_dict, model, pretrained_bands, model_bands) - model: MMSegSwinTransformer = build_model_with_cfg( - MMSegSwinTransformer, - variant, - pretrained, - pretrained_filter_fn=checkpoint_filter_wrapper_fn, - pretrained_strict=False, - feature_cfg={"flatten_sequential": True, "out_indices": out_indices}, - **kwargs, - ) + # When the pretrained configuration is not available in HF, we shift to + # pretrained=False + try: + model: MMSegSwinTransformer = build_model_with_cfg( + MMSegSwinTransformer, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_wrapper_fn, + pretrained_strict=False, + feature_cfg={"flatten_sequential": True, "out_indices": out_indices}, + **kwargs, + ) + except RuntimeError: + print(f"No pretrained configuration was found for the model {variant}.") + model: MMSegSwinTransformer = build_model_with_cfg( + MMSegSwinTransformer, + variant, + False, + pretrained_filter_fn=checkpoint_filter_wrapper_fn, + pretrained_strict=False, + feature_cfg={"flatten_sequential": True, "out_indices": out_indices}, + **kwargs, + ) + model.pretrained_bands = pretrained_bands model.model_bands = model_bands diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index f86ae22f..07a76a65 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -1,20 +1,17 @@ # Copyright contributors to the Terratorch project +import torch import logging from functools import partial -from pathlib import Path -from collections import defaultdict - -from timm.models import FeatureInfo -from timm.models._builder import build_model_with_cfg -from timm.models._registry import generate_default_cfgs, register_model -from torch import nn +from torch import nn, Tensor +from timm.models import (FeatureInfo, load_model_config_from_hf, build_model_with_cfg, generate_default_cfgs, + register_model) from terratorch.datasets import HLSBands from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights -from terratorch.models.backbones.vit_encoder_decoder import TemporalViTEncoder from terratorch.datasets.utils import generate_bands_intervals +from terratorch.models.backbones.prithvi_mae import PrithviViT, PrithviMAE PRETRAINED_BANDS = [ HLSBands.BLUE, @@ -29,55 +26,115 @@ { "prithvi_vit_100": { "hf_hub_id": "ibm-nasa-geospatial/Prithvi-100M", - "hf_hub_filename": "Prithvi_100M.pt", + "hf_hub_filename": "Prithvi_EO_V1_100M.pt", + }, + "prithvi_eo_v2_300": { + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M", + "hf_hub_filename": "Prithvi_EO_V2_300M.pt", + }, + "prithvi_eo_v2_300_tl": { + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", + "hf_hub_filename": "Prithvi_EO_V2_300M_TL.pt", + }, + "prithvi_eo_v2_600": { + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-600M", + "hf_hub_filename": "Prithvi_EO_V2_600M.pt", + }, + "prithvi_eo_v2_600_tl": { + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL", + "hf_hub_filename": "Prithvi_EO_V2_600M_TL.pt", }, - "prithvi_vit_300": {}, "prithvi_vit_tiny": {} } ) -def checkpoint_filter_fn( - state_dict, model: TemporalViTEncoder, pretrained_bands: list[HLSBands | int], model_bands: list[HLSBands | int] + +def checkpoint_filter_fn_vit( + state_dict, model: PrithviViT, pretrained_bands: list[HLSBands | int], model_bands: list[HLSBands | int] ) -> dict: - if "pos_embed" in state_dict: - del state_dict["pos_embed"] - if "decoder_pos_embed" in state_dict: - del state_dict["decoder_pos_embed"] + """Encoder only model""" + clean_dict = {} + for k, v in state_dict.items(): + if "pos_embed" in k: + v = model.pos_embed # pos_embed depends on num_frames and is fixed. + if "decoder" in k or "_dec" in k or k == "mask_token": + continue # Drop decoder weights + + if not model.temporal_encoding and "temporal_embed" in k: + continue + if not model.location_encoding and "location_embed" in k: + continue + + if k.startswith("encoder."): + clean_dict[k.replace("encoder.", "")] = v # Convert Prithvi MAE to Prithvi ViT + else: + clean_dict[k] = v + + state_dict = clean_dict + + state_dict = select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands) + + return state_dict + + +def checkpoint_filter_fn_mae( + state_dict, model: PrithviMAE, pretrained_bands: list[HLSBands | int], model_bands: list[HLSBands | int] +) -> dict: + """Encoder-decoder model""" clean_dict = {} for k, v in state_dict.items(): - if model.encoder_only: - if "decoder" in k: - continue - if "mask_token" in k: - continue - if "temporal_embed_dec" in k: - continue - if "location_embed_dec" in k: - continue - if not model.temporal_encoding and "temporal_embed_enc" in k: + # pos_embed depends on num_frames and is fixed. + if "decoder_pos_embed" in k: + v = model.decoder.decoder_pos_embed + elif "pos_embed" in k: + v = model.encoder.pos_embed + + if not model.encoder.temporal_encoding and "temporal_embed" in k: continue - if not model.location_encoding and "location_embed_enc" in k: + if not model.encoder.location_encoding and "location_embed" in k: continue - clean_dict[k] = v - state_dict = clean_dict + if k.startswith("encoder.") or k.startswith("decoder."): + clean_dict[k] = v # Weights in Prithvi MAE format + # Convert Prithvi V1 weights + elif "decoder" in k or "_dec" in k or k == "mask_token": + clean_dict["decoder." + k] = v + else: + clean_dict["encoder." + k] = v + + state_dict = clean_dict state_dict = select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands) return state_dict + +def pad_images(imgs: Tensor,patch_size: int, padding:str) -> Tensor: + p = patch_size + # h, w = imgs.shape[3], imgs.shape[4] + t, h, w = imgs.shape[-3:] + h_pad, w_pad = (p - h % p) % p, (p - w % p) % p # Ensure padding is within bounds + if h_pad > 0 or w_pad > 0: + imgs = torch.stack([ + nn.functional.pad(img, (0, w_pad, 0, h_pad), mode=padding) + for img in imgs # Apply per image to avoid NotImplementedError from torch.nn.functional.pad + ]) + return imgs + + def _create_prithvi( variant: str, pretrained: bool = False, # noqa: FBT001, FBT002 pretrained_bands: list[HLSBands] | None = None, model_bands: list[HLSBands | int] | None = None, **kwargs, -) -> TemporalViTEncoder: +) -> PrithviViT: if pretrained_bands is None: pretrained_bands = PRETRAINED_BANDS + if model_bands is None: model_bands: list[HLSBands | int] = pretrained_bands logging.info( @@ -87,28 +144,48 @@ def _create_prithvi( else: model_bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in model_bands] + padding = kwargs.get("padding", "none") + patch_size = kwargs.get("patch_size", 16) + if isinstance(patch_size, list): + patch_size = patch_size[-1] + # Little hack because VIT does not support timm's features_only - # so we do it ourselves - encoder_only = kwargs.get("features_only", False) - if "features_only" in kwargs: - kwargs = {k: v for k, v in kwargs.items() if k != "features_only"} + encoder_only = kwargs.pop("features_only", False) model_bands = generate_bands_intervals(model_bands) kwargs["in_chans"] = len(model_bands) - def checkpoint_filter_wrapper_fn(state_dict, model): - return checkpoint_filter_fn(state_dict, model, pretrained_bands, model_bands) - - model = build_model_with_cfg( - TemporalViTEncoder, - variant, - pretrained, - pretrained_filter_fn=checkpoint_filter_wrapper_fn, - pretrained_strict=True, - encoder_only=encoder_only, - **kwargs, - ) + if encoder_only: + prithvi_model_class = PrithviViT + def checkpoint_filter_wrapper_fn(state_dict, model): + return checkpoint_filter_fn_vit(state_dict, model, pretrained_bands, model_bands) + else: + prithvi_model_class = PrithviMAE + def checkpoint_filter_wrapper_fn(state_dict, model): + return checkpoint_filter_fn_mae(state_dict, model, pretrained_bands, model_bands) + + # When the pretrained configuration is not available in HF, we shift to + # pretrained=False + try: + model = build_model_with_cfg( + prithvi_model_class, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_wrapper_fn, + pretrained_strict=True, + **kwargs, + ) + except RuntimeError: + print(f"No pretrained configuration was found for the model {variant}.") + model = build_model_with_cfg( + prithvi_model_class, + variant, + False, + pretrained_filter_fn=checkpoint_filter_wrapper_fn, + pretrained_strict=True, + **kwargs, + ) if encoder_only: default_out_indices = list(range(len(model.blocks))) @@ -124,15 +201,34 @@ def forward_filter_indices(*args, **kwargs): model.model_bands = model_bands model.pretrained_bands = pretrained_bands + if padding != "none": + original_forward = model.forward + original_forward_features = model.forward_features + + def pad_and_forward(forward_fn, patch_size, padding, *args, **kwargs): + inputs = pad_images(args[0], patch_size, padding) + return forward_fn(inputs, **kwargs) + + def forward_pad_images(*args, **kwargs): + return pad_and_forward(original_forward, patch_size, padding, *args, **kwargs) + + def forward_features_pad_images(*args, **kwargs): + return pad_and_forward(original_forward_features, patch_size, padding, *args, **kwargs) + + model.forward = forward_pad_images + model.forward_features = forward_features_pad_images + + return model -def create_prithvi_vit_100( + +def create_prithvi_from_config( model_name: str, pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, + default_cfg: dict = None, **kwargs, -) -> TemporalViTEncoder: - """Prithvi ViT 100M""" +) -> PrithviViT: pretrained_bands = PRETRAINED_BANDS if bands is None: bands = pretrained_bands @@ -141,149 +237,175 @@ def create_prithvi_vit_100( Pretrained patch_embed layer may be misaligned with current bands" ) - model_args = { - "patch_size": 16, - "embed_dim": 768, - "depth": 12, - "num_heads": 12, - "decoder_embed_dim": 512, - "decoder_depth": 8, - "decoder_num_heads": 16, - "mlp_ratio": 4, - "norm_layer": partial(nn.LayerNorm, eps=1e-6), - "num_frames": 1, - } + try: + config, _ = load_model_config_from_hf(default_cfgs[model_name].default.hf_hub_id) + except: + # No connection to hf + config = default_cfg + config.update(num_frames=1) # Assume one timestamp by default + config.update(kwargs) # Overwrite with keyword args model = _create_prithvi( model_name, pretrained=pretrained, model_bands=bands, pretrained_bands=pretrained_bands, - **dict(model_args,**kwargs), + **config, ) return model -def create_prithvi_vit_300( - model_name: str, - pretrained: bool = False, # noqa: FBT001, FBT002 +@register_model +def prithvi_vit_tiny( bands: list[HLSBands | int] | None = None, **kwargs, -) -> TemporalViTEncoder: - """Prithvi ViT 300M""" +) -> PrithviViT: + """Prithvi ViT tiny""" pretrained_bands = PRETRAINED_BANDS if bands is None: bands = pretrained_bands - logging.info( - f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\ - Pretrained patch_embed layer may be misaligned with current bands" - ) model_args = { "patch_size": 16, - "embed_dim": 1024, - "depth": 24, - "num_heads": 16, - "decoder_embed_dim": 512, - "decoder_depth": 8, - "decoder_num_heads": 16, + "embed_dim": 256, + "depth": 4, + "num_heads": 4, + "decoder_embed_dim": 128, + "decoder_depth": 4, + "decoder_num_heads": 4, "mlp_ratio": 4, "norm_layer": partial(nn.LayerNorm, eps=1e-6), "num_frames": 1, + "model_bands": bands, } - model = _create_prithvi( - model_name, - pretrained=pretrained, - pretrained_bands=pretrained_bands, - model_bands=bands, - **dict(model_args, **kwargs), - ) + model_args.update(kwargs) + model = _create_prithvi("prithvi_vit_tiny", **model_args) return model -def create_prithvi_vit_600( - model_name: str, +@register_model +def prithvi_vit_100( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, -) -> TemporalViTEncoder: - """Prithvi ViT 600M""" - pretrained_bands = PRETRAINED_BANDS - if bands is None: - bands = pretrained_bands - logging.info( - f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\ - Pretrained patch_embed layer may be misaligned with current bands" - ) - model_args = { - "patch_size": 14, - "embed_dim": 1280, - "depth": 32, - "num_heads": 16, +) -> PrithviViT: + + default_config = { + "img_size": 224, + "patch_size": [1, 16, 16], + "num_frames": 3, + "in_chans": 6, + "embed_dim": 768, + "depth": 12, + "num_heads": 12, "decoder_embed_dim": 512, "decoder_depth": 8, "decoder_num_heads": 16, "mlp_ratio": 4, - "norm_layer": partial(nn.LayerNorm, eps=1e-6), - "num_frames": 1, } - model = _create_prithvi( - model_name, - pretrained=pretrained, - pretrained_bands=pretrained_bands, - model_bands=bands, - **dict(model_args, **kwargs), - ) - return model + + return create_prithvi_from_config("prithvi_vit_100", pretrained, bands, default_config, **kwargs) @register_model -def prithvi_vit_tiny( - bands: list[HLSBands | int] | None = None, +def prithvi_eo_v2_300( + pretrained: bool = False, # noqa: FBT001, FBT002 + bands: list[HLSBands] | None = None, **kwargs, -) -> TemporalViTEncoder: - """Prithvi ViT tiny""" - pretrained_bands = PRETRAINED_BANDS - if bands is None: - bands = pretrained_bands - model_args = { - "patch_size": 16, - "embed_dim": 256, - "depth": 4, - "num_heads": 4, - "decoder_embed_dim": 128, - "decoder_depth": 4, - "decoder_num_heads": 4, +) -> PrithviViT: + + default_config = { + "img_size": 224, + "num_frames": 4, + "patch_size": [1, 16, 16], + "in_chans": 6, + "embed_dim": 1024, + "depth": 24, + "num_heads": 16, + "decoder_embed_dim": 512, + "decoder_depth": 8, + "decoder_num_heads": 16, "mlp_ratio": 4, - "norm_layer": partial(nn.LayerNorm, eps=1e-6), - "num_frames": 1, - "model_bands": bands, + "coords_encoding": [], + "coords_scale_learn": True, } - model = _create_prithvi("prithvi_vit_tiny", **dict(model_args, **kwargs)) - return model + + return create_prithvi_from_config("prithvi_eo_v2_300", pretrained, bands, default_config, **kwargs) + @register_model -def prithvi_vit_100( +def prithvi_eo_v2_600( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, -) -> TemporalViTEncoder: - return create_prithvi_vit_100("prithvi_vit_100", pretrained, bands, **kwargs) +) -> PrithviViT: + + default_config = { + "img_size": 224, + "num_frames": 4, + "patch_size": [1, 14, 14], + "in_chans": 6, + "embed_dim": 1280, + "depth": 32, + "num_heads": 16, + "decoder_embed_dim": 512, + "decoder_depth": 8, + "decoder_num_heads": 16, + "mlp_ratio": 4, + "coords_encoding": [], + "coords_scale_learn": True, + } + + return create_prithvi_from_config("prithvi_eo_v2_600", pretrained, bands, default_config, **kwargs) @register_model -def prithvi_vit_300( +def prithvi_eo_v2_300_tl( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, -) -> TemporalViTEncoder: - return create_prithvi_vit_300("prithvi_vit_300", pretrained, bands, **kwargs) +) -> PrithviViT: + + default_config = { + "img_size": 224, + "num_frames": 4, + "patch_size": [1, 16, 16], + "in_chans": 6, + "embed_dim": 1024, + "depth": 24, + "num_heads": 16, + "decoder_embed_dim": 512, + "decoder_depth": 8, + "decoder_num_heads": 16, + "mlp_ratio": 4, + "coords_encoding": ["time", "location"], + "coords_scale_learn": True, + } + + return create_prithvi_from_config("prithvi_eo_v2_300_tl", pretrained, bands, default_config, **kwargs) @register_model -def prithvi_vit_600( +def prithvi_eo_v2_600_tl( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, -) -> TemporalViTEncoder: - return create_prithvi_vit_600("prithvi_vit_600", pretrained, bands, **kwargs) \ No newline at end of file +) -> PrithviViT: + + default_config = { + "img_size": 224, + "num_frames": 4, + "patch_size": [1, 14, 14], + "in_chans": 6, + "embed_dim": 1280, + "depth": 32, + "num_heads": 16, + "decoder_embed_dim": 512, + "decoder_depth": 8, + "decoder_num_heads": 16, + "mlp_ratio": 4, + "coords_encoding": ["time", "location"], + "coords_scale_learn": True, + } + + return create_prithvi_from_config("prithvi_eo_v2_600_tl", pretrained, bands, default_config, **kwargs) diff --git a/terratorch/models/backbones/select_patch_embed_weights.py b/terratorch/models/backbones/select_patch_embed_weights.py index b1caef1a..5e9a2b77 100644 --- a/terratorch/models/backbones/select_patch_embed_weights.py +++ b/terratorch/models/backbones/select_patch_embed_weights.py @@ -36,6 +36,7 @@ def select_patch_embed_weights( dict: New state dict """ _possible_keys_for_proj_weight = { + "encoder.patch_embed.proj.weight", "patch_embed.proj.weight", "module.patch_embed.proj.weight", "patch_embed.projection.weight", diff --git a/terratorch/models/clay_model_factory.py b/terratorch/models/clay_model_factory.py index 08b28f26..82d1f183 100644 --- a/terratorch/models/clay_model_factory.py +++ b/terratorch/models/clay_model_factory.py @@ -2,13 +2,12 @@ import sys from collections.abc import Callable -import numpy as np import timm import torch from torch import nn import terratorch.models.decoders as decoder_registry -from terratorch.datasets import HLSBands +from terratorch.models.backbones.clay_v1.embedder import Embedder from terratorch.models.model import ( AuxiliaryHead, AuxiliaryHeadWithDecoderWithoutInstantiatedHead, @@ -25,6 +24,9 @@ SUPPORTED_TASKS = PIXEL_WISE_TASKS + SCALAR_TASKS +class DecoderNotFoundError(Exception): + pass + class ModelWrapper(nn.Module): def __init__(self, model: nn.Module = None) -> None: @@ -57,7 +59,7 @@ def build_model( backbone: str | nn.Module, decoder: str | nn.Module, in_channels: int, - bands: list[HLSBands | int], + bands: list[int] = [], num_classes: int | None = None, pretrained: bool = True, # noqa: FBT001, FBT002 num_frames: int = 1, @@ -81,9 +83,6 @@ def build_model( If an nn.Module, we expect it to expose a property `decoder.out_channels`. Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder". in_channels (int, optional): Number of input channels. Defaults to 3. - bands (list[terratorch.datasets.HLSBands], optional): Bands the model will be trained on. - Should be a list of terratorch.datasets.HLSBands. - Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE]. num_classes (int, optional): Number of classes. None for regression tasks. pretrained (Union[bool, Path], optional): Whether to load pretrained weights for the backbone, if available. Defaults to True. @@ -110,10 +109,9 @@ def build_model( # Path for accessing the model source code. self.syspath_kwarg = "model_sys_path" - bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands] # TODO: support auxiliary heads if not isinstance(backbone, nn.Module): - if not "Clay" in backbone: + if not "clay" in backbone: msg = "This class only handles models for `Clay` encoders" raise NotImplementedError(msg) @@ -130,56 +128,28 @@ def build_model( backbone, pretrained=pretrained, in_chans=in_channels, - num_frames=num_frames, bands=bands, + num_frames=num_frames, features_only=True, **backbone_kwargs, ) - except Exception: - - # When the model is not on HG, it needs be restored locally. - print("This model is not available on HuggingFace. Trying to instantiate locally ...") - - assert checkpoint_path, "A checkpoint must be provided to restore the model." - - # The CLAY source code must be installed or available via PYTHONPATH. - try: # TODO Inlcude the Clay source code into the tolkit in order to - # avoid issues with the modules paths or made it - # seamlessly accesible via configuration. - if self.syspath_kwarg in kwargs: - syspath_value = kwargs.get(self.syspath_kwarg) - - else: - - Exception(f"It is necessary to define the variable {self.syspath_kwarg} on yaml" - "config for restoring local model.") - - sys.path.insert(0, syspath_value) - - from src.model_clay import CLAYModule + except Exception as e: + print(e, "Error loading from HF. Trying to instantiate locally ...") - except ModuleNotFoundError: - - print(f"It is better to review the field {self.syspath_kwarg} in the yaml file.") - - backbone: nn.Module = ModelWrapper(model=CLAYModule(**backbone_kwargs)) - - if self.CPU_ONLY: - model_dict = torch.load(checkpoint_path, map_location="cpu") - else: - model_dict = torch.load(checkpoint_path) - - backbone.model.load_state_dict(model_dict['state_dict']) + else: + if checkpoint_path is None: + raise ValueError("A checkpoint (checkpoint_path) must be provided to restore the model.") - print("Model Clay was successfully restored.") + backbone: nn.Module = Embedder(ckpt_path=checkpoint_path, **backbone_kwargs) + print("Model Clay was successfully restored.") # allow decoder to be a module passed directly decoder_cls = _get_decoder(decoder) - decoder_kwargs, kwargs = extract_prefix_keys(kwargs, "decoder_") # TODO: remove this - decoder: nn.Module = decoder_cls(backbone.channels(), **decoder_kwargs) + decoder: nn.Module = decoder_cls( + backbone.feature_info.channels(), **decoder_kwargs) # decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs) head_kwargs, kwargs = extract_prefix_keys(kwargs, "head_") @@ -191,9 +161,11 @@ def build_model( ) to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = [] + for aux_decoder in aux_decoders: args = aux_decoder.decoder_args if aux_decoder.decoder_args else {} aux_decoder_cls: nn.Module = _get_decoder(aux_decoder.decoder) + aux_decoder_kwargs, kwargs = extract_prefix_keys(args, "decoder_") aux_decoder_instance = aux_decoder_cls(backbone.feature_info.channels(), **aux_decoder_kwargs) # aux_decoder_instance = aux_decoder_cls([128, 256, 512, 1024], **decoder_kwargs) @@ -204,7 +176,8 @@ def build_model( # aux_head: nn.Module = _get_head(task, aux_decoder_instance, num_classes=num_classes, **head_kwargs) # aux_decoder.decoder = nn.Sequential(aux_decoder_instance, aux_head) to_be_aux_decoders.append( - AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs) + AuxiliaryHeadWithDecoderWithoutInstantiatedHead( + aux_decoder.name, aux_decoder_instance, aux_head_kwargs) ) return _build_appropriate_model( @@ -217,6 +190,7 @@ def build_model( auxiliary_heads=to_be_aux_decoders, ) + def _build_appropriate_model( task: str, backbone: nn.Module, @@ -232,7 +206,6 @@ def _build_appropriate_model( backbone, decoder, head_kwargs, - prepare_features_for_image_model=prepare_features_for_image_model, rescale=rescale, auxiliary_heads=auxiliary_heads, ) @@ -242,7 +215,6 @@ def _build_appropriate_model( backbone, decoder, head_kwargs, - prepare_features_for_image_model=prepare_features_for_image_model, auxiliary_heads=auxiliary_heads, ) diff --git a/terratorch/models/decoders/__init__.py b/terratorch/models/decoders/__init__.py index 9d8e958f..ecfc90aa 100644 --- a/terratorch/models/decoders/__init__.py +++ b/terratorch/models/decoders/__init__.py @@ -5,5 +5,6 @@ from terratorch.models.decoders.satmae_head import SatMAEHead, SatMAEHeadViT from terratorch.models.decoders.upernet_decoder import UperNetDecoder from terratorch.models.decoders.aspp_head import ASPPSegmentationHead, ASPPRegressionHead +from terratorch.models.decoders.mlp_decoder import MLPDecoder -__all__ = ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "SatMAEHead", "SatMAEHeadViT", "SMPDecoderWrapper", "ASPPSegmentationHead", "ASPPRegressionHead"] +__all__ = ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "SatMAEHead", "SatMAEHeadViT", "SMPDecoderWrapper", "ASPPSegmentationHead", "ASPPRegressionHead", "MLPDecoder"] diff --git a/terratorch/models/decoders/mlp_decoder.py b/terratorch/models/decoders/mlp_decoder.py new file mode 100644 index 00000000..0c2fbcd6 --- /dev/null +++ b/terratorch/models/decoders/mlp_decoder.py @@ -0,0 +1,39 @@ +# Copyright contributors to the Terratorch project + +"""Pass the features straight through +""" + +from torch import Tensor, nn +import torch +from terratorch.registry import TERRATORCH_DECODER_REGISTRY + + +@TERRATORCH_DECODER_REGISTRY.register +class MLPDecoder(nn.Module): + """Identity decoder. Useful to pass the feature straight to the head.""" + + def __init__(self, embed_dim: int, channels: int = 100, out_dim:int = 100, activation: str = "ReLU", out_index=-1) -> None: + """Constructor + + Args: + embed_dim (int): Input embedding dimension + out_index (int, optional): Index of the input list to take.. Defaults to -1. + """ + + super().__init__() + self.embed_dim = embed_dim + self.channels = channels + self.dim = out_index + self.n_inputs = len(self.embed_dim) + self.out_channels = self.embed_dim[self.dim] + self.hidden_layer = torch.nn.Linear(self.out_channels*self.n_inputs, self.out_channels) + self.activation = getattr(nn, activation)() + + def forward(self, x: list[Tensor]): + + data_ = torch.cat(x, axis=1) + data_ = data_.permute(0, 2, 3, 1) + data_ = self.activation(self.hidden_layer(data_)) + data_ = data_.permute(0, 3, 1, 2) + + return data_ diff --git a/terratorch/models/encoder_decoder_factory.py b/terratorch/models/encoder_decoder_factory.py index b797e5dd..04727265 100644 --- a/terratorch/models/encoder_decoder_factory.py +++ b/terratorch/models/encoder_decoder_factory.py @@ -1,6 +1,8 @@ # Copyright contributors to the Terratorch project +import warnings + from torch import nn from terratorch.models.model import ( @@ -10,6 +12,7 @@ ModelFactory, ) from terratorch.models.necks import Neck, build_neck_list +from terratorch.models.peft_utils import get_peft_backbone from terratorch.models.pixel_wise_model import PixelWiseModel from terratorch.models.scalar_output_model import ScalarOutputModel from terratorch.models.utils import extract_prefix_keys @@ -74,6 +77,7 @@ def build_model( necks: list[dict] | None = None, aux_decoders: list[AuxiliaryHead] | None = None, rescale: bool = True, # noqa: FBT002, FBT001 + peft_config: dict | None = None, **kwargs, ) -> Model: """Generic model factory that combines an encoder and decoder, together with a head, for a specific task. @@ -102,6 +106,15 @@ def build_model( rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True. + peft_config (dict): Configuration options for using [PEFT](https://huggingface.co/docs/peft/index). + The dictionary should have the following keys: + + - "method": Which PEFT method to use. Should be one implemented in PEFT, a list is available [here](https://huggingface.co/docs/peft/package_reference/peft_types#peft.PeftType). + - "replace_qkv": String containing a substring of the name of the submodules to replace with QKVSep. + This should be used when the qkv matrices are merged together in a single linear layer and the PEFT + method should be applied separately to query, key and value matrices (e.g. if LoRA is only desired in + Q and V matrices). e.g. If using Prithvi this should be "qkv" + - "peft_config_kwargs": Dictionary containing keyword arguments which will be passed to [PeftConfig](https://huggingface.co/docs/peft/package_reference/config#peft.PeftConfig) Returns: @@ -115,6 +128,16 @@ def build_model( backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_") backbone = _get_backbone(backbone, **backbone_kwargs) + if peft_config is not None: + if not backbone_kwargs.get("pretrained", False): + msg = ( + "You are using PEFT without a pretrained backbone. If you are loading a checkpoint afterwards " + "this is probably fine, but if you are training a model check the backbone_pretrained parameter." + ) + warnings.warn(msg, stacklevel=1) + + backbone = get_peft_backbone(peft_config, backbone) + try: out_channels = backbone.out_channels except AttributeError as e: @@ -138,7 +161,15 @@ def build_model( if aux_decoders is None: _check_all_args_used(kwargs) - return _build_appropriate_model(task, backbone, decoder, head_kwargs, necks=neck_list, decoder_includes_head=decoder_includes_head, rescale=rescale) + return _build_appropriate_model( + task, + backbone, + decoder, + head_kwargs, + necks=neck_list, + decoder_includes_head=decoder_includes_head, + rescale=rescale, + ) to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = [] for aux_decoder in aux_decoders: diff --git a/terratorch/models/peft_utils.py b/terratorch/models/peft_utils.py new file mode 100644 index 00000000..bb35a34f --- /dev/null +++ b/terratorch/models/peft_utils.py @@ -0,0 +1,112 @@ +import warnings +from dataclasses import dataclass +from typing import Any + +import torch +from torch import nn + +try: + from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_model + + _has_peft = True +except ImportError: + _has_peft = False + +TESTED_PEFT_METHODS = ["LORA"] + + +def _get_submodules(model: nn.Module, key: str) -> tuple[nn.Module, nn.Module, str]: + # adapted from PEFT + parent = model.get_submodule(".".join(key.split(".")[:-1])) + target_name = key.split(".")[-1] + target = model.get_submodule(key) + return parent, target, target_name + + +@dataclass(frozen=True) +class TerratorchPEFTConfig: + method: str + replace_qkv: str | None + peft_config_kwargs: dict[str, Any] + + +def _validate_terratorch_peft_config(peft_config: dict[str, Any]) -> TerratorchPEFTConfig: + terratorch_peft_config = TerratorchPEFTConfig( + method=peft_config["method"], + replace_qkv=peft_config.get("replace_qkv", None), + peft_config_kwargs=peft_config.get("peft_config_kwargs", {}), + ) + if terratorch_peft_config.method not in TESTED_PEFT_METHODS: + msg = f"PEFT method {terratorch_peft_config.method} has not been tested. Use at your own risk." + warnings.warn(msg, stacklevel=1) + return terratorch_peft_config + + +def get_peft_backbone(peft_config: dict[str, Any], backbone: nn.Module) -> nn.Module: + terratorch_peft_config = _validate_terratorch_peft_config(peft_config) + if not _has_peft: + msg = ( + "You need to install terratorch with peft dependency to use peft_config. " + "Use pip install terratorch[peft]" + ) + raise ImportError(msg) + peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[terratorch_peft_config.method] + peft_config_peft = peft_config_cls(**terratorch_peft_config.peft_config_kwargs) + if terratorch_peft_config.replace_qkv is not None: + replace_qkv(backbone, terratorch_peft_config.replace_qkv) # modifies inplace + backbone = get_peft_model(backbone, peft_config_peft) + return backbone + + +class QKVSep(nn.Module): + def __init__(self, original_qkv: nn.Linear): + super().__init__() + if original_qkv.out_features != original_qkv.in_features * 3: + msg = "The output features must be 3 times the input features for Q, K, V separation" + raise ValueError(msg) + + self.in_features = original_qkv.in_features + self.out_features = original_qkv.out_features + + # Create nn.Linear layers for Q, K, V using slices of the original weights and biases + self.q_linear = nn.Linear(self.in_features, self.in_features) + self.k_linear = nn.Linear(self.in_features, self.in_features) + self.v_linear = nn.Linear(self.in_features, self.in_features) + + # Assign weights and biases from the original layer + with torch.no_grad(): + self.q_linear.weight = nn.Parameter(original_qkv.weight[: self.in_features, :]) + self.k_linear.weight = nn.Parameter(original_qkv.weight[self.in_features : 2 * self.in_features, :]) + self.v_linear.weight = nn.Parameter(original_qkv.weight[2 * self.in_features :, :]) + + if original_qkv.bias is not None: + self.q_linear.bias = nn.Parameter(original_qkv.bias[: self.in_features]) + self.k_linear.bias = nn.Parameter(original_qkv.bias[self.in_features : 2 * self.in_features]) + self.v_linear.bias = nn.Parameter(original_qkv.bias[2 * self.in_features :]) + else: + self.q_linear.bias = None + self.k_linear.bias = None + self.v_linear.bias = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + q = self.q_linear(x) + k = self.k_linear(x) + v = self.v_linear(x) + return torch.cat((q, k, v), dim=-1) + + +def replace_qkv(model: nn.Module, qkv_suffix: str): + # This is needed for ViTEncoderDecoder because the qkv matrices are together, + # and it would not work with LoRA (and probably other adapters) + replaced = False + for key, _ in model.named_modules(): + if key.endswith(f".{qkv_suffix}"): + replaced = True + parent, target, target_name = _get_submodules(model, key) + if not isinstance(target, nn.Linear): + msg = "Only a qkv nn.Linear can be replaced." + raise ValueError(msg) + new_module = QKVSep(target) + setattr(parent, target_name, new_module) + if not replaced: + warnings.warn("replace_qkv was not None but no module was found ending with that pattern.", stacklevel=1) diff --git a/terratorch/models/wxc_model_factory.py b/terratorch/models/wxc_model_factory.py index e33f1961..a548e093 100644 --- a/terratorch/models/wxc_model_factory.py +++ b/terratorch/models/wxc_model_factory.py @@ -1,8 +1,6 @@ # Copyright contributors to the Terratorch project import timm import torch -from granitewxc.utils.config import get_config -from granitewxc.utils.downscaling_model import get_finetune_model from torch import nn import os @@ -18,7 +16,6 @@ ) from terratorch.registry import MODEL_FACTORY_REGISTRY - logger = logging.getLogger(__name__) class WxCModuleWrapper(Model, nn.Module): @@ -42,8 +39,7 @@ def forward(self, x) -> ModelOutput: def load_state_dict(self, state_dict: os.Mapping[str, typing.Any], strict: bool = True, assign: bool = False): - return self.module.load_state_dict(state_dict, strict, assign) - + self.module.load_state_dict(state_dict, strict, assign) @MODEL_FACTORY_REGISTRY.register class WxCModelFactory(ModelFactory): @@ -51,8 +47,33 @@ def build_model( self, backbone: str | nn.Module, aux_decoders, + checkpoint_path:str=None, **kwargs, ) -> Model: - module = get_finetune_model(kwargs['model_config']) - - return WxCModuleWrapper(module) + if backbone == 'gravitywave': + try: + __import__('prithviwxc.gravitywave.inference') + from prithviwxc.gravitywave.inference import get_model + from prithviwxc.gravitywave.config import get_cfg + cfg = get_cfg() + model_wrapper = WxCModuleWrapper(get_model(cfg,'uvtp122', cfg.singular_sharded_checkpoint)) + if checkpoint_path: + model_wrapper.load_state_dict(torch.load(checkpoint_path, weights_only=True)) + return model_wrapper + except ImportError as e: + missing_module = e.name if hasattr(e, 'name') else "unknown module" + print('prithvi wxc gravitywave not installed, missing module: {missing_module}') + return None + else: + try: + __import__('granitewxc.utils.config') + from granitewxc.utils.config import get_config + from granitewxc.utils.downscaling_model import get_finetune_model + module = get_finetune_model(kwargs['model_config']) + model_wrapper = WxCModuleWrapper(module) + + if checkpoint_path: + model_wrapper.load_state_dict(torch.load(checkpoint_path, weights_only=True)) + return model_wrapper + except ImportError: + print('granite wxc downscaling not installed') diff --git a/terratorch/registry/__init__.py b/terratorch/registry/__init__.py index 8b65d68c..5526bb20 100644 --- a/terratorch/registry/__init__.py +++ b/terratorch/registry/__init__.py @@ -12,6 +12,7 @@ import terratorch.registry.smp_registry # register smp registry import terratorch.registry.timm_registry # register timm registry # noqa: F401 import terratorch.registry.mmseg_registry +import terratorch.registry.custom_registry __all__ = [ "MultiSourceRegistry", diff --git a/terratorch/registry/custom_registry.py b/terratorch/registry/custom_registry.py new file mode 100644 index 00000000..79e0df4e --- /dev/null +++ b/terratorch/registry/custom_registry.py @@ -0,0 +1,17 @@ +import os +import importlib +import sys +import logging + +CUSTOM_MODULES_DIR_NAME = "custom_modules" + +# import any custom modules +current_working_dir = os.getcwd() +custom_modules_path = os.path.join(current_working_dir, CUSTOM_MODULES_DIR_NAME) +if os.path.exists(custom_modules_path) and os.path.isdir(custom_modules_path): + # Add 'custom_modules' folder to sys.path + sys.path.append(os.getcwd()) + logging.info(f"Found {CUSTOM_MODULES_DIR_NAME}") + importlib.import_module(CUSTOM_MODULES_DIR_NAME) + + diff --git a/terratorch/registry/timm_registry.py b/terratorch/registry/timm_registry.py index 6aa080a5..191fd13e 100644 --- a/terratorch/registry/timm_registry.py +++ b/terratorch/registry/timm_registry.py @@ -29,7 +29,7 @@ class TimmRegistry(Set): def register(self, constructor: Callable | type) -> Callable: raise NotImplementedError() - def build(self, name: str, *constructor_args, **constructor_kwargs) -> nn.Module: + def build(self, name: str, features_only=True, *constructor_args, **constructor_kwargs) -> nn.Module: """Build and return the component. Use prefixes ending with _ to forward to a specific source """ @@ -38,7 +38,7 @@ def build(self, name: str, *constructor_args, **constructor_kwargs) -> nn.Module timm.create_model( name, *constructor_args, - features_only=True, + features_only=features_only, **constructor_kwargs, ) ) diff --git a/terratorch/tasks/__init__.py b/terratorch/tasks/__init__.py index ee570bd7..82eea1be 100644 --- a/terratorch/tasks/__init__.py +++ b/terratorch/tasks/__init__.py @@ -2,6 +2,13 @@ from terratorch.tasks.regression_tasks import PixelwiseRegressionTask from terratorch.tasks.segmentation_tasks import SemanticSegmentationTask from terratorch.tasks.multilabel_classification_tasks import MultiLabelClassificationTask +try: + wxc_present = True + from terratorch.tasks.wxc_downscaling_task import WxCDownscalingTask +except ImportError as e: + print('wxc_downscaling not installed') + wxc_present = False + __all__ = ( "SemanticSegmentationTask", @@ -10,3 +17,6 @@ "MultiLabelClassificationTask" "BATCH_IDX_FOR_VALIDATION_PLOTTING", ) + +if wxc_present: + __all__.__add__(("WxCDownscalingTask", )) \ No newline at end of file diff --git a/terratorch/tasks/classification_tasks.py b/terratorch/tasks/classification_tasks.py index d529c80f..be585eeb 100644 --- a/terratorch/tasks/classification_tasks.py +++ b/terratorch/tasks/classification_tasks.py @@ -112,6 +112,9 @@ def configure_models(self) -> None: "classification", aux_decoders=self.aux_heads, **self.hparams["model_args"] ) if self.hparams["freeze_backbone"]: + if self.hparams.get("peft_config", None) is not None: + msg = "PEFT should be run with freeze_backbone = False" + raise ValueError(msg) self.model.freeze_encoder() if self.hparams["freeze_decoder"]: self.model.freeze_decoder() diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index 2ccc9890..3e9d2cde 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -202,6 +202,9 @@ def configure_models(self) -> None: "regression", aux_decoders=self.aux_heads, **self.hparams["model_args"] ) if self.hparams["freeze_backbone"]: + if self.hparams.get("peft_config", None) is not None: + msg = "PEFT should be run with freeze_backbone = False" + raise ValueError(msg) self.model.freeze_encoder() if self.hparams["freeze_decoder"]: self.model.freeze_decoder() @@ -283,11 +286,15 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) y_hat = model_output.output - self.train_metrics(y_hat, y) - self.log_dict(self.train_metrics, on_epoch=True) + self.train_metrics.update(y_hat, y) return loss["loss"] + def on_train_epoch_end(self) -> None: + self.log_dict(self.train_metrics.compute(), sync_dist=True) + self.train_metrics.reset() + return super().on_train_epoch_end() + def _do_plot_samples(self, batch_index): if not self.plot_on_val: # dont plot if self.plot_on_val is 0 return False @@ -317,10 +324,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) y_hat = model_output.output - out = y_hat[y != -1] - mask = y[y != -1] - self.val_metrics(out, mask) - self.log_dict(self.val_metrics, on_epoch=True) + self.val_metrics.update(y_hat, y) if self._do_plot_samples(batch_idx): try: @@ -346,6 +350,11 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - finally: plt.close() + def on_validation_epoch_end(self) -> None: + self.log_dict(self.val_metrics.compute(), sync_dist=True) + self.val_metrics.reset() + return super().on_validation_epoch_end() + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test loss and additional metrics. @@ -362,8 +371,12 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) y_hat = model_output.output - self.test_metrics(y_hat, y) - self.log_dict(self.test_metrics, on_epoch=True) + self.test_metrics.update(y_hat, y) + + def on_test_epoch_end(self) -> None: + self.log_dict(self.test_metrics.compute(), sync_dist=True) + self.test_metrics.reset() + return super().on_test_epoch_end() def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities. diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index df1638d8..56712e7f 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Any import lightning @@ -34,6 +35,7 @@ class SemanticSegmentationTask(BaseTask): - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor + - Allows to evaluate on multiple test dataloaders """ def __init__( @@ -57,6 +59,7 @@ def __init__( plot_on_val: bool | int = 10, class_names: list[str] | None = None, tiled_inference_parameters: TiledInferenceParameters = None, + test_dataloaders_names: list[str] | None = None, ) -> None: """Constructor @@ -94,6 +97,9 @@ def __init__( Defaults to numeric ordering. tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters used to determine if inference is done on the whole image or through tiling. + test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when + multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, + which assumes only one test dataloader is used. """ self.tiled_inference_parameters = tiled_inference_parameters self.aux_loss = aux_loss @@ -101,7 +107,9 @@ def __init__( self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) super().__init__() self.train_loss_handler = LossHandler(self.train_metrics.prefix) - self.test_loss_handler = LossHandler(self.test_metrics.prefix) + self.test_loss_handler: list[LossHandler] = [] + for metrics in self.test_metrics: + self.test_loss_handler.append(LossHandler(metrics.prefix)) self.val_loss_handler = LossHandler(self.val_metrics.prefix) self.monitor = f"{self.val_metrics.prefix}loss" self.plot_on_val = int(plot_on_val) @@ -115,6 +123,9 @@ def configure_models(self) -> None: "segmentation", aux_decoders=self.aux_heads, **self.hparams["model_args"] ) if self.hparams["freeze_backbone"]: + if self.hparams.get("peft_config", None) is not None: + msg = "PEFT should be run with freeze_backbone = False" + raise ValueError(msg) self.model.freeze_encoder() if self.hparams["freeze_decoder"]: self.model.freeze_decoder() @@ -210,7 +221,12 @@ def configure_metrics(self) -> None: ) self.train_metrics = metrics.clone(prefix="train/") self.val_metrics = metrics.clone(prefix="val/") - self.test_metrics = metrics.clone(prefix="test/") + if self.hparams["test_dataloaders_names"] is not None: + self.test_metrics = nn.ModuleList( + [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]] + ) + else: + self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")]) def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the train loss and additional metrics. @@ -223,7 +239,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) @@ -262,7 +278,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) @@ -309,16 +325,24 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) - loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + if dataloader_idx >= len(self.test_loss_handler): + msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." + raise ValueError(msg) + loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss) + self.test_loss_handler[dataloader_idx].log_loss( + partial(self.log, add_dataloader_idx=False), # We don't need the dataloader idx as prefixes are different + loss_dict=loss, + batch_size=x.shape[0], + ) y_hat_hard = to_segmentation_prediction(model_output) - self.test_metrics.update(y_hat_hard, y) + self.test_metrics[dataloader_idx].update(y_hat_hard, y) def on_test_epoch_end(self) -> None: - self.log_dict(self.test_metrics.compute(), sync_dist=True) - self.test_metrics.reset() + for metrics in self.test_metrics: + self.log_dict(metrics.compute(), sync_dist=True) + metrics.reset() return super().on_test_epoch_end() def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: @@ -335,7 +359,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T x = batch["image"] file_names = batch["filename"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) def model_forward(x): diff --git a/terratorch/tasks/wxc_downscaling_task.py b/terratorch/tasks/wxc_downscaling_task.py index a93889da..006f7e95 100644 --- a/terratorch/tasks/wxc_downscaling_task.py +++ b/terratorch/tasks/wxc_downscaling_task.py @@ -3,7 +3,8 @@ from torchgeo.trainers import BaseTask from typing import Any, Mapping -from terratorch.models.model import Model, get_factory +from terratorch.models.model import Model#, get_factory +from terratorch.registry import MODEL_FACTORY_REGISTRY from terratorch.tasks.loss_handler import LossHandler from terratorch.tasks.optimizer_factory import optimizer_factory from terratorch.tasks.regression_tasks import RootLossWrapper @@ -17,6 +18,7 @@ def __init__( self, model_args: dict, model_factory: str, + extra_kwargs: dict, model_config: ExperimentConfig, loss: str = "mse", lr: float = 0.001, @@ -29,12 +31,84 @@ def __init__( plot_on_val: bool | int = 10, ) -> None: - self.model_factory = get_factory(model_factory) + # Special cases for some parameters that could not be read in + # their own fields. + mask_unit_size = tuple(model_args.pop("mask_unit_size")) + encoder_decoder_kernel_size_per_stage = extra_kwargs.pop("encoder_decoder_kernel_size_per_stage") + output_vars = extra_kwargs.pop("output_vars") + type_dataset = extra_kwargs.pop("type") + input_levels = extra_kwargs.pop("input_levels") + downscaling_patch_size = extra_kwargs.pop("downscaling_patch_size") + n_input_timestamps = extra_kwargs.pop("n_input_timestamps") + downscaling_embed_dim = extra_kwargs.pop("downscaling_embed_dim") + encoder_decoder_conv_channels = extra_kwargs.pop("encoder_decoder_conv_channels") + encoder_decoder_scale_per_stage = extra_kwargs.pop("encoder_decoder_scale_per_stage") + encoder_decoder_upsampling_mode = extra_kwargs.pop("encoder_decoder_upsampling_mode") + encoder_shift = extra_kwargs.pop("encoder_shift") + drop_path = extra_kwargs.pop("drop_path") + encoder_decoder_type = extra_kwargs.pop("encoder_decoder_type") + input_size_lat = extra_kwargs.pop("input_size_lat") + input_size_lon = extra_kwargs.pop("input_size_lon") + freeze_backbone = extra_kwargs.pop("freeze_backbone") + freeze_decoder = extra_kwargs.pop("freeze_decoder") + data_path_surface = extra_kwargs.pop("data_path_surface") + data_path_vertical = extra_kwargs.pop("data_path_vertical") + climatology_path_surface = extra_kwargs.pop("climatology_path_surface") + climatology_path_vertical = extra_kwargs.pop("climatology_path_vertical") + residual_connection = model_args.pop("residual_connection") + residual = extra_kwargs.pop("residual", True) + + # Special cases for required variables + input_scalers_surface_path = extra_kwargs.pop("input_scalers_surface_path", None) + if not input_scalers_surface_path: + raise Exception(f"The parameter `input_scalers_surface_path` must be defined in `extra_kwargs`.") + + input_scalers_vertical_path = extra_kwargs.pop("input_scalers_vertical_path", None) + if not input_scalers_vertical_path: + raise Exception(f"The parameter `input_scalers_vertical_path` must be defined in `extra_kwargs`.") + + output_scalers_surface_path = extra_kwargs.pop("output_scalers_surface_path") + output_scalers_vertical_path = extra_kwargs.pop("output_scalers_vertical_path") + + model_config.freeze_backbone = freeze_backbone + model_config.freeze_decoder = freeze_decoder + model_config.mask_unit_size = mask_unit_size + model_config.model.mask_unit_size = mask_unit_size + model_config.model.encoder_decoder_kernel_size_per_stage = encoder_decoder_kernel_size_per_stage + model_config.model.input_scalers_surface_path = input_scalers_surface_path + model_config.model.input_scalers_vertical_path = input_scalers_vertical_path + model_config.data.output_vars = output_vars + model_config.data.type = type_dataset + model_config.data.input_surface_vars = model_config.data.surface_vars + model_config.data.input_vertical_vars = model_config.data.vertical_vars + model_config.data.input_static_surface_vars = model_config.data.static_surface_vars + model_config.data.input_levels = input_levels + model_config.model.downscaling_patch_size = downscaling_patch_size + model_config.data.n_input_timestamps = n_input_timestamps + model_config.model.downscaling_embed_dim = downscaling_embed_dim + model_config.model.encoder_decoder_conv_channels = encoder_decoder_conv_channels + model_config.model.encoder_decoder_scale_per_stage = encoder_decoder_scale_per_stage + model_config.model.encoder_decoder_upsampling_mode = encoder_decoder_upsampling_mode + model_config.model.encoder_shift = encoder_shift + model_config.model.drop_path = drop_path + model_config.model.encoder_decoder_type = encoder_decoder_type + model_config.data.input_size_lat = input_size_lat + model_config.data.input_size_lon = input_size_lon + model_config.data.data_path_surface = data_path_surface + model_config.data.data_path_vertical = data_path_vertical + model_config.data.climatology_path_surface = climatology_path_surface + model_config.data.climatology_path_vertical = climatology_path_vertical + model_config.model.output_scalers_surface_path = output_scalers_surface_path + model_config.model.output_scalers_vertical_path = output_scalers_vertical_path + model_config.model.residual_connection = residual_connection + model_config.model.residual = residual + + self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) self.model_config = model_config # TODO Unify it with self.hparams self.extended_hparams = self.model_config.to_dict() super().__init__() - print(type(self.hparams)) + self.train_loss_handler = LossHandler(self.train_metrics.prefix) self.test_loss_handler = LossHandler(self.test_metrics.prefix) self.val_loss_handler = LossHandler(self.val_metrics.prefix) @@ -72,7 +146,7 @@ def configure_losses(self) -> None: Raises: ValueError: If *loss* is invalid. """ - #TODO 'reduction' should be chosen using the config and + #TODO 'reduction' should be chosen using the config and # a similar class as IgnoreIndex should be defined for this class loss: str = self.hparams["loss"].lower() @@ -128,11 +202,11 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - y = batch["mask"] model_output: ModelOutput = self(x) - loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, None) - self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) + loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, None) + self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) y_hat = model_output.output - self.train_metrics(y_hat, y) - self.log_dict(self.train_metrics, on_epoch=True) + self.val_metrics(y_hat, y) + self.log_dict(self.val_metrics, on_epoch=True) return loss["loss"] diff --git a/terratorch/tasks/wxc_gravity_wave_task.py b/terratorch/tasks/wxc_gravity_wave_task.py new file mode 100644 index 00000000..9afc114f --- /dev/null +++ b/terratorch/tasks/wxc_gravity_wave_task.py @@ -0,0 +1,15 @@ + + +from torchgeo.trainers import BaseTask +import torch.nn as nn + +class WxCGravityWaveTask(BaseTask): + def __init__(self, model_factory): + self.model_factory = model_factory + super().__init__() + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.learning_rate) + + def configure_models(self): + self.model = self.model_factory.build_model(backbone='gravitywave', aux_decoders=None) \ No newline at end of file diff --git a/tests/resources/configs/manufactured-finetune-mlp_decoder.yaml b/tests/resources/configs/manufactured-finetune-mlp_decoder.yaml new file mode 100644 index 00000000..feb179f9 --- /dev/null +++ b/tests/resources/configs/manufactured-finetune-mlp_decoder.yaml @@ -0,0 +1,147 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: cpu + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 2 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: albumentations.Rotate + init_args: + limit: 30 + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + # mask_value: 1 + p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - 0 + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + - 1 + - 2 + - 3 + - 4 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/resources/inputs + train_label_data_root: tests/resources/inputs + val_data_root: tests/resources/inputs + val_label_data_root: tests/resources/inputs + test_data_root: tests/resources/inputs + test_label_data_root: tests/resources/inputs + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: + - 547.36707 + - 898.5121 + - 1020.9082 + - 2665.5352 + - 2340.584 + - 1610.1407 + stds: + - 411.4701 + - 558.54065 + - 815.94025 + - 812.4403 + - 1113.7145 + - 1067.641 + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: MLPDecoder + pretrained: false + backbone: prithvi_vit_100 + decoder_activation: ReLU + backbone_drop_path_rate: 0.3 + num_frames: 1 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.Identity + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml b/tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml index 69dac095..6e267d00 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_vit_100.yaml @@ -97,8 +97,8 @@ model: decoder: UperNetDecoder pretrained: false backbone: prithvi_vit_100 - # backbone_pretrained_cfg_overlay: - # file: tests/prithvi_vit_100.pt + #backbone_pretrained_cfg_overlay: + #file: tests/all_ecos_random/version_0/checkpoints/epoch=0_state_dict.ckpt #tests/prithvi_vit_100.pt backbone_drop_path_rate: 0.3 num_frames: 1 # backbone_window_size: 8 diff --git a/tests/test_backbones.py b/tests/test_backbones.py index 17b3d8b2..d12e1091 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -3,7 +3,7 @@ import pytest import timm import torch - +import gc from terratorch.models.backbones import scalemae from terratorch.registry import BACKBONE_REGISTRY @@ -25,57 +25,65 @@ def input_512(): def input_224_multitemporal(): return torch.ones((1, NUM_CHANNELS, NUM_FRAMES, 224, 224)) +@pytest.fixture +def input_non_divisible(): + return torch.ones((1, NUM_CHANNELS, NUM_FRAMES, 220, 230)) + @pytest.fixture def input_386(): return torch.ones((1, NUM_CHANNELS, 386, 386)) -@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"]) +@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) @pytest.mark.parametrize("test_input", ["input_224", "input_512"]) def test_can_create_backbones_from_timm(model_name, test_input, request): backbone = timm.create_model(model_name, pretrained=False) input_tensor = request.getfixturevalue(test_input) backbone(input_tensor) + gc.collect() - -@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"]) +@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) @pytest.mark.parametrize("test_input", ["input_224", "input_512"]) def test_can_create_backbones_from_timm_features_only(model_name, test_input, request): backbone = timm.create_model(model_name, pretrained=False, features_only=True) input_tensor = request.getfixturevalue(test_input) backbone(input_tensor) - -@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"]) + gc.collect() +@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) @pytest.mark.parametrize("prefix", ["", "timm_"]) def test_can_create_timm_backbones_from_registry(model_name, input_224, prefix): backbone = BACKBONE_REGISTRY.build(prefix+model_name, pretrained=False) backbone(input_224) + gc.collect() - -@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) +@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_eo_v2_300"]) def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal): backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES) backbone(input_224_multitemporal) - - -@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) + gc.collect() +@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_eo_v2_300"]) +def test_vit_models_non_divisible_input(model_name, input_non_divisible): + #padding 'none','constant', 'reflect', 'replicate' or 'circular' default is 'none' + backbone = timm.create_model(model_name, pretrained=False, features_only=True, num_frames=NUM_FRAMES, padding='constant') + backbone(input_non_divisible) + gc.collect() +@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_eo_v2_300"]) @pytest.mark.parametrize("patch_size", [8, 16]) -@pytest.mark.parametrize("tubelet_size", [1, 2, 4]) -def test_vit_models_different_patch_tubelet_sizes(model_name, patch_size, tubelet_size, input_224_multitemporal): +@pytest.mark.parametrize("patch_size_time", [1, 2, 4]) +def test_vit_models_different_patch_tubelet_sizes(model_name, patch_size, patch_size_time, input_224_multitemporal): backbone = timm.create_model( model_name, pretrained=False, num_frames=NUM_FRAMES, - patch_size=patch_size, - tubelet_size=tubelet_size, + patch_size=[patch_size_time, patch_size, patch_size], features_only=True, ) embedding = backbone(input_224_multitemporal) processed_embedding = backbone.prepare_features_for_image_model(embedding) expected_h_w = 224 // patch_size - expected_t = NUM_FRAMES // tubelet_size + expected_t = NUM_FRAMES // patch_size_time for e in processed_embedding: assert ( @@ -91,8 +99,8 @@ def test_vit_models_different_patch_tubelet_sizes(model_name, patch_size, tubele ), f"Expected embedding dimension to be of size effective time {expected_t} x embedding dimension\ {backbone.embed_dim} = {expected_t * backbone.embed_dim} but was {e.shape[1]}" - -@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) + gc.collect() +@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_eo_v2_300"]) def test_out_indices(model_name, input_224): out_indices = [2, 4, 8, 10] backbone = timm.create_model(model_name, pretrained=False, features_only=True, out_indices=out_indices) @@ -103,7 +111,19 @@ def test_out_indices(model_name, input_224): for filtered_index, full_index in enumerate(out_indices): assert torch.allclose(full_output[full_index], output[filtered_index]) + gc.collect() +@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_eo_v2_300"]) +def test_out_indices_non_divisible(model_name, input_non_divisible): + out_indices = [2, 4, 8, 10] + backbone = timm.create_model(model_name, pretrained=False, features_only=True, num_frames=NUM_FRAMES, out_indices=out_indices, padding='constant') + assert backbone.feature_info.out_indices == out_indices + + output = backbone(input_non_divisible) + full_output = backbone.forward_features(input_non_divisible) + for filtered_index, full_index in enumerate(out_indices): + assert torch.allclose(full_output[full_index], output[filtered_index]) + gc.collect() @pytest.mark.parametrize("model_name", ["vit_base_patch16", "vit_large_patch16"]) def test_scale_mae(model_name): out_indices = [2, 4, 8, 10] @@ -114,7 +134,7 @@ def test_scale_mae(model_name): output = backbone(input_tensor) assert len(output) == len(out_indices) - + gc.collect() @pytest.mark.parametrize("model_name", ["vit_base_patch16", "vit_large_patch16"]) @pytest.mark.parametrize("bands", [2, 4, 6]) def test_scale_mae_new_channels(model_name, bands): @@ -122,4 +142,4 @@ def test_scale_mae_new_channels(model_name, bands): backbone = scalemae.create_model(model_name, bands=list(range(bands))) input_tensor = torch.ones((1, bands, 224, 224)) backbone(input_tensor) - + gc.collect() diff --git a/tests/test_clay_tasks.py b/tests/test_clay_tasks.py new file mode 100644 index 00000000..6c950f75 --- /dev/null +++ b/tests/test_clay_tasks.py @@ -0,0 +1,91 @@ +# Copyright contributors to the Terratorch project + +import pytest +import torch + +from terratorch.models import ClayModelFactory +from terratorch.models.backbones.clay_v1 import WAVELENGTHS +from terratorch.tasks import ClassificationTask, PixelwiseRegressionTask, SemanticSegmentationTask + +NUM_CHANNELS = 6 +NUM_CLASSES = 2 +EXPECTED_SEGMENTATION_OUTPUT_SHAPE = (1, NUM_CLASSES, 224, 224) + + +@pytest.fixture(scope="session") +def model_factory() -> str: + return "ClayModelFactory" + + +@pytest.fixture(scope="session") +def model_input() -> torch.Tensor: + return torch.ones((1, NUM_CHANNELS, 224, 224)) + +@pytest.mark.parametrize("backbone", ["clay_v1_base"]) +@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder"]) +@pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"]) +def test_create_segmentation_task(backbone, decoder, loss, model_factory: ClayModelFactory): + model_args = { + "backbone": backbone, + "decoder": decoder, + "in_channels": NUM_CHANNELS, + "pretrained": False, + "num_classes": NUM_CLASSES, + "bands": list(WAVELENGTHS.keys()), + } + + if decoder == "UperNetDecoder": + model_args["out_indices"] = [1, 2, 3, 4] + model_args["scale_modules"] = True + SemanticSegmentationTask( + model_args, + model_factory, + loss=loss, + ) + + +@pytest.mark.parametrize("backbone", ["clay_v1_base"]) +@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder"]) +@pytest.mark.parametrize("loss", ["mae", "rmse", "huber"]) +def test_create_regression_task(backbone, decoder, loss, model_factory: ClayModelFactory): + model_args = { + "backbone": backbone, + "decoder": decoder, + "in_channels": NUM_CHANNELS, + "pretrained": False, + "bands": list(WAVELENGTHS.keys()), + } + + if decoder == "UperNetDecoder": + model_args["out_indices"] = [1, 2, 3, 4] + model_args["scale_modules"] = True + + PixelwiseRegressionTask( + model_args, + model_factory, + loss=loss, + ) + + +@pytest.mark.parametrize("backbone", ["clay_v1_base"]) +@pytest.mark.parametrize("decoder", ["IdentityDecoder"]) +@pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"]) +def test_create_classification_task(backbone, decoder, loss, model_factory: ClayModelFactory): + model_args = { + "backbone": backbone, + "decoder": decoder, + "in_channels": NUM_CHANNELS, + "pretrained": False, + "num_classes": NUM_CLASSES, + "bands": list(WAVELENGTHS.keys()), + } + + if decoder == "UperNetDecoder": + model_args["out_indices"] = [1, 2, 3, 4] + model_args["scale_modules"] = True + + ClassificationTask( + model_args, + model_factory, + loss=loss, + ) diff --git a/tests/test_decoders.py b/tests/test_decoders.py index cf8b4db5..bebd1fa1 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -9,6 +9,7 @@ import terratorch # noqa: F401 from terratorch.models.decoders.aspp_head import ASPPSegmentationHead +import gc def test_aspphead(): dilations = (1, 6, 12, 18) @@ -19,3 +20,5 @@ def test_aspphead(): image = [torch.from_numpy(np.random.rand(2, 6, 224, 224).astype("float32"))] assert decoder(image).shape == (2, 2, 224, 224) + + gc.collect() diff --git a/tests/test_encoder_decoder_model_factory.py b/tests/test_encoder_decoder_model_factory.py index dda9f795..9d75fbc2 100644 --- a/tests/test_encoder_decoder_model_factory.py +++ b/tests/test_encoder_decoder_model_factory.py @@ -8,6 +8,7 @@ from terratorch.models import EncoderDecoderFactory from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS from terratorch.models.model import AuxiliaryHead +import gc NUM_CHANNELS = 6 NUM_CLASSES = 2 @@ -49,8 +50,9 @@ def test_unused_args_raise_exception(model_factory: EncoderDecoderFactory): ) assert "unused_argument" in str(excinfo.value) + gc.collect() -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) +@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_eo_v2_300"]) def test_create_classification_model(backbone, model_factory: EncoderDecoderFactory, model_input): model = model_factory.build_model( "classification", @@ -65,8 +67,9 @@ def test_create_classification_model(backbone, model_factory: EncoderDecoderFact with torch.no_grad(): assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE + gc.collect() -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) +@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_eo_v2_300"]) def test_create_classification_model_no_in_channels(backbone, model_factory: EncoderDecoderFactory, model_input): model = model_factory.build_model( "classification", @@ -81,6 +84,7 @@ def test_create_classification_model_no_in_channels(backbone, model_factory: Enc with torch.no_grad(): assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE + gc.collect() @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @@ -105,6 +109,7 @@ def test_create_pixelwise_model(backbone, task, expected, decoder, model_factory with torch.no_grad(): assert model(model_input).output.shape == expected + gc.collect() @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @@ -127,6 +132,7 @@ def test_create_model_with_smp_fpn_decoder(backbone, task, expected, model_facto with torch.no_grad(): assert model(model_input).output.shape == expected + gc.collect() @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @@ -152,6 +158,7 @@ def test_create_model_with_smp_unet_decoder( with torch.no_grad(): assert model(model_input).output.shape == expected + gc.collect() @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @@ -176,6 +183,8 @@ def test_create_model_with_smp_deeplabv3plus_decoder( with torch.no_grad(): assert model(model_input).output.shape == expected + gc.collect() + @pytest.mark.skipif(not importlib.util.find_spec("mmseg"), reason="mmsegmentation not installed") @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @@ -205,6 +214,8 @@ def test_create_model_with_mmseg_fcn_decoder( with torch.no_grad(): assert model(model_input).output.shape == expected + gc.collect() + @pytest.mark.skipif(not importlib.util.find_spec("mmseg"), reason="mmsegmentation not installed") @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @@ -236,6 +247,7 @@ def test_create_model_with_mmseg_uperhead_decoder( with torch.no_grad(): assert model(model_input).output.shape == expected + gc.collect() @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @@ -262,6 +274,7 @@ def test_create_pixelwise_model_no_in_channels( with torch.no_grad(): assert model(model_input).output.shape == expected + gc.collect() @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @@ -296,6 +309,7 @@ def test_create_pixelwise_model_with_aux_heads( for _, output in model_output.auxiliary_heads.items(): assert output.shape == expected + gc.collect() @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @@ -321,3 +335,5 @@ def test_create_pixelwise_model_with_extra_bands( model_input = torch.ones((1, NUM_CHANNELS + 1, 224, 224)) with torch.no_grad(): assert model(model_input).output.shape == expected + + gc.collect() diff --git a/tests/test_prithvi_model_factory.py b/tests/test_prithvi_model_factory.py index 4437586a..0a183ae1 100644 --- a/tests/test_prithvi_model_factory.py +++ b/tests/test_prithvi_model_factory.py @@ -8,6 +8,7 @@ from terratorch.models import PrithviModelFactory from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS from terratorch.models.model import AuxiliaryHead +import gc NUM_CHANNELS = 6 NUM_CLASSES = 2 @@ -30,7 +31,7 @@ def model_input() -> torch.Tensor: return torch.ones((1, NUM_CHANNELS, 224, 224)) -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) +@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_eo_v2_300"]) def test_create_classification_model(backbone, model_factory: PrithviModelFactory, model_input): model = model_factory.build_model( "classification", @@ -47,7 +48,7 @@ def test_create_classification_model(backbone, model_factory: PrithviModelFactor assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) +@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_eo_v2_300"]) def test_create_classification_model_no_in_channels(backbone, model_factory: PrithviModelFactory, model_input): model = model_factory.build_model( "classification", @@ -62,6 +63,7 @@ def test_create_classification_model_no_in_channels(backbone, model_factory: Pri with torch.no_grad(): assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE + gc.collect() @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @@ -88,6 +90,7 @@ def test_create_pixelwise_model(backbone, task, expected, decoder, model_factory with torch.no_grad(): assert model(model_input).output.shape == expected + gc.collect() @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @@ -115,6 +118,7 @@ def test_create_pixelwise_model_no_in_channels( with torch.no_grad(): assert model(model_input).output.shape == expected + gc.collect() @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @@ -150,6 +154,7 @@ def test_create_pixelwise_model_with_aux_heads( for _, output in model_output.auxiliary_heads.items(): assert output.shape == expected + gc.collect() @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @@ -174,3 +179,5 @@ def test_create_pixelwise_model_with_extra_bands(backbone, task, expected, decod model_input = torch.ones((1, NUM_CHANNELS + 1, 224, 224)) with torch.no_grad(): assert model(model_input).output.shape == expected + + gc.collect() diff --git a/tests/test_prithvi_tasks.py b/tests/test_prithvi_tasks.py index 4f23e1a5..1c578873 100644 --- a/tests/test_prithvi_tasks.py +++ b/tests/test_prithvi_tasks.py @@ -6,6 +6,8 @@ from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS from terratorch.tasks import ClassificationTask, PixelwiseRegressionTask, SemanticSegmentationTask +import gc + NUM_CHANNELS = 6 NUM_CLASSES = 2 EXPECTED_SEGMENTATION_OUTPUT_SHAPE = (1, NUM_CLASSES, 224, 224) @@ -26,7 +28,7 @@ def model_input() -> torch.Tensor: return torch.ones((1, NUM_CHANNELS, 224, 224)) -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"]) +@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) @pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"]) def test_create_segmentation_task(backbone, decoder, loss, model_factory: str): @@ -46,8 +48,9 @@ def test_create_segmentation_task(backbone, decoder, loss, model_factory: str): loss=loss, ) + gc.collect() -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"]) +@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) @pytest.mark.parametrize("loss", ["mae", "rmse", "huber"]) def test_create_regression_task(backbone, decoder, loss, model_factory: str): @@ -67,8 +70,9 @@ def test_create_regression_task(backbone, decoder, loss, model_factory: str): loss=loss, ) + gc.collect() -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"]) +@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) @pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"]) def test_create_classification_task(backbone, decoder, loss, model_factory: str): @@ -88,3 +92,5 @@ def test_create_classification_task(backbone, decoder, loss, model_factory: str) model_factory, loss=loss, ) + + gc.collect() diff --git a/tests/test_prithvi_vit.py b/tests/test_prithvi_vit.py index 4803ca18..c60ae16b 100644 --- a/tests/test_prithvi_vit.py +++ b/tests/test_prithvi_vit.py @@ -4,6 +4,7 @@ from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights +import gc @pytest.mark.parametrize("patch_size", [4, 8, 16]) @pytest.mark.parametrize("tubelet_size,num_frames", [(1, 1), (1, 2), (1, 3), (2, 2), (3,3)]) @@ -28,14 +29,15 @@ def test_prithvi_vit_patch_embed_loading_compatible(patch_size, tubelet_size, nu select_patch_embed_weights(weights, model, PRETRAINED_BANDS, PRETRAINED_BANDS) -@pytest.mark.parametrize("tubelet_size,tubelet_size_other", [(1, 2), (2, 4)]) -def test_prithvi_vit_patch_embed_loading_not_compatible_tubelet(tubelet_size, tubelet_size_other): + gc.collect() + +@pytest.mark.parametrize("patch_size_time,patch_size_time_other", [(1, 2), (2, 4)]) +def test_prithvi_vit_patch_embed_loading_not_compatible_tubelet(patch_size_time,patch_size_time_other): model = timm.create_model( "prithvi_vit_100", pretrained=False, num_frames=4, - patch_size=16, - tubelet_size=tubelet_size, + patch_size=[patch_size_time, 16, 16], features_only=True, ) @@ -43,8 +45,7 @@ def test_prithvi_vit_patch_embed_loading_not_compatible_tubelet(tubelet_size, tu "prithvi_vit_100", pretrained=False, num_frames=4, - patch_size=16, - tubelet_size=tubelet_size_other, + patch_size=[patch_size_time_other, 16, 16], features_only=True, ).state_dict() @@ -52,6 +53,8 @@ def test_prithvi_vit_patch_embed_loading_not_compatible_tubelet(tubelet_size, tu with pytest.warns(UserWarning): select_patch_embed_weights(weights, model, PRETRAINED_BANDS, PRETRAINED_BANDS) + gc.collect() + @pytest.mark.parametrize("patch_size,patch_size_other", [(2, 4), (4, 8), (16, 4)]) def test_prithvi_vit_patch_embed_loading_not_compatible_patch(patch_size, patch_size_other): model = timm.create_model( @@ -74,3 +77,5 @@ def test_prithvi_vit_patch_embed_loading_not_compatible_patch(patch_size, patch_ with pytest.warns(UserWarning): select_patch_embed_weights(weights, model, PRETRAINED_BANDS, PRETRAINED_BANDS) + + gc.collect() diff --git a/tests/test_smp_model_factory.py b/tests/test_smp_model_factory.py index 2becd053..0d46df30 100644 --- a/tests/test_smp_model_factory.py +++ b/tests/test_smp_model_factory.py @@ -6,6 +6,8 @@ from terratorch.models import SMPModelFactory from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS +import gc + NUM_CHANNELS = 6 NUM_CLASSES = 2 EXPECTED_SEGMENTATION_OUTPUT_SHAPE = (1, NUM_CLASSES, 224, 224) @@ -40,6 +42,7 @@ def test_create_segmentation_model(backbone, model, model_factory: SMPModelFacto with torch.no_grad(): assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE + gc.collect() @pytest.mark.parametrize("backbone", ["timm-regnetx_002"]) @pytest.mark.parametrize("model", ["Unet", "DeepLabV3"]) @@ -57,6 +60,7 @@ def test_create_segmentation_model_no_in_channels(backbone, model, model_factory with torch.no_grad(): assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE + gc.collect() @pytest.mark.parametrize("backbone", ["timm-regnetx_002"]) @pytest.mark.parametrize("model", ["Unet", "DeepLabV3"]) @@ -74,3 +78,5 @@ def test_create_model_with_extra_bands(backbone, model, model_factory: SMPModelF model_input = torch.ones((1, NUM_CHANNELS + 1, 224, 224)) with torch.no_grad(): assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE + + gc.collect()