From dce51521bf6296b571029225dcf1e9718091ccb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Wed, 24 Jul 2024 13:48:28 -0300 Subject: [PATCH 1/3] Avoiding to remove info, but the ordering is not preserved. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/cli_tools.py | 47 +++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index e0d523e1..00c4c8fc 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -130,30 +130,31 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype) -def clean_config_for_deployment_and_dump(config: dict[str, Any]): +def clean_config_for_deployment_and_dump(config: dict[str, Any], clean:bool=False): deploy_config = deepcopy(config) - ## General - # drop ckpt_path - deploy_config.pop("ckpt_path", None) - # drop checkpoints - deploy_config.pop("ModelCheckpoint", None) - deploy_config.pop("StateDictModelCheckpoint", None) - # drop optimizer and lr sheduler - deploy_config.pop("optimizer", None) - deploy_config.pop("lr_scheduler", None) - ## Trainer - # remove logging - deploy_config["trainer"]["logger"] = False - # remove callbacks - deploy_config["trainer"].pop("callbacks", None) - # remove default_root_dir - deploy_config["trainer"].pop("default_root_dir", None) - # set mixed precision by default for inference - deploy_config["trainer"]["precision"] = "16-mixed" - ## Model - # set pretrained to false - if "model_args" in deploy_config["model"]["init_args"]: - deploy_config["model"]["init_args"]["model_args"]["pretrained"] = False + if clean: + ## General + # drop ckpt_path + deploy_config.pop("ckpt_path", None) + # drop checkpoints + deploy_config.pop("ModelCheckpoint", None) + deploy_config.pop("StateDictModelCheckpoint", None) + # drop optimizer and lr sheduler + deploy_config.pop("optimizer", None) + deploy_config.pop("lr_scheduler", None) + ## Trainer + # remove logging + deploy_config["trainer"]["logger"] = False + # remove callbacks + deploy_config["trainer"].pop("callbacks", None) + # remove default_root_dir + deploy_config["trainer"].pop("default_root_dir", None) + # set mixed precision by default for inference + deploy_config["trainer"]["precision"] = "16-mixed" + ## Model + # set pretrained to false + if "model_args" in deploy_config["model"]["init_args"]: + deploy_config["model"]["init_args"]["model_args"]["pretrained"] = False return yaml.safe_dump(deploy_config) From cd319cbbbfae9f3cfe9a8336db823fb2f3dbec33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 25 Jul 2024 11:25:23 -0300 Subject: [PATCH 2/3] The input argument --clean_config allows to control if the outputtd config files will be more or less verbose MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/cli_tools.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 00c4c8fc..7f832a10 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -130,9 +130,10 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype) -def clean_config_for_deployment_and_dump(config: dict[str, Any], clean:bool=False): +def clean_config_for_deployment_and_dump(config: dict[str, Any]): deploy_config = deepcopy(config) - if clean: + + if config["clean_config"]: ## General # drop ckpt_path deploy_config.pop("ckpt_path", None) @@ -176,7 +177,7 @@ def __init__( ): super().__init__(parser, config, config_filename, overwrite, multifile, save_to_log_dir) set_dumper("deploy_config", clean_config_for_deployment_and_dump) - + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: if self.already_saved: return @@ -285,6 +286,7 @@ class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.add_argument("--predict_output_dir", default=None) parser.add_argument("--out_dtype", default="int16") + parser.add_argument("--clean_config", type=bool, default=False) # parser.set_defaults({"trainer.enable_checkpointing": False}) @@ -315,6 +317,9 @@ def instantiate_classes(self) -> None: if hasattr(config, "out_dtype"): self.trainer.out_dtype = config.out_dtype + if hasattr(config, "clean_config"): + self.trainer.clean_config = config.clean_config + def build_lightning_cli( args: ArgsType = None, run=True, # noqa: FBT002 From 23eedb99784721322f5938a662143fb159ae1be1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 25 Jul 2024 11:58:53 -0300 Subject: [PATCH 3/3] Using the same name to save the config yaml file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/cli_tools.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 7f832a10..7a400a34 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -182,6 +182,8 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non if self.already_saved: return + _, self.config_filename = os.path.split(self.config.config[0].abs_path) + if self.save_to_log_dir: log_dir = trainer.log_dir or trainer.default_root_dir # this broadcasts the directory if log_dir is None: