From cd319cbbbfae9f3cfe9a8336db823fb2f3dbec33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 25 Jul 2024 11:25:23 -0300 Subject: [PATCH] The input argument --clean_config allows to control if the outputtd config files will be more or less verbose MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/cli_tools.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 00c4c8fc..7f832a10 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -130,9 +130,10 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype) -def clean_config_for_deployment_and_dump(config: dict[str, Any], clean:bool=False): +def clean_config_for_deployment_and_dump(config: dict[str, Any]): deploy_config = deepcopy(config) - if clean: + + if config["clean_config"]: ## General # drop ckpt_path deploy_config.pop("ckpt_path", None) @@ -176,7 +177,7 @@ def __init__( ): super().__init__(parser, config, config_filename, overwrite, multifile, save_to_log_dir) set_dumper("deploy_config", clean_config_for_deployment_and_dump) - + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: if self.already_saved: return @@ -285,6 +286,7 @@ class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.add_argument("--predict_output_dir", default=None) parser.add_argument("--out_dtype", default="int16") + parser.add_argument("--clean_config", type=bool, default=False) # parser.set_defaults({"trainer.enable_checkpointing": False}) @@ -315,6 +317,9 @@ def instantiate_classes(self) -> None: if hasattr(config, "out_dtype"): self.trainer.out_dtype = config.out_dtype + if hasattr(config, "clean_config"): + self.trainer.clean_config = config.clean_config + def build_lightning_cli( args: ArgsType = None, run=True, # noqa: FBT002