Skip to content

Commit

Permalink
Add a bit of a hack to fix self._device
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Nov 15, 2024
1 parent d7cef8b commit 86de71e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
9 changes: 9 additions & 0 deletions project/algorithms/testsuites/algorithm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def test_initialization_is_deterministic(
assert isinstance(algorithm_1, lightning.LightningModule)

with trainer.init_module():
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm_1._device = torch.get_default_device()
algorithm_1.configure_model()

with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
Expand All @@ -73,6 +76,9 @@ def test_initialization_is_deterministic(
assert isinstance(algorithm_2, lightning.LightningModule)

with trainer.init_module():
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm_2._device = torch.get_default_device()
algorithm_2.configure_model()

torch.testing.assert_close(algorithm_1.state_dict(), algorithm_2.state_dict())
Expand Down Expand Up @@ -157,6 +163,9 @@ def test_initialization_is_reproducible(
algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule)
assert isinstance(algorithm, lightning.LightningModule)
with trainer.init_module():
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm._device = torch.get_default_device()
algorithm.configure_model()

tensor_regression.check(
Expand Down
3 changes: 3 additions & 0 deletions project/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ def algorithm(
algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule)
if isinstance(trainer, lightning.Trainer) and isinstance(algorithm, lightning.LightningModule):
with trainer.init_module():
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm._device = torch.get_default_device()
algorithm.configure_model()
return algorithm

Expand Down

0 comments on commit 86de71e

Please sign in to comment.