diff --git a/examples/notebooks/WxCTutorialGravityWave.ipynb b/examples/notebooks/WxCTutorialGravityWave.ipynb index b12f3f2a..bf7fb11d 100644 --- a/examples/notebooks/WxCTutorialGravityWave.ipynb +++ b/examples/notebooks/WxCTutorialGravityWave.ipynb @@ -9,71 +9,76 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": {}, + "source": "!pip install -U -e ../../\n", "outputs": [], - "source": [ - "!pip install -U git+https://github.com/romeokienzler/terratorch.git@201\n" - ] + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "metadata": {}, + "source": "!pip install -U albumentations # fix until https://github.com/IBM/terratorch/issues/164 is solved", "outputs": [], - "source": [ - "!pip install -U albumentations # fix until https://github.com/IBM/terratorch/issues/164 is solved" - ] + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "!pip install -U git+https://github.com/romeokienzler/gravity-wave-finetuning.git\n" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "!pip install huggingface_hub" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-29T08:14:17.633710Z", + "start_time": "2024-11-29T08:14:03.625634Z" + } + }, + "source": [ + "import terratorch # this import is needed to initialize TT's factories\n", + "from lightning.pytorch import Trainer\n", + "import os\n", + "import torch\n", + "from huggingface_hub import hf_hub_download, snapshot_download\n", + "from terratorch.models.wxc_model_factory import WxCModelFactory\n", + "import torch.distributed as dist" + ], "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/usr/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/romeokienzler/gitco/terratorch.romeokienzler.210/.venv/lib64/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "/home/romeokienzler/.local/lib/python3.12/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n", + "INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.21 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations\n", + "/home/romeokienzler/gitco/terratorch.romeokienzler.210/.venv/lib64/python3.12/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n", " @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)\n" ] } ], - "source": [ - "import terratorch # this import is needed to initialize TT's factories\n", - "from lightning.pytorch import Trainer\n", - "import os\n", - "import torch\n", - "from huggingface_hub import hf_hub_download, snapshot_download\n", - "from terratorch.models.wxc_model_factory import WxCModelFactory\n", - "import torch.distributed as dist" - ] + "execution_count": 1 }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-29T08:14:17.659890Z", + "start_time": "2024-11-29T08:14:17.644319Z" + } + }, "source": [ "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '12355' \n", @@ -87,24 +92,18 @@ " rank=0,\n", " world_size=1\n", ")" - ] + ], + "outputs": [], + "execution_count": 2 }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'config.yaml'" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-29T08:14:18.357440Z", + "start_time": "2024-11-29T08:14:17.896849Z" } - ], + }, "source": [ "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/Gravity_wave_Parameterization\",\n", @@ -117,24 +116,29 @@ " filename=f\"config.yaml\",\n", " local_dir=\".\",\n", ")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, + ], "outputs": [ { "data": { "text/plain": [ - "'wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc'" + "'config.yaml'" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], + "execution_count": 3 + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-29T08:14:18.619614Z", + "start_time": "2024-11-29T08:14:18.386602Z" + } + }, "source": [ "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/Gravity_wave_Parameterization\",\n", @@ -142,121 +146,143 @@ " filename=f\"wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc\",\n", " local_dir=\".\",\n", ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loading weights from magnet-flux-uvtp122-epoch-99-loss-0.1022.pt\n", - "Loaded weights\n" - ] - } ], - "source": [ - "from prithviwxc.gravitywave.datamodule import ERA5DataModule\n", - "from terratorch.tasks.wxc_gravity_wave_task import WxCGravityWaveTask\n", - "task = WxCGravityWaveTask(WxCModelFactory())" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: GPU available: False, used: False\n", - "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", - "INFO: TPU available: False, using: 0 TPU cores\n", - "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", - "INFO: HPU available: False, using: 0 HPUs\n", - "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n" - ] - }, { "data": { "text/plain": [ - "prithviwxc.gravitywave.datamodule.ERA5DataModule" + "'wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc'" ] }, - "execution_count": 6, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "trainer = Trainer(\n", - " max_epochs=1,\n", - ")\n", - "dm = ERA5DataModule(train_data_path='.', valid_data_path='.')\n", - "type(dm)" - ] + "execution_count": 4 }, { "cell_type": "code", - "execution_count": null, +<<<<<<< HEAD + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-29T08:14:19.348442Z", + "start_time": "2024-11-29T08:14:18.647501Z" + } + }, +======= "metadata": {}, +>>>>>>> origin/201 + "source": [ + "from prithviwxc.gravitywave.datamodule import ERA5DataModule\n", + "from terratorch.tasks.wxc_gravity_wave_task import WxCGravityWaveTask\n", + "\n", + "model_args = {\n", + " \"in_channels\": 1280,\n", + " \"input_size_time\": 1,\n", + " \"n_lats_px\": 64,\n", + " \"n_lons_px\": 128,\n", + " \"patch_size_px\": [2, 2],\n", + " \"mask_unit_size_px\": [8, 16],\n", + " \"mask_ratio_inputs\": 0.5,\n", + " \"embed_dim\": 2560,\n", + " \"n_blocks_encoder\": 12,\n", + " \"n_blocks_decoder\": 2,\n", + " \"mlp_multiplier\": 4,\n", + " \"n_heads\": 16,\n", + " \"dropout\": 0.0,\n", + " \"drop_path\": 0.05,\n", + " \"parameter_dropout\": 0.0,\n", + " \"residual\": \"none\",\n", + " \"masking_mode\": \"both\",\n", + " \"decoder_shifting\": False,\n", + " \"positional_encoding\": \"absolute\",\n", + " \"checkpoint_encoder\": [3, 6, 9, 12, 15, 18, 21, 24],\n", + " \"checkpoint_decoder\": [1, 3],\n", + " \"in_channels_static\": 3,\n", + " \"input_scalers_mu\": torch.tensor([0] * 1280),\n", + " \"input_scalers_sigma\": torch.tensor([1] * 1280),\n", + " \"input_scalers_epsilon\": 0,\n", + " \"static_input_scalers_mu\": torch.tensor([0] * 3),\n", + " \"static_input_scalers_sigma\": torch.tensor([1] * 3),\n", + " \"static_input_scalers_epsilon\": 0,\n", + " \"output_scalers\": torch.tensor([0] * 1280),\n", +<<<<<<< HEAD + " \"encoder_hidden_channels_multiplier\" : [1, 2, 4, 8],\n", + " \"encoder_num_encoder_blocks\" : 4,\n", + " \"decoder_hidden_channels_multiplier\" : [(16, 8), (12, 4), (6, 2), (3, 1)],\n", + " \"decoder_num_decoder_blocks\" : 4,\n", +======= +>>>>>>> origin/201 + "}\n", + "task = WxCGravityWaveTask(WxCModelFactory(), model_args=model_args, mode='eval')" + ], "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Predicting DataLoader 0: 19%|█████████████████████████████████▏ | 9/47 [1:06:22<4:40:15, 0.00it/s]" + "ename": "TypeError", + "evalue": "PrithviWxC.__init__() got an unexpected keyword argument 'encoder_hidden_channels_multiplier'", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mTypeError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[5], line 39\u001B[0m\n\u001B[1;32m 2\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mterratorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mtasks\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mwxc_gravity_wave_task\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m WxCGravityWaveTask\n\u001B[1;32m 4\u001B[0m model_args \u001B[38;5;241m=\u001B[39m {\n\u001B[1;32m 5\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124min_channels\u001B[39m\u001B[38;5;124m\"\u001B[39m: \u001B[38;5;241m1280\u001B[39m,\n\u001B[1;32m 6\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124minput_size_time\u001B[39m\u001B[38;5;124m\"\u001B[39m: \u001B[38;5;241m1\u001B[39m,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 37\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mdecoder_num_decoder_blocks\u001B[39m\u001B[38;5;124m\"\u001B[39m : \u001B[38;5;241m4\u001B[39m,\n\u001B[1;32m 38\u001B[0m }\n\u001B[0;32m---> 39\u001B[0m task \u001B[38;5;241m=\u001B[39m \u001B[43mWxCGravityWaveTask\u001B[49m\u001B[43m(\u001B[49m\u001B[43mWxCModelFactory\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmodel_args\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mmodel_args\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmode\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43meval\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/gitco/terratorch.romeokienzler.210/terratorch/tasks/wxc_gravity_wave_task.py:14\u001B[0m, in \u001B[0;36mWxCGravityWaveTask.__init__\u001B[0;34m(self, model_factory, model_args, mode, learning_rate)\u001B[0m\n\u001B[1;32m 12\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmodel_factory \u001B[38;5;241m=\u001B[39m model_factory\n\u001B[1;32m 13\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mlearning_rate \u001B[38;5;241m=\u001B[39m learning_rate\n\u001B[0;32m---> 14\u001B[0m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[38;5;21;43m__init__\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/gitco/terratorch.romeokienzler.210/.venv/lib64/python3.12/site-packages/torchgeo/trainers/base.py:39\u001B[0m, in \u001B[0;36mBaseTask.__init__\u001B[0;34m(self, ignore)\u001B[0m\n\u001B[1;32m 37\u001B[0m \u001B[38;5;28msuper\u001B[39m()\u001B[38;5;241m.\u001B[39m\u001B[38;5;21m__init__\u001B[39m()\n\u001B[1;32m 38\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39msave_hyperparameters(ignore\u001B[38;5;241m=\u001B[39mignore)\n\u001B[0;32m---> 39\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconfigure_models\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 40\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mconfigure_losses()\n\u001B[1;32m 41\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mconfigure_metrics()\n", + "File \u001B[0;32m~/gitco/terratorch.romeokienzler.210/terratorch/tasks/wxc_gravity_wave_task.py:22\u001B[0m, in \u001B[0;36mWxCGravityWaveTask.configure_models\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 20\u001B[0m device \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mdevice(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcuda\u001B[39m\u001B[38;5;124m\"\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m torch\u001B[38;5;241m.\u001B[39mcuda\u001B[38;5;241m.\u001B[39mis_available() \u001B[38;5;28;01melse\u001B[39;00m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcpu\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 21\u001B[0m \u001B[38;5;66;03m#self.model = self.model_factory.build_model(backbone='prithviwxc', aux_decoders=None, **self.hparams[\"model_args\"])\u001B[39;00m\n\u001B[0;32m---> 22\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmodel \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mmodel_factory\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbuild_model\u001B[49m\u001B[43m(\u001B[49m\u001B[43mbackbone\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mprithviwxc\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maux_decoders\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mNone\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mmodel_args\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 23\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmodel \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmodel\u001B[38;5;241m.\u001B[39mto(device)\n\u001B[1;32m 24\u001B[0m layer_devices \u001B[38;5;241m=\u001B[39m []\n", + "File \u001B[0;32m~/gitco/terratorch.romeokienzler.210/terratorch/models/wxc_model_factory.py:64\u001B[0m, in \u001B[0;36mWxCModelFactory.build_model\u001B[0;34m(self, backbone, aux_decoders, backbone_weights, **kwargs)\u001B[0m\n\u001B[1;32m 61\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mModule not found: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00me\u001B[38;5;241m.\u001B[39mname\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m. Please install PrithviWxC using pip install PrithviWxC\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 62\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m\n\u001B[0;32m---> 64\u001B[0m backbone \u001B[38;5;241m=\u001B[39m \u001B[43mprithviwxc\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mPrithviWxC\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 66\u001B[0m \u001B[38;5;66;03m# Freeze PrithviWxC model parameters\u001B[39;00m\n\u001B[1;32m 67\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m param \u001B[38;5;129;01min\u001B[39;00m backbone\u001B[38;5;241m.\u001B[39mparameters():\n", + "\u001B[0;31mTypeError\u001B[0m: PrithviWxC.__init__() got an unexpected keyword argument 'encoder_hidden_channels_multiplier'" ] } ], - "source": [ - "results = trainer.predict(model=task, datamodule=dm, return_predictions=True)" - ] + "execution_count": 5 }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ - "results" - ] + "trainer = Trainer(\n", + " max_epochs=1,\n", + ")\n", + "dm = ERA5DataModule(train_data_path='.', valid_data_path='.')\n", + "type(dm)" + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "metadata": {}, + "source": "results = trainer.predict(model=task, datamodule=dm, return_predictions=True)", "outputs": [], - "source": [ - "dm.setup(stage='predict')" - ] + "execution_count": null }, { + "metadata": {}, "cell_type": "code", - "execution_count": null, + "source": "task = WxCGravityWaveTask(WxCModelFactory(),mode='train')", + "outputs": [], + "execution_count": null + }, + { "metadata": {}, + "cell_type": "code", + "source": "results = trainer.fit(model=task, datamodule=dm)", "outputs": [], - "source": [ - "results = trainer.train(model=task, datamodule=dm, return_predictions=True)" - ] + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "dist.destroy_process_group()" - ] + ], + "outputs": [], + "execution_count": null } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/integrationtests/test_prithvi_wxc_model_factory.py b/integrationtests/test_prithvi_wxc_model_factory.py new file mode 100644 index 00000000..787e13c2 --- /dev/null +++ b/integrationtests/test_prithvi_wxc_model_factory.py @@ -0,0 +1,270 @@ +# Copyright contributors to the Terratorch project + +import os + +import pytest +import torch +import torch.distributed as dist +import yaml +from granitewxc.utils.config import get_config +from huggingface_hub import hf_hub_download +from lightning.pytorch import Trainer + +from terratorch.models.wxc_model_factory import WxCModelFactory +from terratorch.tasks.wxc_task import WxCTask +import lightning.pytorch as pl + +from terratorch.datamodules.era5 import ERA5DataModule +from terratorch.tasks.wxc_task import WxCTask +from typing import Any + + +def setup_function(): + print("\nSetup function is called") + +def teardown_function(): + try: + os.remove("config.yaml") + except OSError: + pass + +class StopTrainerCallback(pl.Callback): + def __init__(self, stop_after_n_batches): + super().__init__() + self.stop_after_n_batches = stop_after_n_batches + self.current_batch = 0 + + def on_predict_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self.current_batch += 1 + if self.current_batch >= self.stop_after_n_batches: + print("Stopping training early...") + #trainer.should_stop = True + raise StopIteration("Stopped prediction after reaching the specified batch limit.") + + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self.current_batch += 1 + if self.current_batch >= self.stop_after_n_batches: + print("Stopping training early...") + #trainer.should_stop = True + raise StopIteration("Stopped prediction after reaching the specified batch limit.") + +@pytest.mark.parametrize("backbone", ["gravitywave", None, 'prithviwxc']) +def test_can_create_wxc_models(backbone): + if backbone == "gravitywave": + config_data = { + "singular_sharded_checkpoint": "./examples/notebooks/magnet-flux-uvtp122-epoch-99-loss-0.1022.pt", + } + + with open("config.yaml", "w") as file: + yaml.dump(config_data, file, default_flow_style=False) + + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + if dist.is_initialized(): + dist.destroy_process_group() + + dist.init_process_group( + backend='gloo', + init_method='env://', + rank=0, + world_size=1 + ) + + f = WxCModelFactory() + f.build_model(backbone, None) + + elif backbone == 'prithviwxc': + f = WxCModelFactory() + f.build_model(backbone, aux_decoders = None, backbone_weights='/dccstor/wfm/shared/pretrained/step_400.pt') + + else: + config = get_config('./examples/confs/granite-wxc-merra2-downscale-config.yaml') + config.download_path = "/dccstor/wfm/shared/datasets/training/merra-2_v1/" + + config.data.data_path_surface = os.path.join(config.download_path,'merra-2') + config.data.data_path_vertical = os.path.join(config.download_path, 'merra-2') + config.data.climatology_path_surface = os.path.join(config.download_path,'climatology') + config.data.climatology_path_vertical = os.path.join(config.download_path,'climatology') + + config.model.input_scalers_surface_path = os.path.join(config.download_path,'climatology/musigma_surface.nc') + config.model.input_scalers_vertical_path = os.path.join(config.download_path,'climatology/musigma_vertical.nc') + config.model.output_scalers_surface_path = os.path.join(config.download_path,'climatology/anomaly_variance_surface.nc') + config.model.output_scalers_vertical_path = os.path.join(config.download_path,'climatology/anomaly_variance_vertical.nc') + f = WxCModelFactory() + f.build_model(backbone, aux_decoders = None, model_config=config) + + + +def test_wxc_unet_pincer_inference(): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + if dist.is_initialized(): + dist.destroy_process_group() + + dist.init_process_group( + backend='gloo', + init_method='env://', + rank=0, + world_size=1 + ) + + hf_hub_download( + repo_id="Prithvi-WxC/Gravity_wave_Parameterization", + filename=f"magnet-flux-uvtp122-epoch-99-loss-0.1022.pt", + local_dir=".", + ) + + hf_hub_download( + ) + + hf_hub_download( + repo_id="Prithvi-WxC/Gravity_wave_Parameterization", + repo_type='dataset', + filename=f"wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc", + local_dir=".", + ) + + model_args = { + "in_channels": 1280, + "input_size_time": 1, + "n_lats_px": 64, + "n_lons_px": 128, + "patch_size_px": [2, 2], + "mask_unit_size_px": [8, 16], + "mask_ratio_inputs": 0.5, + "embed_dim": 2560, + "n_blocks_encoder": 12, + "n_blocks_decoder": 2, + "mlp_multiplier": 4, + "n_heads": 16, + "dropout": 0.0, + "drop_path": 0.05, + "parameter_dropout": 0.0, + "residual": "none", + "masking_mode": "both", + "decoder_shifting": False, + "positional_encoding": "absolute", + "checkpoint_encoder": [3, 6, 9, 12, 15, 18, 21, 24], + "checkpoint_decoder": [1, 3], + "in_channels_static": 3, + "input_scalers_mu": torch.tensor([0] * 1280), + "input_scalers_sigma": torch.tensor([1] * 1280), + "input_scalers_epsilon": 0, + "static_input_scalers_mu": torch.tensor([0] * 3), + "static_input_scalers_sigma": torch.tensor([1] * 3), + "static_input_scalers_epsilon": 0, + "output_scalers": torch.tensor([0] * 1280), + "backbone_weights": "magnet-flux-uvtp122-epoch-99-loss-0.1022.pt", + "backbone": "prithviwxc", + "aux_decoders": "unetpincer", + } + task = WxCTask(WxCModelFactory(), model_args=model_args, mode='eval') + + trainer = Trainer( + max_epochs=1, + callbacks=[StopTrainerCallback(stop_after_n_batches=3)], + ) + dm = ERA5DataModule(train_data_path='.', valid_data_path='.') + results = trainer.predict(model=task, datamodule=dm, return_predictions=True) + + dist.destroy_process_group() + + +def test_wxc_unet_pincer_train(): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + if dist.is_initialized(): + dist.destroy_process_group() + + dist.init_process_group( + backend='gloo', + init_method='env://', + rank=0, + world_size=1 + ) + + hf_hub_download( + repo_id="Prithvi-WxC/Gravity_wave_Parameterization", + filename=f"magnet-flux-uvtp122-epoch-99-loss-0.1022.pt", + local_dir=".", + ) + + hf_hub_download( + repo_id="Prithvi-WxC/Gravity_wave_Parameterization", + filename=f"config.yaml", + local_dir=".", + ) + + hf_hub_download( + repo_id="Prithvi-WxC/Gravity_wave_Parameterization", + repo_type='dataset', + filename=f"wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc", + local_dir=".", + ) + + model_args = { + "in_channels": 1280, + "input_size_time": 1, + "n_lats_px": 64, + "n_lons_px": 128, + "patch_size_px": [2, 2], + "mask_unit_size_px": [8, 16], + "mask_ratio_inputs": 0.5, + "embed_dim": 2560, + "n_blocks_encoder": 12, + "n_blocks_decoder": 2, + "mlp_multiplier": 4, + "n_heads": 16, + "dropout": 0.0, + "drop_path": 0.05, + "parameter_dropout": 0.0, + "residual": "none", + "masking_mode": "both", + "decoder_shifting": False, + "positional_encoding": "absolute", + "checkpoint_encoder": [3, 6, 9, 12, 15, 18, 21, 24], + "checkpoint_decoder": [1, 3], + "in_channels_static": 3, + "input_scalers_mu": torch.tensor([0] * 1280), + "input_scalers_sigma": torch.tensor([1] * 1280), + "input_scalers_epsilon": 0, + "static_input_scalers_mu": torch.tensor([0] * 3), + "static_input_scalers_sigma": torch.tensor([1] * 3), + "static_input_scalers_epsilon": 0, + "output_scalers": torch.tensor([0] * 1280), + "backbone_weights": "magnet-flux-uvtp122-epoch-99-loss-0.1022.pt", + "backbone": "prithviwxc", + "aux_decoders": "unetpincer", + "skip_connection": True, + } + + task = WxCTask(WxCModelFactory(), model_args=model_args, mode='train') + + trainer = Trainer( + callbacks=[StopTrainerCallback(stop_after_n_batches=3)], + max_epochs=1, + ) + dm = ERA5DataModule(train_data_path='.', valid_data_path='.') + results = trainer.fit(model=task, datamodule=dm) + + dist.destroy_process_group() + diff --git a/terratorch/datamodules/era5.py b/terratorch/datamodules/era5.py new file mode 100644 index 00000000..238d2715 --- /dev/null +++ b/terratorch/datamodules/era5.py @@ -0,0 +1,223 @@ +import glob +import os +import torch +import xarray as xr +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler +import lightning as pl + +def get_era5_uvtp122(ds: xr.Dataset, index: int = 0) -> dict[str, torch.Tensor]: + """Retrieve climate data variables at 122 pressure levels. + + Args: + ds: xarray Dataset containing the ERA5 data. + index: Time index to select data from. Defaults to 0. + + Returns: + A dictionary containing: + - x: Input feature vector (u, v, temperature, pressure) as a tensor. + - y: Reordered input tensor. + - target: Output feature vector (theta, u'omega, v'omega) as a tensor. + - lead_time: A tensor with lead time information. + """ + # Select the dataset for the specific time index + ds_t0: xr.Dataset = ds.isel(time=index) + # u: zonal wind, x-component + u = ds_t0["features"].isel(idim=slice(3, 125)) + # v: meridional wind, y-component + v = ds_t0["features"].isel(idim=slice(125, 247)) + # theta: potential temperature + theta = ds_t0["features"].isel(idim=slice(247, 369)) + # pressure: pressure variable + pressure = ds_t0["features"].isel(idim=slice(369, 491)) + + # Reorder --> theta, pressure, u, v + tensor_x = torch.tensor( + data=xr.concat(objs=[theta, pressure, u, v], dim="idim").data.compute() + ) + assert tensor_x.shape == torch.Size([488, 64, 128]) + + # Load output labels from the dataset and convert to a tensor + tensor_y = torch.tensor(data=ds_t0["output"].data.compute()) + assert tensor_y.shape == torch.Size([366, 64, 128]) + + return { + "x": tensor_x.unsqueeze(dim=0), + "y": tensor_x, + "target": tensor_y, + "lead_time": torch.zeros(1), # Placeholder for lead time + } + + +class ERA5Dataset(Dataset): + """ERA5 Dataset loaded into PyTorch tensors. + + This is a custom Dataset class for loading ERA5 climate data into tensors, + used for the Gravity Wave Flux downstream application. + + Attributes: + data_path: Path to the directory containing the NetCDF files. + file_glob_pattern: Pattern to match the NetCDF files. + ds: The xarray Dataset containing concatenated NetCDF data. + sur_static: Tensor representing static surface variables like sine and cosine of latitudes and longitudes. + """ + + def __init__( + self, + data_path: str = "data/uvtp122", + file_glob_pattern: str = "inputfeatures_u_v_theta_uw_vw_era5_training_data_hourly_*.nc", # or "wxc_input_u_v_t_p_*.nc", or "era5_uvtp_uw_vw_uv_*.nc" + ): + """Initializes the ERA5Dataset class by loading NetCDF files. + + Args: + data_path: The directory containing the NetCDF files. + file_glob_pattern: The file pattern to match NetCDF files. + Raises: + ValueError: If no NetCDF files matching the pattern are found. + """ + + nc_files: list[str] = glob.glob( + pathname=os.path.join(data_path, file_glob_pattern) + ) + + if len(nc_files) == 0: + raise ValueError(f"No finetuning NetCDF files found at {data_path}") + + self.ds: xr.Dataset = xr.open_mfdataset( + paths=nc_files, chunks={"time": 1}, combine="nested", concat_dim="time" + ) + + # Calculate static surface variables (latitude and longitude in radians) + latitudes = self.ds.lat.data / 360 * 2.0 * torch.pi + longitudes = self.ds.lon.data / 360 * 2.0 * torch.pi + + # Create a meshgrid of latitudes and longitudes + latitudes, longitudes = torch.meshgrid( + torch.as_tensor(latitudes), torch.as_tensor(longitudes), indexing="ij" + ) + # Stack sine and cosine of latitude and longitude to create static surface tensor + self.sur_static = torch.stack( + [torch.sin(latitudes), torch.cos(longitudes), torch.sin(longitudes)], axis=0 + ) + + def __len__(self) -> int: + """Returns the total number of timesteps in the dataset. + + Returns: + int: The number of timesteps (length of the time dimension). + """ + return len(self.ds.time) + + def __getitem__(self, index: int = 0) -> dict[str, torch.Tensor]: + """Get a tensor of shape (Time, Channels, Height, Width). + + Depending on the number of levels in the dataset (defined by `idim`), + it calls the appropriate function to load the ERA5 data for a given index. + + Args: + index: Index to select the timestep. Defaults to 0. + + Returns: + dict[str, torch.Tensor]: A dictionary with the following keys: + - "x": Input feature tensor. + - "y": Reordered input tensor. + - "target": Output feature tensor. + - "lead_time": Tensor containing lead time information. + - "static": Static surface tensor. + """ + + if len(self.ds.idim) == 491: # 122 levels, wxc_input_*.nc + batch = get_era5_uvtp122(ds=self.ds, index=index) + + batch["static"] = self.sur_static + + return batch + + +class ERA5DataModule(pl.LightningDataModule): + """ + This module handles data loading, batching, and train/validation splits. + + Attributes: + train_data_path: Path to training data. + valid_data_path: Path to validation data. + file_glob_pattern: Pattern to match NetCDF files. + batch_size: Size of each mini-batch. + num_workers: Number of subprocesses for data loading. + """ + + def __init__( + self, + train_data_path: str = "data/uvtp122", + valid_data_path: str = "data/uvtp122", + file_glob_pattern: str = "wxc_input_u_v_t_p_output_theta_uw_vw_*.nc", + batch_size: int = 16, + num_data_workers: int = 8, + ): + """Initializes the ERA5DataModule with the specified settings. + + Args: + train_data_path: Directory containing training data. + valid_data_path: Directory containing validation data. + file_glob_pattern: Glob pattern to match NetCDF files. + batch_size: Size of mini-batches. Defaults to 16. + num_data_workers: Number of workers for data loading. + """ + super().__init__() + self.train_data_path = train_data_path + self.valid_data_path = valid_data_path + self.file_glob_pattern = file_glob_pattern + + self.batch_size: int = batch_size + self.num_workers: int = num_data_workers + + def prepare_data(self): + pass + + def setup(self, stage: str | None = None) -> tuple[Dataset, Dataset]: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if stage == "fit": + self.dataset_train = ERA5Dataset( + data_path=self.train_data_path, file_glob_pattern=self.file_glob_pattern + ) + #self.dataset_train = self.dataset_train.to(device) + self.dataset_val = ERA5Dataset( + data_path=self.valid_data_path, file_glob_pattern=self.file_glob_pattern + ) + #self.dataset_val = self.dataset_val.to(device) + elif stage == "predict": + self.dataset_predict = ERA5Dataset( + data_path=self.valid_data_path, file_glob_pattern=self.file_glob_pattern + ) + #self.dataset_predict = self.dataset_predict.to(device) + + + def train_dataloader(self) -> DataLoader: + """Returns a DataLoader for the training data.""" + return DataLoader( + dataset=self.dataset_train, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=torch.cuda.is_available(), + sampler=DistributedSampler(dataset=self.dataset_train, shuffle=True), + ) + + def val_dataloader(self) -> DataLoader: + """Returns a DataLoader for the validation data.""" + + return DataLoader( + dataset=self.dataset_val, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=torch.cuda.is_available(), + sampler=DistributedSampler(dataset=self.dataset_val, shuffle=False), + ) + + def predict_dataloader(self) -> DataLoader: + """Returns a DataLoader for the prediction data.""" + return DataLoader( + dataset=self.dataset_predict, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) \ No newline at end of file diff --git a/terratorch/models/backbones/select_patch_embed_weights.py b/terratorch/models/backbones/select_patch_embed_weights.py index a4ae7647..38dde626 100644 --- a/terratorch/models/backbones/select_patch_embed_weights.py +++ b/terratorch/models/backbones/select_patch_embed_weights.py @@ -48,7 +48,6 @@ def select_patch_embed_weights( _possible_keys_for_proj_weight = {custom_proj_key} patch_embed_proj_weight_key = state_dict.keys() & _possible_keys_for_proj_weight if (type(state_dict) in [collections.OrderedDict, dict]) else state_dict().keys() & _possible_keys_for_proj_weight - if len(patch_embed_proj_weight_key) == 0: msg = "Could not find key for patch embed weight" raise Exception(msg) @@ -63,9 +62,9 @@ def select_patch_embed_weights( patch_embed_proj_weight_key = list(patch_embed_proj_weight_key)[0] patch_embed_weight = state_dict[patch_embed_proj_weight_key] - - temp_weight = model.state_dict()[patch_embed_proj_weight_key].clone() - + + temp_weight = model.state_dict()[patch_embed_proj_weight_key].clone() + # only do this if the patch size and tubelet size match. If not, start with random weights if patch_embed_weights_are_compatible(temp_weight, patch_embed_weight): torch.nn.init.xavier_uniform_(temp_weight.view([temp_weight.shape[0], -1])) @@ -80,7 +79,7 @@ def select_patch_embed_weights( category=UserWarning, stacklevel=1, ) - + state_dict[patch_embed_proj_weight_key] = temp_weight - + return state_dict diff --git a/terratorch/models/pincers/unet_pincer.py b/terratorch/models/pincers/unet_pincer.py new file mode 100644 index 00000000..7ace5c61 --- /dev/null +++ b/terratorch/models/pincers/unet_pincer.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn + +from terratorch.models.model import Model + + +class Encoder(nn.Module): + def __init__(self, in_channels, hidden_channels, hidden_channels_multiplier : list = [1,2,4,8] , num_encoder_blocks=4): + if len(hidden_channels_multiplier) != num_encoder_blocks: + raise ValueError(f'hidden channels multiplier lenght {len(hidden_channels_multiplier)} not matching encoder blocks {num_encoder_blocks}') + super(Encoder, self).__init__() + + self.encoders = [None] * num_encoder_blocks + + for index in range(num_encoder_blocks): + if index == 0: + self.encoders[index] = nn.Sequential( + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(hidden_channels), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(hidden_channels), + nn.ReLU(inplace=True), + ) + else: + self.encoders[index] = nn.Sequential( + nn.Conv2d(hidden_channels * hidden_channels_multiplier[index-1], hidden_channels * hidden_channels_multiplier[index], kernel_size=3, padding=1), + nn.BatchNorm2d(hidden_channels * hidden_channels_multiplier[index]), + nn.ReLU(inplace=True), + nn.Conv2d( + hidden_channels * hidden_channels_multiplier[index], hidden_channels * hidden_channels_multiplier[index], kernel_size=3, padding=1 + ), + nn.BatchNorm2d(hidden_channels * hidden_channels_multiplier[index]), + nn.ReLU(inplace=True), + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.encoders[index] = self.encoders[index].to(device) + + def forward(self, x): + encoder_values = [None] * len(self.encoders) + for index in range(len(encoder_values)): + if index == 0: + encoder_values[index] = self.encoders[index](x) + else: + encoder_values[index] = self.encoders[index](encoder_values[index-1]) + + return tuple(encoder_values) + + +class Decoder(nn.Module): + def __init__(self, hidden_channels, out_channels, hidden_channels_multiplier : list = [(16,8),(12,4),(6,2),(3,1)] , num_decoder_blocks=4, skip_connection=True): + if len(hidden_channels_multiplier) != num_decoder_blocks: + raise ValueError(f'hidden channels multiplier lenght {len(hidden_channels_multiplier)} not matching encoder blocks {num_decoder_blocks}') + super(Decoder, self).__init__() + + self.decoders = [None] * num_decoder_blocks + + for index in range(num_decoder_blocks): + self.decoders[index] = nn.Sequential( + nn.Conv2d( + hidden_channels * hidden_channels_multiplier[index][0], hidden_channels * hidden_channels_multiplier[index][1], kernel_size=3, padding=1 + ), + nn.BatchNorm2d(hidden_channels * hidden_channels_multiplier[index][1]), + nn.ReLU(inplace=True), + nn.Conv2d( + hidden_channels * hidden_channels_multiplier[index][1], hidden_channels * hidden_channels_multiplier[index][1], kernel_size=3, padding=1 + ), + nn.BatchNorm2d(hidden_channels * hidden_channels_multiplier[index][1]), + nn.ReLU(inplace=True), + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.decoders[index] = self.decoders[index].to(device) + + # Final output layer + self.final_conv = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) + + self.skip_connection = skip_connection + + def forward(self, encoders_values : tuple, backbone_values : torch.Tensor): + if len(encoders_values) != len(self.decoders): + raise ValueError(f'asymetric UNets not (yet) supported, encoders {len(encoders_values)} not matching decoders {len(self.decoders)}') + + pass_on = backbone_values + + if self.skip_connection: + for index, encoder in enumerate(reversed(encoders_values)): + pass_on = self.decoders[index](torch.cat((pass_on, encoder), dim=1)) + + output = self.final_conv(pass_on) + return output + +class UNetPincer(nn.Module): + def __init__( + self, + backbone: nn.Module, + lr: float = 1e-3, + in_channels: int = 488, + hidden_channels: int = 160, + out_channels: int = 366, + patch_size_px: list[int] = [2, 2], + encoder_hidden_channels_multiplier : list = [1,2,4,8], + encoder_num_encoder_blocks=4, + decoder_hidden_channels_multiplier : list = [(16,8),(12,4),(6,2),(3,1)], + decoder_num_decoder_blocks=4, + skip_connection=True, + ): + super().__init__() + + self.lr: float = lr + self.patch_size_px: list[int] = patch_size_px + self.out_channels: int = out_channels + + self.encoder = Encoder(in_channels, hidden_channels, encoder_hidden_channels_multiplier, encoder_num_encoder_blocks) + self.decoder = Decoder(hidden_channels, out_channels, decoder_hidden_channels_multiplier, decoder_num_decoder_blocks) + + self.backbone = backbone + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self.encoder = self.encoder.to(device) + self.decoder = self.decoder.to(device) + self.backbone = backbone.to(device) + self.skip_connection = skip_connection + + def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: + x = batch["x"] + lead_time = batch["lead_time"] + static = batch["static"] + x = x.squeeze(1) + + encoder_values: tuple = self.encoder(x) + + # Reshape encoded data for the transformer on last encoder value + *_, last_encoder_value = encoder_values + batch_size, c, h, w = last_encoder_value.size() + last_encoder_value_reshaped = last_encoder_value.unsqueeze(1) + + # Prepare input for transformer model + batch_dict = { + "x": last_encoder_value_reshaped, + "y": last_encoder_value, + "lead_time": lead_time, + "static": static, + "input_time": torch.zeros_like(lead_time), + } + + # Transformer forward pass + transformer_output = self.backbone(batch_dict) + transformer_output_reshaped = transformer_output.view(batch_size, c, h, w) + + # Decode the transformer output + output = self.decoder(encoder_values, transformer_output_reshaped) + + return output + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return optimizer + + def validation_step( + self, batch: dict[str, torch.Tensor], batch_idx: int = None + ) -> torch.Tensor: + y_hat: torch.Tensor = self(batch) + + loss: torch.Tensor = torch.nn.functional.mse_loss( + input=y_hat, target=batch["target"] + ) + return loss + + def get_model(self): + return self.backbone, self.decoder, self.encoder diff --git a/terratorch/models/wxc_model_factory.py b/terratorch/models/wxc_model_factory.py index a548e093..f446509a 100644 --- a/terratorch/models/wxc_model_factory.py +++ b/terratorch/models/wxc_model_factory.py @@ -6,6 +6,7 @@ import os import typing import logging +import importlib import terratorch.models.decoders as decoder_registry from terratorch.datasets import HLSBands @@ -16,6 +17,8 @@ ) from terratorch.registry import MODEL_FACTORY_REGISTRY +from terratorch.models.pincers.unet_pincer import UNetPincer + logger = logging.getLogger(__name__) class WxCModuleWrapper(Model, nn.Module): @@ -38,7 +41,6 @@ def forward(self, x) -> ModelOutput: return ModelOutput(mo) def load_state_dict(self, state_dict: os.Mapping[str, typing.Any], strict: bool = True, assign: bool = False): - self.module.load_state_dict(state_dict, strict, assign) @MODEL_FACTORY_REGISTRY.register @@ -46,10 +48,75 @@ class WxCModelFactory(ModelFactory): def build_model( self, backbone: str | nn.Module, - aux_decoders, + aux_decoders: str, checkpoint_path:str=None, + backbone_weights: str = None, **kwargs, ) -> Model: + if backbone == 'prithviwxc': + try: + prithviwxc = importlib.import_module('PrithviWxC.model') + except ModuleNotFoundError as e: + print(f"Module not found: {e.name}. Please install PrithviWxC using pip install PrithviWxC") + raise + + #remove parameters not meant for the backbone but for other parts of the model + skip_connection = kwargs.pop('skip_connection') + + backbone = prithviwxc.PrithviWxC(**kwargs) + + # Freeze PrithviWxC model parameters + for param in backbone.parameters(): + param.requires_grad = False + + # Load pre-trained weights if checkpoint is provided + if backbone_weights is not None: + + print(f"Starting to load model from {backbone_weights}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + state_dict = torch.load( + f=backbone_weights, + weights_only=True, + map_location=torch.device(device), + ) + + # Compare the keys in model and saved state_dict + model_keys = set(backbone.state_dict().keys()) + saved_state_dict_keys = set(state_dict.keys()) + + # Find keys that are in the model but not in the saved state_dict + missing_in_saved = model_keys - saved_state_dict_keys + # Find keys that are in the saved state_dict but not in the model + missing_in_model = saved_state_dict_keys - model_keys + # Find keys that are common between the model and the saved state_dict + common_keys = model_keys & saved_state_dict_keys + + # Print the common keys + if common_keys: + print(f"Keys loaded : {common_keys}") + + # Print the discrepancies + if missing_in_saved: + print(f"Keys present in model but missing in saved state_dict: {missing_in_saved}") + if missing_in_model: + print(f"Keys present in saved state_dict but missing in model: {missing_in_model}") + + # Load the state_dict with strict=False to allow partial loading + backbone.load_state_dict(state_dict=state_dict, strict=False) + print('=>'*10, f"Model loaded from {backbone_weights}...") + print("Loaded backbone weights") + else: + print('Not loading backbone model weigts') + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + backbone.to(device) + if aux_decoders is not None: + model_to_return = UNetPincer(backbone, skip_connection=skip_connection).to(device) + return model_to_return + return WxCModuleWrapper(backbone) + + + # starting from there only for backwards compatibility, deprecated if backbone == 'gravitywave': try: __import__('prithviwxc.gravitywave.inference') diff --git a/terratorch/tasks/wxc_gravity_wave_task.py b/terratorch/tasks/wxc_gravity_wave_task.py deleted file mode 100644 index 9afc114f..00000000 --- a/terratorch/tasks/wxc_gravity_wave_task.py +++ /dev/null @@ -1,15 +0,0 @@ - - -from torchgeo.trainers import BaseTask -import torch.nn as nn - -class WxCGravityWaveTask(BaseTask): - def __init__(self, model_factory): - self.model_factory = model_factory - super().__init__() - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=self.learning_rate) - - def configure_models(self): - self.model = self.model_factory.build_model(backbone='gravitywave', aux_decoders=None) \ No newline at end of file diff --git a/terratorch/tasks/wxc_task.py b/terratorch/tasks/wxc_task.py new file mode 100644 index 00000000..87312a9a --- /dev/null +++ b/terratorch/tasks/wxc_task.py @@ -0,0 +1,37 @@ + + +from torchgeo.trainers import BaseTask +import torch.nn as nn +import torch +import logging +logger = logging.getLogger(__name__) + +class WxCTask(BaseTask): + def __init__(self, model_factory, model_args: dict, mode, learning_rate=0.1): + if mode not in ['train', 'eval']: + raise ValueError(f'mode {mode} is not supported. (train, eval)') + self.model_args = model_args + self.model_factory = model_factory + self.learning_rate = learning_rate + super().__init__() + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.learning_rate) + + def configure_models(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self.model = self.model_factory.build_model(**self.model_args) + self.model = self.model.to(device) + layer_devices = [] + for name, module in self.model.named_children(): + device = next(module.parameters(), torch.tensor([])).device + layer_devices.append((name, str(device))) + logging.debug(layer_devices) + + def training_step(self, batch, batch_idx): + output: torch.Tensor = self.model(batch) + + def train_dataloader(self): + return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True) + \ No newline at end of file