Skip to content

Commit

Permalink
Change LightningCLI tests to account for future fix in jsonargparse (#…
Browse files Browse the repository at this point in the history
…20372)

Co-authored-by: Luca Antiga <[email protected]>
  • Loading branch information
mauvilsa and lantiga authored Nov 13, 2024
1 parent bd5866b commit bfe3e8a
Showing 1 changed file with 32 additions and 18 deletions.
50 changes: 32 additions & 18 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,18 +871,27 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir):
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())
expected = {
"_instantiator": "lightning.pytorch.cli.instantiate_module",
"optimizer": "torch.optim.Adam",
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
}
assert hparams == expected

expected_keys = ["_instantiator", "activation", "optimizer", "scheduler"]
expected_instantiator = "lightning.pytorch.cli.instantiate_module"
expected_activation = "torch.nn.LeakyReLU"
expected_optimizer = "torch.optim.Adam"
expected_scheduler = "torch.optim.lr_scheduler.ConstantLR"

assert sorted(hparams.keys()) == expected_keys
assert hparams["_instantiator"] == expected_instantiator
assert hparams["activation"]["class_path"] == expected_activation
assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer
assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
assert checkpoint_path.is_file()
ckpt = torch.load(checkpoint_path, weights_only=True)
assert ckpt["hyper_parameters"] == expected
hparams = torch.load(checkpoint_path, weights_only=True)["hyper_parameters"]
assert sorted(hparams.keys()) == expected_keys
assert hparams["_instantiator"] == expected_instantiator
assert hparams["activation"]["class_path"] == expected_activation
assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer
assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler

model = TestModelSaveHparams.load_from_checkpoint(checkpoint_path)
assert isinstance(model, TestModelSaveHparams)
Expand All @@ -898,18 +907,23 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c
cli = LightningCLI(TestModelSaveHparams, run=False, auto_configure_optimizers=False, subclass_mode_model=True)
cli.trainer.fit(cli.model)

expected = {
"_instantiator": "lightning.pytorch.cli.instantiate_module",
"_class_path": f"{__name__}.TestModelSaveHparams",
"optimizer": "torch.optim.Adam",
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
}
expected_keys = ["_class_path", "_instantiator", "activation", "optimizer", "scheduler"]
expected_instantiator = "lightning.pytorch.cli.instantiate_module"
expected_class_path = f"{__name__}.TestModelSaveHparams"
expected_activation = "torch.nn.LeakyReLU"
expected_optimizer = "torch.optim.Adam"
expected_scheduler = "torch.optim.lr_scheduler.ConstantLR"

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
assert checkpoint_path.is_file()
ckpt = torch.load(checkpoint_path, weights_only=True)
assert ckpt["hyper_parameters"] == expected
hparams = torch.load(checkpoint_path, weights_only=True)["hyper_parameters"]

assert sorted(hparams.keys()) == expected_keys
assert hparams["_instantiator"] == expected_instantiator
assert hparams["_class_path"] == expected_class_path
assert hparams["activation"]["class_path"] == expected_activation
assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer
assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler

model = LightningModule.load_from_checkpoint(checkpoint_path)
assert isinstance(model, TestModelSaveHparams)
Expand Down

0 comments on commit bfe3e8a

Please sign in to comment.