From afd110cc0c310b7ab094a3255786fab75f66c8be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 3 May 2023 03:28:26 +0200 Subject: [PATCH 1/6] Check for mixed new and old style imports --- src/lightning/pytorch/utilities/compile.py | 2 ++ .../pytorch/utilities/model_helpers.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/lightning/pytorch/utilities/compile.py b/src/lightning/pytorch/utilities/compile.py index a75410307d999..dcf8a25703959 100644 --- a/src/lightning/pytorch/utilities/compile.py +++ b/src/lightning/pytorch/utilities/compile.py @@ -18,6 +18,7 @@ import lightning.pytorch as pl from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1 from lightning.pytorch.strategies import DDPStrategy, FSDPStrategy, SingleDeviceStrategy, Strategy +from lightning.pytorch.utilities.model_helpers import _check_mixed_imports def from_compiled(model: "torch._dynamo.OptimizedModule") -> "pl.LightningModule": @@ -122,6 +123,7 @@ def _maybe_unwrap_optimized(model: object) -> "pl.LightningModule": return from_compiled(model) if isinstance(model, pl.LightningModule): return model + _check_mixed_imports(model) raise TypeError( f"`model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `{type(model).__qualname__}`" ) diff --git a/src/lightning/pytorch/utilities/model_helpers.py b/src/lightning/pytorch/utilities/model_helpers.py index 832828c98a45a..d683e294b235b 100644 --- a/src/lightning/pytorch/utilities/model_helpers.py +++ b/src/lightning/pytorch/utilities/model_helpers.py @@ -31,6 +31,7 @@ def is_overridden(method_name: str, instance: Optional[object] = None, parent: O elif isinstance(instance, pl.Callback): parent = pl.Callback if parent is None: + _check_mixed_imports(instance) raise ValueError("Expected a parent") from lightning_utilities.core.overrides import is_overridden as _is_overridden @@ -51,3 +52,19 @@ def get_torchvision_model(model_name: str, **kwargs: Any) -> nn.Module: if torchvision_greater_equal_0_14: return models.get_model(model_name, **kwargs) return getattr(models, model_name)(**kwargs) + + +def _check_mixed_imports(instance: object) -> None: + old, new = "pytorch_" + "lightning", "lightning." + "pytorch" + klass = type(instance) + module = klass.__module__ + if module.startswith(old) and __name__.startswith(new): + pass + elif module.startswith(new) and __name__.startswith(old): + old, new = new, old + else: + return + raise TypeError( + f"You passed a `{old}` object ({type(instance).__qualname__}) to a `{new}`" + " Trainer. Please switch to a single import style." + ) From b72f088e557a56eab3ff950fcb050485a26a20b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 3 May 2023 03:43:56 +0200 Subject: [PATCH 2/6] Test attempt --- .../graveyard/test_legacy_import_unpickler.py | 4 +--- .../tests_pytorch/utilities/test_model_helpers.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/graveyard/test_legacy_import_unpickler.py b/tests/tests_pytorch/graveyard/test_legacy_import_unpickler.py index c820283abe724..8ad6e193dbf42 100644 --- a/tests/tests_pytorch/graveyard/test_legacy_import_unpickler.py +++ b/tests/tests_pytorch/graveyard/test_legacy_import_unpickler.py @@ -20,9 +20,7 @@ def _list_sys_modules(pattern: str) -> str: @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) -@pytest.mark.skipif( - not module_available("lightning_pytorch"), reason="This test is ONLY relevant for the STANDALONE package" -) +@pytest.mark.skipif(module_available("lightning"), reason="This test is ONLY relevant for the STANDALONE package") def test_imports_standalone(pl_version: str): assert any( key.startswith("pytorch_lightning") for key in sys.modules diff --git a/tests/tests_pytorch/utilities/test_model_helpers.py b/tests/tests_pytorch/utilities/test_model_helpers.py index f10a8d18c40f2..0a507ae82e05f 100644 --- a/tests/tests_pytorch/utilities/test_model_helpers.py +++ b/tests/tests_pytorch/utilities/test_model_helpers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import pytest from lightning.pytorch import LightningDataModule @@ -30,3 +31,17 @@ def test_is_overridden(): assert is_overridden("training_step", model) datamodule = BoringDataModule() assert is_overridden("train_dataloader", datamodule) + + +def test_mixed_imports_unified(): + from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized as new_unwrap + from lightning.pytorch.utilities.model_helpers import is_overridden as new_is_overridden + from pytorch_lightning.callbacks import EarlyStopping as OldEarlyStopping + from pytorch_lightning.demos.boring_classes import BoringModel as OldBoringModel + + model = OldBoringModel() + with pytest.raises(TypeError, match=r"`pytorch_lightning` object \(BoringModel\) to a `lightning.pytorch`"): + new_unwrap(model) + + with pytest.raises(TypeError, match=r"`pytorch_lightning` object \(EarlyStopping\) to a `lightning.pytorch`"): + new_is_overridden("on_fit_start", OldEarlyStopping("foo")) From b675f733962808d21d30b7e955a9869606ff5335 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 3 May 2023 04:23:19 +0200 Subject: [PATCH 3/6] fixes --- src/lightning/pytorch/utilities/compile.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning/pytorch/utilities/compile.py b/src/lightning/pytorch/utilities/compile.py index dcf8a25703959..67d9808a66222 100644 --- a/src/lightning/pytorch/utilities/compile.py +++ b/src/lightning/pytorch/utilities/compile.py @@ -43,6 +43,7 @@ def from_compiled(model: "torch._dynamo.OptimizedModule") -> "pl.LightningModule orig_module = model._orig_mod if not isinstance(orig_module, pl.LightningModule): + _check_mixed_imports(model) raise ValueError( f"`model` is expected to be a compiled LightingModule. Found a `{type(orig_module).__name__}` instead" ) @@ -115,6 +116,7 @@ def to_uncompiled(model: Union["pl.LightningModule", "torch._dynamo.OptimizedMod def _maybe_unwrap_optimized(model: object) -> "pl.LightningModule": if not _TORCH_GREATER_EQUAL_2_0: if not isinstance(model, pl.LightningModule): + _check_mixed_imports(model) raise TypeError(f"`model` must be a `LightningModule`, got `{type(model).__qualname__}`") return model from torch._dynamo import OptimizedModule From 8254da84bf86189dd463232c49c8fb4906dd95bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 3 May 2023 04:46:49 +0200 Subject: [PATCH 4/6] skip standalone --- tests/tests_pytorch/utilities/test_model_helpers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/utilities/test_model_helpers.py b/tests/tests_pytorch/utilities/test_model_helpers.py index 0a507ae82e05f..e45a5b3dd8ccb 100644 --- a/tests/tests_pytorch/utilities/test_model_helpers.py +++ b/tests/tests_pytorch/utilities/test_model_helpers.py @@ -10,9 +10,9 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. - +# limitations under the License." import pytest +from lightning_utilities import module_available from lightning.pytorch import LightningDataModule from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel @@ -33,6 +33,7 @@ def test_is_overridden(): assert is_overridden("train_dataloader", datamodule) +@pytest.mark.skipif(not module_available("lightning"), reason="This test is ONLY relevant for the UNIFIED package") def test_mixed_imports_unified(): from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized as new_unwrap from lightning.pytorch.utilities.model_helpers import is_overridden as new_is_overridden From 25fd4089c81456bc4146a485e94229fa8c7259ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 3 May 2023 04:46:59 +0200 Subject: [PATCH 5/6] skip standalone --- tests/tests_pytorch/utilities/test_model_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/utilities/test_model_helpers.py b/tests/tests_pytorch/utilities/test_model_helpers.py index e45a5b3dd8ccb..d9754f022877d 100644 --- a/tests/tests_pytorch/utilities/test_model_helpers.py +++ b/tests/tests_pytorch/utilities/test_model_helpers.py @@ -10,7 +10,7 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License." +# limitations under the License. import pytest from lightning_utilities import module_available From 6392feef48e3ac956cc24f04af85b8c569dbdee1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 5 May 2023 01:17:07 +0200 Subject: [PATCH 6/6] Update skip condition --- tests/tests_pytorch/utilities/test_model_helpers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/utilities/test_model_helpers.py b/tests/tests_pytorch/utilities/test_model_helpers.py index d9754f022877d..13b500d8e4359 100644 --- a/tests/tests_pytorch/utilities/test_model_helpers.py +++ b/tests/tests_pytorch/utilities/test_model_helpers.py @@ -33,7 +33,10 @@ def test_is_overridden(): assert is_overridden("train_dataloader", datamodule) -@pytest.mark.skipif(not module_available("lightning"), reason="This test is ONLY relevant for the UNIFIED package") +@pytest.mark.skipif( + not module_available("lightning") or not module_available("pytorch_lightning"), + reason="This test is ONLY relevant for the UNIFIED package", +) def test_mixed_imports_unified(): from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized as new_unwrap from lightning.pytorch.utilities.model_helpers import is_overridden as new_is_overridden