From bfe3e8ab8f46300f2ecef7caf32734ee97b6810d Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Wed, 13 Nov 2024 08:40:23 -0500 Subject: [PATCH] Change LightningCLI tests to account for future fix in jsonargparse (#20372) Co-authored-by: Luca Antiga --- tests/tests_pytorch/test_cli.py | 50 +++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 56b58d4d157a1..cdec778afbfb8 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -871,18 +871,27 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir): hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml" assert hparams_path.is_file() hparams = yaml.safe_load(hparams_path.read_text()) - expected = { - "_instantiator": "lightning.pytorch.cli.instantiate_module", - "optimizer": "torch.optim.Adam", - "scheduler": "torch.optim.lr_scheduler.ConstantLR", - "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}}, - } - assert hparams == expected + + expected_keys = ["_instantiator", "activation", "optimizer", "scheduler"] + expected_instantiator = "lightning.pytorch.cli.instantiate_module" + expected_activation = "torch.nn.LeakyReLU" + expected_optimizer = "torch.optim.Adam" + expected_scheduler = "torch.optim.lr_scheduler.ConstantLR" + + assert sorted(hparams.keys()) == expected_keys + assert hparams["_instantiator"] == expected_instantiator + assert hparams["activation"]["class_path"] == expected_activation + assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer + assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None) assert checkpoint_path.is_file() - ckpt = torch.load(checkpoint_path, weights_only=True) - assert ckpt["hyper_parameters"] == expected + hparams = torch.load(checkpoint_path, weights_only=True)["hyper_parameters"] + assert sorted(hparams.keys()) == expected_keys + assert hparams["_instantiator"] == expected_instantiator + assert hparams["activation"]["class_path"] == expected_activation + assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer + assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler model = TestModelSaveHparams.load_from_checkpoint(checkpoint_path) assert isinstance(model, TestModelSaveHparams) @@ -898,18 +907,23 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c cli = LightningCLI(TestModelSaveHparams, run=False, auto_configure_optimizers=False, subclass_mode_model=True) cli.trainer.fit(cli.model) - expected = { - "_instantiator": "lightning.pytorch.cli.instantiate_module", - "_class_path": f"{__name__}.TestModelSaveHparams", - "optimizer": "torch.optim.Adam", - "scheduler": "torch.optim.lr_scheduler.ConstantLR", - "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}}, - } + expected_keys = ["_class_path", "_instantiator", "activation", "optimizer", "scheduler"] + expected_instantiator = "lightning.pytorch.cli.instantiate_module" + expected_class_path = f"{__name__}.TestModelSaveHparams" + expected_activation = "torch.nn.LeakyReLU" + expected_optimizer = "torch.optim.Adam" + expected_scheduler = "torch.optim.lr_scheduler.ConstantLR" checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None) assert checkpoint_path.is_file() - ckpt = torch.load(checkpoint_path, weights_only=True) - assert ckpt["hyper_parameters"] == expected + hparams = torch.load(checkpoint_path, weights_only=True)["hyper_parameters"] + + assert sorted(hparams.keys()) == expected_keys + assert hparams["_instantiator"] == expected_instantiator + assert hparams["_class_path"] == expected_class_path + assert hparams["activation"]["class_path"] == expected_activation + assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer + assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler model = LightningModule.load_from_checkpoint(checkpoint_path) assert isinstance(model, TestModelSaveHparams)